# Custom Callbacks: SWA and EMA

**File Location:** `notebooks/03_callbacks_and_checkpointing/07_custom_callbacks_swa_ema.ipynb`

## Introduction

This notebook explores advanced training techniques through custom callbacks, focusing on Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). Learn to implement model ensembling techniques that improve generalization without additional inference cost.

## Custom Callback Fundamentals

### Building Custom Callbacks

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from torch.utils.data import DataLoader, TensorDataset
import copy
from collections import defaultdict
import numpy as np

class TrainingMonitorCallback(Callback):
    """Custom callback to demonstrate callback lifecycle"""
    
    def __init__(self):
        self.training_stats = defaultdict(list)
        
    def on_train_start(self, trainer, pl_module):
        print("🚀 Training started!")
        
    def on_train_epoch_start(self, trainer, pl_module):
        print(f"📊 Starting epoch {trainer.current_epoch}")
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # Log gradient norms every 50 batches
        if batch_idx % 50 == 0:
            total_norm = 0
            for p in pl_module.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
            self.training_stats['grad_norm'].append(total_norm)
            
    def on_validation_epoch_end(self, trainer, pl_module):
        val_loss = trainer.callback_metrics.get('val_loss', 0)
        self.training_stats['val_loss'].append(val_loss.item() if isinstance(val_loss, torch.Tensor) else val_loss)
        
    def on_train_end(self, trainer, pl_module):
        print("🏁 Training completed!")
        print(f"Final validation loss: {self.training_stats['val_loss'][-1]:.4f}")
        print(f"Average gradient norm: {np.mean(self.training_stats['grad_norm']):.6f}")

# Simple model for callback demonstrations
class CallbackDemoModel(pl.LightningModule):
    def __init__(self, input_size=20, hidden_size=128, num_classes=5, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, num_classes)
        )
        
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

# Create synthetic data
def create_demo_data(num_samples=3000, input_size=20, num_classes=5):
    torch.manual_seed(42)
    X = torch.randn(num_samples, input_size)
    weights = torch.randn(input_size)
    logits = X @ weights
    y = torch.div(logits - logits.min(), (logits.max() - logits.min()) / (num_classes - 1), rounding_mode='floor').long()
    y = torch.clamp(y, 0, num_classes - 1)
    return X, y

X, y = create_demo_data()
dataset = TensorDataset(X, y)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print("✓ Setup completed")
```

## Stochastic Weight Averaging (SWA)

### SWA Implementation

```python
class SWACallback(Callback):
    """Stochastic Weight Averaging callback"""
    
    def __init__(self, swa_epoch_start=10, swa_lrs=1e-4, annealing_epochs=5):
        """
        Args:
            swa_epoch_start: Epoch to start SWA
            swa_lrs: SWA learning rate (or list of LRs for param groups)
            annealing_epochs: Number of epochs for learning rate annealing
        """
        self.swa_epoch_start = swa_epoch_start
        self.swa_lrs = swa_lrs
        self.annealing_epochs = annealing_epochs
        self.swa_model = None
        self.swa_n = 0
        
    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch == self.swa_epoch_start:
            print(f"🔄 Starting SWA at epoch {self.swa_epoch_start}")
            # Initialize SWA model
            self.swa_model = copy.deepcopy(pl_module)
            self.swa_n = 0
            
            # Modify optimizer for SWA learning rate
            for param_group in trainer.optimizers[0].param_groups:
                param_group['lr'] = self.swa_lrs
                
    def on_train_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch >= self.swa_epoch_start:
            # Update SWA model
            self._update_swa_model(pl_module)
            self.swa_n += 1
            
            if trainer.current_epoch % 5 == 0:
                print(f"SWA: Averaged {self.swa_n} models")
    
    def _update_swa_model(self, model):
        """Update SWA model with current model parameters"""
        if self.swa_model is None:
            return
            
        # Running average: swa_param = (swa_param * n + current_param) / (n + 1)
        alpha = 1.0 / (self.swa_n + 1)
        
        with torch.no_grad():
            for swa_param, current_param in zip(self.swa_model.parameters(), model.parameters()):
                swa_param.data = (1 - alpha) * swa_param.data + alpha * current_param.data
    
    def on_train_end(self, trainer, pl_module):
        if self.swa_model is not None:
            print("✅ SWA completed. Use get_swa_model() to access averaged model.")
    
    def get_swa_model(self):
        """Get the SWA averaged model"""
        return self.swa_model

