# Transfer Learning: Grokked Addition â†’ Subtraction (Multi-Seed)

This notebook tests whether a grokked modular addition model can transfer to accelerate learning on modular subtraction.

**Enhancements:**
- âœ… Google Drive integration for persistent storage
- âœ… Multiple seeds for statistical robustness
- âœ… Aggregated results with mean/std/confidence intervals
- âœ… Comprehensive visualizations

**Experiment Plan:**
1. Load a fully grokked addition model (mod 113)
2. Run transfer learning experiments across N seeds
3. Run baseline experiments across N seeds
4. Aggregate and compare results statistically

## Step 0: Mount Google Drive

In [None]:
from google.colab import drive
import os
from datetime import datetime

# Mount Google Drive
drive.mount('/content/drive')

# Create experiment directory with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
DRIVE_BASE = '/content/drive/MyDrive/grokking_transfer_experiments'
EXPERIMENT_DIR = f'{DRIVE_BASE}/run_{timestamp}'

os.makedirs(EXPERIMENT_DIR, exist_ok=True)
print(f"âœ“ Google Drive mounted")
print(f"âœ“ Experiment directory: {EXPERIMENT_DIR}")

# Create subdirectories
os.makedirs(f'{EXPERIMENT_DIR}/figures', exist_ok=True)
os.makedirs(f'{EXPERIMENT_DIR}/checkpoints', exist_ok=True)
os.makedirs(f'{EXPERIMENT_DIR}/results', exist_ok=True)
print("âœ“ Created subdirectories: figures, checkpoints, results")

## Step 1: Setup - Clone Repository and Install Dependencies

In [None]:
# Clone the repository if not already cloned
if not os.path.exists('progress-measures-paper-extension'):
    !git clone https://github.com/Junekhunter/progress-measures-paper-extension.git
    
# Change to repo directory
os.chdir('progress-measures-paper-extension')
print(f"Working directory: {os.getcwd()}")

In [None]:
# Install any missing dependencies (Colab has most already)
!pip install -q einops wandb

In [None]:
# Import necessary modules
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass, replace
import random
from pathlib import Path
from tqdm import tqdm
import json
from collections import defaultdict

# Import from the repo
from transformers import Transformer, Config, gen_train_test, full_loss
import helpers

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Step 2: Configuration

In [None]:
# Experiment configuration
NUM_SEEDS = 5  # Number of random seeds to run
NUM_EPOCHS = 5000  # Training epochs per experiment
CHECKPOINT_PATH = 'saved_runs/wd_10-1_mod_addition_loss_curve.pth'

# Seeds for reproducibility
SEEDS = [42, 123, 456, 789, 1024]
assert len(SEEDS) == NUM_SEEDS

print(f"Experiment Configuration:")
print(f"  - Number of seeds: {NUM_SEEDS}")
print(f"  - Seeds: {SEEDS}")
print(f"  - Epochs per run: {NUM_EPOCHS}")
print(f"  - Source checkpoint: {CHECKPOINT_PATH}")
print(f"  - Results will be saved to: {EXPERIMENT_DIR}")

## Step 3: Load Grokked Addition Model

In [None]:
# Load checkpoint and inspect
print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu')

print(f"\nCheckpoint keys: {list(checkpoint.keys())}")

# Analyze the checkpoint
if 'test_losses' in checkpoint:
    test_losses = checkpoint['test_losses']
    train_losses = checkpoint['train_losses']
    
    print(f"Total training epochs: {len(test_losses)}")
    
    # Check if model is fully grokked
    final_test_loss = test_losses[-1]
    if final_test_loss < 0.01:
        print(f"âœ“ Model is FULLY GROKKED (final test loss: {final_test_loss:.6f})")
    else:
        print(f"âœ— Model NOT fully grokked (final test loss: {final_test_loss:.6f})")

