# Topic 15: Production PyTorch Best Practices

## Learning Objectives

By the end of this notebook, you will:
- Master production-ready project structure and organization
- Learn model checkpointing strategies for fault tolerance
- Understand quantization techniques (int8, int4, GPTQ, AWQ)
- Export models to ONNX for deployment
- Compare TorchScript vs ONNX vs torch.compile
- Master debugging techniques and common pitfalls
- Implement monitoring and logging for production systems

---

## 1. The Big Picture: Production vs Research Code

### Research Code vs Production Code

**Research/Prototype**:
- Single notebook or script
- Hardcoded paths and hyperparameters
- No error handling
- Ad-hoc logging
- Manual checkpointing
- Works on your machine

**Production**:
- Modular, testable codebase
- Configuration files for all settings
- Comprehensive error handling
- Structured logging and monitoring
- Automatic checkpointing and recovery
- Works anywhere, at scale

### Production Requirements

**Reliability**:
- ✅ Fault tolerance (recover from crashes)
- ✅ Reproducibility (same inputs → same outputs)
- ✅ Monitoring (know when things go wrong)
- ✅ Testing (catch bugs before deployment)

**Performance**:
- ✅ Efficient inference (low latency, high throughput)
- ✅ Memory optimization (quantization, pruning)
- ✅ Batching and caching strategies

**Maintainability**:
- ✅ Clean code structure
- ✅ Documentation
- ✅ Version control
- ✅ Easy updates and rollbacks

### What We'll Cover

1. **Project Structure**: How to organize production PyTorch projects
2. **Checkpointing**: Save/load strategies for training and inference
3. **Quantization**: Compress models for efficient deployment
4. **Export & Deployment**: ONNX, TorchScript, serving strategies
5. **Debugging**: Common issues and how to fix them
6. **Monitoring**: Track model performance in production

---

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import json
import time
from pathlib import Path
from typing import Dict, Any, Optional
import warnings

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

---

## 2. Production Project Structure

### Recommended Directory Layout

```
my_project/
├── configs/
│   ├── base_config.yaml
│   ├── train_config.yaml
│   └── inference_config.yaml
├── src/
│   ├── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── transformer.py
│   │   └── attention.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── preprocessing.py
│   ├── training/
│   │   ├── __init__.py
│   │   ├── trainer.py
│   │   └── callbacks.py
│   └── utils/
│       ├── __init__.py
│       ├── logging.py
│       └── metrics.py
├── checkpoints/
├── logs/
├── tests/
│   ├── test_model.py
│   └── test_data.py
├── scripts/
│   ├── train.py
│   └── inference.py
├── requirements.txt
├── setup.py
└── README.md
```

### Why This Structure?

- **Modularity**: Easy to find and modify components
- **Testability**: Clear separation allows unit testing
- **Reusability**: Import modules across scripts
- **Scalability**: Add new features without breaking existing code

In [None]:
# Example: Configuration Management

class Config:
    """Central configuration class for training and inference"""
    
    def __init__(self, config_dict: Optional[Dict[str, Any]] = None):
        # Model architecture
        self.d_model = 512
        self.num_heads = 8
        self.num_layers = 6
        self.d_ff = 2048
        self.dropout = 0.1
        
        # Training
        self.batch_size = 32
        self.learning_rate = 1e-4
        self.num_epochs = 10
        self.gradient_clip = 1.0
        self.warmup_steps = 1000
        
        # Optimization
        self.use_amp = True
        self.use_compile = True
        self.gradient_accumulation_steps = 1
        
        # Checkpointing
        self.checkpoint_dir = "./checkpoints"
        self.save_every_n_steps = 1000
        self.keep_last_n_checkpoints = 3
        
        # Logging
        self.log_dir = "./logs"
        self.log_every_n_steps = 100
        
        # Random seed for reproducibility
        self.seed = 42
        
        # Override with provided config
        if config_dict:
            self.__dict__.update(config_dict)
    
    def save(self, path: str):
        """Save configuration to JSON"""
        with open(path, 'w') as f:
            json.dump(self.__dict__, f, indent=2)
    
    @classmethod
    def load(cls, path: str):
        """Load configuration from JSON"""
        with open(path, 'r') as f:
            config_dict = json.load(f)
        return cls(config_dict)
    
    def __repr__(self):
        return f"Config({json.dumps(self.__dict__, indent=2)})"


