In [31]:
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, clear_output

# 1. Load 'cardinal_test_1.csv' and extract features
df_main = pd.read_csv('./Control_Comparison_PRO/cardinal_test_1.csv', comment='#')

# 2. Load 'Q_calculated_integrated' from 'cardinal_test_1_1.csv' to 'cardinal_test_1_8.csv'
integrated_files = [f'./Control_Comparison_PRO/cardinal_test_1_{i}.csv' for i in range(1, 9)]
df_integrated_list = [pd.read_csv(f, comment='#') for f in integrated_files]

# Extract 'time' and 'Q_calculated_integrated' from each file
time_integrated = df_integrated_list[0]['time']  # Assuming all have the same time column
Q_integrated_all = np.array([df['Q_calculated_integrated'].values for df in df_integrated_list])

# Compute average and standard deviation
Q_integrated_mean = Q_integrated_all.mean(axis=0)
Q_integrated_std = Q_integrated_all.std(axis=0)

# 3. Define feature descriptions (improved style)
feature_descriptions = {
    'Q_calculated': 'Control signal calculated by the controller during the experiment with correct parameters (pole mass 15g, pole length 5cm).',
    'Q_calculated_large_parameters': 'Control signal from MPC assuming incorrect parameters (pole mass 150g, pole length 80cm).',
    'Q_calculated_gru_adaptive_2': 'Control signal from a GRU network trained to compute control signals based on the past trajectory with current parameters.',
    'Q_calculated_dense': 'Control signal from a dense neural network trained on the same dataset.',
    'Q_calculated_gru_memoryless': 'Control signal from a GRU network trained to imitate an MPC with random parameter realization at each point.',
    'Q_calculated_integrated_all': 'All 8 traces of Q_calculated_integrated from Monte Carlo simulations.',
    'Q_calculated_integrated_mean': 'Mean of Q_calculated_integrated from Monte Carlo simulations.',
    'Q_calculated_integrated_confidence': 'Mean of Q_calculated_integrated with ±3σ confidence intervals from Monte Carlo simulations.'
}

# 4. Create interactive widgets for feature selection
# List of features from 'cardinal_test_1.csv' to plot
main_features = [
    'Q_calculated',
    'Q_calculated_large_parameters',
    'Q_calculated_gru_adaptive_2',
    'Q_calculated_dense',
    'Q_calculated_gru_memoryless'
]

# Options for 'Q_calculated_integrated'
integrated_options = [
    'Q_calculated_integrated_all',
    'Q_calculated_integrated_mean',
    'Q_calculated_integrated_confidence'
]

# Combine all features for selection
all_features = main_features + integrated_options

# Create checkboxes for each feature
feature_checkboxes = [widgets.Checkbox(value=False, description=feat) for feat in all_features]

# Arrange checkboxes in a vertical box
checkboxes = widgets.VBox(feature_checkboxes)

# Output areas
plot_output = widgets.Output()
info_output = widgets.Output()

