# Trainer Sanity Checks and Debugging Tools

**File Location:** `notebooks/01_lightning_fundamentals/02_trainer_sanity_and_debug.ipynb`

## Introduction

This notebook covers PyTorch Lightning's powerful debugging and development tools. Learn to use `fast_dev_run`, `overfit_batches`, `limit_train_batches`, and other Trainer flags that make development faster and debugging easier.

## Fast Development Run

### fast_dev_run - Quick Smoke Test

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

# Simple model for testing
class DebugModel(pl.LightningModule):
    def __init__(self, input_size=20, hidden_size=64, output_size=10):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        # Add some debugging logs
        self.log('train_loss', loss)
        self.log('batch_size', float(x.shape[0]))
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss)
        return loss
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

# Quick data setup
def create_debug_data(num_samples=1000, input_size=20, num_classes=10, batch_size=32):
    torch.manual_seed(42)
    x = torch.randn(num_samples, input_size)
    y = torch.randint(0, num_classes, (num_samples,))
    dataset = TensorDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

train_loader = create_debug_data(1000, 20, 10, 32)
val_loader = create_debug_data(200, 20, 10, 32)

# fast_dev_run: Run 1 batch of train, val, test to check everything works
print("=== Fast Dev Run ===")
model = DebugModel()

trainer = pl.Trainer(
    fast_dev_run=True,  # This runs 1 batch through the entire pipeline
    logger=False,
    enable_checkpointing=False
)

print("Running fast development check (1 batch only)...")
trainer.fit(model, train_loader, val_loader)
print("✓ Fast dev run completed - no errors in the pipeline!")

# You can also specify number of batches
trainer_5_batches = pl.Trainer(
    fast_dev_run=5,  # Run 5 batches
    logger=False,
    enable_checkpointing=False
)

print("\nRunning with 5 batches...")
trainer_5_batches.fit(model, train_loader, val_loader)
print("✓ 5-batch dev run completed!")
```

### overfit_batches - Check Model Capacity

```python
print("=== Overfit Batches Test ===")

# Test if model can overfit to a small subset - good for debugging model capacity
model = DebugModel()

trainer_overfit = pl.Trainer(
    overfit_batches=2,  # Use only 2 batches for train AND val
    max_epochs=50,      # More epochs to see overfitting
    logger=False,
    enable_checkpointing=False,
    log_every_n_steps=1  # Log every step to see progress
)

print("Training on only 2 batches for 50 epochs...")
print("Model should overfit and reach very low loss")
trainer_overfit.fit(model, train_loader, val_loader)

# Check if model actually overfitted
final_loss = trainer_overfit.logged_metrics.get('train_loss', float('inf'))
print(f"Final training loss: {final_loss:.6f}")
if final_loss < 0.1:
    print("✓ Model successfully overfitted - has sufficient capacity")
else:
    print("⚠ Model may have insufficient capacity or learning rate issues")
```

### Limiting Data - Control Training Size

```python
print("=== Limiting Training Data ===")

# limit_train_batches: Use only subset of training data
model = DebugModel()

trainer_limited = pl.Trainer(
    limit_train_batches=0.3,  # Use only 30% of training data
    limit_val_batches=0.5,    # Use only 50% of validation data
    max_epochs=3,
    logger=False,
    enable_checkpointing=False
)

print("Training with 30% of train data and 50% of val data...")
trainer_limited.fit(model, train_loader, val_loader)
print("✓ Limited data training completed")

# Can also specify exact number of batches
trainer_exact = pl.Trainer(
    limit_train_batches=10,   # Use exactly 10 training batches
    limit_val_batches=5,      # Use exactly 5 validation batches
    max_epochs=2,
    logger=False,
    enable_checkpointing=False
)

print("\nTraining with exactly 10 train batches and 5 val batches...")
trainer_exact.fit(model, train_loader, val_loader)
print("✓ Exact batch count training completed")
```

## Advanced Debugging Features

### Gradient and Loss Debugging

```python
print("=== Gradient and Loss Debugging ===")

class DebugAdvancedModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        # Advanced logging for debugging
        self.log('train_loss', loss)
        
        # Log gradient norms (useful for debugging)
        if batch_idx % 10 == 0:  # Every 10 batches
            for name, param in self.named_parameters():
                if param.grad is not None:
                    grad_norm = param.grad.data.norm(2)
                    self.log(f'grad_norm/{name}', grad_norm, on_step=True, on_epoch=False)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        # Log predictions distribution for debugging
        preds = torch.argmax(y_hat, dim=1)
        unique_preds = torch.unique(preds, return_counts=True)
        
        self.log('val_loss', loss)
        self.log('unique_predictions', float(len(unique_preds[0])))
        
        return loss
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

model = DebugAdvancedModel()

# Enable gradient clipping and track gradients
trainer_debug = pl.Trainer(
    max_epochs=2,
    limit_train_batches=20,
    limit_val_batches=5,
    gradient_clip_val=1.0,      # Clip gradients to prevent explosion
    track_grad_norm=2,          # Track L2 norm of gradients  
    log_every_n_steps=5,
    logger=False,
    enable_checkpointing=False
)

print("Training with gradient tracking and clipping...")
trainer_debug.fit(model, train_loader, val_loader)
print("✓ Gradient debugging completed")
```

### Detecting Anomalies

```python
print("=== Anomaly Detection ===")

# Enable anomaly detection to catch NaN/Inf values
torch.autograd.set_detect_anomaly(True)

class PotentiallyBuggyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(20, 10)
        self.step_count = 0
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        
        # Introduce potential numerical issues for demonstration
        self.step_count += 1
        if self.step_count == 15:  # Introduce NaN at step 15
            print("⚠ Introducing numerical instability for demo...")
            loss = torch.tensor(float('nan'))
        else:
            loss = F.cross_entropy(y_hat, y)
        
        self.log('train_loss', loss)
        return loss
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

model = PotentiallyBuggyModel()

trainer_anomaly = pl.Trainer(
    max_epochs=1,
    limit_train_batches=20,
    detect_anomaly=True,       # This will catch anomalies
    logger=False,
    enable_checkpointing=False
)

print("Testing anomaly detection (will catch NaN)...")
try:
    trainer_anomaly.fit(model, train_loader)
except Exception as e:
    print(f"✓ Anomaly detected: {type(e).__name__}")
    print("This is expected - anomaly detection caught the NaN!")

# Reset anomaly detection
torch.autograd.set_detect_anomaly(False)
```

### Profiling Performance

```python
print("=== Performance Profiling ===")

model = DebugModel()

# Enable profiler to identify bottlenecks
from pytorch_lightning.profilers import SimpleProfiler, PyTorchProfiler

# Simple profiler - basic timing
simple_profiler = SimpleProfiler(dirpath=".", filename="simple_profile")

trainer_profile = pl.Trainer(
    max_epochs=1,
    limit_train_batches=10,
    limit_val_batches=3,
    profiler=simple_profiler,
    logger=False,
    enable_checkpointing=False
)

print("Running with simple profiler...")
trainer_profile.fit(model, train_loader, val_loader)
print("✓ Profiling completed - check simple_profile.txt for timing info")

# Advanced PyTorch profiler (more detailed)
pytorch_profiler = PyTorchProfiler(
    dirpath=".",
    filename="pytorch_profile",
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ] if torch.cuda.is_available() else [torch.profiler.ProfilerActivity.CPU]
)

trainer_advanced_profile = pl.Trainer(
    max_epochs=1,
    limit_train_batches=5,
    profiler=pytorch_profiler,
    logger=False,
    enable_checkpointing=False
)

print("Running with PyTorch profiler...")
trainer_advanced_profile.fit(model, train_loader, val_loader)
print("✓ Advanced profiling completed")
```

## Development Workflow Best Practices

### Complete Debug Workflow

```python
print("=== Complete Development Workflow ===")

