# Checkpointing and Early Stopping

**File Location:** `notebooks/03_callbacks_and_checkpointing/06_checkpoint_earlystop.ipynb`

## Introduction

This notebook covers PyTorch Lightning's checkpointing and early stopping mechanisms. Learn to save model states, resume training, implement intelligent stopping criteria, and manage model versioning for robust ML workflows.

## Model Checkpointing Fundamentals

### Basic Checkpointing Setup

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, TensorDataset
import tempfile
import os
from pathlib import Path

# Simple model for checkpointing demo
class CheckpointModel(pl.LightningModule):
    def __init__(self, input_size=20, hidden_size=64, num_classes=5, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size, num_classes)
        )
        
        # Metrics for monitoring
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        # Add scheduler for more interesting training dynamics
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
                'frequency': 1,
                'interval': 'epoch'
            }
        }

# Create synthetic data
def create_training_data(num_samples=5000, input_size=20, num_classes=5):
    torch.manual_seed(42)
    X = torch.randn(num_samples, input_size)
    # Create targets with some noise to make training more realistic
    weights = torch.randn(input_size)
    logits = X @ weights
    y = torch.div(logits - logits.min(), (logits.max() - logits.min()) / (num_classes - 1), rounding_mode='floor').long()
    y = torch.clamp(y, 0, num_classes - 1)
    return X, y

X, y = create_training_data()
dataset = TensorDataset(X, y)

# Split data
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

print("✓ Basic setup completed")
```

### Standard Checkpointing Configuration

```python
# Create temporary directory for checkpoints
checkpoint_dir = Path(tempfile.mkdtemp(prefix="lightning_checkpoints_"))
print(f"Checkpoint directory: {checkpoint_dir}")

# Basic checkpoint callback - saves best model
basic_checkpoint = ModelCheckpoint(
    dirpath=checkpoint_dir / "basic",
    filename="best-model-{epoch:02d}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_top_k=1,  # Keep only the best model
    verbose=True
)

# Train with basic checkpointing
model = CheckpointModel()
logger = TensorBoardLogger(checkpoint_dir / "logs", name="basic_checkpoint")

trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[basic_checkpoint],
    logger=logger,
    enable_progress_bar=True
)

print("Training with basic checkpointing...")
trainer.fit(model, train_loader, val_loader)

# List saved checkpoints
checkpoint_files = list((checkpoint_dir / "basic").glob("*.ckpt"))
print(f"✓ Training completed. Saved checkpoints: {len(checkpoint_files)}")
for ckpt in checkpoint_files:
    print(f"  {ckpt.name}")
```

### Advanced Checkpointing Strategies

```python
# Multiple checkpoint callbacks for different purposes
callbacks = [
    # Best model based on validation loss
    ModelCheckpoint(
        dirpath=checkpoint_dir / "best_loss",
        filename="best-loss-{epoch:02d}-{val_loss:.3f}",
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        verbose=True
    ),
    
    # Best model based on validation accuracy
    ModelCheckpoint(
        dirpath=checkpoint_dir / "best_acc",
        filename="best-acc-{epoch:02d}-{val_acc:.3f}",
        monitor="val_acc", 
        mode="max",
        save_top_k=1,
        verbose=True
    ),
    
    # Keep top 3 models based on validation loss
    ModelCheckpoint(
        dirpath=checkpoint_dir / "top3",
        filename="top3-{epoch:02d}-{val_loss:.3f}",
        monitor="val_loss",
        mode="min", 
        save_top_k=3,
        verbose=False
    ),
    
    # Regular epoch checkpoints (every 5 epochs)
    ModelCheckpoint(
        dirpath=checkpoint_dir / "periodic",
        filename="epoch-{epoch:02d}",
        every_n_epochs=5,
        save_top_k=-1,  # Keep all
        verbose=False
    ),
    
    # Last checkpoint (always keep the most recent)
    ModelCheckpoint(
        dirpath=checkpoint_dir / "last",
        filename="last-model",
        save_last=True,
        verbose=False
    )
]

