# Transfer Learning: Comprehensive Analysis with Extended Training

**Comprehensive multi-seed experiment with extended epochs and multiple accuracy thresholds**

## Key Features:
- âœ… Extended training: 30K epochs for baseline (to capture grokking)
- âœ… Multiple accuracy thresholds: 90%, 95%, 99%, **99.9%**
- âœ… Grokking moment detection
- âœ… Google Drive persistence
- âœ… Statistical analysis across N seeds
- âœ… Publication-quality visualizations

## Research Question:
Does transfer from a grokked addition model enable rapid generalization on subtraction,
bypassing the typical 10K+ epoch "grokking delay"?

**Expected Runtime:** ~6-8 hours on GPU (5 seeds Ã— 2 conditions with extended epochs)

## Configuration

In [None]:
# ========== EXPERIMENT CONFIGURATION ==========

# Number of random seeds for statistical robustness
NUM_SEEDS = 5
SEEDS = [42, 123, 456, 789, 1024]

# Training epochs per condition
TRANSFER_EPOCHS = 10000   # Transfer converges quickly but we track to 10K for complete picture
BASELINE_EPOCHS = 30000   # Baseline needs time to grok (typically 5K-20K epochs)

# Accuracy thresholds to track
THRESHOLDS = [0.90, 0.95, 0.99, 0.999]  # 90%, 95%, 99%, 99.9%

# Source checkpoint
CHECKPOINT_PATH = 'saved_runs/wd_10-1_mod_addition_loss_curve.pth'

# Save checkpoints every N epochs (for analysis)
SAVE_CHECKPOINT_EVERY = 1000

print("="*80)
print("COMPREHENSIVE TRANSFER LEARNING EXPERIMENT")
print("="*80)
print(f"\nSeeds: {SEEDS} (N={NUM_SEEDS})")
print(f"Transfer epochs: {TRANSFER_EPOCHS:,}")
print(f"Baseline epochs: {BASELINE_EPOCHS:,}")
print(f"Accuracy thresholds: {[f'{t:.1%}' for t in THRESHOLDS]}")
print(f"\nEstimated runtime: ~6-8 hours on GPU")
print(f"Total experiments: {NUM_SEEDS * 2} ({NUM_SEEDS} seeds Ã— 2 conditions)")

## Setup: Mount Google Drive

In [None]:
from google.colab import drive
import os
from datetime import datetime
import time

# Mount Google Drive
drive.mount('/content/drive')

# Create experiment directory with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
DRIVE_BASE = '/content/drive/MyDrive/grokking_transfer_experiments'
EXPERIMENT_DIR = f'{DRIVE_BASE}/comprehensive_run_{timestamp}'

os.makedirs(EXPERIMENT_DIR, exist_ok=True)
os.makedirs(f'{EXPERIMENT_DIR}/figures', exist_ok=True)
os.makedirs(f'{EXPERIMENT_DIR}/checkpoints', exist_ok=True)
os.makedirs(f'{EXPERIMENT_DIR}/results', exist_ok=True)

print(f"âœ“ Google Drive mounted")
print(f"âœ“ Experiment directory: {EXPERIMENT_DIR}")

# Save configuration
import json
config_dict = {
    'num_seeds': NUM_SEEDS,
    'seeds': SEEDS,
    'transfer_epochs': TRANSFER_EPOCHS,
    'baseline_epochs': BASELINE_EPOCHS,
    'thresholds': THRESHOLDS,
    'checkpoint_path': CHECKPOINT_PATH,
    'timestamp': timestamp
}
with open(f'{EXPERIMENT_DIR}/config.json', 'w') as f:
    json.dump(config_dict, f, indent=2)
print(f"âœ“ Saved configuration")

## Setup: Clone Repo and Install Dependencies

In [None]:
# Clone repository
if not os.path.exists('progress-measures-paper-extension'):
    !git clone https://github.com/Junekhunter/progress-measures-paper-extension.git
os.chdir('progress-measures-paper-extension')
print(f"Working directory: {os.getcwd()}")

# Install dependencies
!pip install -q einops

# Imports
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import replace
import random
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict

