In [None]:
import os
import re
import ast

import pandas as pd
import numpy as np

import plotly.express as px
import plotly.graph_objects as go


In [None]:
results_dir = '/Users/tania/Documents/GitHub/critical-learning-period-effect/artifacts/final_results/MNIST_variations'
csv_files = [f for f in os.listdir(results_dir) if f.endswith('.csv')]
print(csv_files)

In [None]:
j = 0
experiment_name = "MNIST Colour - Base model"
results_df = pd.read_csv(os.path.join(results_dir, csv_files[j]))
runs = 3

print(csv_files[j])
i = csv_files[j][0]

In [None]:
# Build the results dataframe
subset_df = results_df[results_df["initialisation"] != "initialised random trained on noisy"]

agg_acc = (
    subset_df
    .groupby(['initialisation', 'epoch'])
    .agg(
        valid_acc_mean=('valid_acc', 'mean'),
        valid_acc_sem=('valid_acc', 'sem')
    )
    .reset_index()
)

# Compute one initial row per initialisation
per_run_init_scores = (
    subset_df
    .groupby(['initialisation', 'run'])['initial_valid_acc']
    .first()
    .reset_index()
)

initial_rows = (
    per_run_init_scores
    .groupby('initialisation')['initial_valid_acc']
    .agg(valid_acc_mean='mean', valid_acc_sem='sem')
    .reset_index()
    .assign(epoch=0)
)

# Concatenate with the aggregated DataFrame
agg_acc = pd.concat([initial_rows, agg_acc], ignore_index=True)

# Sort so that epoch 0 comes first within each initialisation
agg_acc = agg_acc.sort_values(['initialisation', 'epoch']).reset_index(drop=True)

In [None]:
fig = go.Figure()

# Plot for each initialisation
inits = agg_acc['initialisation'].unique()
colors = px.colors.sample_colorscale('Viridis', np.linspace(0, 1, len(inits)))
color_map = dict(zip(inits, colors))

for init in agg_acc['initialisation'].unique():
    df_init = agg_acc[agg_acc['initialisation'] == init]
    color = color_map[init]

    # Plot the mean
    fig.add_trace(go.Scatter(
            x=df_init['epoch'],
            y=df_init['valid_acc_mean'],
            mode='lines',
            name=f"{init} <br> (final value of {str(df_init['valid_acc_mean'].iloc[-1])[:6]} ± {str(df_init['valid_acc_sem'].iloc[-1])[:6]})",
            line=dict(color=color, width=2),
            showlegend=True  # Hide individual runs from legend
        ))

    # Shaded SEM
    fig.add_trace(go.Scatter(
            x=df_init['epoch'].tolist() + df_init['epoch'][::-1].tolist(),
            y=(df_init['valid_acc_mean'] + df_init['valid_acc_sem']).tolist() +
              (df_init['valid_acc_mean'] - df_init['valid_acc_sem'])[::-1].tolist(),
            fill='toself',
            fillcolor=color.replace('rgb', 'rgba').replace(')', ',0.15)'),
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False,
        ))
    
    for source_file, df_run in results_df.where(results_df['initialisation']==init).groupby('source'):
        fig.add_trace(go.Scatter(
            x=df_run['epoch'],
            y=df_run['valid_acc'],
            mode='lines',
            line=dict(color=color, width=1),
            opacity=0.25,  # See-through for individual runs
            name=f'{init} - run {source_file}',
            showlegend=False  # legend stays clean
        ))



fig.update_layout(
    title=f'Fig {i}: {experiment_name} (mean ± SEM of {runs} runs)',
    xaxis_title='Epoch',
    yaxis_title='Test Accuracy',
    template='plotly_white'
)

fig.show()