# Demo configuration
config = Config()
print("Default Configuration:")
print("="*60)
for key, value in config.__dict__.items():
    print(f"  {key}: {value}")

print("\n💡 Configuration management benefits:")
print("   - All settings in one place")
print("   - Easy to save/load experiments")
print("   - Version control friendly")
print("   - Override specific values easily")

---

## 3. Model Checkpointing

### Why Checkpoint?

**Problems without checkpointing**:
- Training crashes → lose all progress
- Can't resume training
- Can't compare models across epochs
- No rollback if model degrades

### What to Save

A complete checkpoint should include:
1. **Model state**: `model.state_dict()`
2. **Optimizer state**: `optimizer.state_dict()`
3. **Scheduler state**: `scheduler.state_dict()` (if using)
4. **Training step/epoch**: For resuming
5. **Config**: To reproduce exact setup
6. **Random states**: For full reproducibility
7. **Metrics**: Best loss, accuracy, etc.

In [None]:
class CheckpointManager:
    """Manages model checkpoints with best model tracking"""
    
    def __init__(
        self,
        checkpoint_dir: str,
        keep_last_n: int = 3,
        keep_best: bool = True
    ):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.keep_last_n = keep_last_n
        self.keep_best = keep_best
        self.best_metric = float('inf')
    
    def save_checkpoint(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        step: int,
        epoch: int,
        config: Config,
        metrics: Dict[str, float],
        scheduler: Optional[Any] = None
    ) -> str:
        """Save a complete checkpoint"""
        
        checkpoint = {
            'step': step,
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics,
            'config': config.__dict__,
            'pytorch_version': torch.__version__,
        }
        
        # Add scheduler if provided
        if scheduler is not None:
            checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        
        # Add random states for reproducibility
        checkpoint['random_state'] = {
            'torch': torch.get_rng_state(),
            'numpy': np.random.get_state(),
        }
        if torch.cuda.is_available():
            checkpoint['random_state']['cuda'] = torch.cuda.get_rng_state()
        
        # Save checkpoint
        checkpoint_path = self.checkpoint_dir / f"checkpoint_step_{step}.pt"
        torch.save(checkpoint, checkpoint_path)
        
        # Save as best if metric improved
        current_metric = metrics.get('loss', float('inf'))
        if self.keep_best and current_metric < self.best_metric:
            self.best_metric = current_metric
            best_path = self.checkpoint_dir / "best_checkpoint.pt"
            torch.save(checkpoint, best_path)
            print(f"💚 New best checkpoint! Loss: {current_metric:.4f}")
        
        # Clean up old checkpoints
        self._cleanup_old_checkpoints()
        
        return str(checkpoint_path)
    
    def load_checkpoint(
        self,
        checkpoint_path: str,
        model: nn.Module,
        optimizer: Optional[torch.optim.Optimizer] = None,
        scheduler: Optional[Any] = None,
        restore_random_state: bool = True
    ) -> Dict[str, Any]:
        """Load checkpoint and restore training state"""
        
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # Load model
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer if provided
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Load scheduler if provided
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        # Restore random states for reproducibility
        if restore_random_state and 'random_state' in checkpoint:
            torch.set_rng_state(checkpoint['random_state']['torch'])
            np.random.set_state(checkpoint['random_state']['numpy'])
            if torch.cuda.is_available() and 'cuda' in checkpoint['random_state']:
                torch.cuda.set_rng_state(checkpoint['random_state']['cuda'])
        
        print(f"✅ Loaded checkpoint from step {checkpoint['step']}, epoch {checkpoint['epoch']}")
        print(f"   Metrics: {checkpoint['metrics']}")
        
        return checkpoint
    
    def _cleanup_old_checkpoints(self):
        """Keep only last N checkpoints"""
        checkpoints = sorted(
            self.checkpoint_dir.glob("checkpoint_step_*.pt"),
            key=lambda p: int(p.stem.split('_')[-1])
        )
        
        # Remove old checkpoints
        for checkpoint in checkpoints[:-self.keep_last_n]:
            checkpoint.unlink()
    
    def list_checkpoints(self) -> list:
        """List all available checkpoints"""
        checkpoints = list(self.checkpoint_dir.glob("*.pt"))
        return sorted(checkpoints)


