In [None]:
import os
import re
import ast

import pandas as pd
import numpy as np

import plotly.express as px
import plotly.graph_objects as go

from typing import List

In [None]:
results_dir = '/Users/tania/Documents/GitHub/critical-learning-period-effect/artifacts/final_results/MNIST_variations'

results = [
    ('Bottleneck model - rotated MNIST 2.5 degrees','5_0_rotationMNIST_source2.5_target0_results.csv'),
    ('Bottleneck model - rotated MNIST 5 degrees','5_1_rotationMNIST_source5_target0_results.csv'),
    ('Bottleneck model - rotated MNIST 10 degrees','5_2_rotationMNIST_source10_target0_results.csv'),
    ('Bottleneck model - rotated MNIST 40 degrees','5_3_rotationMNIST_source40_target0_results.csv'),
    ]

i = 2
experiment_name = results[i][0]
results_df = pd.read_csv(os.path.join(results_dir, results[i][1]))
runs = 3


In [None]:
def plot_accuracy_curves(
    results_df: pd.DataFrame,
    target_eval: str = "MNIST_hard_target",
    source_keywords: List[str] = ["pretraining"],
    accuracy=True

):
    """
    Plot learning curves (mean ± SEM) for different initialisations on a chosen evaluation set.
    Produces two plots: one for non-source initialisations and one for source/pretraining initialisations.
    
    Args:
        results_df: DataFrame containing run histories and initial evaluation columns.
        target_eval: Evaluation set prefix, e.g. "MNIST_hard_target".
        source_keywords: Keywords used to identify pretraining/source initialisations.
    """
    if accuracy:
        acc_col = f"{target_eval}_acc"
        init_acc_col = f"initial_{target_eval}_logits_acc"
        label = "Accuracy"
    else:
        acc_col = f"{target_eval}_loss"
        init_acc_col = f"initial_{target_eval}_logits_loss"
        label = "Loss"

    if acc_col not in results_df.columns or init_acc_col not in results_df.columns:
        raise ValueError(f"Required columns not found for {target_eval}: {acc_col}, {init_acc_col}")

    # Split initialisations into two groups
    all_inits = results_df['initialisation'].unique()
    pre_training = [init for init in all_inits if any(k.lower() in init.lower() for k in source_keywords)]
    target_models = [init for init in all_inits if init not in pre_training]

    def make_plot(sub_df, inits_to_plot, title_suffix):
        # Aggregate over epochs
        agg_acc = (
            sub_df[sub_df['initialisation'].isin(inits_to_plot)]
            .groupby(['initialisation', 'epoch'])
            .agg(acc_mean=(acc_col, 'mean'), acc_sem=(acc_col, 'sem'))
            .reset_index()
        )

        # Initial accuracy row (epoch = 0)
        per_run_init = (
            sub_df[sub_df['initialisation'].isin(inits_to_plot)]
            .groupby(['initialisation', 'run'])[init_acc_col]
            .first()
            .reset_index()
        )

        initial_rows = (
            per_run_init
            .groupby('initialisation')[init_acc_col]
            .agg(acc_mean='mean', acc_sem='sem')
            .reset_index()
            .assign(epoch=0)
        )

        # Combine and sort
        agg_acc = pd.concat([initial_rows, agg_acc], ignore_index=True)
        agg_acc = agg_acc.sort_values(['initialisation', 'epoch']).reset_index(drop=True)

        # Plot
        fig = go.Figure()
        colors = px.colors.sample_colorscale('Viridis', np.linspace(0, 1, len(inits_to_plot)))
        color_map = dict(zip(inits_to_plot, colors))

        for init in inits_to_plot:
            df_init = agg_acc[agg_acc['initialisation'] == init]
            color = color_map[init]

            # Mean curve
            fig.add_trace(go.Scatter(
                x=df_init['epoch'],
                y=df_init['acc_mean'],
                mode='lines',
                name=f"{init}<br>(final {df_init['acc_mean'].iloc[-1]:.3f} ± {df_init['acc_sem'].iloc[-1]:.3f})",
                line=dict(color=color, width=2)
            ))

            # SEM shading
            fig.add_trace(go.Scatter(
                x=df_init['epoch'].tolist() + df_init['epoch'][::-1].tolist(),
                y=(df_init['acc_mean'] + df_init['acc_sem']).tolist() +
                  (df_init['acc_mean'] - df_init['acc_sem'])[::-1].tolist(),
                fill='toself',
                fillcolor=color.replace('rgb', 'rgba').replace(')', ',0.15)'),
                line=dict(color='rgba(255,255,255,0)'),
                hoverinfo="skip",
                showlegend=False,
            ))

            # Individual runs
            for source_file, df_run in sub_df[sub_df['initialisation'] == init].groupby('source'):
                fig.add_trace(go.Scatter(
                    x=df_run['epoch'],
                    y=df_run[acc_col],
                    mode='lines',
                    line=dict(color=color, width=1),
                    opacity=0.25,
                    name=f'{init} - {source_file}',
                    showlegend=False
                ))

        fig.update_layout(
            title=f"{i}: Performance on {experiment_name} {title_suffix} (mean ± SEM)",
            xaxis_title="Epoch",
            yaxis_title=f'{label} on {target_eval}',
            template="plotly_white",
            legend_title="Initialisation",
        )
        fig.show()
    # Plot nonsource initialisations
    if target_models:
        make_plot(results_df, target_models, "")

    # Plot source/pretraining initialisations
    if pre_training:
        make_plot(results_df, pre_training, "")

