# Transfer Learning: Grokked Addition → Subtraction

This notebook tests whether a grokked modular addition model can transfer to accelerate learning on modular subtraction.

**Experiment Plan:**
1. Load a fully grokked addition model (mod 113)
2. Fine-tune it on subtraction task (a - b mod 113)
3. Compare against random initialization baseline

**Key Metrics:**
- Epochs to reach 90% test accuracy
- Test accuracy curves over training
- Training loss curves

## Setup: Clone Repository and Install Dependencies

In [None]:
# Clone the repository if not already cloned
import os
if not os.path.exists('progress-measures-paper-extension'):
    !git clone https://github.com/Junekhunter/progress-measures-paper-extension.git
    
# Change to repo directory
os.chdir('progress-measures-paper-extension')
print(f"Working directory: {os.getcwd()}")

In [None]:
# Install any missing dependencies (Colab has most already)
!pip install -q einops wandb

In [None]:
# Import necessary modules
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass, replace
import random
from pathlib import Path
from tqdm import tqdm

# Import from the repo
from transformers import Transformer, Config, gen_train_test, full_loss
import helpers

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Step 1: Inspect and Load Grokked Addition Model

In [None]:
# Load checkpoint and inspect
checkpoint_path = 'saved_runs/wd_10-1_mod_addition_loss_curve.pth'
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location='cpu')

print(f"\nCheckpoint keys: {list(checkpoint.keys())}")
print("\n" + "="*80)

# Analyze the checkpoint structure
if 'config' in checkpoint:
    print(f"\nConfig: {checkpoint['config']}")
    
if 'test_losses' in checkpoint:
    test_losses = checkpoint['test_losses']
    train_losses = checkpoint['train_losses']
    
    print(f"\nTotal training epochs: {len(test_losses)}")
    print(f"\nFinal 10 epochs:")
    for i in range(max(0, len(test_losses)-10), len(test_losses)):
        print(f"  Epoch {i}: train_loss={train_losses[i]:.6f}, test_loss={test_losses[i]:.6f}")
    
    # Check if model is fully grokked
    final_test_loss = test_losses[-1]
    if final_test_loss < 0.01:
        print(f"\n✓ Model is FULLY GROKKED (final test loss: {final_test_loss:.6f})")
    else:
        print(f"\n✗ Model NOT fully grokked (final test loss: {final_test_loss:.6f})")
        
