In [None]:
import os
import re
import ast

import pandas as pd
import numpy as np

import plotly.express as px
import plotly.graph_objects as go

from typing import List


In [None]:
results_dir = '/Users/tania/Documents/GitHub/critical-learning-period-effect/artifacts/final_results/MNIST_variations'
csv_files = [f for f in os.listdir(results_dir) if f.endswith('.csv')]
print(csv_files)

results = [
    ('Base model - Multi-channel MNIST', '0_MNIST_variations_multi-channel-example_results.csv'),
    ('MNIST Colour (source=1 target=0.999)', '1_colourMNIST_source1_target0.999_target0_oct_results.csv'),
    ('Noisy MNIST Colour (source=1 target=0.999)', '2_colourMNIST_source1_target0.999_target0_botleneck_results.csv'),
    ('Noisy MNIST Colour (source=1 target=0.999)', '3_colourMNISTnoisy_source1_target0.999_eval0_colourNoiseSTD0.07_results.csv'),
    ('Noisy MNIST Colour (source=1 target=0.999)', "4_colourMNISTnoisy_source1_target0.999_eval0_colourNoiseSTD0.07__class 'utils.models.mlp.BasicClassifierModule'__results.csv")
]

In [None]:
i = 4
experiment_name = results[i][0]
fp = results[i][1]
results_df = pd.read_csv(os.path.join(results_dir, fp))
runs = 1

print(fp)

In [None]:
results_df['initialisation'].unique()

In [None]:
def plot_accuracy_curves(
    results_df: pd.DataFrame,
    target_eval: str = "MNIST_hard_target",
    source_keywords: List[str] = ["pretraining"],
    accuracy=True

):
    """
    Plot learning curves (mean ± SEM) for different initialisations on a chosen evaluation set.
    Produces two plots: one for non-source initialisations and one for source/pretraining initialisations.
    
    Args:
        results_df: DataFrame containing run histories and initial evaluation columns.
        target_eval: Evaluation set prefix, e.g. "MNIST_hard_target".
        source_keywords: Keywords used to identify pretraining/source initialisations.
    """
    if accuracy:
        acc_col = f"{target_eval}_acc"
        init_acc_col = f"initial_{target_eval}_logits_acc"
        label = "Accuracy"
    else:
        acc_col = f"{target_eval}_loss"
        init_acc_col = f"initial_{target_eval}_logits_loss"
        label = "Loss"

    if acc_col not in results_df.columns or init_acc_col not in results_df.columns:
        raise ValueError(f"Required columns not found for {target_eval}: {acc_col}, {init_acc_col}")

    # Split initialisations into two groups
    all_inits = results_df['initialisation'].unique()
    pre_training = [init for init in all_inits if any(k.lower() in init.lower() for k in source_keywords)]
    target_models = [init for init in all_inits if init not in pre_training]

    def make_plot(sub_df, inits_to_plot, title_suffix):
        # Aggregate over epochs
        agg_acc = (
            sub_df[sub_df['initialisation'].isin(inits_to_plot)]
            .groupby(['initialisation', 'epoch'])
            .agg(acc_mean=(acc_col, 'mean'), acc_sem=(acc_col, 'sem'))
            .reset_index()
        )

        # Initial accuracy row (epoch = 0)
        per_run_init = (
            sub_df[sub_df['initialisation'].isin(inits_to_plot)]
            .groupby(['initialisation', 'run'])[init_acc_col]
            .first()
            .reset_index()
        )

        initial_rows = (
            per_run_init
            .groupby('initialisation')[init_acc_col]
            .agg(acc_mean='mean', acc_sem='sem')
            .reset_index()
            .assign(epoch=0)
        )

        # Combine and sort
        agg_acc = pd.concat([initial_rows, agg_acc], ignore_index=True)
        agg_acc = agg_acc.sort_values(['initialisation', 'epoch']).reset_index(drop=True)

        # Plot
        fig = go.Figure()
        colors = px.colors.sample_colorscale('Viridis', np.linspace(0, 1, len(inits_to_plot)))
        color_map = dict(zip(inits_to_plot, colors))

        for init in inits_to_plot:
            df_init = agg_acc[agg_acc['initialisation'] == init]
            color = color_map[init]

            # Mean curve
            fig.add_trace(go.Scatter(
                x=df_init['epoch'],
                y=df_init['acc_mean'],
                mode='lines',
                name=f"{init}<br>(final {df_init['acc_mean'].iloc[-1]:.3f} ± {df_init['acc_sem'].iloc[-1]:.3f})",
                line=dict(color=color, width=2)
            ))

            # SEM shading
            fig.add_trace(go.Scatter(
                x=df_init['epoch'].tolist() + df_init['epoch'][::-1].tolist(),
                y=(df_init['acc_mean'] + df_init['acc_sem']).tolist() +
                  (df_init['acc_mean'] - df_init['acc_sem'])[::-1].tolist(),
                fill='toself',
                fillcolor=color.replace('rgb', 'rgba').replace(')', ',0.15)'),
                line=dict(color='rgba(255,255,255,0)'),
                hoverinfo="skip",
                showlegend=False,
            ))

            # Individual runs
            for source_file, df_run in sub_df[sub_df['initialisation'] == init].groupby('source'):
                fig.add_trace(go.Scatter(
                    x=df_run['epoch'],
                    y=df_run[acc_col],
                    mode='lines',
                    line=dict(color=color, width=1),
                    opacity=0.25,
                    name=f'{init} - {source_file}',
                    showlegend=False
                ))

        fig.update_layout(
            title=f"{i}: Performance on {experiment_name} {title_suffix} (mean ± SEM)",
            xaxis_title="Epoch",
            yaxis_title=f'{label} on {target_eval}',
            template="plotly_white",
            legend_title="Initialisation",
        )
        fig.show()
    # Plot nonsource initialisations
    if target_models:
        make_plot(results_df, target_models, "")

    # Plot source/pretraining initialisations
    if pre_training:
        make_plot(results_df, pre_training, "")