# Demo checkpointing
print("Checkpoint Management Demo")
print("="*70)

# Create dummy model and optimizer
model = nn.Linear(10, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
config = Config()

# Create checkpoint manager
ckpt_manager = CheckpointManager("/tmp/pytorch_checkpoints", keep_last_n=3)

# Simulate training and saving
print("\nSimulating training...")
for step in range(0, 5000, 1000):
    metrics = {'loss': 1.0 / (step + 1), 'accuracy': step / 5000}
    
    path = ckpt_manager.save_checkpoint(
        model, optimizer, step, step // 1000, config, metrics
    )
    print(f"Saved: {path}")

# List all checkpoints
print("\nAvailable checkpoints:")
for ckpt in ckpt_manager.list_checkpoints():
    print(f"  {ckpt.name}")

print("\n💡 Checkpointing best practices:")
print("   - Save frequently (every N steps)")
print("   - Keep best checkpoint separate")
print("   - Clean up old checkpoints to save disk")
print("   - Include all state for perfect resumption")

---

## 4. Model Quantization

### Why Quantize?

**float32 model**:
- 7B parameters × 4 bytes = 28 GB
- Slow inference
- Expensive to serve

**int8 quantized**:
- 7B parameters × 1 byte = 7 GB (4x smaller!)
- 2-4x faster inference
- Much cheaper deployment

### Quantization Types

1. **Post-Training Quantization (PTQ)**:
   - Quantize after training
   - No retraining needed
   - Slight quality loss

2. **Quantization-Aware Training (QAT)**:
   - Train with quantization in mind
   - Better quality
   - More compute intensive

3. **Weight-Only Quantization**:
   - Quantize weights, keep activations in float
   - Good for memory-bound models

### Popular Quantization Methods (2025)

- **int8**: Standard, widely supported
- **int4**: Extreme compression (GPTQ, AWQ)
- **bfloat16**: Not quantization, but effective compression
- **GGUF**: Optimized for CPU inference (llama.cpp)

In [None]:
# Dynamic Quantization (simplest approach)

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(512, 512)
        self.linear2 = nn.Linear(512, 512)
        self.linear3 = nn.Linear(512, 10)
    
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        return self.linear3(x)


print("Dynamic Quantization Demo")
print("="*70)

# Create model
model_fp32 = SimpleModel()

# Get model size
def get_model_size(model):
    """Get model size in MB"""
    param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
    return (param_size + buffer_size) / 1024**2

fp32_size = get_model_size(model_fp32)

# Dynamic quantization
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,
    {nn.Linear},  # Quantize Linear layers
    dtype=torch.qint8
)

int8_size = get_model_size(model_int8)

print(f"Model sizes:")
print(f"  float32: {fp32_size:.2f} MB")
print(f"  int8: {int8_size:.2f} MB")
print(f"  Compression ratio: {fp32_size / int8_size:.2f}x")

# Test inference
x = torch.randn(32, 512)

# float32 inference
with torch.no_grad():
    start = time.time()
    for _ in range(100):
        _ = model_fp32(x)
    fp32_time = (time.time() - start) / 100