from transformers import Transformer, Config, gen_train_test, full_loss
import helpers

print("\nâœ“ All imports successful!")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")

## Load Grokked Addition Model

In [None]:
# Load checkpoint
print(f"Loading checkpoint: {CHECKPOINT_PATH}")
checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu')

# Verify it's grokked
if 'test_losses' in checkpoint:
    final_test_loss = checkpoint['test_losses'][-1]
    if final_test_loss < 0.01:
        print(f"âœ“ Model is FULLY GROKKED (test loss: {final_test_loss:.6f})")
    else:
        print(f"âš  Warning: Model may not be fully grokked (test loss: {final_test_loss:.6f})")

# Create config
addition_config = Config(
    lr=1e-3,
    weight_decay=1.0,
    p=113,
    d_model=128,
    fn_name='add',
    frac_train=0.3,
    num_epochs=50000,
    seed=0,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# Load grokked model
grokked_addition_model = Transformer(addition_config, use_cache=False)
grokked_addition_model.to(addition_config.device)

if 'model' in checkpoint:
    grokked_addition_model.load_state_dict(checkpoint['model'])
elif 'state_dicts' in checkpoint:
    grokked_addition_model.load_state_dict(checkpoint['state_dicts'][-1])

print("âœ“ Grokked addition model loaded successfully!")

## Training Function with Multiple Thresholds

In [None]:
def train_with_comprehensive_tracking(model, config, num_epochs, condition_name, seed_label):
    """
    Train model with comprehensive tracking of multiple accuracy thresholds and grokking detection.
    
    Args:
        model: Transformer model
        config: Config object
        num_epochs: Number of training epochs
        condition_name: 'transfer' or 'baseline'
        seed_label: Label for progress bar
    
    Returns:
        Dictionary with comprehensive metrics
    """
    model.to(config.device)
    model.train()
    
    optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay, betas=(0.9, 0.98))
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))
    
    train_data, test_data = gen_train_test(config)
    
    # Tracking
    train_losses = []
    test_losses = []
    test_accuracies = []
    
    # Track when each threshold is reached
    threshold_epochs = {t: None for t in THRESHOLDS}
    
    # Grokking detection (sudden test loss drop)
    grokking_epoch = None
    window_size = 100
    
    # Checkpoints to save
    checkpoint_epochs = []
    
    start_time = time.time()
    pbar = tqdm(range(num_epochs), desc=f"{condition_name.capitalize()} {seed_label}")
    
    for epoch in pbar:
        # Training step
        train_loss = full_loss(config, model, train_data)
        train_loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Evaluation
        with torch.no_grad():
            test_loss = full_loss(config, model, test_data)
            
            test_tensor = torch.tensor(test_data).to(config.device)
            logits = model(test_tensor)[:, -1]
            predictions = logits.argmax(dim=-1)
            labels = torch.tensor([config.fn(i, j) for i, j, _ in test_data]).to(config.device)
            test_accuracy = (predictions == labels).float().mean().item()
        
        train_losses.append(train_loss.item())
        test_losses.append(test_loss.item())
        test_accuracies.append(test_accuracy)
        
        # Check thresholds
        for threshold in THRESHOLDS:
            if threshold_epochs[threshold] is None and test_accuracy >= threshold:
                threshold_epochs[threshold] = epoch
                if threshold == 0.999:  # Special logging for 99.9%
                    print(f"\nðŸŽ¯ Reached 99.9% accuracy at epoch {epoch}!")
        
        # Detect grokking (sudden test loss drop)
        if grokking_epoch is None and epoch >= window_size:
            recent_avg = np.mean(test_losses[epoch-window_size:epoch-10])
            current_avg = np.mean(test_losses[epoch-10:epoch])
            if recent_avg - current_avg > 1.0:  # Significant drop
                grokking_epoch = epoch
                print(f"\nâš¡ Grokking detected at epoch {epoch}! (test loss dropped by {recent_avg - current_avg:.2f})")
        
        # Save checkpoint periodically
        if epoch % SAVE_CHECKPOINT_EVERY == 0 and epoch > 0:
            checkpoint_epochs.append(epoch)
        
        # Update progress bar
        if epoch % 100 == 0:
            pbar.set_postfix({
                'acc': f'{test_accuracy:.4f}',
                'train_loss': f'{train_loss.item():.4f}',
                'test_loss': f'{test_loss.item():.4f}'
            })
    
    end_time = time.time()
    training_time = end_time - start_time
    
    return {
        'train_losses': train_losses,
        'test_losses': test_losses,
        'test_accuracies': test_accuracies,
        'threshold_epochs': threshold_epochs,
        'grokking_epoch': grokking_epoch,
        'final_test_accuracy': test_accuracies[-1],
        'final_train_loss': train_losses[-1],
        'final_test_loss': test_losses[-1],
        'training_time': training_time,
        'model_state': model.state_dict(),
        'seed': config.seed,
        'condition': condition_name,
        'num_epochs': num_epochs
    }