In [None]:
results_df.columns


In [None]:
accuracy = True # Plot accuracy or loss

target_eval = "MNIST_test_target"
plot_accuracy_curves(results_df, target_eval=target_eval, accuracy=accuracy)

target_eval = "MNIST_hard_target"
plot_accuracy_curves(results_df, target_eval=target_eval, accuracy=accuracy)

target_eval = "MNIST_hard_source"
plot_accuracy_curves(results_df, target_eval=target_eval, accuracy=accuracy)


In [None]:
degrees = [320, 350, 355, 357.5]
test_deg = degrees[-1]

target_eval =f"MNIST_hard_{test_deg}"
plot_accuracy_curves(results_df, target_eval=target_eval, accuracy=accuracy)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from typing import List, Dict

def plot_accuracy_curves_multi(
    results_dict: Dict[str, pd.DataFrame],
    target_eval: str = "MNIST_hard_target",
    source_keywords: List[str] = ["pretraining"],
    accuracy: bool = True
):
    """
    Plot learning curves (mean ± SEM) for different initialisations on a chosen evaluation set,
    across multiple experiments.

    Args:
        results_dict: Dictionary mapping experiment_name -> results DataFrame
        target_eval: Evaluation set prefix, e.g. "MNIST_hard_target"
        source_keywords: Keywords used to identify pretraining/source initialisations.
        accuracy: Whether to plot accuracy (True) or loss (False)
    """

    if accuracy:
        acc_col = f"{target_eval}_acc"
        init_acc_col = f"initial_{target_eval}_logits_acc"
        label = "Accuracy"
    else:
        acc_col = f"{target_eval}_loss"
        init_acc_col = f"initial_{target_eval}_logits_loss"
        label = "Loss"

    def make_plot(sub_df, inits_to_plot, experiment_name, title_suffix):
        # Aggregate over epochs
        agg_acc = (
            sub_df[sub_df['initialisation'].isin(inits_to_plot)]
            .groupby(['initialisation', 'epoch'])
            .agg(acc_mean=(acc_col, 'mean'), acc_sem=(acc_col, 'sem'))
            .reset_index()
        )

        # Initial accuracy (epoch=0)
        per_run_init = (
            sub_df[sub_df['initialisation'].isin(inits_to_plot)]
            .groupby(['initialisation', 'run'])[init_acc_col]
            .first()
            .reset_index()
        )

        initial_rows = (
            per_run_init
            .groupby('initialisation')[init_acc_col]
            .agg(acc_mean='mean', acc_sem='sem')
            .reset_index()
            .assign(epoch=0)
        )

        # Combine and sort
        agg_acc = pd.concat([initial_rows, agg_acc], ignore_index=True)
        agg_acc = agg_acc.sort_values(['initialisation', 'epoch']).reset_index(drop=True)

        # Plot
        fig = go.Figure()
        colors = px.colors.sample_colorscale('Viridis', np.linspace(0, 1, len(inits_to_plot)))
        color_map = dict(zip(inits_to_plot, colors))

        for init in inits_to_plot:
            df_init = agg_acc[agg_acc['initialisation'] == init]
            color = color_map[init]

            # Mean curve
            fig.add_trace(go.Scatter(
                x=df_init['epoch'],
                y=df_init['acc_mean'],
                mode='lines',
                name=f"{init}<br>(final {df_init['acc_mean'].iloc[-1]:.3f} ± {df_init['acc_sem'].iloc[-1]:.3f})",
                line=dict(color=color, width=2)
            ))

            # SEM shading
            fig.add_trace(go.Scatter(
                x=df_init['epoch'].tolist() + df_init['epoch'][::-1].tolist(),
                y=(df_init['acc_mean'] + df_init['acc_sem']).tolist() +
                  (df_init['acc_mean'] - df_init['acc_sem'])[::-1].tolist(),
                fill='toself',
                fillcolor=color.replace('rgb', 'rgba').replace(')', ',0.15)'),
                line=dict(color='rgba(255,255,255,0)'),
                hoverinfo="skip",
                showlegend=False,
            ))

            # Individual runs
            for source_file, df_run in sub_df[sub_df['initialisation'] == init].groupby('source'):
                fig.add_trace(go.Scatter(
                    x=df_run['epoch'],
                    y=df_run[acc_col],
                    mode='lines',
                    line=dict(color=color, width=1),
                    opacity=0.25,
                    name=f'{init} - {source_file}',
                    showlegend=False
                ))

        fig.update_layout(
            title=f"{experiment_name}: {target_eval} ({title_suffix})",
            xaxis_title="Epoch",
            yaxis_title=f'{label}',
            template="plotly_white",
            legend_title="Initialisation",
        )
        fig.show()

    # === Main loop over experiments ===
    for experiment_name, df in results_dict.items():
        # Check columns exist
        if acc_col not in df.columns or init_acc_col not in df.columns:
            print(f"⚠️ Skipping {experiment_name} (missing {acc_col} or {init_acc_col})")
            continue

        all_inits = df['initialisation'].unique()
        pre_training = [init for init in all_inits if any(k.lower() in init.lower() for k in source_keywords)]
        target_models = [init for init in all_inits if init not in pre_training]

        if target_models:
            make_plot(df, target_models, experiment_name, "Target-only")
        if pre_training:
            make_plot(df, pre_training, experiment_name, "Pretraining")