# int8 inference
with torch.no_grad():
    start = time.time()
    for _ in range(100):
        _ = model_int8(x)
    int8_time = (time.time() - start) / 100

print(f"\nInference time:")
print(f"  float32: {fp32_time*1000:.2f} ms")
print(f"  int8: {int8_time*1000:.2f} ms")
print(f"  Speedup: {fp32_time / int8_time:.2f}x")

# Check accuracy
with torch.no_grad():
    out_fp32 = model_fp32(x)
    out_int8 = model_int8(x)
    diff = (out_fp32 - out_int8).abs().mean().item()

print(f"\nAccuracy:")
print(f"  Mean absolute difference: {diff:.6f}")
print(f"  Relative error: {diff / out_fp32.abs().mean().item() * 100:.2f}%")

print("\n💡 Dynamic quantization benefits:")
print("   - 3-4x smaller model size")
print("   - 2-3x faster inference (CPU)")
print("   - Minimal accuracy loss")
print("   - No retraining needed")

---

## 5. Model Export: ONNX

### Why Export to ONNX?

**ONNX (Open Neural Network Exchange)**:
- ✅ Platform-independent format
- ✅ Optimized inference engines (ONNX Runtime)
- ✅ Hardware acceleration (TensorRT, OpenVINO)
- ✅ Cross-framework compatibility

**Use cases**:
- Mobile deployment (iOS, Android)
- Edge devices (Raspberry Pi, Jetson)
- Web browsers (ONNX.js)
- Production servers (ONNX Runtime)

### Export Process

1. Define input shapes (ONNX is static)
2. Export using `torch.onnx.export()`
3. Verify exported model
4. Optimize with ONNX tools

In [None]:
import io

print("ONNX Export Demo")
print("="*70)

# Create a simple model
model = SimpleModel()
model.eval()

# Define dummy input (must match expected input shape)
dummy_input = torch.randn(1, 512)

# Export to ONNX
onnx_path = "/tmp/model.onnx"

print("Exporting to ONNX...")
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=14,  # ONNX opset version
    do_constant_folding=True,  # Optimization
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},  # Variable batch size
        'output': {0: 'batch_size'}
    }
)

print(f"✅ Model exported to {onnx_path}")

# Verify ONNX model
try:
    import onnx
    import onnxruntime as ort
    
    # Load and check
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("\n✅ ONNX model is valid")
    
    # Run inference with ONNX Runtime
    ort_session = ort.InferenceSession(onnx_path)
    
    # Prepare input
    ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.numpy()}
    
    # Run inference
    ort_outputs = ort_session.run(None, ort_inputs)
    
    # Compare with PyTorch
    with torch.no_grad():
        torch_output = model(dummy_input).numpy()
    
    diff = np.abs(torch_output - ort_outputs[0]).max()
    print(f"\nPyTorch vs ONNX difference: {diff:.6f}")
    
    # Benchmark
    print("\nBenchmarking...")
    
    # PyTorch
    with torch.no_grad():
        start = time.time()
        for _ in range(1000):
            _ = model(dummy_input)
        torch_time = (time.time() - start) / 1000
    
    # ONNX Runtime
    start = time.time()
    for _ in range(1000):
        _ = ort_session.run(None, ort_inputs)
    onnx_time = (time.time() - start) / 1000
    
    print(f"  PyTorch: {torch_time*1000:.2f} ms")
    print(f"  ONNX Runtime: {onnx_time*1000:.2f} ms")
    print(f"  Speedup: {torch_time / onnx_time:.2f}x")
    
except ImportError:
    print("\n⚠️ onnx and onnxruntime not installed")
    print("   Install with: pip install onnx onnxruntime")

print("\n💡 ONNX benefits:")
print("   - Cross-platform deployment")
print("   - Optimized inference engines")
print("   - Hardware acceleration support")
print("   - Production-ready tooling")

---

## 6. Debugging PyTorch Models

### Common Issues and Solutions