# Advanced SWA with batch normalization update
class AdvancedSWACallback(SWACallback):
    """SWA with batch normalization statistics update"""
    
    def __init__(self, *args, update_bn_every=5, **kwargs):
        super().__init__(*args, **kwargs)
        self.update_bn_every = update_bn_every
        
    def on_train_epoch_end(self, trainer, pl_module):
        super().on_train_epoch_end(trainer, pl_module)
        
        # Update BN statistics every few epochs
        if (trainer.current_epoch >= self.swa_epoch_start and 
            trainer.current_epoch % self.update_bn_every == 0):
            self._update_bn_stats(trainer.train_dataloader, pl_module.device)
    
    def _update_bn_stats(self, dataloader, device):
        """Update batch normalization statistics for SWA model"""
        if self.swa_model is None:
            return
            
        print("🔧 Updating BN statistics for SWA model...")
        self.swa_model.train()
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(dataloader):
                if isinstance(batch, (list, tuple)) and len(batch) >= 2:
                    x, _ = batch
                    if isinstance(x, torch.Tensor):
                        x = x.to(device)
                        self.swa_model(x)
                        
                # Only use first few batches for BN update
                if batch_idx >= 50:
                    break
        
        print("✅ BN statistics updated")

# Test SWA callback
print("=== Testing SWA Callback ===")

swa_callback = AdvancedSWACallback(
    swa_epoch_start=8,
    swa_lrs=5e-4,
    annealing_epochs=3,
    update_bn_every=3
)

model = CallbackDemoModel(learning_rate=1e-3)
monitor_callback = TrainingMonitorCallback()

trainer = pl.Trainer(
    max_epochs=15,
    callbacks=[swa_callback, monitor_callback],
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False
)

print("Training with SWA...")
trainer.fit(model, train_loader, val_loader)

# Get SWA model and compare performance
swa_model = swa_callback.get_swa_model()
if swa_model is not None:
    print("\n=== Comparing Regular vs SWA Model ===")
    
    # Evaluate both models
    trainer.test(model, val_loader, verbose=False)
    regular_acc = trainer.callback_metrics.get('val_acc', 0)
    
    trainer.test(swa_model, val_loader, verbose=False) 
    swa_acc = trainer.callback_metrics.get('val_acc', 0)
    
    print(f"Regular model accuracy: {regular_acc:.4f}")
    print(f"SWA model accuracy: {swa_acc:.4f}")
    print(f"Improvement: {swa_acc - regular_acc:.4f}")

print("✓ SWA demo completed")
```

## Exponential Moving Average (EMA)

### EMA Implementation

```python
class EMACallback(Callback):
    """Exponential Moving Average callback"""
    
    def __init__(self, decay=0.999, update_every=1, start_epoch=0):
        """
        Args:
            decay: EMA decay rate (higher = more averaging)
            update_every: Update EMA every N steps
            start_epoch: Epoch to start EMA
        """
        self.decay = decay
        self.update_every = update_every
        self.start_epoch = start_epoch
        self.ema_model = None
        self.step_count = 0
        
    def on_train_start(self, trainer, pl_module):
        # Initialize EMA model
        self.ema_model = copy.deepcopy(pl_module)
        # Disable gradients for EMA model
        for param in self.ema_model.parameters():
            param.requires_grad = False
        print(f"🔄 EMA initialized with decay={self.decay}")
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if trainer.current_epoch >= self.start_epoch:
            self.step_count += 1
            
            if self.step_count % self.update_every == 0:
                self._update_ema_model(pl_module)
    
    def _update_ema_model(self, model):
        """Update EMA model parameters"""
        if self.ema_model is None:
            return
        
        with torch.no_grad():
            for ema_param, current_param in zip(self.ema_model.parameters(), model.parameters()):
                ema_param.data = self.decay * ema_param.data + (1 - self.decay) * current_param.data
    
    def on_validation_epoch_start(self, trainer, pl_module):
        # Log EMA model performance
        if self.ema_model is not None and trainer.current_epoch >= self.start_epoch:
            self._evaluate_ema_model(trainer, pl_module)
    
    def _evaluate_ema_model(self, trainer, pl_module):
        """Evaluate EMA model and log metrics"""
        if trainer.val_dataloaders is None:
            return
            
        self.ema_model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in trainer.val_dataloaders[0]:
                if isinstance(batch, (list, tuple)) and len(batch) >= 2:
                    x, y = batch
                    x, y = x.to(pl_module.device), y.to(pl_module.device)
                    
                    logits = self.ema_model(x)
                    loss = F.cross_entropy(logits, y)
                    
                    total_loss += loss.item()
                    preds = torch.argmax(logits, dim=1)
                    correct += (preds == y).sum().item()
                    total += y.size(0)
        
        avg_loss = total_loss / len(trainer.val_dataloaders[0])
        accuracy = correct / total
        
        # Log EMA metrics
        pl_module.log('ema_val_loss', avg_loss, on_epoch=True)
        pl_module.log('ema_val_acc', accuracy, on_epoch=True)
    
    def get_ema_model(self):
        """Get the EMA model"""
        return self.ema_model

