# Vision Transformer Training on MNIST
## Sparse Autoencoder Interpretability - Phase 1

This notebook trains a 4-layer ViT on MNIST with activation capture for Phase 2 SAE training.

**Key outputs:**
- Trained model weights saved in `./checkpoints/best_model.pt`
- Layer 3 post-MLP activations ready for SAE extraction
- Memory profiling throughout training

## Setup & Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import time
from pathlib import Path
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA capability: {torch.cuda.get_device_capability(0)}")

## Import Custom Modules

In [None]:
# Import from local modules (ensure they're in same directory or in path)
from config import Config, ModelConfig, TrainingConfig
from model import ViT
from data import get_dataloaders, inspect_batch
from utils import (
    set_seed, gpu_memory_report, save_checkpoint, load_checkpoint,
    cleanup_old_checkpoints, get_device, count_parameters
)

print("✓ All modules imported successfully")

## Configuration

In [None]:
# Create default config
config = Config()

# Optional: Override settings here
# config.training.num_epochs = 20
# config.training.batch_size = 64

print("Model Config:")
print(f"  Image size: {config.model.image_size}x{config.model.image_size}")
print(f"  Patch size: {config.model.patch_size}x{config.model.patch_size}")
print(f"  Num patches: {config.model.num_patches}")
print(f"  Hidden dim: {config.model.hidden_dim}")
print(f"  Num layers: {config.model.num_layers}")
print(f"  Num heads: {config.model.num_heads}")

print("\nTraining Config:")
print(f"  Batch size: {config.training.batch_size}")
print(f"  Learning rate: {config.training.learning_rate}")
print(f"  Num epochs: {config.training.num_epochs}")
print(f"  Mixed precision: {config.training.use_mixed_precision}")
print(f"  Warmup epochs: {config.training.warmup_epochs}")

## Setup: Device, Seed, Model

In [None]:
# Reproducibility
set_seed(config.training.seed)
device = get_device()

# Create model
print("\nCreating ViT model...")
model = ViT(config.model)
model = model.to(device)

total_params, trainable_params = count_parameters(model)

gpu_memory_report("After model creation")

## Data Loading

In [None]:
# Get DataLoaders
print("Loading MNIST data...")
train_loader, val_loader = get_dataloaders(config)

# Inspect first batch
print("\nInspecting first batch...")
inspect_batch(train_loader, config)

gpu_memory_report("After data loading")

## Sanity Check: Overfit to 10 Samples

In [None]:
# Quick test: can model overfit to 10 samples?
print("Running sanity check: overfitting to 10 samples...\n")

model_test = ViT(config.model).to(device)
optimizer_test = optim.Adam(model_test.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Get 10 samples
images_test, labels_test = next(iter(train_loader))
images_test = images_test[:10].to(device)
labels_test = labels_test[:10].to(device)

# Train for 50 iterations
losses = []
for i in range(50):
    optimizer_test.zero_grad()
    logits, _ = model_test(images_test)
    loss = criterion(logits, labels_test)
    loss.backward()
    optimizer_test.step()
    losses.append(loss.item())
    
    if (i + 1) % 10 == 0:
        acc = (logits.argmax(dim=1) == labels_test).float().mean().item()
        print(f"Iter {i+1:3d}: loss={loss.item():.4f}, acc={acc:.4f}")

print(f"\n✓ Overfitting successful! Loss decreased from {losses[0]:.4f} to {losses[-1]:.4f}")
print("  Model architecture is working correctly.")

# Clean up
del model_test, optimizer_test
torch.cuda.empty_cache()

## Training Functions

In [None]:
def create_optimizer(model, config):
    """Create Adam optimizer with weight decay."""
    optimizer = optim.Adam(
        model.parameters(),
        lr=config.training.learning_rate,
        weight_decay=config.training.weight_decay,
    )
    return optimizer


def create_lr_scheduler(optimizer, config, total_steps):
    """Create cosine annealing scheduler with linear warmup."""
    warmup_steps = int(config.training.warmup_epochs * total_steps / config.training.num_epochs)
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=config.training.num_epochs,
        T_mult=1,
        eta_min=1e-6,
    )
    return scheduler, warmup_steps