# Train with multiple checkpointing strategies
model = CheckpointModel(learning_rate=2e-3)  # Slightly different config
logger = TensorBoardLogger(checkpoint_dir / "logs", name="advanced_checkpoint")

trainer = pl.Trainer(
    max_epochs=15,
    callbacks=callbacks,
    logger=logger,
    enable_progress_bar=False  # Reduce output for demo
)

print("Training with advanced checkpointing strategies...")
trainer.fit(model, train_loader, val_loader)

# Analyze saved checkpoints
print("\n=== Checkpoint Analysis ===")
for callback in callbacks:
    if hasattr(callback, 'dirpath'):
        checkpoint_files = list(Path(callback.dirpath).glob("*.ckpt"))
        print(f"{Path(callback.dirpath).name}: {len(checkpoint_files)} checkpoints")
        
        for ckpt in sorted(checkpoint_files):
            # Load checkpoint to inspect metadata
            checkpoint = torch.load(ckpt, map_location='cpu')
            epoch = checkpoint.get('epoch', 'unknown')
            val_loss = checkpoint.get('val_loss', 'unknown')
            print(f"  {ckpt.name} (epoch: {epoch}, val_loss: {val_loss})")

print("✓ Advanced checkpointing completed")
```

## Early Stopping Implementation

### Basic Early Stopping

```python
# Simple early stopping callback
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,  # Minimum change to qualify as an improvement
    patience=5,       # Number of epochs to wait for improvement
    verbose=True,
    mode="min"
)

# Model that might benefit from early stopping
model = CheckpointModel(learning_rate=5e-4)  # Lower LR for more stable training

# Combined checkpointing and early stopping
combined_callbacks = [
    ModelCheckpoint(
        dirpath=checkpoint_dir / "early_stop",
        filename="early-stop-{epoch:02d}-{val_loss:.3f}",
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        verbose=True
    ),
    early_stop_callback
]

logger = TensorBoardLogger(checkpoint_dir / "logs", name="early_stopping")

trainer = pl.Trainer(
    max_epochs=50,  # High max epochs, but early stopping will prevent overfitting
    callbacks=combined_callbacks,
    logger=logger,
    enable_progress_bar=True
)

print("Training with early stopping...")
trainer.fit(model, train_loader, val_loader)

print(f"Training stopped at epoch: {trainer.current_epoch}")
print(f"Reason: {early_stop_callback.stopped_epoch}")
print("✓ Early stopping demo completed")
```

### Advanced Early Stopping Strategies

```python
# Multiple early stopping criteria
advanced_callbacks = [
    # Stop if validation loss doesn't improve
    EarlyStopping(
        monitor="val_loss",
        min_delta=0.001,
        patience=7,
        verbose=True,
        mode="min",
        check_finite=True  # Stop if metric becomes NaN/Inf
    ),
    
    # Stop if validation accuracy doesn't improve (different patience)
    EarlyStopping(
        monitor="val_acc",
        min_delta=0.005,
        patience=10,
        verbose=True,
        mode="max",
        check_finite=True
    ),
    
    # Best model checkpoint
    ModelCheckpoint(
        dirpath=checkpoint_dir / "advanced_early_stop",
        filename="best-{epoch:02d}-{val_loss:.3f}-{val_acc:.3f}",
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        verbose=True
    )
]

# Model with potential for overfitting
class OverfittingModel(CheckpointModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Larger, more complex model prone to overfitting
        self.model = nn.Sequential(
            nn.Linear(self.hparams.input_size, 128),
            nn.ReLU(),
            nn.Dropout(0.1),  # Less dropout = more overfitting potential
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, self.hparams.num_classes)
        )

model = OverfittingModel(learning_rate=1e-3)
logger = TensorBoardLogger(checkpoint_dir / "logs", name="advanced_early_stopping")

trainer = pl.Trainer(
    max_epochs=100,
    callbacks=advanced_callbacks,
    logger=logger,
    enable_progress_bar=False
)

print("Training with advanced early stopping...")
trainer.fit(model, train_loader, val_loader)

