# Lab 2.1.6: Checkpointing System - Never Lose Your Progress

**Module:** 2.1 - Deep Learning with PyTorch  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê (Intermediate)

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Save and load complete training state (model + optimizer + scheduler)
- [ ] Implement best model tracking
- [ ] Create early stopping logic
- [ ] Test interrupt and resume functionality
- [ ] Handle optimizer state correctly

---

## Prerequisites

- Completed: Tasks 6.1-6.5
- Knowledge of: Training loops, file I/O

---

## Real-World Context

Training large models can take days or weeks. You NEED checkpointing to:
- **Survive crashes**: Power outage? Network issue? No problem.
- **Resume training**: Continue from where you left off
- **Track best models**: Save only the models that improved
- **Early stopping**: Stop when validation stops improving
- **Experiment management**: Compare different runs

Every major ML framework (PyTorch Lightning, Hugging Face, etc.) includes robust checkpointing!

---

## ELI5: What is Checkpointing?

> **Imagine you're playing a video game...** üéÆ
>
> - **Save game**: Store your progress (level, items, score)
> - **Load game**: Continue from your last save
> - **Autosave**: Game saves automatically at checkpoints
> - **Best save**: Keep a save from your highest score
>
> **Neural network checkpointing is exactly the same!**
> - **Model weights** = your character's equipment
> - **Optimizer state** = your skill tree progress
> - **Epoch number** = current level
> - **Best accuracy** = high score

---

## Part 1: Environment Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

import os
import json
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional
from dataclasses import dataclass, asdict

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Simple model for demonstration
class SimpleNet(nn.Module):
    """
    Simple CNN for CIFAR-10 classification.
    
    Architecture:
        - 2 convolutional blocks with max pooling
        - 2 fully connected layers
        - ~350K parameters
    
    Args:
        num_classes: Number of output classes (default: 10 for CIFAR-10)
    """
    
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

print(f"Model parameters: {sum(p.numel() for p in SimpleNet().parameters()):,}")

---

## Part 2: Basic Checkpointing

The simplest form of checkpointing: save and load model weights.

In [None]:
# Create checkpoint directory
checkpoint_dir = Path('./checkpoints')
checkpoint_dir.mkdir(exist_ok=True)

# Create and initialize model
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("=== Before Training ===")
print(f"First conv weight sum: {model.features[0].weight.sum().item():.4f}")

In [None]:
# Train for a few batches
criterion = nn.CrossEntropyLoss()
model.train()

for i, (inputs, labels) in enumerate(trainloader):
    if i >= 10:
        break
    inputs, labels = inputs.to(device), labels.to(device)
    
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

print("=== After Training ===")
print(f"First conv weight sum: {model.features[0].weight.sum().item():.4f}")

In [None]:
# Basic save: Just model weights
# This is the SIMPLEST form - only saves model parameters

# Save
torch.save(model.state_dict(), checkpoint_dir / 'model_weights_only.pth')
print("Saved model weights")

# Load into a new model
model_loaded = SimpleNet().to(device)
model_loaded.load_state_dict(torch.load(checkpoint_dir / 'model_weights_only.pth'))

print(f"\nLoaded model weight sum: {model_loaded.features[0].weight.sum().item():.4f}")
print(f"Original model weight sum: {model.features[0].weight.sum().item():.4f}")
print(f"Weights match: {torch.allclose(model.features[0].weight, model_loaded.features[0].weight)}")

### Why Just Weights Isn't Enough

