# Continue Training to Step 2000

**Purpose**: Load step 1000 checkpoints and continue training to step 2000

**Runtime**: ~15 minutes on GPU A100

**Prerequisites**: 
- Upload the existing `checkpoint_experiments_results.zip` to Colab
- Extract it to get step 1000 checkpoints

**Output**: 3 additional checkpoints (step 2000 for each experiment)

---

## Setup Instructions

1. **Runtime**: Set to GPU (Runtime → Change runtime type → GPU → A100)
2. **Upload**: Upload `checkpoint_experiments_results.zip` to /content/
3. **Execute cells in order**
4. **Download**: New checkpoints will be zipped at the end

## 1. Environment Setup

In [None]:
# Install dependencies
!pip install -q torch torchvision tqdm scikit-learn

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import json
import time
from pathlib import Path
from typing import Dict, List

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Extract Existing Checkpoints

In [None]:
# Extract the uploaded zip file
!unzip -q /content/checkpoint_experiments_results.zip -d /content/

print("✓ Checkpoints extracted")
print("\nExisting checkpoints:")
!ls -lh /content/checkpoints/transformer_deep_mnist/
!ls -lh /content/checkpoints/cnn_deep_mnist/
!ls -lh /content/checkpoints/mlp_narrow_mnist/

## 3. Architecture Definitions

(Same architectures as before)

In [None]:
# MLP Architecture
class SimpleMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: List[int], num_classes: int):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, num_classes))
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x.view(x.size(0), -1))

