In [7]:
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, clear_output

# 1. Load 'cardinal_test_1.csv' and extract features
df_main = pd.read_csv('./cardinal_test_1.csv', comment='#')

# List of features from 'cardinal_test_1.csv' to plot
main_features = [
    'Q_calculated',
    'Q_calculated_large_parameters',
    'Q_calculated_gru_adaptive_2',
    'Q_calculated_gru_adaptive_fixed_Q',
    'Q_calculated_dense',
    'Q_calculated_gru_memoryless'
]

# Clip Q values to [-1, 1] in df_main
for feature in main_features:
    df_main[feature] = df_main[feature].clip(-1, 1)

# 2. Load 'Q_calculated_integrated' from 'cardinal_test_1_1.csv' to 'cardinal_test_1_8.csv'
integrated_files = [f'./cardinal_test_1_{i}.csv' for i in range(1, 9)]
df_integrated_list = [pd.read_csv(f, comment='#') for f in integrated_files]

# Extract 'time' and 'Q_calculated_integrated' from each file
time_integrated = df_integrated_list[0]['time']  # Assuming all have the same time column
Q_integrated_all = np.array([df['Q_calculated_integrated'].values for df in df_integrated_list])

# Clip Q_integrated_all to [-1, 1]
Q_integrated_all = np.clip(Q_integrated_all, -1, 1)

# Compute average and standard deviation
Q_integrated_mean = Q_integrated_all.mean(axis=0)
Q_integrated_std = Q_integrated_all.std(axis=0)

# 3. Define feature descriptions (improved style)
feature_descriptions = {
    'Q_calculated': 'Control signal calculated by the controller during the experiment with correct parameters (pole mass 15g, pole length 5cm).',
    'Q_calculated_large_parameters': 'Control signal from MPC assuming incorrect parameters (pole mass 150g, pole length 80cm).',
    'Q_calculated_gru_adaptive_2': 'Control signal from a GRU network trained to compute control signals based on the past trajectory with current parameters.',
    'Q_calculated_dense': 'Control signal from a dense neural network trained on the same dataset.',
    'Q_calculated_gru_memoryless': 'Control signal from a GRU network trained to imitate an MPC with random parameter realization at each point.',
    'Q_calculated_integrated_all': 'All 8 traces of Q_calculated_integrated from Monte Carlo simulations.',
    'Q_calculated_integrated_mean': 'Mean of Q_calculated_integrated from Monte Carlo simulations.',
    'Q_calculated_integrated_confidence': 'Mean of Q_calculated_integrated with ±3σ confidence intervals from Monte Carlo simulations.',
    'Q_calculated_gru_adaptive_fixed_Q': 'Providing right past Q',
}

# 4. Create interactive widgets for feature selection
# Options for 'Q_calculated_integrated'
integrated_options = [
    'Q_calculated_integrated_all',
    'Q_calculated_integrated_mean',
    'Q_calculated_integrated_confidence'
]

# Combine all features for selection
all_features = main_features + integrated_options

# Create checkboxes for each feature
feature_checkboxes = [widgets.Checkbox(value=False, description=feat) for feat in all_features]

# Create checkboxes for reference feature selection
reference_feature_checkboxes = [widgets.Checkbox(value=False, description=feat) for feat in all_features]

# Arrange checkboxes in columns
feature_checkboxes_label = widgets.HTML("<h3>Select Features to Plot:</h3>")
reference_checkboxes_label = widgets.HTML("<h3>Select Reference Feature:</h3>")

feature_checkboxes_box = widgets.VBox([feature_checkboxes_label] + feature_checkboxes)
reference_checkboxes_box = widgets.VBox([reference_checkboxes_label] + reference_feature_checkboxes)

checkboxes_box = widgets.HBox([feature_checkboxes_box, reference_checkboxes_box])

# Enforce that at most one reference feature can be selected
def on_reference_checkbox_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        if change['new']:
            # Deselect all other checkboxes
            for cb in reference_feature_checkboxes:
                if cb != change['owner']:
                    cb.value = False