def warmup_lr(optimizer, step, warmup_steps, base_lr):
    """Apply linear warmup to learning rate."""
    if step < warmup_steps:
        lr = base_lr * (step / warmup_steps)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def train_one_epoch(model, train_loader, optimizer, criterion, device, config, epoch, scaler=None):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        # Forward pass
        if config.training.use_mixed_precision and scaler is not None:
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits, _ = model(images)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
        else:
            logits, _ = model(images)
            loss = criterion(logits, labels)
            loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_norm_grad_clip)
        
        # Optimizer step
        if config.training.use_mixed_precision and scaler is not None:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()
        
        # Compute accuracy
        preds = logits.argmax(dim=1)
        accuracy = (preds == labels).float().mean()
        
        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{accuracy.item():.4f}'})
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches
    
    return avg_loss, avg_accuracy


@torch.no_grad()
def validate(model, val_loader, criterion, device, config, epoch):
    """Validate model on validation set."""
    model.eval()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False)
    
    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        if config.training.use_mixed_precision:
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits, _ = model(images)
                loss = criterion(logits, labels)
        else:
            logits, _ = model(images)
            loss = criterion(logits, labels)
        
        preds = logits.argmax(dim=1)
        accuracy = (preds == labels).float().mean()
        
        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{accuracy.item():.4f}'})
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches
    
    torch.cuda.empty_cache()
    
    return avg_loss, avg_accuracy

print("✓ Training functions defined")

## Setup Optimizer & Scheduler

In [None]:
# Fresh model for actual training
model = ViT(config.model).to(device)

optimizer = create_optimizer(model, config)
criterion = nn.CrossEntropyLoss()

total_steps = len(train_loader) * config.training.num_epochs
scheduler, warmup_steps = create_lr_scheduler(optimizer, config, total_steps)

# Mixed precision
scaler = None
if config.training.use_mixed_precision and device.type == 'cuda':
    scaler = torch.cuda.amp.GradScaler()
    print("✓ Mixed precision training enabled")

print(f"✓ Optimizer: Adam (lr={config.training.learning_rate}, weight_decay={config.training.weight_decay})")
print(f"✓ Scheduler: Cosine annealing (warmup_steps={warmup_steps})")

## Training Loop

In [None]:
print("\n" + "="*60)
print("TRAINING START")
print("="*60 + "\n")

gpu_memory_report("Initial")

best_val_accuracy = 0.0
training_history = {
    'train_loss': [],
    'train_accuracy': [],
    'val_loss': [],
    'val_accuracy': [],
}

training_start = time.time()

for epoch in range(config.training.num_epochs):
    epoch_start = time.time()
    
    # Training
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, criterion, device, config, epoch, scaler
    )
    
    # Update learning rate
    if epoch < config.training.warmup_epochs:
        warmup_lr(optimizer, epoch, config.training.warmup_epochs, config.training.learning_rate)
    scheduler.step()
    
    # Validation
    if (epoch + 1) % config.training.validate_frequency == 0:
        val_loss, val_acc = validate(model, val_loader, criterion, device, config, epoch)
    else:
        val_loss, val_acc = 0.0, 0.0
    
    # Track metrics
    training_history['train_loss'].append(train_loss)
    training_history['train_accuracy'].append(train_acc)
    training_history['val_loss'].append(val_loss)
    training_history['val_accuracy'].append(val_acc)
    
    epoch_time = time.time() - epoch_start
    
    # Print summary
    print(f"\nEpoch {epoch+1}/{config.training.num_epochs} ({epoch_time:.1f}s)")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    if val_acc > 0:
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")
        
        # Save best checkpoint
        if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc
            metrics = {
                'best_val_accuracy': best_val_accuracy,
                'history': training_history,
            }
            save_checkpoint(
                model, optimizer, epoch, metrics, config,
                config.training.checkpoint_dir, best=True
            )
            print(f"  ✓ New best checkpoint saved!")
    
    # Save regular checkpoint
    metrics = {
        'best_val_accuracy': best_val_accuracy,
        'history': training_history,
    }
    save_checkpoint(
        model, optimizer, epoch, metrics, config,
        config.training.checkpoint_dir, best=False
    )
    
    # Cleanup old checkpoints
    cleanup_old_checkpoints(
        config.training.checkpoint_dir,
        keep_best=True,
        keep_last_n=config.training.keep_last_n
    )
    
    gpu_memory_report(f"End of Epoch {epoch+1}")

