# Lab 2.1.6: Checkpointing System - SOLUTIONS

This notebook contains complete solutions and extended implementations for the Checkpointing System exercises.

---

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

import os
import json
import shutil
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---

## Setup: Model and Data

In [None]:
class SimpleNet(nn.Module):
    """Simple CNN for CIFAR-10."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

print(f"Train: {len(trainset)} samples, {len(trainloader)} batches")
print(f"Test: {len(testset)} samples, {len(testloader)} batches")

---

## Solution 1: Complete Checkpoint Functions

Production-ready save and load functions with all necessary state.

In [None]:
def save_checkpoint(
    path: str,
    model: nn.Module,
    optimizer: optim.Optimizer,
    epoch: int,
    train_loss: float,
    val_loss: float,
    val_accuracy: float,
    scheduler: Optional[Any] = None,
    best_val_loss: Optional[float] = None,
    best_val_accuracy: Optional[float] = None,
    config: Optional[Dict] = None,
    rng_state: bool = True,
):
    """
    Save complete training checkpoint with atomic write.
    
    Args:
        path: File path to save checkpoint
        model: Model to save
        optimizer: Optimizer with state to save
        epoch: Current epoch number
        train_loss: Training loss this epoch
        val_loss: Validation loss this epoch
        val_accuracy: Validation accuracy this epoch
        scheduler: Learning rate scheduler (optional)
        best_val_loss: Best validation loss so far
        best_val_accuracy: Best validation accuracy so far
        config: Training configuration dict (optional)
        rng_state: Whether to save RNG states for exact reproducibility
    """
    checkpoint = {
        # Model and optimizer state
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        
        # Training metrics
        'epoch': epoch,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_accuracy': val_accuracy,
        
        # Best metrics
        'best_val_loss': best_val_loss,
        'best_val_accuracy': best_val_accuracy,
        
        # Metadata
        'timestamp': datetime.now().isoformat(),
        'pytorch_version': torch.__version__,
    }
    
    # Scheduler state
    if scheduler is not None:
        checkpoint['scheduler_state_dict'] = scheduler.state_dict()
    
    # Configuration
    if config is not None:
        checkpoint['config'] = config
    
    # RNG states for exact reproducibility
    if rng_state:
        checkpoint['rng_state'] = {
            'python': None,  # random.getstate() if needed
            'numpy': None,   # np.random.get_state() if needed
            'torch': torch.get_rng_state(),
            'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
        }
    
    # Atomic save: write to temp file, then rename
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    temp_path = path.with_suffix('.tmp')
    
    torch.save(checkpoint, temp_path)
    temp_path.rename(path)
    
    return path


def load_checkpoint(
    path: str,
    model: nn.Module,
    optimizer: Optional[optim.Optimizer] = None,
    scheduler: Optional[Any] = None,
    device: Optional[torch.device] = None,
    strict: bool = True,
    restore_rng: bool = False,
) -> Dict[str, Any]:
    """
    Load training checkpoint and restore state.
    
    Args:
        path: Path to checkpoint file
        model: Model to load weights into
        optimizer: Optimizer to restore state (optional)
        scheduler: Scheduler to restore state (optional)
        device: Device to map tensors to
        strict: Whether to strictly enforce state_dict keys match
        restore_rng: Whether to restore RNG states
        
    Returns:
        Checkpoint dictionary with all saved info
    """
    checkpoint = torch.load(path, map_location=device)
    
    # Restore model
    model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
    
    # Restore optimizer
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Move optimizer state to correct device
        if device is not None:
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)
    
    # Restore scheduler
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Restore RNG states
    if restore_rng and 'rng_state' in checkpoint:
        rng = checkpoint['rng_state']
        if rng.get('torch') is not None:
            torch.set_rng_state(rng['torch'])
        if rng.get('cuda') is not None and torch.cuda.is_available():
            torch.cuda.set_rng_state_all(rng['cuda'])
    
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
    print(f"  Val Loss: {checkpoint['val_loss']:.4f}")
    print(f"  Val Accuracy: {checkpoint['val_accuracy']:.2f}%")
    
    return checkpoint

In [None]:
# Test save/load functions
checkpoint_dir = Path('./checkpoints_solution')
checkpoint_dir.mkdir(exist_ok=True)

model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Simulate training
criterion = nn.CrossEntropyLoss()
for epoch in range(3):
    for i, (inputs, labels) in enumerate(trainloader):
        if i >= 10:
            break
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = criterion(model(inputs), labels)
        loss.backward()
        optimizer.step()
    scheduler.step()

# Save
save_checkpoint(
    path=checkpoint_dir / 'test_checkpoint.pth',
    model=model,
    optimizer=optimizer,
    epoch=3,
    train_loss=0.5,
    val_loss=0.6,
    val_accuracy=75.5,
    scheduler=scheduler,
    best_val_loss=0.55,
    best_val_accuracy=77.0,
    config={'lr': 0.001, 'batch_size': 64},
)
print("Checkpoint saved!")

# Load into fresh model
new_model = SimpleNet().to(device)
new_optimizer = optim.Adam(new_model.parameters(), lr=0.001)
new_scheduler = optim.lr_scheduler.StepLR(new_optimizer, step_size=5, gamma=0.1)

ckpt = load_checkpoint(
    checkpoint_dir / 'test_checkpoint.pth',
    new_model,
    new_optimizer,
    new_scheduler,
    device,
)

print(f"\nConfig: {ckpt.get('config')}")

---

## Solution 2: Enhanced Checkpoint Manager

Complete implementation with all features.

In [None]:
@dataclass
class CheckpointConfig:
    """Configuration for checkpoint manager."""
    checkpoint_dir: str
    mode: str = 'min'  # 'min' for loss, 'max' for accuracy
    patience: int = 10
    min_delta: float = 0.0  # Minimum improvement to be considered better
    max_checkpoints: int = 5
    save_every: int = 1  # Save every N epochs
    verbose: bool = True


class EnhancedCheckpointManager:
    """
    Enhanced checkpoint manager with comprehensive features.
    
    Features:
    - Best model tracking (by loss or accuracy)
    - Early stopping with configurable patience and min_delta
    - Checkpoint rotation (keep last N checkpoints)
    - Training history persistence
    - Resume capability
    - Atomic saves to prevent corruption
    """
    
    def __init__(self, config: CheckpointConfig):
        self.config = config
        self.checkpoint_dir = Path(config.checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        # State
        self.best_metric = float('inf') if config.mode == 'min' else float('-inf')
        self.best_epoch = 0
        self.patience_counter = 0
        self.history: List[Dict] = []
        self.checkpoint_files: List[Path] = []
        
        # Load existing state
        self._load_state()
    
    def _load_state(self):
        """Load existing training state if resuming."""
        state_path = self.checkpoint_dir / 'manager_state.json'
        if state_path.exists():
            with open(state_path, 'r') as f:
                state = json.load(f)
            self.best_metric = state['best_metric']
            self.best_epoch = state['best_epoch']
            self.patience_counter = state['patience_counter']
            self.history = state.get('history', [])
            self.checkpoint_files = [Path(p) for p in state.get('checkpoint_files', [])]
            if self.config.verbose:
                print(f"Resumed checkpoint manager: best={self.best_metric:.4f} at epoch {self.best_epoch}")
    
    def _save_state(self):
        """Save manager state."""
        state = {
            'best_metric': self.best_metric,
            'best_epoch': self.best_epoch,
            'patience_counter': self.patience_counter,
            'history': self.history,
            'checkpoint_files': [str(p) for p in self.checkpoint_files],
        }
        state_path = self.checkpoint_dir / 'manager_state.json'
        with open(state_path, 'w') as f:
            json.dump(state, f, indent=2)
    
    def _is_better(self, metric: float) -> bool:
        """Check if metric is better than best (with min_delta threshold)."""
        if self.config.mode == 'min':
            return metric < (self.best_metric - self.config.min_delta)
        return metric > (self.best_metric + self.config.min_delta)
    
    def _atomic_save(self, path: Path, data: dict):
        """Save with atomic write to prevent corruption."""
        temp_path = path.with_suffix('.tmp')
        torch.save(data, temp_path)
        temp_path.rename(path)
    
    def step(
        self,
        epoch: int,
        model: nn.Module,
        optimizer: optim.Optimizer,
        train_loss: float,
        val_loss: float,
        val_accuracy: float,
        scheduler: Optional[Any] = None,
        extra_metrics: Optional[Dict] = None,
    ) -> Dict[str, Any]:
        """
        Process end of epoch: save checkpoints and check early stopping.
        
        Args:
            epoch: Current epoch number
            model: Model to checkpoint
            optimizer: Optimizer to checkpoint
            train_loss: Training loss this epoch
            val_loss: Validation loss this epoch
            val_accuracy: Validation accuracy this epoch
            scheduler: Learning rate scheduler (optional)
            extra_metrics: Additional metrics to log (optional)
            
        Returns:
            Dict with status: is_best, should_stop, etc.
        """
        # Determine tracked metric
        tracked_metric = val_loss if self.config.mode == 'min' else val_accuracy
        is_best = self._is_better(tracked_metric)
        
        # Build checkpoint data
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'best_metric': self.best_metric,
            'best_epoch': self.best_epoch,
            'timestamp': datetime.now().isoformat(),
        }
        if scheduler is not None:
            checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        if extra_metrics:
            checkpoint['extra_metrics'] = extra_metrics
        
        # Save latest checkpoint (always)
        latest_path = self.checkpoint_dir / 'latest.pth'
        self._atomic_save(latest_path, checkpoint)
        
        # Save periodic checkpoint
        if epoch % self.config.save_every == 0:
            epoch_path = self.checkpoint_dir / f'epoch_{epoch:04d}.pth'
            self._atomic_save(epoch_path, checkpoint)
            self.checkpoint_files.append(epoch_path)
            
            # Remove old checkpoints
            while len(self.checkpoint_files) > self.config.max_checkpoints:
                old_path = self.checkpoint_files.pop(0)
                if old_path.exists():
                    old_path.unlink()
        
        # Handle best model
        if is_best:
            self.best_metric = tracked_metric
            self.best_epoch = epoch
            self.patience_counter = 0
            
            best_path = self.checkpoint_dir / 'best.pth'
            self._atomic_save(best_path, checkpoint)
            
            if self.config.verbose:
                print(f"  -> New best model! {self.config.mode}={tracked_metric:.4f}")
        else:
            self.patience_counter += 1
        
        # Update history
        history_entry = {
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'is_best': is_best,
            'lr': optimizer.param_groups[0]['lr'],
        }
        if extra_metrics:
            history_entry.update(extra_metrics)
        self.history.append(history_entry)
        
        # Persist state
        self._save_state()
        
        return {
            'is_best': is_best,
            'should_stop': self.patience_counter >= self.config.patience,
            'best_metric': self.best_metric,
            'best_epoch': self.best_epoch,
            'patience_counter': self.patience_counter,
            'patience_remaining': self.config.patience - self.patience_counter,
        }
    
    def load_best(
        self,
        model: nn.Module,
        optimizer: Optional[optim.Optimizer] = None,
        scheduler: Optional[Any] = None,
        device: Optional[torch.device] = None,
    ) -> Dict:
        """Load the best checkpoint."""
        return self._load(self.checkpoint_dir / 'best.pth', model, optimizer, scheduler, device)
    
    def load_latest(
        self,
        model: nn.Module,
        optimizer: Optional[optim.Optimizer] = None,
        scheduler: Optional[Any] = None,
        device: Optional[torch.device] = None,
    ) -> Dict:
        """Load the latest checkpoint for resuming training."""
        return self._load(self.checkpoint_dir / 'latest.pth', model, optimizer, scheduler, device)
    
    def _load(
        self,
        path: Path,
        model: nn.Module,
        optimizer: Optional[optim.Optimizer],
        scheduler: Optional[Any],
        device: Optional[torch.device],
    ) -> Dict:
        """Load a checkpoint file."""
        if not path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {path}")
        
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if scheduler and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        if self.config.verbose:
            print(f"Loaded checkpoint: epoch {checkpoint['epoch']}, "
                  f"val_loss={checkpoint['val_loss']:.4f}")
        
        return checkpoint
    
    def get_history_df(self):
        """Get training history as pandas DataFrame (if pandas available)."""
        try:
            import pandas as pd
            return pd.DataFrame(self.history)
        except ImportError:
            return self.history
    
    def plot_history(self):
        """Plot training history."""
        try:
            import matplotlib.pyplot as plt
            
            fig, axes = plt.subplots(1, 2, figsize=(12, 4))
            
            epochs = [h['epoch'] for h in self.history]
            
            # Loss
            axes[0].plot(epochs, [h['train_loss'] for h in self.history], label='Train')
            axes[0].plot(epochs, [h['val_loss'] for h in self.history], label='Val')
            axes[0].axvline(self.best_epoch, color='g', linestyle='--', label='Best')
            axes[0].set_xlabel('Epoch')
            axes[0].set_ylabel('Loss')
            axes[0].legend()
            axes[0].set_title('Training Loss')
            
            # Accuracy
            axes[1].plot(epochs, [h['val_accuracy'] for h in self.history])
            axes[1].axvline(self.best_epoch, color='g', linestyle='--', label='Best')
            axes[1].set_xlabel('Epoch')
            axes[1].set_ylabel('Accuracy (%)')
            axes[1].set_title('Validation Accuracy')
            
            plt.tight_layout()
            plt.show()
        except ImportError:
            print("matplotlib not available for plotting")

In [None]:
# Test the enhanced checkpoint manager
def evaluate(model, dataloader, criterion):
    """Evaluate model on dataloader."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(dataloader), 100 * correct / total