# Plot training curves
if 'train_losses' in checkpoint and 'test_losses' in checkpoint:
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(checkpoint['train_losses'], label='Train Loss', alpha=0.7)
    plt.plot(checkpoint['test_losses'], label='Test Loss', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Grokked Addition Model: Training Curves')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(np.log10(np.array(checkpoint['train_losses'])+1e-10), label='Log Train Loss', alpha=0.7)
    plt.plot(np.log10(np.array(checkpoint['test_losses'])+1e-10), label='Log Test Loss', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Log10(Loss)')
    plt.title('Grokked Addition Model: Log Training Curves')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Create config for the addition model
addition_config = Config(
    lr=1e-3,
    weight_decay=1.0,
    p=113,
    d_model=128,
    fn_name='add',
    frac_train=0.3,
    num_epochs=50000,
    seed=0,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# Create model and load grokked weights
grokked_addition_model = Transformer(addition_config, use_cache=False)
grokked_addition_model.to(addition_config.device)

# Load the trained weights
if 'model' in checkpoint:
    grokked_addition_model.load_state_dict(checkpoint['model'])
    print("✓ Loaded model from 'model' key")
elif 'state_dicts' in checkpoint:
    # If there are multiple checkpoints, use the last one
    grokked_addition_model.load_state_dict(checkpoint['state_dicts'][-1])
    print(f"✓ Loaded model from 'state_dicts' (checkpoint {len(checkpoint['state_dicts'])-1})")
else:
    print("✗ Could not find model weights in checkpoint!")

print("\n✓ Grokked addition model loaded successfully!")

## Step 2: Verify Addition Model Performance

In [None]:
# Test the loaded model on addition
grokked_addition_model.eval()

# Generate test data
test_samples = 20
print("Testing grokked addition model on random examples:\n")
print("Input (a, b) | Ground Truth (a+b mod 113) | Model Prediction")
print("-" * 60)

correct = 0
with torch.no_grad():
    for _ in range(test_samples):
        a = np.random.randint(0, 113)
        b = np.random.randint(0, 113)
        ground_truth = (a + b) % 113
        
        # Prepare input
        input_tensor = torch.tensor([[a, b, 113]]).to(addition_config.device)
        logits = grokked_addition_model(input_tensor)[0, -1]
        prediction = logits.argmax().item()
        
        is_correct = prediction == ground_truth
        correct += is_correct
        
        symbol = "✓" if is_correct else "✗"
        print(f"{symbol} ({a:3d}, {b:3d}) | {ground_truth:3d} | {prediction:3d}")

accuracy = correct / test_samples * 100
print(f"\nAccuracy: {accuracy:.1f}% ({correct}/{test_samples})")

## Step 3: Define Training Function for Subtraction

In [None]:
def train_subtraction_model(model, config, num_epochs=5000, save_every=100, verbose=True):
    """
    Train a model on the subtraction task.
    
    Args:
        model: Transformer model to train
        config: Config object with fn_name='subtract'
        num_epochs: Number of training epochs
        save_every: How often to save metrics
        verbose: Whether to print progress
    
    Returns:
        Dictionary with train_losses, test_losses, and other metrics
    """
    # Set up training
    model.to(config.device)
    model.train()
    
    optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay, betas=(0.9, 0.98))
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))
    
    # Generate train/test split
    train_data, test_data = gen_train_test(config)
    
    # Tracking metrics
    train_losses = []
    test_losses = []
    test_accuracies = []
    epochs_to_90_percent = None
    
    if verbose:
        print(f"Training on {len(train_data)} examples, testing on {len(test_data)} examples")
        pbar = tqdm(range(num_epochs), desc="Training")
    else:
        pbar = range(num_epochs)
    
    for epoch in pbar:
        # Training step
        train_loss = full_loss(config, model, train_data)
        train_loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Evaluation
        with torch.no_grad():
            test_loss = full_loss(config, model, test_data)
            
            # Calculate test accuracy
            test_tensor = torch.tensor(test_data).to(config.device)
            logits = model(test_tensor)[:, -1]
            predictions = logits.argmax(dim=-1)
            labels = torch.tensor([config.fn(i, j) for i, j, _ in test_data]).to(config.device)
            test_accuracy = (predictions == labels).float().mean().item()
        
        train_losses.append(train_loss.item())
        test_losses.append(test_loss.item())
        test_accuracies.append(test_accuracy)
        
        # Check if we reached 90% accuracy
        if epochs_to_90_percent is None and test_accuracy >= 0.90:
            epochs_to_90_percent = epoch
            if verbose:
                print(f"\n✓ Reached 90% test accuracy at epoch {epoch}")
        
        # Update progress bar
        if verbose and epoch % save_every == 0:
            pbar.set_postfix({
                'train_loss': f'{train_loss.item():.4f}',
                'test_loss': f'{test_loss.item():.4f}',
                'test_acc': f'{test_accuracy:.3f}'
            })
    
    return {
        'train_losses': train_losses,
        'test_losses': test_losses,
        'test_accuracies': test_accuracies,
        'epochs_to_90_percent': epochs_to_90_percent,
        'final_test_accuracy': test_accuracies[-1],
        'model_state': model.state_dict()
    }

print("Training function defined!")

## Step 4: Transfer Learning Experiment (Grokked Addition → Subtraction)

In [None]:
# Create config for subtraction task
subtraction_config = replace(
    addition_config,
    fn_name='subtract',
    seed=42  # Different seed for different train/test split
)

print("Starting Transfer Learning Experiment (Addition → Subtraction)")
print("="*80)

# Create a model initialized with grokked addition weights
transfer_model = Transformer(subtraction_config, use_cache=False)
transfer_model.load_state_dict(grokked_addition_model.state_dict())
transfer_model.to(subtraction_config.device)

print("✓ Transfer model created with grokked addition weights")

# Train on subtraction
transfer_results = train_subtraction_model(
    transfer_model,
    subtraction_config,
    num_epochs=5000,
    save_every=100,
    verbose=True
)