# Create config for the addition model
addition_config = Config(
    lr=1e-3,
    weight_decay=1.0,
    p=113,
    d_model=128,
    fn_name='add',
    frac_train=0.3,
    num_epochs=50000,
    seed=0,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# Create model and load grokked weights
grokked_addition_model = Transformer(addition_config, use_cache=False)
grokked_addition_model.to(addition_config.device)

# Load the trained weights
if 'model' in checkpoint:
    grokked_addition_model.load_state_dict(checkpoint['model'])
    print("âœ“ Loaded model from 'model' key")
elif 'state_dicts' in checkpoint:
    grokked_addition_model.load_state_dict(checkpoint['state_dicts'][-1])
    print(f"âœ“ Loaded model from 'state_dicts' (checkpoint {len(checkpoint['state_dicts'])-1})")
else:
    print("âœ— Could not find model weights in checkpoint!")

print("\nâœ“ Grokked addition model loaded successfully!")

## Step 4: Training Function

In [None]:
def train_subtraction_model(model, config, num_epochs=5000, verbose=True, seed_label=""):
    """
    Train a model on the subtraction task.
    
    Args:
        model: Transformer model to train
        config: Config object with fn_name='subtract'
        num_epochs: Number of training epochs
        verbose: Whether to print progress
        seed_label: Label for progress bar
    
    Returns:
        Dictionary with train_losses, test_losses, and other metrics
    """
    # Set up training
    model.to(config.device)
    model.train()
    
    optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay, betas=(0.9, 0.98))
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))
    
    # Generate train/test split
    train_data, test_data = gen_train_test(config)
    
    # Tracking metrics
    train_losses = []
    test_losses = []
    test_accuracies = []
    epochs_to_90_percent = None
    epochs_to_95_percent = None
    epochs_to_99_percent = None
    
    desc = f"Training {seed_label}"
    pbar = tqdm(range(num_epochs), desc=desc, disable=not verbose)
    
    for epoch in pbar:
        # Training step
        train_loss = full_loss(config, model, train_data)
        train_loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Evaluation
        with torch.no_grad():
            test_loss = full_loss(config, model, test_data)
            
            # Calculate test accuracy
            test_tensor = torch.tensor(test_data).to(config.device)
            logits = model(test_tensor)[:, -1]
            predictions = logits.argmax(dim=-1)
            labels = torch.tensor([config.fn(i, j) for i, j, _ in test_data]).to(config.device)
            test_accuracy = (predictions == labels).float().mean().item()
        
        train_losses.append(train_loss.item())
        test_losses.append(test_loss.item())
        test_accuracies.append(test_accuracy)
        
        # Track milestone epochs
        if epochs_to_90_percent is None and test_accuracy >= 0.90:
            epochs_to_90_percent = epoch
        if epochs_to_95_percent is None and test_accuracy >= 0.95:
            epochs_to_95_percent = epoch
        if epochs_to_99_percent is None and test_accuracy >= 0.99:
            epochs_to_99_percent = epoch
        
        # Update progress bar
        if epoch % 100 == 0:
            pbar.set_postfix({
                'train_loss': f'{train_loss.item():.4f}',
                'test_acc': f'{test_accuracy:.3f}'
            })
    
    return {
        'train_losses': train_losses,
        'test_losses': test_losses,
        'test_accuracies': test_accuracies,
        'epochs_to_90_percent': epochs_to_90_percent,
        'epochs_to_95_percent': epochs_to_95_percent,
        'epochs_to_99_percent': epochs_to_99_percent,
        'final_test_accuracy': test_accuracies[-1],
        'final_train_loss': train_losses[-1],
        'final_test_loss': test_losses[-1],
        'model_state': model.state_dict(),
        'seed': config.seed
    }

print("âœ“ Training function defined!")

## Step 5: Run Multi-Seed Experiments

In [None]:
# Storage for all results
all_transfer_results = []
all_baseline_results = []

print("="*80)
print(f"RUNNING {NUM_SEEDS} EXPERIMENTS")
print("="*80)