# Setup
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
criterion = nn.CrossEntropyLoss()

# Create checkpoint manager
ckpt_config = CheckpointConfig(
    checkpoint_dir='./checkpoints_solution/enhanced_run',
    mode='min',  # Track validation loss
    patience=5,
    min_delta=0.001,  # Require at least 0.001 improvement
    max_checkpoints=3,
)
ckpt_manager = EnhancedCheckpointManager(ckpt_config)

# Training loop
print("=== Training with Enhanced Checkpoint Manager ===")
NUM_EPOCHS = 10

for epoch in range(NUM_EPOCHS):
    # Train
    model.train()
    train_loss = 0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = criterion(model(inputs), labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(trainloader)
    
    # Evaluate
    val_loss, val_acc = evaluate(model, testloader, criterion)
    scheduler.step(val_loss)
    
    # Checkpoint
    status = ckpt_manager.step(
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        train_loss=train_loss,
        val_loss=val_loss,
        val_accuracy=val_acc,
        scheduler=scheduler,
    )
    
    best_marker = " *" if status['is_best'] else ""
    print(f"Epoch {epoch+1:2d} | Train: {train_loss:.4f} | "
          f"Val: {val_loss:.4f} | Acc: {val_acc:.2f}% | "
          f"Patience: {status['patience_remaining']}{best_marker}")
    
    if status['should_stop']:
        print(f"\nEarly stopping triggered at epoch {epoch+1}!")
        break

print(f"\nBest model: epoch {ckpt_manager.best_epoch} with val_loss={ckpt_manager.best_metric:.4f}")

---

## Solution 3: Training Resume Demo

Demonstrate complete training resume capability.

In [None]:
def resumable_training(
    model: nn.Module,
    optimizer: optim.Optimizer,
    trainloader: DataLoader,
    testloader: DataLoader,
    criterion: nn.Module,
    num_epochs: int,
    checkpoint_dir: str,
    scheduler: Optional[Any] = None,
    resume: bool = True,
):
    """
    Training loop with full resume capability.
    
    Can be interrupted and resumed exactly where it left off.
    """
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    start_epoch = 0
    best_val_loss = float('inf')
    history = []
    
    # Try to resume
    latest_path = checkpoint_dir / 'latest.pth'
    if resume and latest_path.exists():
        print("Resuming from checkpoint...")
        checkpoint = torch.load(latest_path, map_location=device)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        history = checkpoint.get('history', [])
        
        print(f"Resumed from epoch {checkpoint['epoch']}, "
              f"best_val_loss={best_val_loss:.4f}")
    
    # Training loop
    for epoch in range(start_epoch, num_epochs):
        # Train
        model.train()
        train_loss = 0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            loss = criterion(model(inputs), labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(trainloader)
        
        # Evaluate
        val_loss, val_acc = evaluate(model, testloader, criterion)
        
        if scheduler:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
        
        # Track best
        is_best = val_loss < best_val_loss
        if is_best:
            best_val_loss = val_loss
        
        # Save history
        history.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_acc,
        })
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_acc,
            'best_val_loss': best_val_loss,
            'history': history,
        }
        if scheduler:
            checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        
        torch.save(checkpoint, latest_path)
        
        if is_best:
            torch.save(checkpoint, checkpoint_dir / 'best.pth')
        
        status = " (best)" if is_best else ""
        print(f"Epoch {epoch+1:2d}/{num_epochs} | "
              f"Train: {train_loss:.4f} | Val: {val_loss:.4f} | "
              f"Acc: {val_acc:.2f}%{status}")
    
    return history