In [None]:
import pandas as pd
import plotly.graph_objects as go

def plot_single_initialisation_runs(
    results_df: pd.DataFrame,
    init_name: str,
    target_eval: str = "MNIST_hard_target",
    accuracy=True
):
    """
    Plot all runs for a single initialisation on a given evaluation metric.
    Draws individual run curves as opaque lines, and shows initial epoch-0 points.
    No mean curve or shading.
    """
    if accuracy:
        acc_col = f"{target_eval}_acc"
        init_acc_col = f"initial_{target_eval}_logits_acc"
        label = "Accuracy"
    else:
        acc_col = f"{target_eval}_loss"
        init_acc_col = f"initial_{target_eval}_logits_loss"
        label = "Loss"
    
    if acc_col not in results_df.columns or init_acc_col not in results_df.columns:
        raise ValueError(f"Missing required columns for {target_eval}: {acc_col}, {init_acc_col}")

    df_init = results_df[results_df['initialisation'] == init_name]
    if df_init.empty:
        raise ValueError(f"No runs found for initialisation '{init_name}'")

    fig = go.Figure()

    # Plot each run line and its initial point
    for source_file, df_run in df_init.groupby('source'):
        init_val = df_run[init_acc_col].iloc[0]

        # Run curve
        fig.add_trace(go.Scatter(
            x=[0] + df_run['epoch'].tolist(),
            y=[init_val] + df_run[acc_col].tolist(),
            mode='lines',
            name=str(source_file),
            line=dict(width=2),
        ))

    fig.update_layout(
        title=f"All runs for '{init_name}' on {target_eval}",
        xaxis_title='Epoch',
        yaxis_title=label,
        template='plotly_white',
        legend_title='Run',
    )

    fig.show()


In [None]:
accuracy = True # Plot accuracy or loss

plot_accuracy_curves(results_df, target_eval="MNIST_hard_target", accuracy=accuracy)
plot_accuracy_curves(results_df, target_eval="MNIST_test_target", accuracy=accuracy)
plot_accuracy_curves(results_df, target_eval="MNIST_hard_eval", accuracy=accuracy)
plot_accuracy_curves(results_df, target_eval="MNIST_hard_gray", accuracy=accuracy)

In [None]:
plot_accuracy_curves(results_df, target_eval="MNIST_hard_target", accuracy=accuracy)
plot_accuracy_curves(results_df, target_eval="MNIST_test_target", accuracy=accuracy)
plot_accuracy_curves(results_df, target_eval="MNIST_hard_eval", accuracy=accuracy)
plot_accuracy_curves(results_df, target_eval="MNIST_hard_gray", accuracy=accuracy)

In [None]:
plot_accuracy_curves(results_df, target_eval="MNIST_hard_target")
plot_accuracy_curves(results_df, target_eval="MNIST_test_target")
plot_accuracy_curves(results_df, target_eval="MNIST_hard_eval")
plot_accuracy_curves(results_df, target_eval="MNIST_hard_gray")

In [None]:
accuracy = False
metric = "MNIST_test_target"
plot_single_initialisation_runs(results_df, "target with random init", target_eval=metric, accuracy=accuracy)
plot_single_initialisation_runs(results_df, "pretraining on source from random init", target_eval=metric, accuracy=accuracy)
plot_single_initialisation_runs(results_df, "target with source pre-trained model init", target_eval=metric, accuracy=accuracy)