print("\n" + "="*80)
print("Transfer Learning Results:")
print(f"  Final test accuracy: {transfer_results['final_test_accuracy']:.4f}")
print(f"  Epochs to 90% accuracy: {transfer_results['epochs_to_90_percent']}")

## Step 5: Baseline Experiment (Random Initialization)

In [None]:
print("Starting Baseline Experiment (Random Initialization)")
print("="*80)

# Create a fresh model with random initialization
baseline_model = Transformer(subtraction_config, use_cache=False)
baseline_model.to(subtraction_config.device)

print("✓ Baseline model created with random initialization")

# Train on subtraction
baseline_results = train_subtraction_model(
    baseline_model,
    subtraction_config,
    num_epochs=5000,
    save_every=100,
    verbose=True
)

print("\n" + "="*80)
print("Baseline Results:")
print(f"  Final test accuracy: {baseline_results['final_test_accuracy']:.4f}")
print(f"  Epochs to 90% accuracy: {baseline_results['epochs_to_90_percent']}")

## Step 6: Compare Results

In [None]:
# Summary comparison
print("\n" + "="*80)
print("EXPERIMENT SUMMARY")
print("="*80)

print("\n1. Transfer Learning (Grokked Addition → Subtraction):")
print(f"   - Final test accuracy: {transfer_results['final_test_accuracy']:.4f}")
print(f"   - Epochs to 90% accuracy: {transfer_results['epochs_to_90_percent']}")
print(f"   - Final train loss: {transfer_results['train_losses'][-1]:.6f}")
print(f"   - Final test loss: {transfer_results['test_losses'][-1]:.6f}")

print("\n2. Baseline (Random Initialization):")
print(f"   - Final test accuracy: {baseline_results['final_test_accuracy']:.4f}")
print(f"   - Epochs to 90% accuracy: {baseline_results['epochs_to_90_percent']}")
print(f"   - Final train loss: {baseline_results['train_losses'][-1]:.6f}")
print(f"   - Final test loss: {baseline_results['test_losses'][-1]:.6f}")

# Calculate speedup
if transfer_results['epochs_to_90_percent'] and baseline_results['epochs_to_90_percent']:
    speedup = baseline_results['epochs_to_90_percent'] / transfer_results['epochs_to_90_percent']
    improvement = baseline_results['epochs_to_90_percent'] - transfer_results['epochs_to_90_percent']
    print("\n3. Transfer Learning Benefits:")
    print(f"   - Speedup: {speedup:.2f}x faster to reach 90% accuracy")
    print(f"   - Saved {improvement} epochs ({improvement/baseline_results['epochs_to_90_percent']*100:.1f}% reduction)")
else:
    print("\n3. Note: One or both models did not reach 90% accuracy within the training budget.")