In [None]:
# Demo: Train for 3 epochs, then "interrupt" and resume
print("=== First Training Run (3 epochs) ===")
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Clear previous checkpoints
resume_dir = Path('./checkpoints_solution/resume_demo')
if resume_dir.exists():
    shutil.rmtree(resume_dir)

history1 = resumable_training(
    model, optimizer, trainloader, testloader, criterion,
    num_epochs=3,
    checkpoint_dir=resume_dir,
    resume=False,
)

print("\n=== Simulating interrupt... ===")
print("Creating new model and optimizer (simulating restart)...")

# Create fresh model/optimizer (simulating program restart)
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("\n=== Resuming Training (3 more epochs) ===")
history2 = resumable_training(
    model, optimizer, trainloader, testloader, criterion,
    num_epochs=6,  # Total epochs we want
    checkpoint_dir=resume_dir,
    resume=True,  # This will load the checkpoint
)

print(f"\nTotal epochs trained: {len(history2)}")

---

## Solution 4: Multi-Model Checkpoint

Save multiple models (e.g., for GANs, ensemble, or distillation).

In [None]:
def save_multi_model_checkpoint(
    path: str,
    models: Dict[str, nn.Module],
    optimizers: Dict[str, optim.Optimizer],
    epoch: int,
    metrics: Dict[str, float],
    schedulers: Optional[Dict[str, Any]] = None,
):
    """
    Save checkpoint with multiple models (for GANs, ensembles, etc.).
    
    Args:
        path: Checkpoint file path
        models: Dict of model_name -> model
        optimizers: Dict of model_name -> optimizer
        epoch: Current epoch
        metrics: Training metrics
        schedulers: Dict of model_name -> scheduler (optional)
    """
    checkpoint = {
        'epoch': epoch,
        'metrics': metrics,
        'models': {},
        'optimizers': {},
    }
    
    for name, model in models.items():
        checkpoint['models'][name] = model.state_dict()
    
    for name, opt in optimizers.items():
        checkpoint['optimizers'][name] = opt.state_dict()
    
    if schedulers:
        checkpoint['schedulers'] = {}
        for name, sched in schedulers.items():
            checkpoint['schedulers'][name] = sched.state_dict()
    
    torch.save(checkpoint, path)
    print(f"Saved multi-model checkpoint: {list(models.keys())}")