Saving only weights means you lose:
- **Optimizer state**: Momentum, adaptive learning rates (Adam's m, v)
- **Training progress**: Epoch number, step count
- **Scheduler state**: Learning rate schedule position
- **Best metrics**: Track of best validation score

---

## Part 3: Complete Training State

Save EVERYTHING needed to resume training exactly.

In [None]:
def save_checkpoint(
    model: nn.Module,
    optimizer: optim.Optimizer,
    epoch: int,
    loss: float,
    accuracy: float,
    path: Path,
    scheduler: Optional[Any] = None,
    extra_info: Optional[Dict] = None,
):
    """
    Save complete training checkpoint.
    
    This saves everything needed to resume training exactly
    where we left off.
    """
    checkpoint = {
        # Model state
        'model_state_dict': model.state_dict(),
        
        # Optimizer state (CRUCIAL for Adam, SGD with momentum, etc.)
        'optimizer_state_dict': optimizer.state_dict(),
        
        # Training progress
        'epoch': epoch,
        'loss': loss,
        'accuracy': accuracy,
        
        # Metadata
        'timestamp': datetime.now().isoformat(),
        'pytorch_version': torch.__version__,
    }
    
    # Optional: scheduler state
    if scheduler is not None:
        checkpoint['scheduler_state_dict'] = scheduler.state_dict()
    
    # Optional: extra info
    if extra_info:
        checkpoint['extra_info'] = extra_info
    
    torch.save(checkpoint, path)
    print(f"Checkpoint saved: {path}")


def load_checkpoint(
    path: Path,
    model: nn.Module,
    optimizer: Optional[optim.Optimizer] = None,
    scheduler: Optional[Any] = None,
    device: torch.device = None,
) -> Dict[str, Any]:
    """
    Load training checkpoint and restore state.
    
    Returns:
        Dictionary with training info (epoch, loss, accuracy, etc.)
    """
    # Load checkpoint
    checkpoint = torch.load(path, map_location=device)
    
    # Restore model
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Restore optimizer (if provided)
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Restore scheduler (if provided)
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    print(f"Checkpoint loaded: {path}")
    print(f"  Epoch: {checkpoint['epoch']}")
    print(f"  Loss: {checkpoint['loss']:.4f}")
    print(f"  Accuracy: {checkpoint['accuracy']:.2f}%")
    
    return checkpoint

In [None]:
# Test complete checkpointing
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Simulate some training
for epoch in range(3):
    for i, (inputs, labels) in enumerate(trainloader):
        if i >= 5:
            break
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = criterion(model(inputs), labels)
        loss.backward()
        optimizer.step()
    scheduler.step()

# Save checkpoint
save_checkpoint(
    model=model,
    optimizer=optimizer,
    epoch=3,
    loss=loss.item(),
    accuracy=75.5,  # Simulated
    path=checkpoint_dir / 'complete_checkpoint.pth',
    scheduler=scheduler,
    extra_info={'batch_size': 64, 'learning_rate': 0.001}
)

In [None]:
# Load checkpoint into fresh model
new_model = SimpleNet().to(device)
new_optimizer = optim.Adam(new_model.parameters(), lr=0.001)
new_scheduler = optim.lr_scheduler.StepLR(new_optimizer, step_size=5, gamma=0.1)

checkpoint_info = load_checkpoint(
    checkpoint_dir / 'complete_checkpoint.pth',
    new_model,
    new_optimizer,
    new_scheduler,
    device
)

# Verify optimizer state was restored
print(f"\nOptimizer state keys: {list(new_optimizer.state_dict()['state'].keys())}")
print(f"Scheduler last_epoch: {new_scheduler.last_epoch}")

---

## Part 4: Best Model Tracking & Early Stopping

Save only when performance improves, stop when it doesn't.

In [None]:
@dataclass
class CheckpointManager:
    """
    Manages checkpoints with best model tracking and early stopping.
    
    Features:
    - Save best model based on validation metric
    - Keep last N checkpoints
    - Early stopping when metric doesn't improve
    """
    
    checkpoint_dir: Path
    best_metric: float = float('inf')  # For loss (lower is better)
    best_epoch: int = 0
    patience: int = 5
    patience_counter: int = 0
    mode: str = 'min'  # 'min' for loss, 'max' for accuracy
    max_checkpoints: int = 3
    
    def __post_init__(self):
        self.checkpoint_dir = Path(self.checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.checkpoint_history = []
        
        if self.mode == 'max':
            self.best_metric = float('-inf')
    
    def _is_better(self, metric: float) -> bool:
        """Check if metric is better than best."""
        if self.mode == 'min':
            return metric < self.best_metric
        else:
            return metric > self.best_metric
    
    def save_checkpoint(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        epoch: int,
        metric: float,
        scheduler: Optional[Any] = None,
    ) -> bool:
        """
        Save checkpoint if metric improved.
        
        Returns:
            True if this was the best model
        """
        is_best = self._is_better(metric)
        
        # Always save latest
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'metric': metric,
            'best_metric': self.best_metric if is_best else self.best_metric,
        }
        if scheduler:
            checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        
        # Save latest checkpoint
        latest_path = self.checkpoint_dir / 'checkpoint_latest.pth'
        torch.save(checkpoint, latest_path)
        
        # Save epoch checkpoint
        epoch_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pth'
        torch.save(checkpoint, epoch_path)
        self.checkpoint_history.append(epoch_path)
        
        # Remove old checkpoints
        while len(self.checkpoint_history) > self.max_checkpoints:
            old_path = self.checkpoint_history.pop(0)
            if old_path.exists():
                old_path.unlink()
        
        # Update best if improved
        if is_best:
            self.best_metric = metric
            self.best_epoch = epoch
            self.patience_counter = 0
            
            best_path = self.checkpoint_dir / 'checkpoint_best.pth'
            torch.save(checkpoint, best_path)
            print(f"New best! Epoch {epoch}: {metric:.4f}")
        else:
            self.patience_counter += 1
        
        return is_best
    
    def should_stop(self) -> bool:
        """Check if training should stop (early stopping)."""
        return self.patience_counter >= self.patience
    
    def load_best(self, model: nn.Module, optimizer: Optional[optim.Optimizer] = None):
        """Load the best checkpoint."""
        best_path = self.checkpoint_dir / 'checkpoint_best.pth'
        if not best_path.exists():
            raise FileNotFoundError(f"Best checkpoint not found: {best_path}")
        
        checkpoint = torch.load(best_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        print(f"Loaded best checkpoint from epoch {checkpoint['epoch']}")
        return checkpoint

In [None]:
# Training with checkpoint manager
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(dataloader), 100 * correct / total


# Setup
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)
criterion = nn.CrossEntropyLoss()

# Checkpoint manager tracking validation loss
ckpt_manager = CheckpointManager(
    checkpoint_dir=checkpoint_dir / 'experiment_1',
    patience=5,
    mode='min',  # Lower loss is better
    max_checkpoints=3,
)

# Training loop
print("Starting training with checkpointing...\n")
NUM_EPOCHS = 15

for epoch in range(NUM_EPOCHS):
    # Train
    model.train()
    train_loss = 0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = criterion(model(inputs), labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(trainloader)
    
    # Evaluate
    val_loss, val_acc = evaluate(model, testloader, criterion)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save checkpoint
    is_best = ckpt_manager.save_checkpoint(
        model, optimizer, epoch, val_loss, scheduler
    )
    
    status = " (best)" if is_best else ""
    print(f"Epoch {epoch+1:2d} | Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%{status}")
    
    # Early stopping
    if ckpt_manager.should_stop():
        print(f"\nEarly stopping at epoch {epoch+1}!")
        break

print(f"\nBest model: Epoch {ckpt_manager.best_epoch} with loss {ckpt_manager.best_metric:.4f}")

---

## Part 5: Resuming Training

Test that we can resume training from a checkpoint.

In [None]:
# Simulate resuming from checkpoint
print("=== Resuming Training ===")

# Create fresh model and optimizer
resumed_model = SimpleNet().to(device)
resumed_optimizer = optim.Adam(resumed_model.parameters(), lr=0.001)

# Load latest checkpoint
latest_path = checkpoint_dir / 'experiment_1' / 'checkpoint_latest.pth'
checkpoint = torch.load(latest_path)

resumed_model.load_state_dict(checkpoint['model_state_dict'])
resumed_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1

print(f"Resumed from epoch {checkpoint['epoch']}")
print(f"Continuing from epoch {start_epoch}")

# Verify optimizer state
print(f"\nOptimizer state restored: {len(resumed_optimizer.state) > 0}")

# Continue training for a few more epochs
for epoch in range(start_epoch, start_epoch + 3):
    resumed_model.train()
    for i, (inputs, labels) in enumerate(trainloader):
        if i >= 10:
            break
        inputs, labels = inputs.to(device), labels.to(device)
        resumed_optimizer.zero_grad()
        loss = criterion(resumed_model(inputs), labels)
        loss.backward()
        resumed_optimizer.step()
    
    val_loss, val_acc = evaluate(resumed_model, testloader, criterion)
    print(f"Epoch {epoch+1} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

print("\nTraining resumed successfully!")

---

## Part 6: Production Checkpoint Manager

A complete, production-ready checkpoint manager.

In [None]:
class ProductionCheckpointManager:
    """
    Production-ready checkpoint manager.
    
    Features:
    - Best model tracking
    - Early stopping
    - Checkpoint rotation
    - Training log persistence
    - Atomic saves (prevents corruption)
    """
    
    def __init__(
        self,
        checkpoint_dir: str,
        model: nn.Module,
        optimizer: optim.Optimizer,
        scheduler: Optional[Any] = None,
        mode: str = 'min',
        patience: int = 10,
        max_checkpoints: int = 5,
    ):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        
        self.mode = mode
        self.patience = patience
        self.max_checkpoints = max_checkpoints
        
        # State
        self.best_metric = float('inf') if mode == 'min' else float('-inf')
        self.best_epoch = 0
        self.patience_counter = 0
        self.history = []
        self.checkpoint_files = []
        
        # Load existing history if resuming
        self._load_history()
    
    def _load_history(self):
        """Load training history if exists."""
        history_path = self.checkpoint_dir / 'training_history.json'
        if history_path.exists():
            with open(history_path, 'r') as f:
                self.history = json.load(f)
            print(f"Loaded {len(self.history)} epochs of history")
    
    def _save_history(self):
        """Save training history."""
        history_path = self.checkpoint_dir / 'training_history.json'
        with open(history_path, 'w') as f:
            json.dump(self.history, f, indent=2)
    
    def _is_better(self, metric: float) -> bool:
        if self.mode == 'min':
            return metric < self.best_metric
        return metric > self.best_metric
    
    def _save_checkpoint(self, path: Path, checkpoint: dict):
        """Atomic save to prevent corruption."""
        # Save to temp file first
        temp_path = path.with_suffix('.tmp')
        torch.save(checkpoint, temp_path)
        # Rename (atomic on most filesystems)
        temp_path.rename(path)
    
    def step(
        self,
        epoch: int,
        train_loss: float,
        val_loss: float,
        val_metric: float,
        extra_metrics: Optional[Dict] = None,
    ) -> Dict[str, Any]:
        """
        Call at end of each epoch.
        
        Returns:
            Dict with 'is_best', 'should_stop' flags
        """
        is_best = self._is_better(val_metric)
        
        # Build checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_metric': val_metric,
            'best_metric': self.best_metric,
            'best_epoch': self.best_epoch,
        }
        if self.scheduler:
            checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
        
        # Save latest
        self._save_checkpoint(
            self.checkpoint_dir / 'latest.pth', checkpoint
        )
        
        # Save periodic checkpoint
        epoch_path = self.checkpoint_dir / f'epoch_{epoch:04d}.pth'
        self._save_checkpoint(epoch_path, checkpoint)
        self.checkpoint_files.append(epoch_path)
        
        # Remove old checkpoints
        while len(self.checkpoint_files) > self.max_checkpoints:
            old_path = self.checkpoint_files.pop(0)
            if old_path.exists():
                old_path.unlink()
        
        # Handle best model
        if is_best:
            self.best_metric = val_metric
            self.best_epoch = epoch
            self.patience_counter = 0
            self._save_checkpoint(
                self.checkpoint_dir / 'best.pth', checkpoint
            )
        else:
            self.patience_counter += 1
        
        # Update history
        self.history.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_metric': val_metric,
            'is_best': is_best,
            **(extra_metrics or {})
        })
        self._save_history()
        
        return {
            'is_best': is_best,
            'should_stop': self.patience_counter >= self.patience,
            'best_metric': self.best_metric,
            'best_epoch': self.best_epoch,
            'patience_counter': self.patience_counter,
        }
    
    def load_best(self):
        """Load the best model."""
        path = self.checkpoint_dir / 'best.pth'
        return self._load_checkpoint(path)
    
    def load_latest(self):
        """Load the latest checkpoint for resuming."""
        path = self.checkpoint_dir / 'latest.pth'
        return self._load_checkpoint(path)
    
    def _load_checkpoint(self, path: Path) -> dict:
        """Load checkpoint and restore state."""
        checkpoint = torch.load(path)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if self.scheduler and 'scheduler_state_dict' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        self.best_metric = checkpoint.get('best_metric', self.best_metric)
        self.best_epoch = checkpoint.get('best_epoch', self.best_epoch)
        
        return checkpoint

