In [1]:
#!/usr/bin/env python3
"""
plot_mse_comparison_experiment_dashboard.py

This script scans the local "analysis_cache" folder for experiments (runs) that contain a
'run_data.csv' file (produced by your analysis script). It then launches an interactive
dashboard (using Dash) that lets you select an experiment and filter the data via checkboxes for:

  - Analysis Type: Trained, Trained Shuffled, and Random.
  - Belief Type: Normalized and Unnormalized.
  - Layer Selection: Individual layers (with the highest index shown as "All Layers (Concat)").

For trained models the x axis shows the checkpoint number; for random controls the data are 
averaged over seeds and displayed as a horizontal line with a shaded region representing the SEM.

Usage:
    python plot_mse_comparison_experiment_dashboard.py

Dependencies:
    - dash
    - pandas
    - plotly
"""

import os
from pathlib import Path
from datetime import datetime
import math

import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px

# ------------------------------
# 1. Scan for Experiments (Runs)
# ------------------------------

BASE_DIR = Path("analysis_cache")
experiment_csvs = list(BASE_DIR.glob("**/run_data.csv"))
experiment_options = []
for csv_file in experiment_csvs:
    parts = csv_file.parts
    # Expected structure: analysis_cache/<sweep_id>/<run_id>/run_data.csv
    if len(parts) >= 4:
        sweep_id = parts[-3]
        run_id = parts[-2]
        label = f"Sweep {sweep_id} | Run {run_id}"
        value = str(csv_file)
        experiment_options.append({'label': label, 'value': value})

experiment_options = sorted(experiment_options, key=lambda x: x['label'])
default_experiment = experiment_options[0]['value'] if experiment_options else None

# ------------------------------
# 2. Define Checkbox Options
# ------------------------------

analysis_type_options = [
    {'label': 'Trained', 'value': 'trained'},
    {'label': 'Trained Shuffled', 'value': 'trained_shuffled'},
    {'label': 'Random', 'value': 'random'}
]

belief_type_options = [
    {'label': 'Normalized', 'value': 'normalized'},
    {'label': 'Unnormalized', 'value': 'unnormalized'}
]

# To initialize the layer checklist, we load the default experiment CSV.
if default_experiment is not None:
    try:
        df_default = pd.read_csv(default_experiment)
        def compute_x_value(row):
            if row['random_or_trained'] == 'random':
                return row['seed']
            else:
                try:
                    return float(row['checkpoint'])
                except:
                    return None
        df_default['x_value'] = df_default.apply(compute_x_value, axis=1)
        layers_default = sorted(df_default['layer_index'].unique())
    except Exception as e:
        print("Error reading default experiment CSV:", e)
        layers_default = []
else:
    layers_default = []

max_layer = max(layers_default) if layers_default else None
layer_options = []
for l in layers_default:
    label = f"Layer {int(l)}" if l != max_layer else "All Layers (Concat)"
    layer_options.append({'label': label, 'value': l})

# ------------------------------
# 3. Build the Dash App Layout
# ------------------------------

app = dash.Dash(__name__)
app.title = "MSE Comparison Dashboard"