print("âœ“ Training function defined!")

## Run Comprehensive Multi-Seed Experiments

In [None]:
all_transfer_results = []
all_baseline_results = []

experiment_start = time.time()

print("\n" + "="*80)
print(f"STARTING COMPREHENSIVE EXPERIMENTS")
print(f"Total experiments: {NUM_SEEDS * 2}")
print(f"Estimated time: 6-8 hours")
print("="*80)

for i, seed in enumerate(SEEDS):
    print(f"\n{'='*80}")
    print(f"SEED {i+1}/{NUM_SEEDS}: {seed}")
    print(f"{'='*80}")
    
    # Create config for this seed
    subtraction_config = replace(
        addition_config,
        fn_name='subtract',
        seed=seed
    )
    
    # ===== TRANSFER LEARNING =====
    print(f"\nðŸ”„ TRANSFER LEARNING (seed {seed})")
    print("-" * 80)
    
    transfer_model = Transformer(subtraction_config, use_cache=False)
    transfer_model.load_state_dict(grokked_addition_model.state_dict())
    transfer_model.to(subtraction_config.device)
    
    transfer_results = train_with_comprehensive_tracking(
        transfer_model,
        subtraction_config,
        TRANSFER_EPOCHS,
        'transfer',
        f"seed {seed}"
    )
    
    all_transfer_results.append(transfer_results)
    
    # Print summary
    print(f"\nâœ“ Transfer Results (seed {seed}):")
    print(f"  Final accuracy: {transfer_results['final_test_accuracy']:.4f}")
    for threshold in THRESHOLDS:
        epoch = transfer_results['threshold_epochs'][threshold]
        if epoch is not None:
            print(f"  {threshold:.1%} at epoch: {epoch}")
        else:
            print(f"  {threshold:.1%}: NOT REACHED")
    print(f"  Training time: {transfer_results['training_time']/60:.1f} min")
    
    # Save checkpoint
    torch.save(transfer_results, f"{EXPERIMENT_DIR}/checkpoints/transfer_seed{seed}_comprehensive.pth")
    
    # ===== BASELINE =====
    print(f"\nðŸŽ² BASELINE (seed {seed})")
    print("-" * 80)
    
    baseline_model = Transformer(subtraction_config, use_cache=False)
    baseline_model.to(subtraction_config.device)
    
    baseline_results = train_with_comprehensive_tracking(
        baseline_model,
        subtraction_config,
        BASELINE_EPOCHS,
        'baseline',
        f"seed {seed}"
    )
    
    all_baseline_results.append(baseline_results)
    
    # Print summary
    print(f"\nâœ“ Baseline Results (seed {seed}):")
    print(f"  Final accuracy: {baseline_results['final_test_accuracy']:.4f}")
    for threshold in THRESHOLDS:
        epoch = baseline_results['threshold_epochs'][threshold]
        if epoch is not None:
            print(f"  {threshold:.1%} at epoch: {epoch}")
        else:
            print(f"  {threshold:.1%}: NOT REACHED")
    if baseline_results['grokking_epoch'] is not None:
        print(f"  âš¡ Grokking at epoch: {baseline_results['grokking_epoch']}")
    print(f"  Training time: {baseline_results['training_time']/60:.1f} min")
    
    # Save checkpoint
    torch.save(baseline_results, f"{EXPERIMENT_DIR}/checkpoints/baseline_seed{seed}_comprehensive.pth")
    
    # Free GPU memory
    del transfer_model, baseline_model
    torch.cuda.empty_cache()
    
    # Progress update
    experiments_done = (i + 1) * 2
    total_experiments = NUM_SEEDS * 2
    elapsed = time.time() - experiment_start
    estimated_total = elapsed / experiments_done * total_experiments
    remaining = estimated_total - elapsed
    
    print(f"\nðŸ“Š Progress: {experiments_done}/{total_experiments} experiments complete")
    print(f"   Elapsed: {elapsed/3600:.1f}h | Remaining: ~{remaining/3600:.1f}h")