class AdaptiveEMACallback(EMACallback):
    """EMA with adaptive decay based on training progress"""
    
    def __init__(self, initial_decay=0.999, min_decay=0.99, decay_schedule='linear', **kwargs):
        super().__init__(decay=initial_decay, **kwargs)
        self.initial_decay = initial_decay
        self.min_decay = min_decay
        self.decay_schedule = decay_schedule
        
    def on_train_epoch_start(self, trainer, pl_module):
        # Adapt decay rate based on training progress
        if trainer.max_epochs > 0:
            progress = trainer.current_epoch / trainer.max_epochs
            
            if self.decay_schedule == 'linear':
                self.decay = self.initial_decay - progress * (self.initial_decay - self.min_decay)
            elif self.decay_schedule == 'cosine':
                self.decay = self.min_decay + 0.5 * (self.initial_decay - self.min_decay) * (1 + np.cos(np.pi * progress))
            
            # Clamp decay to reasonable bounds
            self.decay = max(self.min_decay, min(self.initial_decay, self.decay))

# Test EMA callback
print("\n=== Testing EMA Callback ===")

ema_callback = AdaptiveEMACallback(
    initial_decay=0.999,
    min_decay=0.99,
    decay_schedule='cosine',
    update_every=2,
    start_epoch=2
)

model = CallbackDemoModel(learning_rate=1e-3)
monitor_callback = TrainingMonitorCallback()

trainer = pl.Trainer(
    max_epochs=12,
    callbacks=[ema_callback, monitor_callback],
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False
)

print("Training with EMA...")
trainer.fit(model, train_loader, val_loader)