# CNN Architecture
class SimpleCNN(nn.Module):
    def __init__(self, in_channels: int, num_classes: int, conv_channels: List[int]):
        super().__init__()
        layers = []
        prev_channels = in_channels
        for channels in conv_channels:
            layers.extend([
                nn.Conv2d(prev_channels, channels, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2)
            ])
            prev_channels = channels
        self.conv_layers = nn.Sequential(*layers)
        self.flat_size = conv_channels[-1] * 3 * 3
        self.fc = nn.Linear(self.flat_size, num_classes)
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Transformer Architecture
class SimpleTransformer(nn.Module):
    def __init__(self, input_dim: int, d_model: int, nhead: int, 
                 num_layers: int, num_classes: int, seq_len: int = 16):
        super().__init__()
        self.seq_len = seq_len
        self.input_proj = nn.Linear(input_dim // seq_len, d_model)
        self.pos_encoder = nn.Parameter(torch.randn(seq_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model*4, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, self.seq_len, -1)
        x = self.input_proj(x)
        x = x + self.pos_encoder.unsqueeze(0)
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.fc(x)

def create_model(arch_name: str, input_dim: int, num_classes: int):
    """Create model based on architecture name."""
    if arch_name == 'mlp_narrow':
        return SimpleMLP(input_dim, [32, 32, 32, 32], num_classes)
    elif arch_name == 'cnn_deep':
        return SimpleCNN(1, num_classes, [32, 64, 128])
    elif arch_name == 'transformer_deep':
        return SimpleTransformer(input_dim, d_model=128, nhead=4, num_layers=4, num_classes=num_classes)
    else:
        raise ValueError(f"Unknown architecture: {arch_name}")

print("✓ Architectures defined")

## 4. Load MNIST Data

In [None]:
def get_mnist_loaders(batch_size=64):
    """Load MNIST dataset."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = torchvision.datasets.MNIST(
        root='/content/data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        root='/content/data', train=False, download=True, transform=transform
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader, 28*28, 10

print("Loading MNIST dataset...")
train_loader, test_loader, input_dim, num_classes = get_mnist_loaders(batch_size=64)
print(f"✓ Data loaded: {len(train_loader.dataset)} train, {len(test_loader.dataset)} test")

## 5. Continue Training Function

**Key Fix**: Changed loop condition to `step <= num_steps` to reach step 2000

In [None]:
def continue_training_to_2000(checkpoint_path, model, train_loader, test_loader, 
                               device, experiment_name):
    """Load step 1000 checkpoint and continue training to step 2000."""
    
    # Load checkpoint
    print(f"\nLoading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model = model.to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    start_step = checkpoint['step']
    print(f"✓ Loaded checkpoint from step {start_step}")
    print(f"✓ Starting loss: {checkpoint['loss']:.4f}")
    
    # Continue training from 1000 to 2000
    target_step = 2000
    step = start_step
    measurements = []
    
    print(f"\nContinuing training from step {start_step} to {target_step}...")
    pbar = tqdm(total=target_step - start_step, desc=experiment_name)
    
    # FIXED: Changed condition to <= to reach step 2000
    while step <= target_step:
        for inputs, labels in train_loader:
            if step > target_step:  # FIXED: Changed from >= to >
                break
            
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Training step
            model.train()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient norm
            grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() 
                          if p.grad is not None) ** 0.5
            
            optimizer.step()
            
            # Record measurement every 5 steps
            if step % 5 == 0:
                measurements.append({
                    'step': step,
                    'loss': loss.item(),
                    'grad_norm': grad_norm
                })
            
            step += 1
            pbar.update(1)
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    pbar.close()
    
    # Save step 2000 checkpoint
    checkpoint_dir = Path(checkpoint_path).parent
    step_2000_path = checkpoint_dir / 'checkpoint_step_02000.pt'
    
    torch.save({
        'step': 2000,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item(),
        'experiment_name': experiment_name
    }, step_2000_path)
    
    print(f"\n✓ Saved checkpoint: {step_2000_path.name}")
    
    # Final evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    final_accuracy = correct / total
    print(f"✓ Final accuracy at step 2000: {final_accuracy:.4f}")
    
    return {
        'experiment_name': experiment_name,
        'measurements': measurements,
        'final_accuracy': final_accuracy,
        'checkpoint_path': str(step_2000_path)
    }

print("✓ Continue training function ready")

## 6. Run All Continuations

In [None]:
# Experiment configurations
EXPERIMENTS = {
    'transformer_deep_mnist': 'transformer_deep',
    'cnn_deep_mnist': 'cnn_deep',
    'mlp_narrow_mnist': 'mlp_narrow'
}

all_results = []
start_time = time.time()

for exp_name, arch_name in EXPERIMENTS.items():
    print("\n" + "="*70)
    print(f"CONTINUING: {exp_name}")
    print("="*70)
    
    # Create model
    model = create_model(arch_name, input_dim, num_classes)
    print(f"Model: {arch_name}")
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Path to step 1000 checkpoint
    checkpoint_path = Path(f'/content/checkpoints/{exp_name}/checkpoint_step_01000.pt')
    
    # Continue training
    exp_start = time.time()
    result = continue_training_to_2000(
        checkpoint_path=checkpoint_path,
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        device=device,
        experiment_name=exp_name
    )
    exp_elapsed = time.time() - exp_start
    
    result['elapsed_time'] = exp_elapsed
    all_results.append(result)
    
    print(f"\n✓ Completed in {exp_elapsed/60:.1f} minutes")

total_time = time.time() - start_time

# Summary
print("\n" + "="*70)
print("ALL CONTINUATIONS COMPLETE")
print("="*70)
print(f"Total time: {total_time/60:.1f} minutes")
print(f"New checkpoints created: {len(all_results)}")
print("\nFinal accuracies at step 2000:")
for result in all_results:
    print(f"  {result['experiment_name']}: {result['final_accuracy']:.4f}")

# Save summary
summary = {
    'total_time': total_time,
    'num_experiments': len(all_results),
    'experiments': all_results
}
with open('/content/results/continuation_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n✓ Summary saved: /content/results/continuation_summary.json")

## 7. Verify New Checkpoints

In [None]:
import os

print("\nVERIFYING NEW CHECKPOINTS:")
print("="*70)

for exp_name in EXPERIMENTS.keys():
    checkpoint_dir = Path(f'/content/checkpoints/{exp_name}')
    print(f"\n{exp_name}:")
    for checkpoint_file in sorted(checkpoint_dir.glob('*.pt')):
        size_mb = checkpoint_file.stat().st_size / 1e6
        print(f"  ✓ {checkpoint_file.name}: {size_mb:.1f}MB")

# Total checkpoint count
all_checkpoints = list(Path('/content/checkpoints').rglob('*.pt'))
print(f"\nTotal checkpoints: {len(all_checkpoints)} (expected: 9)")

if len(all_checkpoints) == 9:
    print("\n✅ SUCCESS: All 9 checkpoints are now present!")
else:
    print(f"\n⚠️ WARNING: Expected 9 checkpoints, found {len(all_checkpoints)}")

## 8. Package Complete Results for Download

In [None]:
# Create new zip with ALL checkpoints (including the 3 new step 2000 ones)
!zip -r /content/checkpoint_experiments_COMPLETE.zip /content/checkpoints /content/results

print("\n" + "="*70)
print("READY FOR DOWNLOAD")
print("="*70)
print("\nFile: checkpoint_experiments_COMPLETE.zip")
print("\nContents:")
print("  - 9 checkpoint files (.pt) - NOW COMPLETE!")
print("    • 3 experiments × 3 checkpoints each")
print("    • Steps: 100, 1000, 2000")
print("  - 4 result files (.json)")
print("\nDownload instructions:")
print("  1. Click the folder icon on the left")
print("  2. Right-click 'checkpoint_experiments_COMPLETE.zip'")
print("  3. Select 'Download'")
print("\nOr run the next cell to download directly:")

In [None]:
from google.colab import files
files.download('/content/checkpoint_experiments_COMPLETE.zip')

## Next Steps

After downloading the complete results:

1. **Upload the complete zip file** to the Claude Code environment
2. **Extract to the experiments directory**
3. **Run Modified Phase 2 Analysis**:
   ```bash
   cd experiments/mechanistic_interpretability
   python3 modified_phase2_analysis.py
   ```

The analysis will now have all 9 checkpoints and can:
- Compare early (step 100) vs mid (step 1000) vs late (step 2000) features
- Test hypothesis: Are early features qualitatively different from late?
- Generate complete visualizations and final report

**Total runtime for this continuation**: ~15 minutes (vs 90 minutes for full re-run)