#### 1. NaN/Inf in Training

**Causes**:
- Learning rate too high
- Gradient explosion
- Numerical instability (log(0), division by zero)
- Mixed precision underflow

**Solutions**:
- Use gradient clipping
- Lower learning rate
- Check for inf/nan after each operation
- Use `torch.autograd.detect_anomaly()`

In [None]:
def debug_nan_inf(model, x, target, optimizer):
    """Helper to debug NaN/Inf issues"""
    
    # Enable anomaly detection
    with torch.autograd.detect_anomaly():
        # Forward pass
        output = model(x)
        loss = F.mse_loss(output, target)
        
        # Check for NaN/Inf
        if torch.isnan(loss) or torch.isinf(loss):
            print("❌ NaN/Inf detected in loss!")
            print(f"   Output stats: min={output.min():.4f}, max={output.max():.4f}")
            print(f"   Target stats: min={target.min():.4f}, max={target.max():.4f}")
            return False
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Check gradients
        for name, param in model.named_parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    print(f"❌ NaN in gradients of {name}")
                    return False
                if torch.isinf(param.grad).any():
                    print(f"❌ Inf in gradients of {name}")
                    return False
                
                grad_norm = param.grad.norm()
                if grad_norm > 100:
                    print(f"⚠️ Large gradient in {name}: {grad_norm:.2f}")
        
        # Apply gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
    
    return True


# Demo debugging
print("Debugging NaN/Inf Issues")
print("="*70)

model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Normal case
x = torch.randn(32, 512)
target = torch.randn(32, 10)

print("Testing normal data...")
success = debug_nan_inf(model, x, target, optimizer)
print(f"Result: {'✅ OK' if success else '❌ Failed'}")

# Pathological case (very large values)
x_bad = torch.randn(32, 512) * 1e10  # Extremely large values
target_bad = torch.randn(32, 10)

print("\nTesting pathological data (very large values)...")
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    success = debug_nan_inf(model, x_bad, target_bad, optimizer)
    print(f"Result: {'✅ OK' if success else '❌ Failed'}")

print("\n💡 Debugging tips:")
print("   - Always use gradient clipping")
print("   - Check data ranges before training")
print("   - Use torch.autograd.detect_anomaly() to find source")
print("   - Log gradient norms to monitor stability")

### 2. Memory Issues

**Out of Memory (OOM) errors**:

**Causes**:
- Batch size too large
- Model too large
- Gradient accumulation
- Memory leaks (detach issues)

**Solutions**:
- Reduce batch size
- Use gradient checkpointing
- Use mixed precision (AMP)
- Clear cache: `torch.cuda.empty_cache()`
- Profile memory usage

In [None]:
def profile_memory_usage(model, batch_sizes=[8, 16, 32, 64]):
    """Profile memory usage for different batch sizes"""
    
    if not torch.cuda.is_available():
        print("⚠️ CUDA not available, skipping memory profiling")
        return
    
    print("Memory Usage Profile")
    print("="*70)
    print(f"{'Batch Size':>12} {'Peak Memory (MB)':>20} {'Success':>10}")
    print("="*70)
    
    model = model.to('cuda')
    
    for batch_size in batch_sizes:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        
        try:
            x = torch.randn(batch_size, 512, device='cuda')
            target = torch.randn(batch_size, 10, device='cuda')
            
            output = model(x)
            loss = F.mse_loss(output, target)
            loss.backward()
            
            peak_mem = torch.cuda.max_memory_allocated() / 1024**2
            print(f"{batch_size:>12} {peak_mem:>19.2f} {'✅':>10}")
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"{batch_size:>12} {'OOM':>19} {'❌':>10}")
            else:
                raise
    
    print("\n💡 Memory optimization tips:")
    print("   - Use mixed precision (AMP) for 2x memory savings")
    print("   - Enable gradient checkpointing for large models")
    print("   - Use gradient accumulation instead of large batches")
    print("   - Clear cache between runs: torch.cuda.empty_cache()")