for i, seed in enumerate(SEEDS):
    print(f"\n{'='*80}")
    print(f"EXPERIMENT {i+1}/{NUM_SEEDS} - SEED {seed}")
    print(f"{'='*80}\n")
    
    # Create config for this seed
    subtraction_config = replace(
        addition_config,
        fn_name='subtract',
        seed=seed
    )
    
    # ========== Transfer Learning ==========
    print(f"\nðŸ”„ Transfer Learning (Seed {seed})")
    print("-" * 80)
    
    transfer_model = Transformer(subtraction_config, use_cache=False)
    transfer_model.load_state_dict(grokked_addition_model.state_dict())
    transfer_model.to(subtraction_config.device)
    
    transfer_results = train_subtraction_model(
        transfer_model,
        subtraction_config,
        num_epochs=NUM_EPOCHS,
        verbose=True,
        seed_label=f"Transfer (seed {seed})"
    )
    
    all_transfer_results.append(transfer_results)
    
    print(f"âœ“ Transfer - Seed {seed}:")
    print(f"  Final acc: {transfer_results['final_test_accuracy']:.4f}")
    print(f"  90% at epoch: {transfer_results['epochs_to_90_percent']}")
    
    # Save checkpoint
    torch.save(transfer_results, f"{EXPERIMENT_DIR}/checkpoints/transfer_seed{seed}.pth")
    
    # ========== Baseline ==========
    print(f"\nðŸŽ² Baseline (Seed {seed})")
    print("-" * 80)
    
    baseline_model = Transformer(subtraction_config, use_cache=False)
    baseline_model.to(subtraction_config.device)
    
    baseline_results = train_subtraction_model(
        baseline_model,
        subtraction_config,
        num_epochs=NUM_EPOCHS,
        verbose=True,
        seed_label=f"Baseline (seed {seed})"
    )
    
    all_baseline_results.append(baseline_results)
    
    print(f"âœ“ Baseline - Seed {seed}:")
    print(f"  Final acc: {baseline_results['final_test_accuracy']:.4f}")
    print(f"  90% at epoch: {baseline_results['epochs_to_90_percent']}")
    
    # Save checkpoint
    torch.save(baseline_results, f"{EXPERIMENT_DIR}/checkpoints/baseline_seed{seed}.pth")
    
    # Free GPU memory
    del transfer_model, baseline_model
    torch.cuda.empty_cache()

print(f"\n{'='*80}")
print("ALL EXPERIMENTS COMPLETE!")
print(f"{'='*80}")

## Step 6: Aggregate Results

In [None]:
def calculate_stats(values):
    """Calculate mean, std, and 95% confidence interval"""
    values = [v for v in values if v is not None]  # Filter out None values
    if len(values) == 0:
        return None, None, None, None
    
    mean = np.mean(values)
    std = np.std(values, ddof=1) if len(values) > 1 else 0
    
    # 95% confidence interval (t-distribution)
    from scipy import stats
    if len(values) > 1:
        ci = stats.t.interval(0.95, len(values)-1, loc=mean, scale=stats.sem(values))
    else:
        ci = (mean, mean)
    
    return mean, std, ci[0], ci[1]

# Aggregate metrics
transfer_90_epochs = [r['epochs_to_90_percent'] for r in all_transfer_results]
baseline_90_epochs = [r['epochs_to_90_percent'] for r in all_baseline_results]

transfer_95_epochs = [r['epochs_to_95_percent'] for r in all_transfer_results]
baseline_95_epochs = [r['epochs_to_95_percent'] for r in all_baseline_results]

transfer_final_acc = [r['final_test_accuracy'] for r in all_transfer_results]
baseline_final_acc = [r['final_test_accuracy'] for r in all_baseline_results]

# Calculate statistics
print("\n" + "="*80)
print("AGGREGATED RESULTS ACROSS ALL SEEDS")
print("="*80)

print("\nðŸ“Š Epochs to 90% Accuracy:")
t_mean, t_std, t_ci_low, t_ci_high = calculate_stats(transfer_90_epochs)
b_mean, b_std, b_ci_low, b_ci_high = calculate_stats(baseline_90_epochs)

print(f"  Transfer:  {t_mean:.1f} Â± {t_std:.1f} epochs (95% CI: [{t_ci_low:.1f}, {t_ci_high:.1f}])")
print(f"  Baseline:  {b_mean:.1f} Â± {b_std:.1f} epochs (95% CI: [{b_ci_low:.1f}, {b_ci_high:.1f}])")