for cb in reference_feature_checkboxes:
    cb.observe(on_reference_checkbox_change, names='value')

# Create a range slider for time selection
time_slider = widgets.FloatRangeSlider(
    value=[df_main['time'].min(), df_main['time'].max()],
    min=df_main['time'].min(),
    max=df_main['time'].max(),
    step=0.1,
    description='Time Range:',
    continuous_update=False,
    layout=widgets.Layout(width='800px')
)

# Output areas
plot_output = widgets.Output()
info_output = widgets.Output()

# Function to get feature values
def get_feature_values(feature, df_main_filtered, time_integrated_filtered, Q_integrated_all_filtered, Q_integrated_mean_filtered, Q_integrated_std_filtered):
    if feature in main_features:
        y_values = df_main_filtered[feature].values
        x_values = df_main_filtered['time'].values
        return x_values, y_values

    elif feature == 'Q_calculated_integrated_all':
        y_values = Q_integrated_all_filtered  # y_values.shape = (8, N)
        x_values = time_integrated_filtered.values
        return x_values, y_values

    elif feature == 'Q_calculated_integrated_mean':
        y_values = Q_integrated_mean_filtered
        x_values = time_integrated_filtered.values
        return x_values, y_values

    elif feature == 'Q_calculated_integrated_confidence':
        y_mean = Q_integrated_mean_filtered
        y_upper = Q_integrated_mean_filtered + 3 * Q_integrated_std_filtered
        y_lower = Q_integrated_mean_filtered - 3 * Q_integrated_std_filtered
        x_values = time_integrated_filtered.values
        return x_values, (y_mean, y_upper, y_lower)

    else:
        return None, None

