# Modified Phase 2: Checkpoint Experiments

**Goal**: Run 3 experiments with checkpoint saving at steps [100, 1000, 2000]

**Experiments**:
1. transformer_deep_mnist
2. cnn_deep_mnist
3. mlp_narrow_mnist

**Runtime**: ~90 minutes on GPU A100

**Output**: 9 checkpoints (~75MB total) + experiment results

---

## Setup Instructions

1. **Runtime**: Set to GPU (Runtime → Change runtime type → GPU → A100)
2. **Execute cells in order**
3. **Results saved to**: `/content/checkpoints/` and `/content/results/`
4. **Download**: At end, zip and download results

## 1. Environment Setup

In [None]:
# Install dependencies
!pip install -q torch torchvision tqdm scikit-learn

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import json
import time
from pathlib import Path
from typing import Dict, List

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Checkpoint Plan

In [None]:
# Checkpoint configuration
CHECKPOINT_PLAN = {
    'transformer_deep_mnist': {
        'arch': 'transformer_deep',
        'dataset': 'mnist',
        'checkpoint_steps': [100, 1000, 2000],
        'total_steps': 2000
    },
    'cnn_deep_mnist': {
        'arch': 'cnn_deep',
        'dataset': 'mnist',
        'checkpoint_steps': [100, 1000, 2000],
        'total_steps': 2000
    },
    'mlp_narrow_mnist': {
        'arch': 'mlp_narrow',
        'dataset': 'mnist',
        'checkpoint_steps': [100, 1000, 2000],
        'total_steps': 2000
    }
}

# Create directories
Path('/content/checkpoints').mkdir(exist_ok=True)
Path('/content/results').mkdir(exist_ok=True)

print("Checkpoint plan loaded:")
print(f"Total experiments: {len(CHECKPOINT_PLAN)}")
print(f"Checkpoints per experiment: 3")
print(f"Total checkpoints: 9")

## 3. Architecture Definitions

In [None]:
# MLP Architecture
class SimpleMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: List[int], num_classes: int):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, num_classes))
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x.view(x.size(0), -1))