if t_mean and b_mean:
    speedup = b_mean / t_mean
    improvement = b_mean - t_mean
    print(f"\n  ðŸš€ Speedup: {speedup:.2f}x faster")
    print(f"  ðŸ“‰ Saved: {improvement:.1f} epochs ({improvement/b_mean*100:.1f}% reduction)")

print("\nðŸ“Š Epochs to 95% Accuracy:")
t_mean_95, t_std_95, _, _ = calculate_stats(transfer_95_epochs)
b_mean_95, b_std_95, _, _ = calculate_stats(baseline_95_epochs)

if t_mean_95:
    print(f"  Transfer:  {t_mean_95:.1f} Â± {t_std_95:.1f} epochs")
if b_mean_95:
    print(f"  Baseline:  {b_mean_95:.1f} Â± {b_std_95:.1f} epochs")

print("\nðŸ“Š Final Test Accuracy:")
t_acc_mean, t_acc_std, _, _ = calculate_stats(transfer_final_acc)
b_acc_mean, b_acc_std, _, _ = calculate_stats(baseline_final_acc)

print(f"  Transfer:  {t_acc_mean:.4f} Â± {t_acc_std:.4f}")
print(f"  Baseline:  {b_acc_mean:.4f} Â± {b_acc_std:.4f}")

# Save aggregated statistics
stats_dict = {
    'num_seeds': NUM_SEEDS,
    'seeds': SEEDS,
    'transfer': {
        'epochs_to_90': {'mean': t_mean, 'std': t_std, 'ci': [t_ci_low, t_ci_high], 'values': transfer_90_epochs},
        'epochs_to_95': {'mean': t_mean_95, 'std': t_std_95, 'values': transfer_95_epochs},
        'final_acc': {'mean': t_acc_mean, 'std': t_acc_std, 'values': transfer_final_acc}
    },
    'baseline': {
        'epochs_to_90': {'mean': b_mean, 'std': b_std, 'ci': [b_ci_low, b_ci_high], 'values': baseline_90_epochs},
        'epochs_to_95': {'mean': b_mean_95, 'std': b_std_95, 'values': baseline_95_epochs},
        'final_acc': {'mean': b_acc_mean, 'std': b_acc_std, 'values': baseline_final_acc}
    },
    'speedup': speedup if (t_mean and b_mean) else None,
    'improvement_epochs': improvement if (t_mean and b_mean) else None
}

# Save to JSON
with open(f'{EXPERIMENT_DIR}/results/aggregated_stats.json', 'w') as f:
    json.dump(stats_dict, f, indent=2)

print(f"\nâœ“ Saved aggregated statistics to {EXPERIMENT_DIR}/results/aggregated_stats.json")

## Step 7: Visualize Results

In [None]:
# Prepare data for plotting
max_epochs = NUM_EPOCHS
epochs = np.arange(max_epochs)

# Stack all curves
transfer_acc_curves = np.array([r['test_accuracies'] for r in all_transfer_results])
baseline_acc_curves = np.array([r['test_accuracies'] for r in all_baseline_results])

transfer_loss_curves = np.array([r['test_losses'] for r in all_transfer_results])
baseline_loss_curves = np.array([r['test_losses'] for r in all_baseline_results])

# Calculate mean and std
transfer_acc_mean = transfer_acc_curves.mean(axis=0)
transfer_acc_std = transfer_acc_curves.std(axis=0)

baseline_acc_mean = baseline_acc_curves.mean(axis=0)
baseline_acc_std = baseline_acc_curves.std(axis=0)

transfer_loss_mean = transfer_loss_curves.mean(axis=0)
transfer_loss_std = transfer_loss_curves.std(axis=0)

baseline_loss_mean = baseline_loss_curves.mean(axis=0)
baseline_loss_std = baseline_loss_curves.std(axis=0)

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 3, figsize=(20, 12))

# Plot 1: Test Accuracy (with confidence bands)
axes[0, 0].plot(epochs, transfer_acc_mean, label='Transfer Learning', color='blue', linewidth=2)
axes[0, 0].fill_between(epochs, 
                        transfer_acc_mean - transfer_acc_std, 
                        transfer_acc_mean + transfer_acc_std,
                        alpha=0.3, color='blue')