print(f"Final training epoch: {trainer.current_epoch}")
for callback in advanced_callbacks:
    if isinstance(callback, EarlyStopping):
        print(f"Early stopping ({callback.monitor}): stopped at epoch {callback.stopped_epoch}")

print("✓ Advanced early stopping demo completed")
```

## Model Loading and Resuming Training

### Loading from Checkpoints

```python
# Function to find the best checkpoint
def find_best_checkpoint(checkpoint_dir):
    """Find the best checkpoint based on validation loss"""
    checkpoint_files = list(checkpoint_dir.glob("**/*.ckpt"))
    if not checkpoint_files:
        return None
    
    best_loss = float('inf')
    best_checkpoint = None
    
    for ckpt_path in checkpoint_files:
        try:
            checkpoint = torch.load(ckpt_path, map_location='cpu')
            val_loss = checkpoint.get('val_loss', float('inf'))
            if val_loss < best_loss:
                best_loss = val_loss
                best_checkpoint = ckpt_path
        except:
            continue
    
    return best_checkpoint

# Find and load best checkpoint
best_checkpoint_path = find_best_checkpoint(checkpoint_dir)
print(f"Best checkpoint: {best_checkpoint_path}")

if best_checkpoint_path:
    # Method 1: Load model from checkpoint
    loaded_model = CheckpointModel.load_from_checkpoint(best_checkpoint_path)
    print(f"✓ Loaded model from checkpoint")
    print(f"Model hyperparameters: {loaded_model.hparams}")
    
    # Method 2: Resume training from checkpoint
    print("\nResuming training from checkpoint...")
    
    # Create new trainer for resumed training
    resume_callbacks = [
        ModelCheckpoint(
            dirpath=checkpoint_dir / "resumed",
            filename="resumed-{epoch:02d}-{val_loss:.3f}",
            monitor="val_loss",
            mode="min",
            save_top_k=3,
            verbose=True
        ),
        EarlyStopping(
            monitor="val_loss",
            patience=5,
            verbose=True,
            mode="min"
        )
    ]
    
    resume_trainer = pl.Trainer(
        max_epochs=25,  # Continue training for more epochs
        callbacks=resume_callbacks,
        logger=TensorBoardLogger(checkpoint_dir / "logs", name="resumed_training"),
        enable_progress_bar=False
    )
    
    # Resume from checkpoint
    resume_trainer.fit(
        loaded_model,
        train_loader,
        val_loader,
        ckpt_path=best_checkpoint_path
    )
    
    print("✓ Resumed training completed")