# Get EMA model
ema_model = ema_callback.get_ema_model()
print("✓ EMA training completed")
```

## Combined SWA and EMA

### Hybrid Approach

```python
class HybridSWAEMACallback(Callback):
    """Combines SWA and EMA for robust model averaging"""
    
    def __init__(
        self,
        ema_decay=0.999,
        swa_start_epoch=10,
        swa_lr=1e-4,
        warmup_epochs=3
    ):
        self.ema_decay = ema_decay
        self.swa_start_epoch = swa_start_epoch
        self.swa_lr = swa_lr
        self.warmup_epochs = warmup_epochs
        
        # Models
        self.ema_model = None
        self.swa_model = None
        self.swa_n = 0
        
        # Tracking
        self.ema_metrics = []
        self.swa_metrics = []
        
    def on_train_start(self, trainer, pl_module):
        # Initialize EMA model
        self.ema_model = copy.deepcopy(pl_module)
        for param in self.ema_model.parameters():
            param.requires_grad = False
        print("🔄 Hybrid SWA+EMA initialized")
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # Update EMA every step after warmup
        if trainer.current_epoch >= self.warmup_epochs:
            self._update_ema(pl_module)
            
    def on_train_epoch_start(self, trainer, pl_module):
        # Start SWA at specified epoch
        if trainer.current_epoch == self.swa_start_epoch:
            print(f"🔄 Starting SWA phase at epoch {self.swa_start_epoch}")
            self.swa_model = copy.deepcopy(pl_module)
            self.swa_n = 0
            
            # Reduce learning rate for SWA
            for param_group in trainer.optimizers[0].param_groups:
                param_group['lr'] = self.swa_lr
                
    def on_train_epoch_end(self, trainer, pl_module):
        # Update SWA if in SWA phase
        if trainer.current_epoch >= self.swa_start_epoch and self.swa_model is not None:
            self._update_swa(pl_module)
            self.swa_n += 1
            
    def _update_ema(self, model):
        """Update EMA model"""
        if self.ema_model is None:
            return
            
        with torch.no_grad():
            for ema_param, current_param in zip(self.ema_model.parameters(), model.parameters()):
                ema_param.data = (
                    self.ema_decay * ema_param.data + 
                    (1 - self.ema_decay) * current_param.data
                )
    
    def _update_swa(self, model):
        """Update SWA model"""
        if self.swa_model is None:
            return
            
        alpha = 1.0 / (self.swa_n + 1)
        
        with torch.no_grad():
            for swa_param, current_param in zip(self.swa_model.parameters(), model.parameters()):
                swa_param.data = (1 - alpha) * swa_param.data + alpha * current_param.data
    
    def on_validation_epoch_end(self, trainer, pl_module):
        """Evaluate both EMA and SWA models"""
        val_loss = trainer.callback_metrics.get('val_loss', 0)
        val_acc = trainer.callback_metrics.get('val_acc', 0)
        
        # Evaluate EMA model
        if self.ema_model is not None and trainer.current_epoch >= self.warmup_epochs:
            ema_loss, ema_acc = self._evaluate_model(self.ema_model, trainer.val_dataloaders[0], pl_module.device)
            pl_module.log('ema_val_loss', ema_loss)
            pl_module.log('ema_val_acc', ema_acc)
            self.ema_metrics.append(ema_acc)
        
        # Evaluate SWA model
        if self.swa_model is not None and trainer.current_epoch >= self.swa_start_epoch:
            swa_loss, swa_acc = self._evaluate_model(self.swa_model, trainer.val_dataloaders[0], pl_module.device)
            pl_module.log('swa_val_loss', swa_loss)
            pl_module.log('swa_val_acc', swa_acc)
            self.swa_metrics.append(swa_acc)
        
        # Log comparison metrics
        if len(self.ema_metrics) > 0 and len(self.swa_metrics) > 0:
            ema_improvement = self.ema_metrics[-1] - val_acc
            swa_improvement = self.swa_metrics[-1] - val_acc
            pl_module.log('ema_improvement', ema_improvement)
            pl_module.log('swa_improvement', swa_improvement)
    
    def _evaluate_model(self, model, dataloader, device):
        """Evaluate a model on validation data"""
        model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in dataloader:
                if isinstance(batch, (list, tuple)) and len(batch) >= 2:
                    x, y = batch
                    x, y = x.to(device), y.to(device)
                    
                    logits = model(x)
                    loss = F.cross_entropy(logits, y)
                    
                    total_loss += loss.item()
                    preds = torch.argmax(logits, dim=1)
                    correct += (preds == y).sum().item()
                    total += y.size(0)
        
        avg_loss = total_loss / len(dataloader)
        accuracy = correct / total
        return avg_loss, accuracy
    
    def get_best_model(self):
        """Return the best performing model (EMA, SWA, or original)"""
        if not self.ema_metrics and not self.swa_metrics:
            return None, "original"
        
        best_ema = max(self.ema_metrics) if self.ema_metrics else 0
        best_swa = max(self.swa_metrics) if self.swa_metrics else 0
        
        if best_ema > best_swa:
            return self.ema_model, "ema"
        else:
            return self.swa_model, "swa"

# Test hybrid approach
print("\n=== Testing Hybrid SWA+EMA ===")

hybrid_callback = HybridSWAEMACallback(
    ema_decay=0.999,
    swa_start_epoch=8,
    swa_lr=5e-4,
    warmup_epochs=2
)

model = CallbackDemoModel(learning_rate=1e-3)

trainer = pl.Trainer(
    max_epochs=15,
    callbacks=[hybrid_callback],
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False
)