axes[0, 0].plot(epochs, baseline_acc_mean, label='Random Init', color='orange', linewidth=2)
axes[0, 0].fill_between(epochs, 
                        baseline_acc_mean - baseline_acc_std, 
                        baseline_acc_mean + baseline_acc_std,
                        alpha=0.3, color='orange')

axes[0, 0].axhline(y=0.9, color='red', linestyle='--', alpha=0.5, label='90% Target')
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Test Accuracy', fontsize=12)
axes[0, 0].set_title(f'Test Accuracy Over Time (N={NUM_SEEDS} seeds)', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Test Loss (log scale)
axes[0, 1].plot(epochs, np.log10(transfer_loss_mean + 1e-10), label='Transfer Learning', color='blue', linewidth=2)
axes[0, 1].fill_between(epochs,
                        np.log10(transfer_loss_mean - transfer_loss_std + 1e-10),
                        np.log10(transfer_loss_mean + transfer_loss_std + 1e-10),
                        alpha=0.3, color='blue')

axes[0, 1].plot(epochs, np.log10(baseline_loss_mean + 1e-10), label='Random Init', color='orange', linewidth=2)
axes[0, 1].fill_between(epochs,
                        np.log10(baseline_loss_mean - baseline_loss_std + 1e-10),
                        np.log10(baseline_loss_mean + baseline_loss_std + 1e-10),
                        alpha=0.3, color='orange')

axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Log10(Test Loss)', fontsize=12)
axes[0, 1].set_title(f'Test Loss Over Time (N={NUM_SEEDS} seeds)', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Individual runs (Test Accuracy)
for i, (seed, results) in enumerate(zip(SEEDS, all_transfer_results)):
    axes[0, 2].plot(results['test_accuracies'], alpha=0.4, color='blue', linewidth=1)
for i, (seed, results) in enumerate(zip(SEEDS, all_baseline_results)):
    axes[0, 2].plot(results['test_accuracies'], alpha=0.4, color='orange', linewidth=1)

axes[0, 2].plot([], [], color='blue', label='Transfer', linewidth=2)
axes[0, 2].plot([], [], color='orange', label='Baseline', linewidth=2)
axes[0, 2].axhline(y=0.9, color='red', linestyle='--', alpha=0.5)
axes[0, 2].set_xlabel('Epoch', fontsize=12)
axes[0, 2].set_ylabel('Test Accuracy', fontsize=12)
axes[0, 2].set_title('Individual Runs (All Seeds)', fontsize=14, fontweight='bold')
axes[0, 2].legend(fontsize=10)
axes[0, 2].grid(True, alpha=0.3)

# Plot 4: Zoomed (first 1000 epochs)
zoom = 1000
axes[1, 0].plot(epochs[:zoom], transfer_acc_mean[:zoom], label='Transfer', color='blue', linewidth=2)
axes[1, 0].fill_between(epochs[:zoom],
                        transfer_acc_mean[:zoom] - transfer_acc_std[:zoom],
                        transfer_acc_mean[:zoom] + transfer_acc_std[:zoom],
                        alpha=0.3, color='blue')

axes[1, 0].plot(epochs[:zoom], baseline_acc_mean[:zoom], label='Baseline', color='orange', linewidth=2)
axes[1, 0].fill_between(epochs[:zoom],
                        baseline_acc_mean[:zoom] - baseline_acc_std[:zoom],
                        baseline_acc_mean[:zoom] + baseline_acc_std[:zoom],
                        alpha=0.3, color='orange')

axes[1, 0].axhline(y=0.9, color='red', linestyle='--', alpha=0.5)

# Mark mean 90% epochs
if t_mean and t_mean < zoom:
    axes[1, 0].axvline(x=t_mean, color='blue', linestyle=':', alpha=0.7, linewidth=2)
if b_mean and b_mean < zoom:
    axes[1, 0].axvline(x=b_mean, color='orange', linestyle=':', alpha=0.7, linewidth=2)

axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Test Accuracy', fontsize=12)
axes[1, 0].set_title(f'Test Accuracy (First {zoom} Epochs)', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, alpha=0.3)

# Plot 5: Box plot of epochs to 90%
data_to_plot = [transfer_90_epochs, baseline_90_epochs]
bp = axes[1, 1].boxplot(data_to_plot, labels=['Transfer', 'Baseline'], 
                        patch_artist=True, showmeans=True)
bp['boxes'][0].set_facecolor('lightblue')
bp['boxes'][1].set_facecolor('lightsalmon')

axes[1, 1].set_ylabel('Epochs to 90% Accuracy', fontsize=12)
axes[1, 1].set_title('Distribution of Epochs to 90%', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3, axis='y')

# Plot 6: Summary statistics table
axes[1, 2].axis('off')
summary_text = f"""
SUMMARY STATISTICS (N={NUM_SEEDS} seeds)

Epochs to 90% Accuracy:
  Transfer:  {t_mean:.1f} Â± {t_std:.1f}
  Baseline:  {b_mean:.1f} Â± {b_std:.1f}
  
  Speedup:   {speedup:.2f}x
  Saved:     {improvement:.1f} epochs
  
Final Test Accuracy:
  Transfer:  {t_acc_mean:.4f} Â± {t_acc_std:.4f}
  Baseline:  {b_acc_mean:.4f} Â± {b_acc_std:.4f}
  
Seeds: {SEEDS}
"""
axes[1, 2].text(0.1, 0.5, summary_text, fontsize=11, family='monospace',
               verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
axes[1, 2].set_title('Summary', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(f'{EXPERIMENT_DIR}/figures/multiseed_results.png', dpi=200, bbox_inches='tight')
print(f"âœ“ Saved figure to {EXPERIMENT_DIR}/figures/multiseed_results.png")
plt.show()

## Step 8: Save All Results to Google Drive

In [None]:
# Save complete results
complete_results = {
    'config': {
        'num_seeds': NUM_SEEDS,
        'seeds': SEEDS,
        'num_epochs': NUM_EPOCHS,
        'checkpoint_path': CHECKPOINT_PATH,
        'timestamp': timestamp
    },
    'transfer_results': all_transfer_results,
    'baseline_results': all_baseline_results,
    'statistics': stats_dict
}

# Save as PyTorch checkpoint
torch.save(complete_results, f'{EXPERIMENT_DIR}/results/complete_results.pth')
print(f"âœ“ Saved complete results to {EXPERIMENT_DIR}/results/complete_results.pth")

# Save curves as numpy for easy analysis
np.savez(f'{EXPERIMENT_DIR}/results/curves.npz',
         transfer_acc_curves=transfer_acc_curves,
         baseline_acc_curves=baseline_acc_curves,
         transfer_loss_curves=transfer_loss_curves,
         baseline_loss_curves=baseline_loss_curves,
         seeds=np.array(SEEDS))
print(f"âœ“ Saved curves to {EXPERIMENT_DIR}/results/curves.npz")

print("\n" + "="*80)
print("âœ… EXPERIMENT COMPLETE - ALL RESULTS SAVED TO GOOGLE DRIVE")
print("="*80)
print(f"\nResults location: {EXPERIMENT_DIR}")
print("\nDirectory structure:")
print(f"  {EXPERIMENT_DIR}/")
print(f"    â”œâ”€â”€ figures/")
print(f"    â”‚   â””â”€â”€ multiseed_results.png")
print(f"    â”œâ”€â”€ checkpoints/")
print(f"    â”‚   â”œâ”€â”€ transfer_seed*.pth ({NUM_SEEDS} files)")
print(f"    â”‚   â””â”€â”€ baseline_seed*.pth ({NUM_SEEDS} files)")
print(f"    â””â”€â”€ results/")
print(f"        â”œâ”€â”€ complete_results.pth")
print(f"        â”œâ”€â”€ curves.npz")
print(f"        â””â”€â”€ aggregated_stats.json")

## Conclusion

This multi-seed experiment provides statistical evidence for whether grokked modular addition models can transfer to accelerate subtraction learning.

**Key Features:**
- âœ… Multiple seeds for reproducibility
- âœ… Statistical analysis (mean, std, confidence intervals)
- âœ… Comprehensive visualizations
- âœ… Persistent storage in Google Drive

**Next Steps:**
- Analyze variance across seeds
- Test different source checkpoints
- Try other target operations
- Investigate which model components transfer most effectively