app.layout = html.Div(
    style={'fontFamily': 'Arial, sans-serif', 'margin': '20px'},
    children=[
        html.H1("MSE Comparison Dashboard", style={'textAlign': 'center'}),
        html.Div([
            html.Label("Select Experiment:"),
            dcc.Dropdown(
                id="experiment-dropdown",
                options=experiment_options,
                value=default_experiment,
                clearable=False,
                style={"width": "80%", "margin": "auto"}
            )
        ], style={'margin-bottom': '20px'}),
        html.Div([
            html.Div([
                html.H3("Analysis Type"),
                dcc.Checklist(
                    id='analysis-type-checklist',
                    options=analysis_type_options,
                    value=['trained', 'trained_shuffled', 'random'],
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc', 'margin-right': '20px'}),
            html.Div([
                html.H3("Belief Type"),
                dcc.Checklist(
                    id='belief-type-checklist',
                    options=belief_type_options,
                    value=['normalized', 'unnormalized'],
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc', 'margin-right': '20px'}),
            html.Div([
                html.H3("Layer Selection"),
                dcc.Checklist(
                    id='layer-checklist',
                    options=layer_options,
                    value=layers_default,
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc'})
        ], style={'display': 'flex', 'justifyContent': 'center', 'margin-bottom': '20px'}),
        dcc.Graph(id='mse-graph')
    ]
)

# ------------------------------
# 4. Callbacks
# ------------------------------

@app.callback(
    Output('layer-checklist', 'options'),
    Output('layer-checklist', 'value'),
    Input('experiment-dropdown', 'value')
)
def update_layer_options(selected_experiment):
    if not selected_experiment:
        return [], []
    try:
        df_exp = pd.read_csv(selected_experiment)
        def compute_x_value(row):
            if row['random_or_trained'] == 'random':
                return row['seed']
            else:
                try:
                    return float(row['checkpoint'])
                except:
                    return None
        df_exp['x_value'] = df_exp.apply(compute_x_value, axis=1)
        layers_exp = sorted(df_exp['layer_index'].unique())
        if not layers_exp:
            return [], []
        max_layer_exp = max(layers_exp)
        opts = [{'label': f"Layer {int(l)}" if l != max_layer_exp else "All Layers (Concat)", 'value': l} for l in layers_exp]
        return opts, layers_exp
    except Exception as e:
        print("Error updating layer options:", e)
        return [], []

@app.callback(
    Output('mse-graph', 'figure'),
    Input('experiment-dropdown', 'value'),
    Input('analysis-type-checklist', 'value'),
    Input('belief-type-checklist', 'value'),
    Input('layer-checklist', 'value')
)
def update_graph(selected_experiment, selected_analysis_types, selected_belief_types, selected_layers):
    if not selected_experiment or not selected_analysis_types or not selected_belief_types or not selected_layers:
        return go.Figure()
    try:
        df_exp = pd.read_csv(selected_experiment)
    except Exception as e:
        print("Error reading CSV:", e)
        return go.Figure()
    
    # Compute x_value: for random controls, use seed; else, use checkpoint (as float).
    def compute_x_value(row):
        if row['random_or_trained'] == 'random':
            return row['seed']
        else:
            try:
                return float(row['checkpoint'])
            except:
                return None
    df_exp['x_value'] = df_exp.apply(compute_x_value, axis=1)
    
    # Filter based on selections.
    filtered = df_exp[
        df_exp['random_or_trained'].isin(selected_analysis_types) &
        df_exp['norm_type'].isin(selected_belief_types) &
        df_exp['layer_index'].isin(selected_layers)
    ]
    filtered = filtered.sort_values('x_value')
    
    # Determine x-axis range from non-random data.
    non_random = filtered[filtered['random_or_trained'] != 'random']
    if not non_random.empty:
        x_min = non_random['x_value'].min()
        x_max = non_random['x_value'].max()
    else:
        x_min = filtered['x_value'].min()
        x_max = filtered['x_value'].max()
    x_range = (x_min, x_max)
    
    # For color mapping, use all layers from the full dataset (df_exp), so that the mapping stays constant.
    all_layers = sorted(df_exp['layer_index'].unique())
    if not all_layers:
        return go.Figure()  # Return empty figure if no layers
        
    max_layer_val = max(all_layers)
    # Identify non-concat layers (all except the maximum)
    non_concat_layers = [l for l in all_layers if l != max_layer_val]
    
    color_map = {}
    for layer in all_layers:
        if layer == max_layer_val:
            color_map[layer] = 'black'
        else:
            # Handle the case where there's only one non-concat layer
            if len(non_concat_layers) == 1:
                frac = 0.5  # Use middle of colorscale for single layer
            else:
                # Normalize between 0 and 1 based on position in non_concat_layers
                frac = non_concat_layers.index(layer) / (len(non_concat_layers) - 1) if len(non_concat_layers) > 1 else 0
            sampled = px.colors.sample_colorscale("Oryel", frac)
            color_map[layer] = sampled[0]
    
    # Define line dash styles for each analysis type.
    line_styles = {
        'trained': 'solid',
        'trained_shuffled': 'dash',
        'random': 'dot'
    }
    
    fig = go.Figure()
    # Group by analysis type, belief type, and layer_index.
    grouped = filtered.groupby(['random_or_trained', 'norm_type', 'layer_index'])
    for (analysis_type, norm_type, layer_index), group in grouped:
        layer_label = f"Layer {int(layer_index)}" if layer_index != max_layer_val else "All Layers (Concat)"
        trace_label = f"{analysis_type.title()} | {norm_type.title()} | {layer_label}"
        trace_color = color_map.get(layer_index, "black")
        
        if analysis_type == 'random':
            # Aggregate over seeds.
            avg_MSE = group['MSE'].mean()
            std_MSE = group['MSE'].std()
            n = group.shape[0]
            sem_MSE = std_MSE / math.sqrt(n) if n > 0 else 0
            # Plot a shaded region for (avg ± SEM) across the x_range.
            x_shaded = [x_range[0], x_range[1], x_range[1], x_range[0]]
            y_shaded = [avg_MSE - sem_MSE, avg_MSE - sem_MSE, avg_MSE + sem_MSE, avg_MSE + sem_MSE]
            # Convert the hex color to an RGBA with transparency.
            # This simple conversion assumes hex color in format "#RRGGBB".
            try:
                r = int(trace_color[1:3], 16)
                g = int(trace_color[3:5], 16)
                b = int(trace_color[5:7], 16)
                fill_color = f"rgba({r},{g},{b},0.2)"
            except:
                fill_color = "rgba(0,0,0,0.2)"
            fig.add_trace(go.Scatter(
                x=x_shaded,
                y=y_shaded,
                fill="toself",
                fillcolor=fill_color,
                line=dict(color="rgba(255,255,255,0)"),
                showlegend=False,
                hoverinfo='skip'
            ))
            # Plot the horizontal mean line.
            fig.add_trace(go.Scatter(
                x=[x_range[0], x_range[1]],
                y=[avg_MSE, avg_MSE],
                mode='lines',
                name=trace_label,
                line=dict(color=trace_color, dash=line_styles[analysis_type], width=3),
                hovertemplate=(
                    f"Analysis Type: {analysis_type}<br>" +
                    f"Belief Type: {norm_type}<br>" +
                    f"{layer_label}<br>" +
                    f"Avg MSE: {avg_MSE:.5f}<br>" +
                    f"SEM: {sem_MSE:.5f}<extra></extra>"
                )
            ))
        else:
            group = group.sort_values('x_value')
            fig.add_trace(go.Scatter(
                x=group['x_value'],
                y=group['MSE'],
                mode='lines+markers',
                name=trace_label,
                line=dict(color=trace_color, dash=line_styles[analysis_type]),
                marker=dict(symbol='circle', size=5),  # Reduced marker size.
                hovertemplate=(
                    f"Analysis Type: {analysis_type}<br>" +
                    f"Belief Type: {norm_type}<br>" +
                    f"{layer_label}<br>" +
                    "X: %{x}<br>" +
                    "MSE: %{y:.5f}<extra></extra>"
                )
            ))
    
    fig.update_layout(
        title="MSE vs. Checkpoint/Seed (Unified Comparison)",
        xaxis_title="Checkpoint Number (for Trained) or Seed (Random controls averaged)",
        yaxis_title="MSE (log scale)",
        yaxis_type="log",
        legend_title="Trace Groups",
        margin={'l': 50, 'r': 50, 't': 50, 'b': 50}
    )
    return fig

# ------------------------------
# 5. Run the App
# ------------------------------

if __name__ == '__main__':
    app.run_server(debug=True)


In [1]:
#!/usr/bin/env python3
"""
plot_mse_comparison_normalized_dashboard.py

This script scans the local "analysis_cache" folder for experiments (runs)
that contain a 'run_data.csv' file. It then launches an interactive dashboard
(using Dash) that lets you select an experiment and filter the data via checkboxes for:

  - Analysis Type: Trained, Trained Shuffled, and Random.
  - Belief Type: Normalized and Unnormalized.
  - Layer Selection: Individual layers (with the highest index shown as "All Layers (Concat)").

Additionally, two new dropdowns allow you to choose:
  - Which shuffled baseline to use for normalization ("Trained Shuffled" vs. "Random").
  - Which normalization method to apply (None, Ratio, Percent Difference, or Z-Score).

For trained models the x axis shows the checkpoint number; for random controls the
data are averaged over seeds (or used in aggregate) and displayed accordingly.

Usage:
    python plot_mse_comparison_normalized_dashboard.py

Dependencies:
    - dash
    - pandas
    - plotly
"""

import os
from pathlib import Path
import math
from datetime import datetime

import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px

# ------------------------------
# 1. Scan for Experiments (Runs)
# ------------------------------

BASE_DIR = Path("analysis_cache")
experiment_csvs = list(BASE_DIR.glob("**/run_data.csv"))
experiment_options = []
for csv_file in experiment_csvs:
    parts = csv_file.parts
    # Expected structure: analysis_cache/<sweep_id>/<run_id>/run_data.csv
    if len(parts) >= 4:
        sweep_id = parts[-3]
        run_id = parts[-2]
        label = f"Sweep {sweep_id} | Run {run_id}"
        value = str(csv_file)
        experiment_options.append({'label': label, 'value': value})

experiment_options = sorted(experiment_options, key=lambda x: x['label'])
default_experiment = experiment_options[0]['value'] if experiment_options else None

# ------------------------------
# 2. Define Filter Options
# ------------------------------

analysis_type_options = [
    {'label': 'Trained', 'value': 'trained'},
    {'label': 'Trained Shuffled', 'value': 'trained_shuffled'},
    {'label': 'Random', 'value': 'random'}
]

belief_type_options = [
    {'label': 'Normalized', 'value': 'normalized'},
    {'label': 'Unnormalized', 'value': 'unnormalized'}
]

# Load the default experiment CSV to initialize the layer checklist.
if default_experiment is not None:
    try:
        df_default = pd.read_csv(default_experiment)
        def compute_x_value(row):
            if row['random_or_trained'] == 'random':
                return row['seed']
            else:
                try:
                    return float(row['checkpoint'])
                except:
                    return None
        df_default['x_value'] = df_default.apply(compute_x_value, axis=1)
        layers_default = sorted(df_default['layer_index'].unique())
    except Exception as e:
        print("Error reading default experiment CSV:", e)
        layers_default = []
else:
    layers_default = []

max_layer = max(layers_default) if layers_default else None
layer_options = []
for l in layers_default:
    label = f"Layer {int(l)}" if l != max_layer else "All Layers (Concat)"
    layer_options.append({'label': label, 'value': l})

# ------------------------------
# 3. Build the Dash App Layout
# ------------------------------

app = dash.Dash(__name__)
app.title = "MSE Normalization Dashboard"

app.layout = html.Div(
    style={'fontFamily': 'Arial, sans-serif', 'margin': '20px'},
    children=[
        html.H1("MSE Normalization Dashboard", style={'textAlign': 'center'}),
        html.Div([
            html.Label("Select Experiment:"),
            dcc.Dropdown(
                id="experiment-dropdown",
                options=experiment_options,
                value=default_experiment,
                clearable=False,
                style={"width": "80%", "margin": "auto"}
            )
        ], style={'margin-bottom': '20px'}),
        html.Div([
            html.Div([
                html.H3("Analysis Type"),
                dcc.Checklist(
                    id='analysis-type-checklist',
                    options=analysis_type_options,
                    value=['trained', 'trained_shuffled', 'random'],
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc', 'margin-right': '20px'}),
            html.Div([
                html.H3("Belief Type"),
                dcc.Checklist(
                    id='belief-type-checklist',
                    options=belief_type_options,
                    value=['normalized', 'unnormalized'],
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc', 'margin-right': '20px'}),
            html.Div([
                html.H3("Layer Selection"),
                dcc.Checklist(
                    id='layer-checklist',
                    options=layer_options,
                    value=layers_default,
                    labelStyle={'display': 'inline-block', 'margin-right': '15px'}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc'})
        ], style={'display': 'flex', 'justifyContent': 'center', 'margin-bottom': '20px'}),
        # --- New controls for normalization ---
        html.Div([
            html.Div([
                html.H3("Shuffle Baseline"),
                dcc.Dropdown(
                    id="baseline-type-dropdown",
                    options=[
                        {'label': 'Trained Shuffled', 'value': 'trained_shuffled'},
                        {'label': 'Random', 'value': 'random'},
                    ],
                    value='trained_shuffled',
                    clearable=False,
                    style={"width": "200px"}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc', 'margin-right': '20px'}),
            html.Div([
                html.H3("Normalization Method"),
                dcc.Dropdown(
                    id="normalization-method-dropdown",
                    options=[
                        {'label': 'None (Raw MSE)', 'value': 'none'},
                        {'label': 'Ratio', 'value': 'ratio'},
                        {'label': 'Percent Difference', 'value': 'percent'},
                        {'label': 'Z-Score', 'value': 'zscore'},
                    ],
                    value='none',
                    clearable=False,
                    style={"width": "200px"}
                )
            ], style={'padding': '10px', 'border': '1px solid #ccc'})
        ], style={'display': 'flex', 'justifyContent': 'center', 'margin-bottom': '20px'}),
        dcc.Graph(id='mse-graph')
    ]
)

# ------------------------------
# 4. Callbacks
# ------------------------------

# Update the available layer options when the experiment is changed.
@app.callback(
    Output('layer-checklist', 'options'),
    Output('layer-checklist', 'value'),
    Input('experiment-dropdown', 'value')
)
def update_layer_options(selected_experiment):
    if not selected_experiment:
        return [], []
    try:
        df_exp = pd.read_csv(selected_experiment)
        def compute_x_value(row):
            if row['random_or_trained'] == 'random':
                return row['seed']
            else:
                try:
                    return float(row['checkpoint'])
                except:
                    return None
        df_exp['x_value'] = df_exp.apply(compute_x_value, axis=1)
        layers_exp = sorted(df_exp['layer_index'].unique())
        if not layers_exp:
            return [], []
        max_layer_exp = max(layers_exp)
        opts = [{'label': f"Layer {int(l)}" if l != max_layer_exp else "All Layers (Concat)", 'value': l} for l in layers_exp]
        return opts, layers_exp
    except Exception as e:
        print("Error updating layer options:", e)
        return [], []

# Main callback for updating the graph.
@app.callback(
    Output('mse-graph', 'figure'),
    Input('experiment-dropdown', 'value'),
    Input('analysis-type-checklist', 'value'),
    Input('belief-type-checklist', 'value'),
    Input('layer-checklist', 'value'),
    Input('baseline-type-dropdown', 'value'),
    Input('normalization-method-dropdown', 'value')
)
def update_graph(selected_experiment, selected_analysis_types, selected_belief_types,
                 selected_layers, baseline_type, normalization_method):
    # If no experiment or if required selections are empty, return an empty figure.
    if not selected_experiment or not selected_belief_types or not selected_layers:
        return go.Figure()

    try:
        df_exp = pd.read_csv(selected_experiment)
    except Exception as e:
        print("Error reading CSV:", e)
        return go.Figure()

    # Compute x_value: for non-random runs, use checkpoint (as float); for random, use seed.
    def compute_x_value(row):
        if row['random_or_trained'] == 'random':
            return row['seed']
        else:
            try:
                return float(row['checkpoint'])
            except:
                return None
    df_exp['x_value'] = df_exp.apply(compute_x_value, axis=1)

    # Filter data based on belief type and layer selection.
    filtered = df_exp[
        df_exp['norm_type'].isin(selected_belief_types) &
        df_exp['layer_index'].isin(selected_layers)
    ]

    # If no normalization is requested, plot raw MSE for the selected analysis types.
    if normalization_method == 'none':
        # Further filter based on analysis types (raw view includes all types).
        filtered = filtered[filtered['random_or_trained'].isin(selected_analysis_types)]
        filtered = filtered.sort_values('x_value')
        
        # Determine x_range from non-random data.
        non_random = filtered[filtered['random_or_trained'] != 'random']
        if not non_random.empty:
            x_min = non_random['x_value'].min()
            x_max = non_random['x_value'].max()
        else:
            x_min = filtered['x_value'].min()
            x_max = filtered['x_value'].max()
        x_range = (x_min, x_max)
        
        # Color mapping based on layers.
        all_layers = sorted(df_exp['layer_index'].unique())
        if not all_layers:
            return go.Figure()
        max_layer_val = max(all_layers)
        non_concat_layers = [l for l in all_layers if l != max_layer_val]
        color_map = {}
        for layer in all_layers:
            if layer == max_layer_val:
                color_map[layer] = 'black'
            else:
                if len(non_concat_layers) == 1:
                    frac = 0.5
                else:
                    frac = non_concat_layers.index(layer) / (len(non_concat_layers) - 1) if len(non_concat_layers) > 1 else 0
                sampled = px.colors.sample_colorscale("Oryel", frac)
                color_map[layer] = sampled[0]
        
        # Define line dash styles.
        line_styles = {
            'trained': 'solid',
            'trained_shuffled': 'dash',
            'random': 'dot'
        }
        
        # Define belief type styles
        belief_type_styles = {
            'normalized': {'opacity': 1.0, 'width': 3},
            'unnormalized': {'opacity': 0.6, 'width': 2}
        }
        
        fig = go.Figure()
        grouped = filtered.groupby(['random_or_trained', 'norm_type', 'layer_index'])
        for (analysis_type, norm_type, layer_index), group in grouped:
            # Label the layer.
            layer_label = f"Layer {int(layer_index)}" if layer_index != max_layer_val else "All Layers (Concat)"
            trace_label = f"{analysis_type.title()} | {norm_type.title()} | {layer_label}"
            trace_color = color_map.get(layer_index, "black")
            
            # Get belief type style
            style = belief_type_styles[norm_type]
            
            if analysis_type == 'random':
                # Aggregate random runs: show horizontal line (with SEM shading).
                avg_MSE = group['MSE'].mean()
                std_MSE = group['MSE'].std()
                n = group.shape[0]
                sem_MSE = std_MSE / math.sqrt(n) if n > 0 else 0
                x_shaded = [x_range[0], x_range[1], x_range[1], x_range[0]]
                y_shaded = [avg_MSE - sem_MSE, avg_MSE - sem_MSE, avg_MSE + sem_MSE, avg_MSE + sem_MSE]
                try:
                    r = int(trace_color[1:3], 16)
                    g = int(trace_color[3:5], 16)
                    b = int(trace_color[5:7], 16)
                    fill_color = f"rgba({r},{g},{b},0.2)"
                except:
                    fill_color = "rgba(0,0,0,0.2)"
                fig.add_trace(go.Scatter(
                    x=x_shaded,
                    y=y_shaded,
                    fill="toself",
                    fillcolor=fill_color,
                    line=dict(color="rgba(255,255,255,0)"),
                    showlegend=False,
                    hoverinfo='skip'
                ))
                fig.add_trace(go.Scatter(
                    x=[x_range[0], x_range[1]],
                    y=[avg_MSE, avg_MSE],
                    mode='lines',
                    name=trace_label,
                    line=dict(
                        color=trace_color, 
                        dash=line_styles[analysis_type], 
                        width=style['width']
                    ),
                    opacity=style['opacity'],
                    hovertemplate=(
                        f"Analysis Type: {analysis_type}<br>" +
                        f"Belief Type: {norm_type}<br>" +
                        f"{layer_label}<br>" +
                        f"Avg MSE: {avg_MSE:.5f}<br>" +
                        f"SEM: {sem_MSE:.5f}<extra></extra>"
                    )
                ))
            else:
                group = group.sort_values('x_value')
                fig.add_trace(go.Scatter(
                    x=group['x_value'],
                    y=group['MSE'],
                    mode='lines+markers',
                    name=trace_label,
                    line=dict(
                        color=trace_color, 
                        dash=line_styles[analysis_type],
                        width=style['width']
                    ),
                    opacity=style['opacity'],
                    marker=dict(symbol='circle', size=5),
                    hovertemplate=(
                        f"Analysis Type: {analysis_type}<br>" +
                        f"Belief Type: {norm_type}<br>" +
                        f"{layer_label}<br>" +
                        "X: %{x}<br>" +
                        "MSE: %{y:.5f}<extra></extra>"
                    )
                ))
        fig.update_layout(
            title="MSE vs. Checkpoint/Seed (Raw Values)",
            xaxis_title="Checkpoint (for Trained) or Seed (for Random)",
            yaxis_title="MSE (log scale)",
            yaxis_type="log",
            legend_title="Trace Groups",
            margin={'l': 50, 'r': 50, 't': 50, 'b': 50}
        )
        return fig

    else:
        # When normalization is requested, we compare only the trained data,
        # normalized by the chosen baseline.
        # Filter out trained runs.
        trained_data = filtered[filtered['random_or_trained'] == 'trained']
        if trained_data.empty:
            return go.Figure()

        # Get baseline data using the selected baseline type.
        baseline_data = filtered[filtered['random_or_trained'] == baseline_type]
        if baseline_data.empty:
            return go.Figure()

        # Depending on the baseline type, define the merge key.
        # For "trained_shuffled", we assume matching checkpoints (x_value) exist.
        # For "random", we aggregate baseline data by layer and norm_type.
        if baseline_type == 'trained_shuffled':
            merged = pd.merge(
                trained_data,
                baseline_data,
                on=['x_value', 'layer_index', 'norm_type'],
                suffixes=('', '_baseline')
            )
        elif baseline_type == 'random':
            # Aggregate baseline: compute mean and std over the group.
            baseline_stats = baseline_data.groupby(['layer_index', 'norm_type'])['MSE']\
                                .agg(['mean', 'std']).reset_index()
            merged = pd.merge(
                trained_data,
                baseline_stats,
                on=['layer_index', 'norm_type'],
                how='left'
            )
            # Rename for consistency.
            merged.rename(columns={'mean': 'MSE_baseline', 'std': 'std_baseline'}, inplace=True)
        else:
            merged = pd.DataFrame()  # Should not occur.

        if merged.empty:
            return go.Figure()

        # Compute the normalized MSE based on the chosen method.
        if normalization_method == 'ratio':
            merged['normalized_MSE'] = merged.apply(
                lambda row: 1 - row['MSE'] / row['MSE_baseline'] if row['MSE_baseline'] != 0 else None,
                axis=1
            )
            y_label = "Normalized MSE (Trained / Baseline)"
        elif normalization_method == 'percent':
            merged['normalized_MSE'] = merged.apply(
                lambda row: 100 * (row['MSE'] - row['MSE_baseline']) / row['MSE_baseline']
                if row['MSE_baseline'] != 0 else None,
                axis=1
            )
            y_label = "Percent Difference (%)"
        elif normalization_method == 'zscore':
            # For zscore, if baseline is "random" we use the aggregated std;
            # if baseline is "trained_shuffled" and if there are multiple entries per checkpoint, it will work similarly.
            if baseline_type == 'random':
                merged['normalized_MSE'] = merged.apply(
                    lambda row: (row['MSE'] - row['MSE_baseline']) / row['std_baseline']
                    if row['std_baseline'] != 0 else None,
                    axis=1
                )
            else:
                # For "trained_shuffled" we might have only one baseline value per checkpoint.
                merged['normalized_MSE'] = None  # Could be skipped or handled differently.
            y_label = "Z-Score (Trained relative to Baseline)"
        else:
            merged['normalized_MSE'] = merged['MSE']  # Fallback, though this branch should not occur.
            y_label = "MSE"

        # Sort merged data by x_value.
        merged = merged.sort_values('x_value')

        # Set up color mapping (using the overall available layers in the full dataset).
        all_layers = sorted(df_exp['layer_index'].unique())
        if not all_layers:
            return go.Figure()
        max_layer_val = max(all_layers)
        non_concat_layers = [l for l in all_layers if l != max_layer_val]
        color_map = {}
        for layer in all_layers:
            if layer == max_layer_val:
                color_map[layer] = 'black'
            else:
                if len(non_concat_layers) == 1:
                    frac = 0.5
                else:
                    frac = non_concat_layers.index(layer) / (len(non_concat_layers) - 1) if len(non_concat_layers) > 1 else 0
                sampled = px.colors.sample_colorscale("Oryel", frac)
                color_map[layer] = sampled[0]

        # Define belief type styles
        belief_type_styles = {
            'normalized': {'opacity': 1.0, 'width': 3},
            'unnormalized': {'opacity': 0.6, 'width': 2}
        }

        fig = go.Figure()
        # Group the merged data by norm_type and layer_index.
        grouped = merged.groupby(['norm_type', 'layer_index'])
        for (norm_type, layer_index), group in grouped:
            layer_label = f"Layer {int(layer_index)}" if layer_index != max_layer_val else "All Layers (Concat)"
            trace_label = f"Trained | {norm_type.title()} | {layer_label} (Normalized by {baseline_type})"
            trace_color = color_map.get(layer_index, "black")
            
            # Get belief type style
            style = belief_type_styles[norm_type]
            
            fig.add_trace(go.Scatter(
                x=group['x_value'],
                y=group['normalized_MSE'],
                mode='lines+markers',
                name=trace_label,
                line=dict(color=trace_color, dash='solid', width=style['width']),
                opacity=style['opacity'],
                marker=dict(symbol='circle', size=5),
                hovertemplate="Checkpoint: %{x}<br>Normalized MSE: %{y:.5f}<extra></extra>"
            ))
        fig.update_layout(
            title="Normalized MSE vs. Checkpoint (Trained Models)",
            xaxis_title="Checkpoint (for Trained models)",
            yaxis_title=y_label,
            margin={'l': 50, 'r': 50, 't': 50, 'b': 50}
        )
        return fig

# ------------------------------
# 5. Run the App
# ------------------------------

if __name__ == '__main__':
    app.run_server(debug=True)