```

### Checkpoint Inspection and Analysis

```python
def analyze_checkpoint(checkpoint_path):
    """Analyze checkpoint contents"""
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    print(f"=== Checkpoint Analysis: {checkpoint_path.name} ===")
    print(f"Epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"Global step: {checkpoint.get('global_step', 'N/A')}")
    print(f"Validation loss: {checkpoint.get('val_loss', 'N/A')}")
    print(f"Validation accuracy: {checkpoint.get('val_acc', 'N/A')}")
    
    # Hyperparameters
    if 'hyper_parameters' in checkpoint:
        print("Hyperparameters:")
        for key, value in checkpoint['hyper_parameters'].items():
            print(f"  {key}: {value}")
    
    # Model state
    state_dict = checkpoint.get('state_dict', {})
    print(f"Model parameters: {len(state_dict)} tensors")
    
    # Optimizer state
    if 'optimizer_states' in checkpoint:
        print(f"Optimizer states: {len(checkpoint['optimizer_states'])}")
    
    # Learning rate schedulers
    if 'lr_schedulers' in checkpoint:
        print(f"LR schedulers: {len(checkpoint['lr_schedulers'])}")
    
    print()

# Analyze several checkpoints
checkpoint_files = list(checkpoint_dir.glob("**/*.ckpt"))[:3]  # Analyze first 3
for ckpt_path in checkpoint_files:
    try:
        analyze_checkpoint(ckpt_path)
    except Exception as e:
        print(f"Error analyzing {ckpt_path}: {e}")
```

## Custom Checkpoint Callbacks

### Custom Checkpoint Strategy

```python
from pytorch_lightning.callbacks import Callback

class CustomCheckpointCallback(Callback):
    """Custom checkpoint callback with advanced logic"""
    
    def __init__(self, save_dir, save_every_n_epochs=5, save_on_improvement=True):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.save_every_n_epochs = save_every_n_epochs
        self.save_on_improvement = save_on_improvement
        self.best_val_loss = float('inf')
        
    def on_validation_epoch_end(self, trainer, pl_module):
        current_epoch = trainer.current_epoch
        
        # Get current validation loss
        val_loss = trainer.callback_metrics.get('val_loss', float('inf'))
        val_acc = trainer.callback_metrics.get('val_acc', 0.0)
        
        # Save every N epochs
        if (current_epoch + 1) % self.save_every_n_epochs == 0:
            checkpoint_path = self.save_dir / f"periodic_epoch_{current_epoch:03d}.ckpt"
            trainer.save_checkpoint(checkpoint_path)
            print(f"Saved periodic checkpoint: {checkpoint_path.name}")
        
        # Save on improvement
        if self.save_on_improvement and val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            checkpoint_path = self.save_dir / f"best_model_epoch_{current_epoch:03d}_loss_{val_loss:.4f}.ckpt"
            trainer.save_checkpoint(checkpoint_path)
            print(f"Saved improvement checkpoint: {checkpoint_path.name}")
            
            # Also save model metadata
            metadata = {
                'epoch': current_epoch,
                'val_loss': val_loss.item() if isinstance(val_loss, torch.Tensor) else val_loss,
                'val_acc': val_acc.item() if isinstance(val_acc, torch.Tensor) else val_acc,
                'hyperparameters': dict(pl_module.hparams),
                'model_class': pl_module.__class__.__name__
            }
            
            import json
            with open(self.save_dir / f"metadata_epoch_{current_epoch:03d}.json", 'w') as f:
                json.dump(metadata, f, indent=2)

class ConditionalEarlyStopping(Callback):
    """Early stopping with custom conditions"""
    
    def __init__(self, patience=10, min_epochs=5, improvement_threshold=0.001):
        self.patience = patience
        self.min_epochs = min_epochs
        self.improvement_threshold = improvement_threshold
        self.wait_count = 0
        self.best_score = float('inf')
        
    def on_validation_epoch_end(self, trainer, pl_module):
        current_epoch = trainer.current_epoch
        val_loss = trainer.callback_metrics.get('val_loss', float('inf'))
        
        # Don't stop before minimum epochs
        if current_epoch < self.min_epochs:
            return
        
        # Check for improvement
        if val_loss < self.best_score - self.improvement_threshold:
            self.best_score = val_loss
            self.wait_count = 0
        else:
            self.wait_count += 1
        
        # Stop if no improvement for too long
        if self.wait_count >= self.patience:
            print(f"Early stopping at epoch {current_epoch} (waited {self.wait_count} epochs)")
            trainer.should_stop = True

# Test custom callbacks
custom_callbacks = [
    CustomCheckpointCallback(
        save_dir=checkpoint_dir / "custom",
        save_every_n_epochs=3,
        save_on_improvement=True
    ),
    ConditionalEarlyStopping(
        patience=8,
        min_epochs=10,
        improvement_threshold=0.002
    )
]

model = CheckpointModel(learning_rate=1e-3)
logger = TensorBoardLogger(checkpoint_dir / "logs", name="custom_callbacks")

trainer = pl.Trainer(
    max_epochs=30,
    callbacks=custom_callbacks,
    logger=logger,
    enable_progress_bar=False
)

print("Training with custom callbacks...")
trainer.fit(model, train_loader, val_loader)
print("✓ Custom callbacks demo completed")

# Check custom checkpoint directory
custom_checkpoints = list((checkpoint_dir / "custom").glob("*"))
print(f"Custom checkpoints saved: {len([f for f in custom_checkpoints if f.suffix == '.ckpt'])}")
print(f"Metadata files saved: {len([f for f in custom_checkpoints if f.suffix == '.json'])}")
```

## Best Practices and Production Tips

### Checkpoint Management Strategy

```python
def create_production_checkpoint_setup(checkpoint_base_dir, experiment_name):
    """Create production-ready checkpoint configuration"""
    
    checkpoint_dir = Path(checkpoint_base_dir) / experiment_name
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    return [
        # Best model for inference
        ModelCheckpoint(
            dirpath=checkpoint_dir / "best",
            filename="best-model-{epoch:02d}-{val_loss:.4f}",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            verbose=True,
            save_weights_only=False,  # Save full model state
            save_last=False
        ),
        
        # Top 3 models for ensemble
        ModelCheckpoint(
            dirpath=checkpoint_dir / "top3",
            filename="top3-{epoch:02d}-{val_loss:.4f}-{val_acc:.4f}",
            monitor="val_loss",
            mode="min",
            save_top_k=3,
            verbose=False,
            save_weights_only=False
        ),
        
        # Periodic backups
        ModelCheckpoint(
            dirpath=checkpoint_dir / "periodic",
            filename="backup-epoch-{epoch:02d}",
            every_n_epochs=10,
            save_top_k=-1,
            save_weights_only=True,  # Save space for backups
            verbose=False
        ),
        
        # Always save last checkpoint for resuming
        ModelCheckpoint(
            dirpath=checkpoint_dir / "recovery",
            save_last=True,
            filename="last-checkpoint",
            verbose=False
        ),
        
        # Conservative early stopping
        EarlyStopping(
            monitor="val_loss",
            min_delta=0.0001,
            patience=15,
            verbose=True,
            mode="min",
            check_finite=True,
            strict=True  # Crash if monitored metric is not found
        )
    ]

# Production example
production_callbacks = create_production_checkpoint_setup(
    checkpoint_base_dir=checkpoint_dir,
    experiment_name="production_run_v1"
)

model = CheckpointModel(
    input_size=20,
    hidden_size=128,
    num_classes=5,
    learning_rate=1e-3
)

logger = TensorBoardLogger(
    checkpoint_dir / "logs", 
    name="production",
    version="v1"
)

trainer = pl.Trainer(
    max_epochs=50,
    callbacks=production_callbacks,
    logger=logger,
    enable_progress_bar=True,
    log_every_n_steps=10,
    check_val_every_n_epoch=1
)

print("Running production training setup...")
trainer.fit(model, train_loader, val_loader)
print("✓ Production training completed")

print("\n=== Production Checkpoint Summary ===")
prod_dir = checkpoint_dir / "production_run_v1"
for subdir in ["best", "top3", "periodic", "recovery"]:
    ckpt_files = list((prod_dir / subdir).glob("*.ckpt"))
    print(f"{subdir}: {len(ckpt_files)} checkpoints")
```

## Summary

This notebook covered comprehensive checkpointing and early stopping strategies:

1. **Basic Checkpointing**: Save best models based on validation metrics
2. **Advanced Strategies**: Multiple checkpoint types (best, top-k, periodic, last)
3. **Early Stopping**: Prevent overfitting with intelligent stopping criteria
4. **Model Loading**: Resume training and load models from checkpoints
5. **Custom Callbacks**: Build domain-specific checkpoint and stopping logic
6. **Production Setup**: Robust checkpoint management for real-world deployment

Key checkpoint strategies:
- **Best Model**: Save the single best performing model for inference
- **Top-K Models**: Keep multiple good models for ensembling
- **Periodic Backups**: Regular saves to prevent data loss
- **Last Checkpoint**: Always keep most recent state for resuming training
- **Recovery Points**: Strategic saves at important training milestones

Early stopping best practices:
- Monitor appropriate metrics (usually validation loss)
- Set reasonable patience (5-15 epochs typically)
- Use minimum delta to avoid stopping on noise
- Consider multiple stopping criteria for robust training
- Always combine with checkpointing to save best model

Production considerations:
- Organize checkpoints by experiment and version
- Save metadata alongside checkpoints
- Implement checkpoint cleanup to manage disk space
- Use descriptive filenames with key metrics
- Plan for training resumption and model recovery

Next notebook: We'll explore custom callbacks including SWA and EMA.