# 5. Define the update function
def update_plot(change):
    with plot_output:
        clear_output(wait=True)
        selected_features = [cb.description for cb in feature_checkboxes if cb.value]
        selected_reference_features = [cb.description for cb in reference_feature_checkboxes if cb.value]
        reference_feature = selected_reference_features[0] if selected_reference_features else None

        if not selected_features:
            print("No features selected. Please select at least one feature to plot.")
            return

        # Get the selected time range from the slider
        t_min, t_max = time_slider.value

        # Filter df_main and time_integrated based on t_min and t_max
        mask_main = (df_main['time'] >= t_min) & (df_main['time'] <= t_max)
        df_main_filtered = df_main.loc[mask_main]

        mask_integrated = (time_integrated >= t_min) & (time_integrated <= t_max)
        time_integrated_filtered = time_integrated.loc[mask_integrated]

        # For Q_integrated_all and Q_integrated_mean and Q_integrated_std, we need to filter accordingly
        Q_integrated_all_filtered = Q_integrated_all[:, mask_integrated]
        Q_integrated_mean_filtered = Q_integrated_mean[mask_integrated]
        Q_integrated_std_filtered = Q_integrated_std[mask_integrated]

        # Get reference feature values
        if reference_feature is not None:
            x_ref_values, y_ref_values = get_feature_values(
                reference_feature,
                df_main_filtered,
                time_integrated_filtered,
                Q_integrated_all_filtered,
                Q_integrated_mean_filtered,
                Q_integrated_std_filtered
            )

        # Create subplots with secondary y-axis for subplot 2
        fig = make_subplots(
            rows=3, cols=1, shared_xaxes=True,
            row_heights=[0.5, 0.25, 0.25],
            vertical_spacing=0.05,
            specs=[[{}],
                   [{"secondary_y": True}],
                   [{}]],
            subplot_titles=("Control Signals", "Target Position & Equilibrium", "Angle (degrees)")
        )

        # Main plot (control signals)
        for feature in selected_features:
            if feature in main_features:
                x_feat_values, y_original_values = get_feature_values(
                    feature,
                    df_main_filtered,
                    time_integrated_filtered,
                    Q_integrated_all_filtered,
                    Q_integrated_mean_filtered,
                    Q_integrated_std_filtered
                )

                if reference_feature is not None:
                    # Interpolate y_ref_values onto x_feat_values
                    y_ref_values_interp = np.interp(x_feat_values, x_ref_values, y_ref_values)
                    y_diff = y_original_values - y_ref_values_interp
                    sum_squared_diff = np.sum(y_diff ** 2)
                    y_feat_values = y_diff
                    trace_name = f"{feature} (SSD={sum_squared_diff:.2f})"
                else:
                    y_feat_values = y_original_values
                    trace_name = feature

                fig.add_trace(
                    go.Scatter(
                        x=x_feat_values, y=y_feat_values, name=trace_name,
                        legendgroup='group1', legendgrouptitle_text='Control Signals'
                    ),
                    row=1, col=1
                )

            elif feature == 'Q_calculated_integrated_all':
                x_feat_values, y_feat_values_all = get_feature_values(
                    feature,
                    df_main_filtered,
                    time_integrated_filtered,
                    Q_integrated_all_filtered,
                    Q_integrated_mean_filtered,
                    Q_integrated_std_filtered
                )

                if reference_feature is not None:
                    y_ref_values_interp = np.interp(x_feat_values, x_ref_values, y_ref_values)
                    for i in range(8):
                        y_values = y_feat_values_all[i]
                        y_diff = y_values - y_ref_values_interp
                        sum_squared_diff = np.sum(y_diff ** 2)
                        trace_name = f'Trace {i+1} (SSD={sum_squared_diff:.2f})'
                        fig.add_trace(
                            go.Scatter(
                                x=x_feat_values, y=y_diff, name=trace_name, opacity=0.5,
                                legendgroup='group1', legendgrouptitle_text='Control Signals'
                            ),
                            row=1, col=1
                        )
                else:
                    for i in range(8):
                        y_values = y_feat_values_all[i]
                        trace_name = f'Trace {i+1}'
                        fig.add_trace(
                            go.Scatter(
                                x=x_feat_values, y=y_values, name=trace_name, opacity=0.5,
                                legendgroup='group1', legendgrouptitle_text='Control Signals'
                            ),
                            row=1, col=1
                        )

            elif feature == 'Q_calculated_integrated_mean':
                x_feat_values, y_original_values = get_feature_values(
                    feature,
                    df_main_filtered,
                    time_integrated_filtered,
                    Q_integrated_all_filtered,
                    Q_integrated_mean_filtered,
                    Q_integrated_std_filtered
                )

                if reference_feature is not None:
                    y_ref_values_interp = np.interp(x_feat_values, x_ref_values, y_ref_values)
                    y_diff = y_original_values - y_ref_values_interp
                    sum_squared_diff = np.sum(y_diff ** 2)
                    y_feat_values = y_diff
                    trace_name = f'Integrated Mean (SSD={sum_squared_diff:.2f})'
                else:
                    y_feat_values = y_original_values
                    trace_name = 'Integrated Mean'

                fig.add_trace(
                    go.Scatter(
                        x=x_feat_values, y=y_feat_values, name=trace_name, line=dict(color='black'),
                        legendgroup='group1', legendgrouptitle_text='Control Signals'
                    ),
                    row=1, col=1
                )

            elif feature == 'Q_calculated_integrated_confidence':
                x_feat_values, y_values = get_feature_values(
                    feature,
                    df_main_filtered,
                    time_integrated_filtered,
                    Q_integrated_all_filtered,
                    Q_integrated_mean_filtered,
                    Q_integrated_std_filtered
                )
                y_mean, y_upper, y_lower = y_values

                if reference_feature is not None:
                    y_ref_values_interp = np.interp(x_feat_values, x_ref_values, y_ref_values)
                    y_mean_diff = y_mean - y_ref_values_interp
                    sum_squared_diff = np.sum(y_mean_diff ** 2)
                    y_mean = y_mean_diff
                    y_upper = y_upper - y_ref_values_interp
                    y_lower = y_lower - y_ref_values_interp
                    trace_name = f'Integrated Mean (SSD={sum_squared_diff:.2f})'
                else:
                    trace_name = 'Integrated Mean'

                fig.add_trace(
                    go.Scatter(
                        x=x_feat_values, y=y_mean, name=trace_name, line=dict(color='black'),
                        legendgroup='group1', legendgrouptitle_text='Control Signals'
                    ),
                    row=1, col=1
                )
                fig.add_trace(
                    go.Scatter(
                        x=np.concatenate([x_feat_values, x_feat_values[::-1]]),
                        y=np.concatenate([y_upper, y_lower[::-1]]),
                        fill='toself',
                        fillcolor='rgba(128, 128, 128, 0.2)',
                        line=dict(color='rgba(255,255,255,0)'),
                        hoverinfo="skip",
                        showlegend=True,
                        name='±3σ Confidence Interval',
                        legendgroup='group1', legendgrouptitle_text='Control Signals'
                    ),
                    row=1, col=1
                )

        fig.update_yaxes(title_text='Control Signal', row=1, col=1)

        # Subplot 2: 'target_position' (converted to cm) and 'target_equilibrium', separate y-axes
        # Left y-axis: 'target_position' (converted to cm)
        fig.add_trace(
            go.Scatter(
                x=df_main_filtered['time'],
                y=df_main_filtered['target_position'] * 100,  # Convert meters to centimeters
                name='Target Position (cm)',
                marker_color='blue',
                legendgroup='group2', legendgrouptitle_text='Target Position & Equilibrium'
            ),
            row=2, col=1, secondary_y=False
        )
        # Right y-axis: 'target_equilibrium'
        fig.add_trace(
            go.Scatter(
                x=df_main_filtered['time'],
                y=df_main_filtered['target_equilibrium'],
                name='Target Equilibrium',
                marker_color='red',
                legendgroup='group2', legendgrouptitle_text='Target Position & Equilibrium'
            ),
            row=2, col=1, secondary_y=True
        )

        fig.update_yaxes(title_text='Target Position (cm)', row=2, col=1, secondary_y=False)
        fig.update_yaxes(title_text='Target Equilibrium', row=2, col=1, secondary_y=True)

        # Subplot 3: 'angle' in degrees
        angle_degrees = np.degrees(df_main_filtered['angle'])
        fig.add_trace(
            go.Scatter(
                x=df_main_filtered['time'], y=angle_degrees, name='Angle (deg)', marker_color='green',
                legendgroup='group3', legendgrouptitle_text='Angle (degrees)'
            ),
            row=3, col=1
        )
        fig.update_yaxes(title_text='Angle (degrees)', row=3, col=1)

        # Update layout
        fig.update_layout(
            height=900,
            xaxis3=dict(title='Time'),
            hovermode='x unified',
            showlegend=True,
            legend_traceorder="grouped",
            legend_tracegroupgap=50,
            legend=dict(
                x=1.02,
                y=1,
                xanchor='left',
                yanchor='top',
                font=dict(size=10)
            )
        )

        fig.show()
    
    with info_output:
        clear_output(wait=True)
        for feature in selected_features:
            description = feature_descriptions.get(feature, 'No description available.')
            display(widgets.HTML(f"<b>{feature}</b>: {description}<br>"))

# 6. Attach the update function to the checkboxes and time slider
for cb in feature_checkboxes:
    cb.observe(update_plot, names='value')

for cb in reference_feature_checkboxes:
    cb.observe(update_plot, names='value')

time_slider.observe(update_plot, names='value')

# 7. Display the widgets
display(checkboxes_box)
display(widgets.HTML("<h2>Select Time Range:</h2>"))
display(time_slider)
display(plot_output)
display(widgets.HTML("<h2>Feature Information:</h2>"))
display(info_output)

# Initial plot
update_plot(None)


HBox(children=(VBox(children=(HTML(value='<h3>Select Features to Plot:</h3>'), Checkbox(value=False, descripti…

HTML(value='<h2>Select Time Range:</h2>')

FloatRangeSlider(value=(0.01, 12.08), continuous_update=False, description='Time Range:', layout=Layout(width=…

Output()

HTML(value='<h2>Feature Information:</h2>')

Output()