total_time = time.time() - experiment_start
print(f"\n{'='*80}")
print("âœ… ALL EXPERIMENTS COMPLETE!")
print(f"Total time: {total_time/3600:.2f} hours")
print(f"{'='*80}")

## Aggregate Results and Statistical Analysis

In [None]:
from scipy import stats

def calculate_stats(values):
    """Calculate mean, std, and 95% CI"""
    values = [v for v in values if v is not None]
    if len(values) == 0:
        return None, None, None, None, 0
    
    mean = np.mean(values)
    std = np.std(values, ddof=1) if len(values) > 1 else 0
    
    if len(values) > 1:
        ci = stats.t.interval(0.95, len(values)-1, loc=mean, scale=stats.sem(values))
    else:
        ci = (mean, mean)
    
    return mean, std, ci[0], ci[1], len(values)

# Aggregate statistics for each threshold
print("\n" + "="*80)
print("COMPREHENSIVE STATISTICAL ANALYSIS")
print("="*80)

stats_dict = {
    'config': config_dict,
    'thresholds': {}
}

for threshold in THRESHOLDS:
    print(f"\nðŸ“Š {threshold:.1%} Accuracy Threshold:")
    print("-" * 80)
    
    # Extract epochs for this threshold
    transfer_epochs = [r['threshold_epochs'][threshold] for r in all_transfer_results]
    baseline_epochs = [r['threshold_epochs'][threshold] for r in all_baseline_results]
    
    # Calculate stats
    t_mean, t_std, t_ci_low, t_ci_high, t_n = calculate_stats(transfer_epochs)
    b_mean, b_std, b_ci_low, b_ci_high, b_n = calculate_stats(baseline_epochs)
    
    if t_mean is not None:
        print(f"  Transfer:  {t_mean:.1f} Â± {t_std:.1f} epochs (N={t_n}/{NUM_SEEDS} reached)")
        print(f"             95% CI: [{t_ci_low:.1f}, {t_ci_high:.1f}]")
    else:
        print(f"  Transfer:  NOT REACHED by any seed")
    
    if b_mean is not None:
        print(f"  Baseline:  {b_mean:.1f} Â± {b_std:.1f} epochs (N={b_n}/{NUM_SEEDS} reached)")
        print(f"             95% CI: [{b_ci_low:.1f}, {b_ci_high:.1f}]")
    else:
        print(f"  Baseline:  NOT REACHED by any seed")
    
    # Calculate speedup if both reached
    speedup = None
    if t_mean is not None and b_mean is not None:
        speedup = b_mean / t_mean
        improvement = b_mean - t_mean
        print(f"\n  ðŸš€ Speedup: {speedup:.2f}x faster")
        print(f"  ðŸ“‰ Saved: {improvement:.1f} epochs ({improvement/b_mean*100:.1f}% reduction)")
    
    # Save to dict
    stats_dict['thresholds'][threshold] = {
        'transfer': {'mean': t_mean, 'std': t_std, 'ci': [t_ci_low, t_ci_high], 'n_reached': t_n, 'values': transfer_epochs},
        'baseline': {'mean': b_mean, 'std': b_std, 'ci': [b_ci_low, b_ci_high], 'n_reached': b_n, 'values': baseline_epochs},
        'speedup': speedup
    }

# Final accuracies
print(f"\nðŸ“Š Final Test Accuracy:")
print("-" * 80)
transfer_final = [r['final_test_accuracy'] for r in all_transfer_results]
baseline_final = [r['final_test_accuracy'] for r in all_baseline_results]