def load_multi_model_checkpoint(
    path: str,
    models: Dict[str, nn.Module],
    optimizers: Optional[Dict[str, optim.Optimizer]] = None,
    schedulers: Optional[Dict[str, Any]] = None,
    device: Optional[torch.device] = None,
) -> Dict:
    """
    Load multi-model checkpoint.
    """
    checkpoint = torch.load(path, map_location=device)
    
    for name, model in models.items():
        if name in checkpoint['models']:
            model.load_state_dict(checkpoint['models'][name])
    
    if optimizers:
        for name, opt in optimizers.items():
            if name in checkpoint.get('optimizers', {}):
                opt.load_state_dict(checkpoint['optimizers'][name])
    
    if schedulers and 'schedulers' in checkpoint:
        for name, sched in schedulers.items():
            if name in checkpoint['schedulers']:
                sched.load_state_dict(checkpoint['schedulers'][name])
    
    print(f"Loaded multi-model checkpoint from epoch {checkpoint['epoch']}")
    return checkpoint


# Example: Save/load two models
model_a = SimpleNet().to(device)
model_b = SimpleNet().to(device)

opt_a = optim.Adam(model_a.parameters(), lr=0.001)
opt_b = optim.Adam(model_b.parameters(), lr=0.001)

# Save
save_multi_model_checkpoint(
    checkpoint_dir / 'multi_model.pth',
    models={'model_a': model_a, 'model_b': model_b},
    optimizers={'model_a': opt_a, 'model_b': opt_b},
    epoch=5,
    metrics={'loss_a': 0.5, 'loss_b': 0.6},
)

