In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))
sys.path.append(os.path.abspath("."))

In [None]:
import utils_io
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import yaml
import json
import torch


from tqdm.notebook import tqdm
# --- Import our custom modules ---
# Use autoreload to pick up changes in .py files without restarting kernel
%load_ext autoreload
%autoreload 2


# --- Matplotlib Aesthetics Settings ---
plt.rcParams.update({
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "font.family": "serif",
    "axes.titlesize": 18,
    "axes.labelsize": 16,
    "legend.fontsize": 12,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "axes.edgecolor": "black",
    "axes.linewidth": 1,
    'grid.alpha': 0.5,
    'grid.color': '#c0c0c0',
    'grid.linestyle': '--',
    'grid.linewidth': 0.8,
    'figure.dpi': 300,
    'figure.edgecolor': 'black',
    'figure.facecolor': 'white',
    'figure.figsize': [6, 5],
})

In [3]:
ROOT_DIR = './results/sgd_hyperparam_sweep_v1'

In [4]:
def get_experiment_index(root_dir):
    configs = []
    files = glob.glob(os.path.join(root_dir, '**', 'config.yaml'), recursive=True)
    print(len(files))
    print(f"Indexing {len(files)} experiments...")
    for f in files:
        try:
            with open(f, 'r') as stream:
                cfg = yaml.safe_load(stream)
            
            configs.append({
                'config': cfg,
                'metrics_path': os.path.join(os.path.dirname(f), 'metrics.pt'),
                'mode': cfg.get('training_mode', 'unknown'),
                'reg_type': cfg.get('reg_type'),
                'seed': cfg.get('seed'),
                'curvature': cfg.get('curvature_type', 'fisher'),
                'accumulate': cfg.get('accumulate', False),
                'data_relationship': cfg.get('data_relationship', 'independent') , 
                'ignore_gradient': cfg.get('ignore_gradient', False),
                'spectral_override': cfg.get('spectral_override', False),
                'alpha': cfg.get('alpha', 1.0),
                'projection': cfg.get('projection', False),
                'reg_frac': cfg.get('reg_frac', 1.0),
                'replay_frac': cfg.get('replay_frac', 0.0)
            })
        except Exception: 
            print(f"Failed to load config from {f}, skipping...")
            continue
    return configs

In [7]:
def extract_trajectory_series(results, config, metric_root):
    num_steps = config['num_steps']
    num_tasks = config['environment_args']['num_tasks']
    steps_axis, vals_stale, vals_fresh = [], [], []

    for t in range(num_tasks):
        if t not in results: continue
        task_data = results[t]
        offset = t * num_steps

        # --- A. Handle Vector Metrics (Alignments/Projections) ---
        # --- Vector Metrics (Alignments) ---
        if metric_root in ['alignments_with_top_eigs', 'projections_along_top_eigs']:
            alignment_history = task_data.get(metric_root, [])

            # IF TASK 0 IS EMPTY: Pad it with zeros so the trajectory starts at Step 1
            if t == 0 and not alignment_history:
                # Assuming alignment is logged at the same frequency as performance
                for i in range(num_steps):
                    step = ((i + 1)) + offset
                    steps_axis.append(step)
                    # Pad with a list of zeros of length K
                    zero_vec = [0.0] * 10
                    vals_stale.append(zero_vec)
                    vals_fresh.append(zero_vec)
            else:
                # Normal processing for tasks that have data
                for i, step_values in enumerate(alignment_history):
                    step = ((i + 1)) + offset
                    steps_axis.append(step)
                    vals_stale.append(step_values)
                    vals_fresh.append(step_values)

            continue

        # --- B. Route Standard Metrics to the correct bucket ---
        if any(k in metric_root for k in ['kappa', 'cos_sim', 'accuracy']):
            bucket = 'history'
        elif any(k in metric_root for k in ['sharpness', 'grad_norm', 'energy_frac']):
            bucket = 'landscape'
        else:
            bucket = 'performance'

        data_list = task_data.get(bucket, [])

        for i, entry in enumerate(data_list):
            # --- C. Step Calculation ---
            if bucket == 'performance':
                step = entry.get('step', 0) + offset
            elif bucket == 'landscape':
                l_int = config.get('log_every_n_steps', 1) * config.get('landscape_interval', 1)
                step = (i * l_int) + offset
            else: # History
                step = (i + 1) + offset

            # --- D. Value Extraction ---
            v_s, v_f = np.nan, np.nan

            if bucket == 'history':
                # entry is a LIST of sample dicts. Average them.
                if metric_root in entry[0]:
                    val = [s.get(metric_root, np.nan) for s in entry]
                    v_s, v_f = val, val
                else: 
                    # Kappa metrics have _acc and _res versions
                    v_s = [s.get(f"{metric_root}_acc", np.nan) for s in entry]
                    v_f = [s.get(f"{metric_root}_res", np.nan) for s in entry]
                
                if metric_root == 'kappa_loss':
                    v_s = [abs(x) if not np.isnan(x) else np.nan for x in v_s]
                    v_f = [abs(x) if not np.isnan(x) else np.nan for x in v_f]
                
                v_s = np.mean(v_s)
                v_f = np.mean(v_f)

            elif bucket == 'performance':
                if metric_root in entry:
                    val = entry.get(metric_root)
                elif '_' in metric_root:
                    parts = metric_root.split('_')
                    top_key = parts[0]
                    sub_key = "_".join(parts[1:])
                    val = entry.get(top_key, {}).get(sub_key, np.nan)
                else:
                    val = np.nan
                v_s, v_f = val, val

            elif bucket == 'landscape':
                val = entry.get(metric_root, np.nan)
                v_s, v_f = val, val

            steps_axis.append(step)
            vals_stale.append(v_s)
            vals_fresh.append(v_f)

    # Return after all tasks have been processed
    return np.array(steps_axis), np.array(vals_stale, dtype=object), np.array(vals_fresh, dtype=object)

