In [13]:
#!/usr/bin/env python3
"""
plot_mse_comparison_experiment_dashboard.py

This script scans the local "analysis_cache" folder for experiments (runs) that contain a
'run_data.csv' file (produced by your analysis script). It then launches an interactive
dashboard (using Dash) that lets you select an experiment and filter the data via checkboxes for:

  - Analysis Type: Trained, Trained Shuffled, and Random.
  - Belief Type: Normalized and Unnormalized.
  - Layer Selection: Individual layers (with the highest index shown as "All Layers (Concat)").

For trained models the x axis shows the checkpoint number; for random controls the data are 
averaged over seeds and displayed as a horizontal line with a shaded region representing the SEM.

Usage:
    python plot_mse_comparison_experiment_dashboard.py

Dependencies:
    - dash
    - pandas
    - plotly
"""

import os
from pathlib import Path
from datetime import datetime
import math

import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px

# ------------------------------
# 1. Scan for Experiments (Runs)
# ------------------------------

BASE_DIR = Path("analysis_cache")
experiment_csvs = list(BASE_DIR.glob("**/run_data.csv"))
experiment_options = []
for csv_file in experiment_csvs:
    parts = csv_file.parts
    # Expected structure: analysis_cache/<sweep_id>/<run_id>/run_data.csv
    if len(parts) >= 4:
        sweep_id = parts[-3]
        run_id = parts[-2]
        label = f"Sweep {sweep_id} | Run {run_id}"
        value = str(csv_file)
        experiment_options.append({'label': label, 'value': value})

experiment_options = sorted(experiment_options, key=lambda x: x['label'])
default_experiment = experiment_options[0]['value'] if experiment_options else None

# ------------------------------
# 2. Define Checkbox Options
# ------------------------------

analysis_type_options = [
    {'label': 'Trained', 'value': 'trained'},
    {'label': 'Trained Shuffled', 'value': 'trained_shuffled'},
    {'label': 'Random', 'value': 'random'}
]

belief_type_options = [
    {'label': 'Normalized', 'value': 'normalized'},
    {'label': 'Unnormalized', 'value': 'unnormalized'}
]

# To initialize the layer checklist, we load the default experiment CSV.
if default_experiment is not None:
    try:
        df_default = pd.read_csv(default_experiment)
        def compute_x_value(row):
            if row['random_or_trained'] == 'random':
                return row['seed']
            else:
                try:
                    return float(row['checkpoint'])
                except:
                    return None
        df_default['x_value'] = df_default.apply(compute_x_value, axis=1)
        layers_default = sorted(df_default['layer_index'].unique())
    except Exception as e:
        print("Error reading default experiment CSV:", e)
        layers_default = []
else:
    layers_default = []

max_layer = max(layers_default) if layers_default else None
layer_options = []
for l in layers_default:
    label = f"Layer {int(l)}" if l != max_layer else "All Layers (Concat)"
    layer_options.append({'label': label, 'value': l})

# ------------------------------
# 3. Build the Dash App Layout
# ------------------------------

app = dash.Dash(__name__)
app.title = "MSE Comparison Dashboard"