# CNN Architecture
class SimpleCNN(nn.Module):
    def __init__(self, in_channels: int, num_classes: int, conv_channels: List[int]):
        super().__init__()
        layers = []
        prev_channels = in_channels
        for channels in conv_channels:
            layers.extend([
                nn.Conv2d(prev_channels, channels, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2)
            ])
            prev_channels = channels
        self.conv_layers = nn.Sequential(*layers)
        # After 3 pooling layers: 28x28 -> 14x14 -> 7x7 -> 3x3
        self.flat_size = conv_channels[-1] * 3 * 3
        self.fc = nn.Linear(self.flat_size, num_classes)
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Transformer Architecture
class SimpleTransformer(nn.Module):
    def __init__(self, input_dim: int, d_model: int, nhead: int, 
                 num_layers: int, num_classes: int, seq_len: int = 16):
        super().__init__()
        self.seq_len = seq_len
        self.input_proj = nn.Linear(input_dim // seq_len, d_model)
        self.pos_encoder = nn.Parameter(torch.randn(seq_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model*4, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, self.seq_len, -1)
        x = self.input_proj(x)
        x = x + self.pos_encoder.unsqueeze(0)
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.fc(x)

def create_model(arch_name: str, input_dim: int, num_classes: int):
    """Create model based on architecture name."""
    if arch_name == 'mlp_narrow':
        return SimpleMLP(input_dim, [32, 32, 32, 32], num_classes)
    elif arch_name == 'cnn_deep':
        return SimpleCNN(1, num_classes, [32, 64, 128])  # MNIST has 1 channel
    elif arch_name == 'transformer_deep':
        return SimpleTransformer(input_dim, d_model=128, nhead=4, num_layers=4, num_classes=num_classes)
    else:
        raise ValueError(f"Unknown architecture: {arch_name}")

print("✓ Architectures defined")

## 4. Data Loading

In [None]:
def get_mnist_loaders(batch_size=64):
    """Load MNIST dataset."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = torchvision.datasets.MNIST(
        root='/content/data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        root='/content/data', train=False, download=True, transform=transform
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader, 28*28, 10  # input_dim, num_classes

print("✓ Data loading ready")

## 5. Training with Checkpoints

In [None]:
def train_with_checkpoints(model, train_loader, test_loader, device, 
                           num_steps, checkpoint_steps, checkpoint_dir, 
                           experiment_name):
    """Train model and save checkpoints at specified steps."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    checkpoint_steps_set = set(checkpoint_steps)
    measurements = []
    step = 0
    
    print(f"\nTraining {experiment_name}...")
    print(f"Target steps: {num_steps}")
    print(f"Checkpoint steps: {checkpoint_steps}")
    
    pbar = tqdm(total=num_steps, desc=experiment_name)
    
    while step < num_steps:
        for inputs, labels in train_loader:
            if step >= num_steps:
                break
            
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Training step
            model.train()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient norm
            grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() 
                          if p.grad is not None) ** 0.5
            
            optimizer.step()
            
            # Save checkpoint
            if step in checkpoint_steps_set:
                checkpoint_path = checkpoint_dir / f'checkpoint_step_{step:05d}.pt'
                torch.save({
                    'step': step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss.item(),
                    'experiment_name': experiment_name
                }, checkpoint_path)
                print(f"\n✓ Saved checkpoint: {checkpoint_path.name}")
            
            # Record measurement every 5 steps
            if step % 5 == 0:
                measurements.append({
                    'step': step,
                    'loss': loss.item(),
                    'grad_norm': grad_norm
                })
            
            step += 1
            pbar.update(1)
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    pbar.close()
    
    # Final evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    final_accuracy = correct / total
    print(f"\n✓ Final accuracy: {final_accuracy:.4f}")
    
    return {
        'experiment_name': experiment_name,
        'measurements': measurements,
        'final_accuracy': final_accuracy,
        'total_steps': num_steps,
        'checkpoint_steps': checkpoint_steps
    }

print("✓ Training function ready")

## 6. Run All Experiments

In [None]:
# Load data once
print("Loading MNIST dataset...")
train_loader, test_loader, input_dim, num_classes = get_mnist_loaders(batch_size=64)
print(f"✓ Data loaded: {len(train_loader.dataset)} train, {len(test_loader.dataset)} test\n")

# Run all experiments
all_results = []
start_time = time.time()

for exp_name, config in CHECKPOINT_PLAN.items():
    print("\n" + "="*70)
    print(f"EXPERIMENT: {exp_name}")
    print("="*70)
    
    # Create checkpoint directory
    checkpoint_dir = Path(f'/content/checkpoints/{exp_name}')
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Create model
    model = create_model(config['arch'], input_dim, num_classes)
    print(f"Model: {config['arch']}")
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Train
    exp_start = time.time()
    result = train_with_checkpoints(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        device=device,
        num_steps=config['total_steps'],
        checkpoint_steps=config['checkpoint_steps'],
        checkpoint_dir=checkpoint_dir,
        experiment_name=exp_name
    )
    exp_elapsed = time.time() - exp_start
    
    result['elapsed_time'] = exp_elapsed
    all_results.append(result)
    
    # Save result
    result_file = Path(f'/content/results/{exp_name}_result.json')
    with open(result_file, 'w') as f:
        json.dump(result, f, indent=2)
    
    print(f"\n✓ Completed in {exp_elapsed/60:.1f} minutes")
    print(f"✓ Result saved: {result_file}")

total_time = time.time() - start_time

# Summary
print("\n" + "="*70)
print("ALL EXPERIMENTS COMPLETE")
print("="*70)
print(f"Total time: {total_time/60:.1f} minutes ({total_time/3600:.1f} hours)")
print(f"Experiments: {len(all_results)}")
print(f"Checkpoints saved: {len(all_results) * 3}")
print("\nAccuracies:")
for result in all_results:
    print(f"  {result['experiment_name']}: {result['final_accuracy']:.4f}")

# Save summary
summary = {
    'total_time': total_time,
    'num_experiments': len(all_results),
    'experiments': all_results
}
with open('/content/results/summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n✓ Summary saved: /content/results/summary.json")

## 7. Verify Checkpoints

In [None]:
import os

print("\nVERIFYING CHECKPOINTS:")
print("="*70)

for exp_name, config in CHECKPOINT_PLAN.items():
    checkpoint_dir = Path(f'/content/checkpoints/{exp_name}')
    print(f"\n{exp_name}:")
    total_size = 0
    for step in config['checkpoint_steps']:
        checkpoint_path = checkpoint_dir / f'checkpoint_step_{step:05d}.pt'
        if checkpoint_path.exists():
            size_mb = checkpoint_path.stat().st_size / 1e6
            total_size += size_mb
            print(f"  ✓ Step {step}: {size_mb:.1f}MB")
        else:
            print(f"  ✗ Step {step}: MISSING")
    print(f"  Total: {total_size:.1f}MB")

# Total checkpoint size
all_checkpoints_size = sum(
    f.stat().st_size for f in Path('/content/checkpoints').rglob('*.pt')
) / 1e6
print(f"\nTotal checkpoints size: {all_checkpoints_size:.1f}MB")
print(f"Checkpoint count: {len(list(Path('/content/checkpoints').rglob('*.pt')))}")

## 8. Package Results for Download

In [None]:
# Create zip file
!zip -r /content/checkpoint_experiments_results.zip /content/checkpoints /content/results

print("\n" + "="*70)
print("READY FOR DOWNLOAD")
print("="*70)
print("\nFile: checkpoint_experiments_results.zip")
print("\nContents:")
print("  - 9 checkpoint files (.pt)")
print("  - 3 experiment result files (.json)")
print("  - 1 summary file (summary.json)")
print("\nDownload instructions:")
print("  1. Click the folder icon on the left")
print("  2. Right-click 'checkpoint_experiments_results.zip'")
print("  3. Select 'Download'")
print("\nOr run the next cell to download directly:")

In [None]:
from google.colab import files
files.download('/content/checkpoint_experiments_results.zip')

## Next Steps

After downloading the results:

1. **Extract the zip file** to your local machine
2. **Upload to Claude Code environment**:
   - Place checkpoints in: `experiments/checkpoints/`
   - Place results in: `experiments/new/results/phase1_checkpoints/`
3. **Run Modified Phase 2 Analysis**:
   ```bash
   cd experiments/mechanistic_interpretability
   python3 modified_phase2_analysis.py
   ```

The analysis will:
- Load the 9 checkpoints
- Extract features (attention/filters/activations)
- Compute similarity matrices
- Test hypothesis: Are early features qualitatively different from late?
- Generate visualizations and final report