def debug_model_pipeline(model, train_loader, val_loader):
    """Complete debugging pipeline for new models"""
    
    print("Step 1: Fast dev run - Check basic functionality")
    trainer = pl.Trainer(fast_dev_run=True, logger=False, enable_checkpointing=False)
    try:
        trainer.fit(model, train_loader, val_loader)
        print("✓ Pipeline working")
    except Exception as e:
        print(f"✗ Pipeline error: {e}")
        return False
    
    print("\nStep 2: Overfit test - Check model capacity")
    trainer = pl.Trainer(
        overfit_batches=1, 
        max_epochs=20, 
        logger=False, 
        enable_checkpointing=False,
        enable_progress_bar=False
    )
    trainer.fit(model, train_loader, val_loader)
    final_loss = trainer.logged_metrics.get('train_loss', float('inf'))
    
    if final_loss < 0.1:
        print(f"✓ Model can overfit (final loss: {final_loss:.4f})")
    else:
        print(f"⚠ Model may have issues (final loss: {final_loss:.4f})")
    
    print("\nStep 3: Limited data run - Check training stability")
    trainer = pl.Trainer(
        limit_train_batches=10,
        limit_val_batches=3,
        max_epochs=3,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False
    )
    trainer.fit(model, train_loader, val_loader)
    print("✓ Limited training completed")
    
    print("\nStep 4: Full training ready!")
    return True

# Test the workflow
model = DebugModel()
success = debug_model_pipeline(model, train_loader, val_loader)

if success:
    print("\n🎉 Model passed all debugging checks - ready for full training!")
else:
    print("\n❌ Model needs fixes before full training")
```

### Debugging Checklist

```python
print("=== Debugging Checklist ===")

checklist = {
    "fast_dev_run": "Quick smoke test - does the pipeline work?",
    "overfit_batches": "Can the model learn? (overfit small data)",
    "limit_train_batches": "Does training work with limited data?",
    "gradient_clip_val": "Are gradients exploding?",
    "track_grad_norm": "Monitor gradient health",
    "detect_anomaly": "Catch NaN/Inf values early", 
    "profiler": "Identify performance bottlenecks",
    "log_every_n_steps": "Monitor training frequently during debug"
}

print("Development debugging workflow:")
for i, (flag, description) in enumerate(checklist.items(), 1):
    print(f"{i}. {flag}: {description}")

print("\nCommon debugging trainer configurations:")
print("""
# Quick development check:
trainer = pl.Trainer(fast_dev_run=True)

# Overfit test:
trainer = pl.Trainer(overfit_batches=5, max_epochs=50)

# Limited data debugging:
trainer = pl.Trainer(limit_train_batches=0.1, limit_val_batches=0.2, max_epochs=3)

# Full debugging mode:
trainer = pl.Trainer(
    limit_train_batches=50,
    gradient_clip_val=1.0,
    track_grad_norm=2,
    detect_anomaly=True,
    profiler="simple"
)
""")
```

## Summary

This notebook covered essential debugging and development tools in PyTorch Lightning:

1. **fast_dev_run**: Quick smoke test with minimal batches to verify pipeline integrity
2. **overfit_batches**: Test model capacity by overfitting to small data subset
3. **limit_*_batches**: Control training data size for faster iteration during development
4. **Gradient debugging**: Track gradient norms and clip values to prevent instability
5. **Anomaly detection**: Catch NaN/Inf values early in development
6. **Profiling**: Identify performance bottlenecks in your training pipeline

Key development workflow:
1. Start with `fast_dev_run=True` to catch basic errors
2. Use `overfit_batches` to verify model can learn
3. Test with `limit_train_batches` for quick iterations
4. Add gradient monitoring for numerical stability
5. Profile performance before scaling up

These tools dramatically speed up development by catching issues early and enabling rapid iteration on model architectures and hyperparameters.

Next notebook: We'll explore LightningCLI for config-driven experiments.