In [None]:
results_dict = {
    results[0][0]: pd.read_csv(os.path.join(results_dir, results[0][1])),
    results[1][0]: pd.read_csv(os.path.join(results_dir, results[1][1])),
    results[2][0]: pd.read_csv(os.path.join(results_dir, results[2][1])),
    results[3][0]: pd.read_csv(os.path.join(results_dir, results[3][1])),
}

plot_accuracy_curves_multi(results_dict, target_eval="MNIST_hard_target")

In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from typing import Dict, List

def plot_combined_nonpretraining(
    results_dict: Dict[str, pd.DataFrame],
    target_eval: str = "MNIST_hard_target",
    source_keywords: List[str] = ["pretraining"],
    accuracy: bool = True
):
    """
    Combine all non-pretraining runs from multiple experiments into a single plot.
    Each (experiment, initialisation) pair gets its own line.
    """
    # Choose column names
    if accuracy:
        acc_col = f"{target_eval}_acc"
        init_acc_col = f"initial_{target_eval}_logits_acc"
        label = "Accuracy"
    else:
        acc_col = f"{target_eval}_loss"
        init_acc_col = f"initial_{target_eval}_logits_loss"
        label = "Loss"

    # Concatenate all runs and add experiment name
    all_dfs = []
    for exp_name, df in results_dict.items():
        df = df.copy()
        df['experiment'] = exp_name
        all_dfs.append(df)
    big_df = pd.concat(all_dfs, ignore_index=True)

    # Identify pretraining vs non-pretraining initialisations
    all_inits = big_df['initialisation'].unique()
    pre_training = [init for init in all_inits if any(k.lower() in init.lower() for k in source_keywords)]
    target_models = [init for init in all_inits if init not in pre_training]

    if not target_models:
        print("⚠️ No non-pretraining runs found.")
        return

    # Aggregate over runs to get mean ± SEM
    agg_df = (
        big_df[big_df['initialisation'].isin(target_models)]
        .groupby(['experiment', 'initialisation', 'epoch'])
        .agg(acc_mean=(acc_col, 'mean'), acc_sem=(acc_col, 'sem'))
        .reset_index()
    )

    # Initial accuracy rows
    per_run_init = (
        big_df[big_df['initialisation'].isin(target_models)]
        .groupby(['experiment', 'initialisation', 'run'])[init_acc_col]
        .first()
        .reset_index()
    )
    initial_rows = (
        per_run_init
        .groupby(['experiment', 'initialisation'])[init_acc_col]
        .agg(acc_mean='mean', acc_sem='sem')
        .reset_index()
        .assign(epoch=0)
    )

    agg_df = pd.concat([initial_rows, agg_df], ignore_index=True)
    agg_df = agg_df.sort_values(['experiment', 'initialisation', 'epoch']).reset_index(drop=True)

    # Build plot
    fig = go.Figure()

    # Define colors for each experiment
    unique_experiments = agg_df['experiment'].unique()
    colors = px.colors.sample_colorscale('Viridis', np.linspace(0, 1, len(unique_experiments)))
    color_map = dict(zip(unique_experiments, colors))

    # Line styles: dashed for random initialisation, solid otherwise
    def get_line_style(init_name, base_color):
        if "random" in init_name.lower():
            return dict(
                dash='dash',          # dashed line
                width=2,              # thinner
                color=base_color.replace('rgb', 'rgba').replace(')', ',0.6)')  # 40% opacity
            )
        else:
            return dict(
                dash='solid',
                width=2,
                color=base_color
            )

    # Add traces
    for (exp_name, init), df_sub in agg_df.groupby(['experiment', 'initialisation']):
        color = color_map[exp_name]

        fig.add_trace(go.Scatter(
            x=df_sub['epoch'],
            y=df_sub['acc_mean'],
            mode='lines',
            name=f"{exp_name} — {init}",
            line=get_line_style(init, color)
        ))

    fig.update_layout(
        title=f"4. Combined Experiment Results — models pretrained on different degrees of rotation",
        xaxis_title="Epoch",
        yaxis_title=f'label on {target_eval}',
        template="plotly_white",
        legend_title="Experiment — Init",
    )
    fig.update_layout(showlegend=False)
    fig.show()

    # Legend plot
    legend_fig = go.Figure()

    for (exp_name, init), df_sub in agg_df.groupby(['experiment', 'initialisation']):
        color = color_map[exp_name]

        # Add a single dummy point just for the legend
        legend_fig.add_trace(go.Scatter(
            x=[None],  # no actual points
            y=[None],
            mode='lines',
            name=f"{exp_name} — {init}",
            line=get_line_style(init, color)
        ))

    legend_fig.update_layout(
        title="Legend",
        template="plotly_white"
    )
    legend_fig.show()

In [None]:
plot_combined_nonpretraining(results_dict, target_eval="MNIST_hard_target")