In [18]:
def collect_results_to_df(filtered_index, metrics_to_collect):
    all_rows = []
    
    for exp in filtered_index:
        if not os.path.exists(exp['metrics_path']): continue
        try:
            results_data = torch.load(exp['metrics_path'], map_location='cpu', weights_only=False)
        except: 
            continue
        cfg = exp['config']
        
        # Metadata block (kept exactly as you had it)
        metadata = {
            'seed': cfg.get('seed'),
            'lr': cfg['optimizer'].get('lr'),        # New: from sweep
            'alpha': cfg.get('alpha', 1.0),
            'batch_size': cfg.get('batch_size'),   # New: from sweep
            'momentum': cfg['optimizer'].get('momentum'),       # New: from sweep
            'mode': cfg.get('training_mode'),
            'reg_type': cfg.get('reg_type'),
            'curvature': cfg.get('curvature_type', 'fisher'),
            'strategy': 'accumulate' if cfg.get('accumulate') else 'refresh',
            'ignore_gradient': cfg.get('ignore_gradient', False),
            'spectral_override': cfg.get('spectral_override', False),
            # Folder name identifier for traceability
            'run_slug': os.path.basename(os.path.dirname(exp['metrics_path']))
        }

        # 1. Extraction: Store as {metric_name: {step: (stale_val, fresh_val)}}
        # This prevents IndexError because we map values to steps directly
        extracted_lookup = {}
        all_unique_steps = set()
        step_to_task_id = {} # Map to store task index from data

        for m_root in metrics_to_collect:
            steps, v_s, v_f = extract_trajectory_series(results_data, cfg, m_root)
            if len(steps) > 0:
                # Create a lookup for this specific metric
                extracted_lookup[m_root] = {
                    s: (vs, vf) for s, vs, vf in zip(steps, v_s, v_f)
                }
                all_unique_steps.update(steps)


        # 2. Extract Task IDs from 'performance' logs (as requested)
        for t_idx, task_data in results_data.items():
            if not isinstance(t_idx, int): continue
            offset = t_idx * cfg['num_steps']
            for entry in task_data.get('performance', []):
                g_step = entry.get('step', 0) + offset
                step_to_task_id[g_step] = entry.get('task', t_idx)

        if not all_unique_steps: 
            continue

        # 3. Build rows by iterating through the Master Step List
        for step in sorted(list(all_unique_steps)):
            # Determine task_id from the 'task' variable, fallback to calculation if missing
            current_task_id = step_to_task_id.get(step, int(step // cfg['num_steps']))
            
            for approx in ['stale', 'fresh']:
                row = metadata.copy()
                row.update({
                    'step': step, 
                    'approx_type': approx, 
                    'task_idx': current_task_id
                })
                
                # Fill metrics: Use values if step exists for this metric, else NaN
                for m_root in metrics_to_collect:
                    vals = extracted_lookup.get(m_root, {}).get(step)
                    if vals is not None:
                        # index 0 is stale, index 1 is fresh
                        row[m_root] = vals[0] if approx == 'stale' else vals[1]
                    else:
                        row[m_root] = np.nan # Quantity not present at this step
                
                all_rows.append(row)

    return pd.DataFrame(all_rows)

In [26]:
full_index = get_experiment_index(ROOT_DIR)

print(f"Total experiments found: {len(full_index)}")

170
Indexing 170 experiments...
Total experiments found: 170


In [28]:
# Define metrics you want to track for the sweep summary
sweep_metrics = ['test_avg_past', 'test_current', 'test_avg_total', 'accuracy']

# Collect all experiments into a giant DataFrame
# We'll filter later to keep the memory footprint manageable
df_full = collect_results_to_df(full_index, sweep_metrics)

# 1. Define the parameters that stay constant for a single 'run' across seeds
# Ensure these match the keys in your collect_results_to_df exactly
config_cols = ['reg_type', 'alpha', 'momentum', 'lr', 'batch_size']
metrics_to_show = ['test_avg_total', 'test_avg_past', 'test_current']

df_clean = df_full

# 3. Handle Potential NaNs in config cols (prevents groups from being dropped)
for col in config_cols:
    df_clean[col] = df_clean[col].fillna('NA')

# 4. Step 1: Find the "Final State" for every single seed/config combo
# We group by config + seed to get the last step recorded for that specific run
last_step_indices = df_clean.groupby(config_cols + ['seed'])['step'].idxmax()
df_final_per_seed = df_clean.loc[last_step_indices]

# 5. Step 2: Average those final states ACROSS seeds
# This gives you the mean and std for each hyperparameter configuration
sweep_summary = df_final_per_seed.groupby(config_cols)[metrics_to_show].agg(['mean', 'std'])

# Flatten the Multi-Index columns (e.g., ('test_current', 'mean') -> 'test_current_mean')
sweep_summary.columns = [f"{m}_{s}" for m, s in sweep_summary.columns]
sweep_summary = sweep_summary.reset_index()

# 6. Sort by Average Past Performance (Memory) to see which config solves forgetting best
top_configs = sweep_summary.sort_values('test_avg_past_mean', ascending=False)

print(f"Aggregated {len(sweep_summary)} hyperparameter combinations.")
display(top_configs.head(20))

Aggregated 54 hyperparameter combinations.


Unnamed: 0,reg_type,alpha,momentum,lr,batch_size,test_avg_total_mean,test_avg_total_std,test_avg_past_mean,test_avg_past_std,test_current_mean,test_current_std
17,taylor-full,0.05,0.9,0.0008,32,0.888889,0.021999,0.89625,0.005303,0.83,0.155563
16,taylor-full,0.05,0.9,0.0008,16,0.884444,0.021999,0.8925,0.003536,0.82,0.169706
25,taylor-full,0.1,0.9,0.001,32,0.88,0.028284,0.8875,0.010607,0.82,0.169706
18,taylor-full,0.05,0.9,0.002,16,0.883333,0.020428,0.885,0.010607,0.87,0.098995
24,taylor-full,0.1,0.9,0.001,16,0.875556,0.021999,0.88375,0.005303,0.81,0.155563
19,taylor-full,0.05,0.9,0.002,32,0.883333,0.017285,0.88375,0.008839,0.88,0.084853
2,taylor-block,0.1,0.9,0.0008,32,0.846667,,0.865,,0.7,
15,taylor-full,0.01,0.9,0.003,32,0.866667,0.0,0.86,0.007071,0.92,0.056569
1,taylor-block,0.05,0.9,0.0008,32,0.844444,,0.8575,,0.74,
20,taylor-full,0.05,0.9,0.003,16,0.852222,0.001571,0.84875,0.012374,0.88,0.113137