In [None]:
# Example usage
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

ckpt = ProductionCheckpointManager(
    checkpoint_dir='./checkpoints/production_run',
    model=model,
    optimizer=optimizer,
    mode='max',  # Track accuracy (higher is better)
    patience=5,
)

print("Production checkpoint manager ready!")

---

## Checkpoint

You've learned:
- ‚úÖ Basic model weight saving/loading
- ‚úÖ Complete training state checkpointing
- ‚úÖ Best model tracking
- ‚úÖ Early stopping implementation
- ‚úÖ Production-ready checkpoint management

---

## Common Mistakes

### Mistake 1: Not saving optimizer state
```python
# ‚ùå Wrong - loses momentum, adaptive LR state
torch.save(model.state_dict(), 'checkpoint.pth')

# ‚úÖ Right - save complete state
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}, 'checkpoint.pth')
```

### Mistake 2: Forgetting map_location when loading
```python
# ‚ùå Wrong - may fail if saved on GPU, loaded on CPU
checkpoint = torch.load('checkpoint.pth')

# ‚úÖ Right - specify device
checkpoint = torch.load('checkpoint.pth', map_location=device)
```

---

## Further Reading

- [PyTorch Saving and Loading](https://pytorch.org/tutorials/beginner/saving_loading_models.html)
- [PyTorch Lightning Checkpointing](https://lightning.ai/docs/pytorch/stable/common/checkpointing.html)

In [None]:
# Cleanup
import shutil
import gc

# Remove checkpoint directories
if checkpoint_dir.exists():
    shutil.rmtree(checkpoint_dir)
    print("Removed checkpoint directory")

prod_dir = Path('./checkpoints/production_run')
if prod_dir.exists():
    shutil.rmtree(prod_dir)

torch.cuda.empty_cache()
gc.collect()

print("Cleanup complete!")