## Step 7: Visualize Results

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Plot 1: Training Loss
axes[0, 0].plot(transfer_results['train_losses'], label='Transfer Learning', alpha=0.7, linewidth=2)
axes[0, 0].plot(baseline_results['train_losses'], label='Random Init', alpha=0.7, linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Training Loss')
axes[0, 0].set_title('Training Loss Over Time')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Test Loss
axes[0, 1].plot(transfer_results['test_losses'], label='Transfer Learning', alpha=0.7, linewidth=2)
axes[0, 1].plot(baseline_results['test_losses'], label='Random Init', alpha=0.7, linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Test Loss')
axes[0, 1].set_title('Test Loss Over Time')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Test Accuracy
axes[0, 2].plot(transfer_results['test_accuracies'], label='Transfer Learning', alpha=0.7, linewidth=2)
axes[0, 2].plot(baseline_results['test_accuracies'], label='Random Init', alpha=0.7, linewidth=2)
axes[0, 2].axhline(y=0.9, color='r', linestyle='--', alpha=0.5, label='90% Target')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Test Accuracy')
axes[0, 2].set_title('Test Accuracy Over Time')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Plot 4: Log Training Loss
axes[1, 0].plot(np.log10(np.array(transfer_results['train_losses'])+1e-10), label='Transfer Learning', alpha=0.7, linewidth=2)
axes[1, 0].plot(np.log10(np.array(baseline_results['train_losses'])+1e-10), label='Random Init', alpha=0.7, linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Log10(Training Loss)')
axes[1, 0].set_title('Log Training Loss Over Time')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 5: Log Test Loss
axes[1, 1].plot(np.log10(np.array(transfer_results['test_losses'])+1e-10), label='Transfer Learning', alpha=0.7, linewidth=2)
axes[1, 1].plot(np.log10(np.array(baseline_results['test_losses'])+1e-10), label='Random Init', alpha=0.7, linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Log10(Test Loss)')
axes[1, 1].set_title('Log Test Loss Over Time')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Plot 6: First 1000 epochs (zoomed)
zoom_epochs = 1000
axes[1, 2].plot(transfer_results['test_accuracies'][:zoom_epochs], label='Transfer Learning', alpha=0.7, linewidth=2)
axes[1, 2].plot(baseline_results['test_accuracies'][:zoom_epochs], label='Random Init', alpha=0.7, linewidth=2)
axes[1, 2].axhline(y=0.9, color='r', linestyle='--', alpha=0.5, label='90% Target')

# Mark the 90% achievement points
if transfer_results['epochs_to_90_percent'] and transfer_results['epochs_to_90_percent'] < zoom_epochs:
    axes[1, 2].axvline(x=transfer_results['epochs_to_90_percent'], color='blue', linestyle=':', alpha=0.5)
    axes[1, 2].text(transfer_results['epochs_to_90_percent'], 0.85, f"{transfer_results['epochs_to_90_percent']}", 
                   rotation=90, verticalalignment='bottom', color='blue', fontsize=9)

if baseline_results['epochs_to_90_percent'] and baseline_results['epochs_to_90_percent'] < zoom_epochs:
    axes[1, 2].axvline(x=baseline_results['epochs_to_90_percent'], color='orange', linestyle=':', alpha=0.5)
    axes[1, 2].text(baseline_results['epochs_to_90_percent'], 0.85, f"{baseline_results['epochs_to_90_percent']}", 
                   rotation=90, verticalalignment='bottom', color='orange', fontsize=9)

axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Test Accuracy')
axes[1, 2].set_title(f'Test Accuracy (First {zoom_epochs} Epochs)')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('transfer_learning_results.png', dpi=150, bbox_inches='tight')
print("✓ Saved visualization to transfer_learning_results.png")
plt.show()

## Step 8: Save Results

In [None]:
# Save experiment results
results_dict = {
    'transfer_learning': transfer_results,
    'baseline': baseline_results,
    'experiment_config': {
        'num_epochs': 5000,
        'p': 113,
        'frac_train': 0.3,
        'lr': 1e-3,
        'weight_decay': 1.0,
        'source_checkpoint': checkpoint_path
    }
}

torch.save(results_dict, 'transfer_learning_experiment_results.pth')
print("✓ Saved results to transfer_learning_experiment_results.pth")

# Also save as numpy for easier analysis
np.savez('transfer_learning_experiment_results.npz',
         transfer_train_losses=np.array(transfer_results['train_losses']),
         transfer_test_losses=np.array(transfer_results['test_losses']),
         transfer_test_accuracies=np.array(transfer_results['test_accuracies']),
         baseline_train_losses=np.array(baseline_results['train_losses']),
         baseline_test_losses=np.array(baseline_results['test_losses']),
         baseline_test_accuracies=np.array(baseline_results['test_accuracies']))
print("✓ Saved results to transfer_learning_experiment_results.npz")

print("\n" + "="*80)
print("EXPERIMENT COMPLETE!")
print("="*80)

## Conclusion

This notebook investigated whether a grokked modular addition model can transfer to accelerate learning on modular subtraction.

**Key Findings:**
- The transfer learning approach (grokked addition → subtraction) was compared against random initialization
- Metrics tracked: epochs to 90% accuracy, test accuracy curves, and training loss curves
- Results are visualized and saved for further analysis

**Next Steps:**
- Try different hyperparameters (learning rate, weight decay)
- Test transfer to other operations (multiplication, division)
- Analyze which model components transfer most effectively
- Investigate if partially grokked models also show transfer benefits