app.layout = html.Div(
    style={'fontFamily': 'Arial, sans-serif', 'margin': '20px'},
    children=[
        html.H1("MSE Comparison Dashboard", style={'textAlign': 'center'}),
        html.Div([
            html.Label("Select Experiment:"),
            dcc.Dropdown(
                id="experiment-dropdown",
                options=experiment_options,
                value=default_experiment,
                clearable=False,
                style={"width": "80%", "margin": "auto"}
            )
        ], style={'margin-bottom': '20px'}),
        html.Div([
            html.Div([
                html.H3("Analysis Type"),
                dcc.Checklist(
                    id='analysis-type-checklist',
                    options=analysis_type_options,
                    value=['trained', 'trained_shuffled', 'random'],
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc', 'margin-right': '20px'}),
            html.Div([
                html.H3("Belief Type"),
                dcc.Checklist(
                    id='belief-type-checklist',
                    options=belief_type_options,
                    value=['normalized', 'unnormalized'],
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc', 'margin-right': '20px'}),
            html.Div([
                html.H3("Layer Selection"),
                dcc.Checklist(
                    id='layer-checklist',
                    options=layer_options,
                    value=layers_default,
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc'})
        ], style={'display': 'flex', 'justifyContent': 'center', 'margin-bottom': '20px'}),
        dcc.Graph(id='mse-graph')
    ]
)

# ------------------------------
# 4. Callbacks
# ------------------------------

@app.callback(
    Output('layer-checklist', 'options'),
    Output('layer-checklist', 'value'),
    Input('experiment-dropdown', 'value')
)
def update_layer_options(selected_experiment):
    if not selected_experiment:
        return [], []
    try:
        df_exp = pd.read_csv(selected_experiment)
        def compute_x_value(row):
            if row['random_or_trained'] == 'random':
                return row['seed']
            else:
                try:
                    return float(row['checkpoint'])
                except:
                    return None
        df_exp['x_value'] = df_exp.apply(compute_x_value, axis=1)
        layers_exp = sorted(df_exp['layer_index'].unique())
        if not layers_exp:
            return [], []
        max_layer_exp = max(layers_exp)
        opts = [{'label': f"Layer {int(l)}" if l != max_layer_exp else "All Layers (Concat)", 'value': l} for l in layers_exp]
        return opts, layers_exp
    except Exception as e:
        print("Error updating layer options:", e)
        return [], []

@app.callback(
    Output('mse-graph', 'figure'),
    Input('experiment-dropdown', 'value'),
    Input('analysis-type-checklist', 'value'),
    Input('belief-type-checklist', 'value'),
    Input('layer-checklist', 'value')
)
def update_graph(selected_experiment, selected_analysis_types, selected_belief_types, selected_layers):
    if not selected_experiment or not selected_analysis_types or not selected_belief_types or not selected_layers:
        return go.Figure()
    try:
        df_exp = pd.read_csv(selected_experiment)
    except Exception as e:
        print("Error reading CSV:", e)
        return go.Figure()
    
    # Compute x_value: for random controls, use seed; else, use checkpoint (as float).
    def compute_x_value(row):
        if row['random_or_trained'] == 'random':
            return row['seed']
        else:
            try:
                return float(row['checkpoint'])
            except:
                return None
    df_exp['x_value'] = df_exp.apply(compute_x_value, axis=1)
    
    # Filter based on selections.
    filtered = df_exp[
        df_exp['random_or_trained'].isin(selected_analysis_types) &
        df_exp['norm_type'].isin(selected_belief_types) &
        df_exp['layer_index'].isin(selected_layers)
    ]
    filtered = filtered.sort_values('x_value')
    
    # Determine x-axis range from non-random data.
    non_random = filtered[filtered['random_or_trained'] != 'random']
    if not non_random.empty:
        x_min = non_random['x_value'].min()
        x_max = non_random['x_value'].max()
    else:
        x_min = filtered['x_value'].min()
        x_max = filtered['x_value'].max()
    x_range = (x_min, x_max)
    
    # For color mapping, use all layers from the full dataset (df_exp), so that the mapping stays constant.
    all_layers = sorted(df_exp['layer_index'].unique())
    if all_layers:
        max_layer_val = max(all_layers)
    else:
        max_layer_val = None
    # Identify non-concat layers (all except the maximum)
    non_concat_layers = [l for l in all_layers if l != max_layer_val]
    max_non_concat = max(non_concat_layers) if non_concat_layers else 1
    color_map = {}
    for layer in all_layers:
        if layer == max_layer_val:
            color_map[layer] = 'black'
        else:
            frac = layer / max_non_concat if max_non_concat else 0
            # sample_colorscale returns a list with one color.
            sampled = px.colors.sample_colorscale("Oryel", frac)
            color_map[layer] = sampled[0]
    
    # Define line dash styles for each analysis type.
    line_styles = {
        'trained': 'solid',
        'trained_shuffled': 'dash',
        'random': 'dot'
    }
    
    fig = go.Figure()
    # Group by analysis type, belief type, and layer_index.
    grouped = filtered.groupby(['random_or_trained', 'norm_type', 'layer_index'])
    for (analysis_type, norm_type, layer_index), group in grouped:
        layer_label = f"Layer {int(layer_index)}" if layer_index != max_layer_val else "All Layers (Concat)"
        trace_label = f"{analysis_type.title()} | {norm_type.title()} | {layer_label}"
        trace_color = color_map.get(layer_index, "black")
        
        if analysis_type == 'random':
            # Aggregate over seeds.
            avg_MSE = group['MSE'].mean()
            std_MSE = group['MSE'].std()
            n = group.shape[0]
            sem_MSE = std_MSE / math.sqrt(n) if n > 0 else 0
            # Plot a shaded region for (avg ± SEM) across the x_range.
            x_shaded = [x_range[0], x_range[1], x_range[1], x_range[0]]
            y_shaded = [avg_MSE - sem_MSE, avg_MSE - sem_MSE, avg_MSE + sem_MSE, avg_MSE + sem_MSE]
            # Convert the hex color to an RGBA with transparency.
            # This simple conversion assumes hex color in format "#RRGGBB".
            try:
                r = int(trace_color[1:3], 16)
                g = int(trace_color[3:5], 16)
                b = int(trace_color[5:7], 16)
                fill_color = f"rgba({r},{g},{b},0.2)"
            except:
                fill_color = "rgba(0,0,0,0.2)"
            fig.add_trace(go.Scatter(
                x=x_shaded,
                y=y_shaded,
                fill="toself",
                fillcolor=fill_color,
                line=dict(color="rgba(255,255,255,0)"),
                showlegend=False,
                hoverinfo='skip'
            ))
            # Plot the horizontal mean line.
            fig.add_trace(go.Scatter(
                x=[x_range[0], x_range[1]],
                y=[avg_MSE, avg_MSE],
                mode='lines',
                name=trace_label,
                line=dict(color=trace_color, dash=line_styles[analysis_type], width=3),
                hovertemplate=(
                    f"Analysis Type: {analysis_type}<br>" +
                    f"Belief Type: {norm_type}<br>" +
                    f"{layer_label}<br>" +
                    f"Avg MSE: {avg_MSE:.5f}<br>" +
                    f"SEM: {sem_MSE:.5f}<extra></extra>"
                )
            ))
        else:
            group = group.sort_values('x_value')
            fig.add_trace(go.Scatter(
                x=group['x_value'],
                y=group['MSE'],
                mode='lines+markers',
                name=trace_label,
                line=dict(color=trace_color, dash=line_styles[analysis_type]),
                marker=dict(symbol='circle', size=5),  # Reduced marker size.
                hovertemplate=(
                    f"Analysis Type: {analysis_type}<br>" +
                    f"Belief Type: {norm_type}<br>" +
                    f"{layer_label}<br>" +
                    "X: %{x}<br>" +
                    "MSE: %{y:.5f}<extra></extra>"
                )
            ))
    
    fig.update_layout(
        title="MSE vs. Checkpoint/Seed (Unified Comparison)",
        xaxis_title="Checkpoint Number (for Trained) or Seed (Random controls averaged)",
        yaxis_title="MSE (log scale)",
        yaxis_type="log",
        legend_title="Trace Groups",
        margin={'l': 50, 'r': 50, 't': 50, 'b': 50}
    )
    return fig

# ------------------------------
# 5. Run the App
# ------------------------------

if __name__ == '__main__':
    app.run_server(debug=True)