print(f"  Transfer: {np.mean(transfer_final):.4f} Â± {np.std(transfer_final):.4f}")
print(f"  Baseline: {np.mean(baseline_final):.4f} Â± {np.std(baseline_final):.4f}")

stats_dict['final_accuracy'] = {
    'transfer': {'mean': np.mean(transfer_final), 'std': np.std(transfer_final), 'values': transfer_final},
    'baseline': {'mean': np.mean(baseline_final), 'std': np.std(baseline_final), 'values': baseline_final}
}

# Grokking epochs for baseline
print(f"\nâš¡ Grokking Detection (Baseline only):")
print("-" * 80)
grokking_epochs = [r['grokking_epoch'] for r in all_baseline_results if r['grokking_epoch'] is not None]
if len(grokking_epochs) > 0:
    print(f"  Detected in {len(grokking_epochs)}/{NUM_SEEDS} baseline runs")
    print(f"  Mean grokking epoch: {np.mean(grokking_epochs):.1f} Â± {np.std(grokking_epochs):.1f}")
else:
    print(f"  No grokking moments detected (gradual learning)")

stats_dict['grokking'] = {
    'n_detected': len(grokking_epochs),
    'epochs': grokking_epochs
}

# Save statistics
with open(f'{EXPERIMENT_DIR}/results/comprehensive_stats.json', 'w') as f:
    json.dump(stats_dict, f, indent=2)

print(f"\nâœ“ Saved comprehensive statistics")

## Visualization: Publication-Quality Figures

In [None]:
# Prepare curves (pad baseline to match transfer length for plotting)
transfer_curves = np.array([r['test_accuracies'] for r in all_transfer_results])
baseline_curves_full = [r['test_accuracies'] for r in all_baseline_results]

# For joint plotting, we'll handle different lengths
transfer_mean = transfer_curves.mean(axis=0)
transfer_std = transfer_curves.std(axis=0)

# Compute baseline stats at each epoch (handling variable lengths)
max_baseline_len = max(len(c) for c in baseline_curves_full)
baseline_curves_padded = []
for curve in baseline_curves_full:
    padded = list(curve) + [curve[-1]] * (max_baseline_len - len(curve))
    baseline_curves_padded.append(padded)
baseline_curves_padded = np.array(baseline_curves_padded)

baseline_mean = baseline_curves_padded.mean(axis=0)
baseline_std = baseline_curves_padded.std(axis=0)

print("âœ“ Prepared data for visualization")

In [None]:
# Create comprehensive figure
fig = plt.figure(figsize=(20, 16))
gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)

# Plot 1: Transfer Learning Progress
ax1 = fig.add_subplot(gs[0, :])
epochs_t = np.arange(len(transfer_mean))
ax1.plot(epochs_t, transfer_mean, 'b-', linewidth=2.5, label='Transfer Learning')
ax1.fill_between(epochs_t, transfer_mean - transfer_std, transfer_mean + transfer_std,
                 alpha=0.3, color='blue')
for threshold in THRESHOLDS:
    ax1.axhline(y=threshold, color='red', linestyle='--', alpha=0.4, linewidth=1)
    ax1.text(len(epochs_t)*0.02, threshold + 0.01, f'{threshold:.1%}', fontsize=9, color='red')
ax1.set_xlabel('Epoch', fontsize=13)
ax1.set_ylabel('Test Accuracy', fontsize=13)
ax1.set_title(f'Transfer Learning: Test Accuracy (N={NUM_SEEDS} seeds)', fontsize=15, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 1.05])

# Plot 2: Baseline Progress
ax2 = fig.add_subplot(gs[1, :])
epochs_b = np.arange(len(baseline_mean))
ax2.plot(epochs_b, baseline_mean, 'orange', linewidth=2.5, label='Baseline (Random Init)')
ax2.fill_between(epochs_b, baseline_mean - baseline_std, baseline_mean + baseline_std,
                 alpha=0.3, color='orange')