# Load into fresh models
new_a = SimpleNet().to(device)
new_b = SimpleNet().to(device)

load_multi_model_checkpoint(
    checkpoint_dir / 'multi_model.pth',
    models={'model_a': new_a, 'model_b': new_b},
    device=device,
)

---

## Summary: Checkpointing Best Practices

| Practice | Why It Matters |
|----------|----------------|
| Save optimizer state | Preserves momentum, Adam's m/v buffers |
| Save scheduler state | Maintains LR schedule position |
| Use map_location | Handles GPU/CPU portability |
| Atomic saves | Prevents corruption from crashes |
| Track best model | Always have best performing checkpoint |
| Save training history | Enables analysis and visualization |
| Checkpoint rotation | Manages disk space |
| Early stopping | Prevents overfitting, saves time |

In [None]:
# Cleanup
import gc

if checkpoint_dir.exists():
    shutil.rmtree(checkpoint_dir)
if resume_dir.exists():
    shutil.rmtree(resume_dir)

# Clean up production run if exists
prod_run = Path('./checkpoints_solution/enhanced_run')
if prod_run.exists():
    shutil.rmtree(prod_run)

# Remove main solution dir
sol_dir = Path('./checkpoints_solution')
if sol_dir.exists():
    shutil.rmtree(sol_dir)

torch.cuda.empty_cache()
gc.collect()

print("Cleanup complete!")