# Demo memory profiling
profile_memory_usage(SimpleModel())

---

## 7. Production Monitoring

### What to Monitor

**During Training**:
- Loss curves (train and validation)
- Learning rate
- Gradient norms
- Weight norms
- GPU utilization
- Memory usage

**During Inference**:
- Latency (p50, p95, p99)
- Throughput (requests/sec)
- Error rates
- Model confidence scores
- Input/output distributions

In [None]:
class TrainingMonitor:
    """Monitor training metrics and detect issues"""
    
    def __init__(self, window_size: int = 100):
        self.window_size = window_size
        self.metrics = {
            'loss': [],
            'learning_rate': [],
            'grad_norm': [],
            'weight_norm': []
        }
    
    def log_step(
        self,
        loss: float,
        learning_rate: float,
        model: nn.Module,
        step: int
    ):
        """Log metrics for current step"""
        
        # Compute gradient norm
        grad_norm = 0.0
        for param in model.parameters():
            if param.grad is not None:
                grad_norm += param.grad.norm().item() ** 2
        grad_norm = grad_norm ** 0.5
        
        # Compute weight norm
        weight_norm = sum(p.norm().item() for p in model.parameters())
        
        # Store metrics
        self.metrics['loss'].append(loss)
        self.metrics['learning_rate'].append(learning_rate)
        self.metrics['grad_norm'].append(grad_norm)
        self.metrics['weight_norm'].append(weight_norm)
        
        # Keep only recent window
        for key in self.metrics:
            self.metrics[key] = self.metrics[key][-self.window_size:]
        
        # Detect anomalies
        self._check_anomalies(step)
    
    def _check_anomalies(self, step: int):
        """Check for training anomalies"""
        
        if len(self.metrics['loss']) < 10:
            return  # Not enough data
        
        recent_loss = self.metrics['loss'][-10:]
        
        # Check for NaN
        if any(np.isnan(recent_loss)):
            print(f"\n❌ Step {step}: NaN detected in loss!")
        
        # Check for loss explosion
        if recent_loss[-1] > 10 * np.median(recent_loss[:-1]):
            print(f"\n⚠️ Step {step}: Loss exploded!")
            print(f"   Current: {recent_loss[-1]:.4f}")
            print(f"   Median: {np.median(recent_loss[:-1]):.4f}")
        
        # Check for vanishing gradients
        recent_grad = self.metrics['grad_norm'][-10:]
        if np.mean(recent_grad) < 1e-6:
            print(f"\n⚠️ Step {step}: Vanishing gradients detected!")
            print(f"   Mean gradient norm: {np.mean(recent_grad):.2e}")
        
        # Check for exploding gradients
        if recent_grad[-1] > 100:
            print(f"\n⚠️ Step {step}: Exploding gradients!")
            print(f"   Gradient norm: {recent_grad[-1]:.2f}")
    
    def plot_metrics(self):
        """Plot training metrics"""
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Loss
        axes[0, 0].plot(self.metrics['loss'], linewidth=2)
        axes[0, 0].set_xlabel('Step', fontsize=12)
        axes[0, 0].set_ylabel('Loss', fontsize=12)
        axes[0, 0].set_title('Training Loss', fontsize=14)
        axes[0, 0].grid(True, alpha=0.3)
        
        # Learning rate
        axes[0, 1].plot(self.metrics['learning_rate'], linewidth=2, color='orange')
        axes[0, 1].set_xlabel('Step', fontsize=12)
        axes[0, 1].set_ylabel('Learning Rate', fontsize=12)
        axes[0, 1].set_title('Learning Rate Schedule', fontsize=14)
        axes[0, 1].grid(True, alpha=0.3)
        
        # Gradient norm
        axes[1, 0].plot(self.metrics['grad_norm'], linewidth=2, color='green')
        axes[1, 0].set_xlabel('Step', fontsize=12)
        axes[1, 0].set_ylabel('Gradient Norm', fontsize=12)
        axes[1, 0].set_title('Gradient Norm', fontsize=14)
        axes[1, 0].grid(True, alpha=0.3)
        
        # Weight norm
        axes[1, 1].plot(self.metrics['weight_norm'], linewidth=2, color='purple')
        axes[1, 1].set_xlabel('Step', fontsize=12)
        axes[1, 1].set_ylabel('Weight Norm', fontsize=12)
        axes[1, 1].set_title('Weight Norm', fontsize=14)
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()