for threshold in THRESHOLDS:
    ax2.axhline(y=threshold, color='red', linestyle='--', alpha=0.4, linewidth=1)
    ax2.text(len(epochs_b)*0.02, threshold + 0.01, f'{threshold:.1%}', fontsize=9, color='red')

# Mark grokking moments
for r in all_baseline_results:
    if r['grokking_epoch'] is not None:
        ax2.axvline(x=r['grokking_epoch'], color='green', linestyle=':', alpha=0.5, linewidth=1.5)

ax2.set_xlabel('Epoch', fontsize=13)
ax2.set_ylabel('Test Accuracy', fontsize=13)
ax2.set_title(f'Baseline (Random Init): Test Accuracy (N={NUM_SEEDS} seeds)', fontsize=15, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1.05])

# Plot 3: Box plots for each threshold
ax3 = fig.add_subplot(gs[2, 0])
box_data = []
labels = []
positions = []
pos = 0
for i, threshold in enumerate(THRESHOLDS):
    t_epochs = [e for e in stats_dict['thresholds'][threshold]['transfer']['values'] if e is not None]
    b_epochs = [e for e in stats_dict['thresholds'][threshold]['baseline']['values'] if e is not None]
    
    if len(t_epochs) > 0:
        box_data.append(t_epochs)
        labels.append(f"{threshold:.1%}\nTransfer")
        positions.append(pos)
        pos += 1
    
    if len(b_epochs) > 0:
        box_data.append(b_epochs)
        labels.append(f"{threshold:.1%}\nBaseline")
        positions.append(pos)
        pos += 1
    
    pos += 0.5  # Gap between thresholds

if len(box_data) > 0:
    bp = ax3.boxplot(box_data, positions=positions, labels=labels, patch_artist=True, showmeans=True)
    # Color boxes
    for i, box in enumerate(bp['boxes']):
        if 'Transfer' in labels[i]:
            box.set_facecolor('lightblue')
        else:
            box.set_facecolor('lightsalmon')

ax3.set_ylabel('Epochs to Threshold', fontsize=12)
ax3.set_title('Epochs to Reach Thresholds', fontsize=14, fontweight='bold')
ax3.grid(True, alpha=0.3, axis='y')
plt.setp(ax3.xaxis.get_majorticklabels(), rotation=45, ha='right', fontsize=9)

# Plot 4: Speedup factors
ax4 = fig.add_subplot(gs[2, 1])
speedups = []
speedup_labels = []
for threshold in THRESHOLDS:
    speedup = stats_dict['thresholds'][threshold]['speedup']
    if speedup is not None:
        speedups.append(speedup)
        speedup_labels.append(f"{threshold:.1%}")

if len(speedups) > 0:
    bars = ax4.bar(range(len(speedups)), speedups, color=['green', 'darkgreen', 'blue', 'darkblue'][:len(speedups)])
    ax4.set_xticks(range(len(speedups)))
    ax4.set_xticklabels(speedup_labels, fontsize=11)
    ax4.set_ylabel('Speedup Factor', fontsize=12)
    ax4.set_title('Transfer Learning Speedup', fontsize=14, fontweight='bold')
    ax4.grid(True, alpha=0.3, axis='y')
    ax4.axhline(y=1, color='red', linestyle='--', linewidth=1, alpha=0.5)
    
    # Add value labels on bars
    for i, (bar, val) in enumerate(zip(bars, speedups)):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{val:.1f}x', ha='center', fontsize=10, fontweight='bold')

# Plot 5: Individual runs overlay (zoomed to 5000 epochs)
ax5 = fig.add_subplot(gs[2, 2])
zoom = min(5000, len(transfer_mean))
for i, r in enumerate(all_transfer_results):
    ax5.plot(r['test_accuracies'][:zoom], alpha=0.3, color='blue', linewidth=1)
for i, r in enumerate(all_baseline_results):
    ax5.plot(r['test_accuracies'][:zoom], alpha=0.3, color='orange', linewidth=1)
ax5.plot([], [], color='blue', label='Transfer', linewidth=2)
ax5.plot([], [], color='orange', label='Baseline', linewidth=2)
for threshold in THRESHOLDS:
    ax5.axhline(y=threshold, color='red', linestyle='--', alpha=0.3, linewidth=0.8)