print("Training with hybrid SWA+EMA...")
trainer.fit(model, train_loader, val_loader)

# Compare all models
print("\n=== Final Model Comparison ===")
best_model, best_type = hybrid_callback.get_best_model()

# Evaluate original model
trainer.test(model, val_loader, verbose=False)
original_acc = trainer.callback_metrics.get('val_acc', 0)

print(f"Original model accuracy: {original_acc:.4f}")
print(f"Best EMA accuracy: {max(hybrid_callback.ema_metrics) if hybrid_callback.ema_metrics else 0:.4f}")
print(f"Best SWA accuracy: {max(hybrid_callback.swa_metrics) if hybrid_callback.swa_metrics else 0:.4f}")
print(f"Best overall model: {best_type}")

print("✓ Hybrid approach demo completed")
```

## Production-Ready Callback Suite

### Complete Custom Callback System

```python
class ProductionCallbackSuite:
    """Complete callback suite for production training"""
    
    @staticmethod
    def get_callbacks(config):
        """Get production callbacks based on configuration"""
        callbacks = []
        
        # Model averaging
        if config.get('use_ema', True):
            callbacks.append(EMACallback(
                decay=config.get('ema_decay', 0.999),
                update_every=config.get('ema_update_every', 1),
                start_epoch=config.get('ema_start_epoch', 0)
            ))
        
        if config.get('use_swa', True):
            callbacks.append(AdvancedSWACallback(
                swa_epoch_start=config.get('swa_start_epoch', 10),
                swa_lrs=config.get('swa_lr', 1e-4),
                annealing_epochs=config.get('swa_annealing', 5),
                update_bn_every=config.get('swa_bn_update_every', 3)
            ))
        
        # Monitoring
        callbacks.append(TrainingMonitorCallback())
        
        return callbacks

# Configuration for production training
production_config = {
    'use_ema': True,
    'ema_decay': 0.9995,
    'ema_update_every': 1,
    'ema_start_epoch': 5,
    
    'use_swa': True,
    'swa_start_epoch': 15,
    'swa_lr': 1e-4,
    'swa_annealing': 3,
    'swa_bn_update_every': 5
}

# Get production callbacks
production_callbacks = ProductionCallbackSuite.get_callbacks(production_config)

print("=== Production Training with Full Callback Suite ===")

model = CallbackDemoModel(learning_rate=1e-3, hidden_size=256)  # Larger model

trainer = pl.Trainer(
    max_epochs=25,
    callbacks=production_callbacks,
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=True
)

trainer.fit(model, train_loader, val_loader)

print("✓ Production training completed with full callback suite")
```

## Summary

This notebook covered advanced training techniques through custom callbacks:

1. **Custom Callback Basics**: Understanding callback lifecycle and building monitoring callbacks
2. **Stochastic Weight Averaging (SWA)**: Averaging model weights from multiple epochs for better generalization
3. **Exponential Moving Average (EMA)**: Maintaining running averages of model parameters during training
4. **Advanced Implementations**: Adaptive decay rates, batch normalization updates, and hybrid approaches
5. **Production Integration**: Complete callback suites for real-world training pipelines

Key benefits of model averaging:
- **Better Generalization**: Averaged models often perform better than individual checkpoints
- **Stability**: Reduced variance in model predictions
- **No Inference Cost**: Single model at inference time with ensemble-like benefits
- **Robustness**: Less sensitive to learning rate and optimization noise

SWA vs EMA comparison:
- **SWA**: Averages models from different epochs, requires learning rate scheduling
- **EMA**: Maintains running average throughout training, more computationally efficient
- **Hybrid**: Combines both approaches for maximum robustness

Best practices:
- Start model averaging after initial training phase (warmup)
- Use lower learning rates during SWA phase
- Update batch normalization statistics for SWA models
- Monitor both original and averaged model performance
- Choose appropriate decay rates for EMA (0.995-0.9999 typically)

Production considerations:
- Save both original and averaged models
- Implement proper callback configuration management
- Monitor training dynamics and gradient norms
- Use callbacks for automated hyperparameter adjustment
- Integrate with experiment tracking and logging systems

Next notebook: We'll explore mixed precision training and performance optimization.