# Demo monitoring
print("Training Monitoring Demo")
print("="*70)

model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
monitor = TrainingMonitor(window_size=100)

print("\nSimulating training...")
for step in range(200):
    # Dummy training step
    x = torch.randn(32, 512)
    target = torch.randn(32, 10)
    
    output = model(x)
    loss = F.mse_loss(output, target)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Log metrics
    monitor.log_step(
        loss.item(),
        optimizer.param_groups[0]['lr'],
        model,
        step
    )
    
    if (step + 1) % 50 == 0:
        print(f"Step {step+1}: Loss={loss.item():.4f}")

print("\n✅ Training complete!")
monitor.plot_metrics()

print("\n💡 Monitoring best practices:")
print("   - Track multiple metrics, not just loss")
print("   - Set up anomaly detection for early warning")
print("   - Visualize trends to spot issues")
print("   - Log to persistent storage (TensorBoard, W&B)")

---

## Mini Exercises

### Exercise 1: Complete Checkpoint System

Implement a training loop that:
1. Saves checkpoints every 100 steps
2. Keeps only last 3 checkpoints
3. Can resume from checkpoint
4. Saves best model based on validation loss

In [None]:
# Your code here


### Exercise 2: Quantization Comparison

Compare dynamic quantization vs static quantization for a model.
Measure: model size, inference speed, and accuracy.

In [None]:
# Your code here


### Exercise 3: Production Inference API

Create a simple inference API that:
1. Loads a checkpoint
2. Handles batching
3. Logs latency
4. Handles errors gracefully

In [None]:
# Your code here


---

## Key Takeaways

1. **Structure matters**: Organize code for maintainability and scalability
2. **Configuration over hardcoding**: Use config files for all hyperparameters
3. **Checkpoint everything**: Model, optimizer, scheduler, random states
4. **Quantize for deployment**: 4x smaller, 2-3x faster with minimal quality loss
5. **Export to ONNX**: Cross-platform deployment and optimized inference
6. **Monitor extensively**: Track metrics to catch issues early
7. **Debug systematically**: Use profiling and anomaly detection tools

## Production Checklist

**Before deploying**:
- ✅ All hyperparameters in config files
- ✅ Comprehensive checkpointing
- ✅ Quantized model for efficiency
- ✅ ONNX export tested
- ✅ Error handling for all edge cases
- ✅ Monitoring and logging in place
- ✅ Unit tests for critical components
- ✅ Documentation complete

## Modern LLM Production Practices (2025)

**Model Serving**:
- vLLM: High-throughput LLM serving
- TensorRT-LLM: NVIDIA optimized inference
- Hugging Face TGI: Production-ready serving

**Quantization**:
- GPTQ: 4-bit quantization for LLMs
- AWQ: Activation-aware weight quantization
- GGUF: CPU-optimized format (llama.cpp)

**Monitoring**:
- Weights & Biases: Experiment tracking
- TensorBoard: Visualization
- Prometheus + Grafana: Production metrics

---

## Further Reading

- [PyTorch Production Best Practices](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html)
- [Quantization Documentation](https://pytorch.org/docs/stable/quantization.html)
- [ONNX Documentation](https://onnx.ai/)
- [TorchServe](https://pytorch.org/serve/)
- [Model Optimization Guide](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html)
- [Production ML Best Practices](https://developers.google.com/machine-learning/guides/rules-of-ml)