ax5.set_xlabel('Epoch', fontsize=12)
ax5.set_ylabel('Test Accuracy', fontsize=12)
ax5.set_title(f'Individual Runs (First {zoom} Epochs)', fontsize=14, fontweight='bold')
ax5.legend(fontsize=11)
ax5.grid(True, alpha=0.3)

# Plot 6: Summary table
ax6 = fig.add_subplot(gs[3, :])
ax6.axis('off')
summary_lines = [
    "COMPREHENSIVE EXPERIMENT SUMMARY",
    "="*70,
    f"Seeds: {SEEDS} (N={NUM_SEEDS})",
    f"Transfer epochs: {TRANSFER_EPOCHS:,} | Baseline epochs: {BASELINE_EPOCHS:,}",
    "",
    "Epochs to Thresholds (mean Â± std):"
]

for threshold in THRESHOLDS:
    t_stats = stats_dict['thresholds'][threshold]['transfer']
    b_stats = stats_dict['thresholds'][threshold]['baseline']
    speedup = stats_dict['thresholds'][threshold]['speedup']
    
    summary_lines.append(f"  {threshold:.1%}:")
    if t_stats['mean'] is not None:
        summary_lines.append(f"    Transfer: {t_stats['mean']:.0f} Â± {t_stats['std']:.0f} ({t_stats['n_reached']}/{NUM_SEEDS} seeds)")
    else:
        summary_lines.append(f"    Transfer: NOT REACHED")
    
    if b_stats['mean'] is not None:
        summary_lines.append(f"    Baseline: {b_stats['mean']:.0f} Â± {b_stats['std']:.0f} ({b_stats['n_reached']}/{NUM_SEEDS} seeds)")
    else:
        summary_lines.append(f"    Baseline: NOT REACHED")
    
    if speedup is not None:
        summary_lines.append(f"    Speedup: {speedup:.1f}x")

summary_lines.extend([
    "",
    f"Final Test Accuracy:",
    f"  Transfer: {stats_dict['final_accuracy']['transfer']['mean']:.4f} Â± {stats_dict['final_accuracy']['transfer']['std']:.4f}",
    f"  Baseline: {stats_dict['final_accuracy']['baseline']['mean']:.4f} Â± {stats_dict['final_accuracy']['baseline']['std']:.4f}"
])

summary_text = "\n".join(summary_lines)
ax6.text(0.5, 0.5, summary_text, fontsize=10, family='monospace',
         verticalalignment='center', horizontalalignment='center',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

plt.savefig(f'{EXPERIMENT_DIR}/figures/comprehensive_results.png', dpi=200, bbox_inches='tight')
print(f"\nâœ“ Saved comprehensive figure")
plt.show()

## Save Complete Results

In [None]:
# Save everything
complete_results = {
    'config': config_dict,
    'transfer_results': all_transfer_results,
    'baseline_results': all_baseline_results,
    'statistics': stats_dict,
    'total_time_hours': total_time / 3600
}

torch.save(complete_results, f'{EXPERIMENT_DIR}/results/complete_results.pth')
print(f"âœ“ Saved complete results (.pth)")

# Save curves as numpy
np.savez(f'{EXPERIMENT_DIR}/results/curves.npz',
         transfer_curves=transfer_curves,
         baseline_curves_padded=baseline_curves_padded,
         seeds=np.array(SEEDS),
         thresholds=np.array(THRESHOLDS))
print(f"âœ“ Saved curves (.npz)")

print("\n" + "="*80)
print("âœ… COMPREHENSIVE EXPERIMENT COMPLETE!")
print("="*80)
print(f"\nResults saved to: {EXPERIMENT_DIR}")
print(f"Total runtime: {total_time/3600:.2f} hours")
print(f"\nKey files:")
print(f"  - figures/comprehensive_results.png (publication-quality figure)")
print(f"  - results/comprehensive_stats.json (summary statistics)")
print(f"  - results/complete_results.pth (full experimental data)")
print(f"  - checkpoints/*.pth ({NUM_SEEDS * 2} model checkpoints)")