# 5. Define the update function
def update_plot(change):
    with plot_output:
        clear_output(wait=True)
        selected_features = [cb.description for cb in feature_checkboxes if cb.value]
        
        if not selected_features:
            print("No features selected. Please select at least one feature to plot.")
            return
        
        # Create subplots with secondary y-axis for subplot 2
        fig = make_subplots(
            rows=3, cols=1, shared_xaxes=True,
            row_heights=[0.5, 0.25, 0.25],
            vertical_spacing=0.05,
            specs=[[{}],
                   [{"secondary_y": True}],
                   [{}]],
            subplot_titles=("Control Signals", "Target Position & Equilibrium", "Angle (degrees)")
        )
        
        # Main plot (control signals), clip features to [-1, 1]
        for feature in selected_features:
            if feature in main_features:
                y_values = np.clip(df_main[feature], -1, 1)
                fig.add_trace(
                    go.Scatter(
                        x=df_main['time'], y=y_values, name=feature,
                        legendgroup='group1', legendgrouptitle_text='Control Signals'
                    ),
                    row=1, col=1
                )
            elif feature == 'Q_calculated_integrated_all':
                # Plot all 8 traces, clipped to [-1, 1]
                for i in range(8):
                    y_values = np.clip(Q_integrated_all[i], -1, 1)
                    fig.add_trace(
                        go.Scatter(
                            x=time_integrated, y=y_values, name=f'Trace {i+1}', opacity=0.5,
                            legendgroup='group1', legendgrouptitle_text='Control Signals'
                        ),
                        row=1, col=1
                    )
            elif feature == 'Q_calculated_integrated_mean':
                y_values = np.clip(Q_integrated_mean, -1, 1)
                fig.add_trace(
                    go.Scatter(
                        x=time_integrated, y=y_values, name='Integrated Mean', line=dict(color='black'),
                        legendgroup='group1', legendgrouptitle_text='Control Signals'
                    ),
                    row=1, col=1
                )
            elif feature == 'Q_calculated_integrated_confidence':
                y_mean = np.clip(Q_integrated_mean, -1, 1)
                y_upper = np.clip(Q_integrated_mean + 3 * Q_integrated_std, -1, 1)
                y_lower = np.clip(Q_integrated_mean - 3 * Q_integrated_std, -1, 1)
                fig.add_trace(
                    go.Scatter(
                        x=time_integrated, y=y_mean, name='Integrated Mean', line=dict(color='black'),
                        legendgroup='group1', legendgrouptitle_text='Control Signals'
                    ),
                    row=1, col=1
                )
                fig.add_trace(
                    go.Scatter(
                        x=np.concatenate([time_integrated, time_integrated[::-1]]),
                        y=np.concatenate([y_upper, y_lower[::-1]]),
                        fill='toself',
                        fillcolor='rgba(128, 128, 128, 0.2)',
                        line=dict(color='rgba(255,255,255,0)'),
                        hoverinfo="skip",
                        showlegend=True,
                        name='±3σ Confidence Interval',
                        legendgroup='group1', legendgrouptitle_text='Control Signals'
                    ),
                    row=1, col=1
                )
        
        fig.update_yaxes(title_text='Control Signal', row=1, col=1)
        
        # Subplot 2: 'target_position' (converted to cm) and 'target_equilibrium', separate y-axes
        # Left y-axis: 'target_position' (converted to cm)
        fig.add_trace(
            go.Scatter(
                x=df_main['time'],
                y=df_main['target_position'] * 100,  # Convert meters to centimeters
                name='Target Position (cm)',
                marker_color='blue',
                legendgroup='group2', legendgrouptitle_text='Target Position & Equilibrium'
            ),
            row=2, col=1, secondary_y=False
        )
        # Right y-axis: 'target_equilibrium'
        fig.add_trace(
            go.Scatter(
                x=df_main['time'],
                y=df_main['target_equilibrium'],
                name='Target Equilibrium',
                marker_color='red',
                legendgroup='group2', legendgrouptitle_text='Target Position & Equilibrium'
            ),
            row=2, col=1, secondary_y=True
        )
        
        fig.update_yaxes(title_text='Target Position (cm)', row=2, col=1, secondary_y=False)
        fig.update_yaxes(title_text='Target Equilibrium', row=2, col=1, secondary_y=True)
        
        # Subplot 3: 'angle' in degrees
        angle_degrees = np.degrees(df_main['angle'])
        fig.add_trace(
            go.Scatter(
                x=df_main['time'], y=angle_degrees, name='Angle (deg)', marker_color='green',
                legendgroup='group3', legendgrouptitle_text='Angle (degrees)'
            ),
            row=3, col=1
        )
        fig.update_yaxes(title_text='Angle (degrees)', row=3, col=1)
        
        # Update layout
        fig.update_layout(
            height=900,
            xaxis3=dict(title='Time'),
            hovermode='x unified',
            showlegend=True,
            legend_traceorder="grouped",
            legend_tracegroupgap=50,
            legend=dict(
                x=1.02,
                y=1,
                xanchor='left',
                yanchor='top',
                font=dict(size=10)
            )
        )
        
        fig.show()
    
    with info_output:
        clear_output(wait=True)
        for feature in selected_features:
            description = feature_descriptions.get(feature, 'No description available.')
            display(widgets.HTML(f"<b>{feature}</b>: {description}<br>"))


# 6. Attach the update function to the checkboxes
for cb in feature_checkboxes:
    cb.observe(update_plot, names='value')

# 7. Display the widgets
display(widgets.HTML("<h2>Select Features to Plot:</h2>"))
display(checkboxes)
display(plot_output)
display(widgets.HTML("<h2>Feature Information:</h2>"))
display(info_output)

# Initial plot
update_plot(None)


HTML(value='<h2>Select Features to Plot:</h2>')

VBox(children=(Checkbox(value=False, description='Q_calculated'), Checkbox(value=False, description='Q_calcula…

Output()

HTML(value='<h2>Feature Information:</h2>')

Output()