total_time = time.time() - training_start
print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"Total time: {total_time/60:.1f} minutes")
print(f"Best validation accuracy: {best_val_accuracy:.4f}")

gpu_memory_report("Final")

## Plot Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss
axes[0].plot(training_history['train_loss'], label='Train Loss', marker='o', markersize=3)
axes[0].plot(training_history['val_loss'], label='Val Loss', marker='s', markersize=3)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(training_history['train_accuracy'], label='Train Acc', marker='o', markersize=3)
axes[1].plot(training_history['val_accuracy'], label='Val Acc', marker='s', markersize=3)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training & Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('./training_curves.png', dpi=100, bbox_inches='tight')
plt.show()

print("✓ Training curves saved to training_curves.png")

## Checkpoint Verification

In [None]:
# Verify checkpoints were saved
checkpoint_dir = Path(config.training.checkpoint_dir)
if checkpoint_dir.exists():
    checkpoints = list(checkpoint_dir.glob('*.pt'))
    print(f"Checkpoints saved: {len(checkpoints)}")
    for ckpt in sorted(checkpoints):
        size_mb = ckpt.stat().st_size / 1e6
        print(f"  {ckpt.name}: {size_mb:.1f} MB")
else:
    print("No checkpoints directory found")

## Load Best Model & Test Inference

In [None]:
# Load best checkpoint
best_checkpoint_path = checkpoint_dir / 'best_model.pt'

if best_checkpoint_path.exists():
    print(f"Loading best model from {best_checkpoint_path}...")
    model_loaded = ViT(config.model).to(device)
    metadata = load_checkpoint(best_checkpoint_path, model_loaded, device=device)
    
    # Test inference on a single batch
    model_loaded.eval()
    images_test, labels_test = next(iter(val_loader))
    images_test = images_test.to(device)
    labels_test = labels_test.to(device)
    
    with torch.no_grad():
        logits, _ = model_loaded(images_test)
        preds = logits.argmax(dim=1)
        accuracy = (preds == labels_test).float().mean()
    
    print(f"\n✓ Best model loaded and tested")
    print(f"  Test accuracy on batch: {accuracy:.4f}")
    print(f"  Best epoch: {metadata['epoch']}")
    print(f"  Best validation accuracy: {metadata['best_val_accuracy']:.4f}")
else:
    print("No best model checkpoint found")

## Next Steps: Activation Capture for SAE Training

The trained model is ready for Phase 2. To extract Layer 3 post-MLP activations:

```python
# Load model with activation capture
model_capture = ViT(config.model, capture_layer=3).to(device)
model_capture.load_state_dict(torch.load('checkpoints/best_model.pt')['model_state_dict'])
model_capture.eval()

# Extract activations from a batch
with torch.no_grad():
    logits, activations = model_capture(images)  # activations shape: (batch, 17, 192)
```

**Key info for ViT Prisma integration:**
- Checkpoint path: `./checkpoints/best_model.pt`
- Activation shape: `(batch_size, num_patches+1, hidden_dim)` = `(batch, 17, 192)`
- Layer 3 post-MLP output is ready to train SAE on
- Model achieved {best_val_accuracy:.1%} validation accuracy