# Lab 2.1.4: Mixed Precision Training - SOLUTIONS

This notebook contains complete solutions for the Mixed Precision Training exercises.

---

In [None]:
import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler  # Updated for PyTorch 2.0+
from torch.utils.data import DataLoader
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---

## Exercise Solution: Adaptive Precision Training with Fallback

This implementation automatically falls back from FP16 to FP32 if too many gradient overflows are detected.

In [None]:
class AdaptivePrecisionTrainer:
    """
    Trainer that automatically falls back from FP16 to FP32 if unstable.
    
    This is useful when:
    - Training with very small learning rates
    - Using architectures prone to gradient issues
    - Working with unusual loss functions
    
    Args:
        model: Neural network model
        criterion: Loss function
        max_skip_ratio: Maximum ratio of skipped steps before fallback (default: 0.1)
        window_size: Number of steps to consider for skip ratio (default: 100)
    """
    
    def __init__(
        self, 
        model: nn.Module, 
        criterion: nn.Module,
        max_skip_ratio: float = 0.1,
        window_size: int = 100
    ):
        self.model = model
        self.criterion = criterion
        self.max_skip_ratio = max_skip_ratio
        self.window_size = window_size
        
        # Start with FP16
        self.use_fp16 = True
        self.scaler = GradScaler()
        
        # Tracking
        self.skip_history = []
        self.fallback_triggered = False
        self.total_steps = 0
        self.total_skips = 0
    
    def _check_and_update_precision(self):
        """
        Check if we should fall back to FP32.
        """
        if not self.use_fp16 or self.fallback_triggered:
            return
        
        # Only check after enough history
        if len(self.skip_history) < self.window_size:
            return
        
        # Calculate skip ratio over recent window
        recent_skips = sum(self.skip_history[-self.window_size:])
        skip_ratio = recent_skips / self.window_size
        
        if skip_ratio > self.max_skip_ratio:
            print(f"\n⚠️  Fallback triggered! Skip ratio: {skip_ratio:.2%} > {self.max_skip_ratio:.2%}")
            print(f"    Switching from FP16 to FP32...")
            self.use_fp16 = False
            self.fallback_triggered = True
            # Disable scaler
            self.scaler = GradScaler(enabled=False)
    
    def train_step(self, inputs, targets, optimizer):
        """
        Perform a single training step with adaptive precision.
        
        Args:
            inputs: Input batch
            targets: Target labels
            optimizer: Optimizer instance
            
        Returns:
            Tuple of (loss_value, was_skipped)
        """
        self.model.train()
        optimizer.zero_grad()
        
        # Record scale before step
        old_scale = self.scaler.get_scale() if self.use_fp16 else 1.0
        
        # Forward pass with PyTorch 2.0+ API
        if self.use_fp16:
            with autocast(device_type='cuda', dtype=torch.float16):
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
            
            # Backward with scaling
            self.scaler.scale(loss).backward()
            self.scaler.step(optimizer)
            self.scaler.update()
        else:
            # FP32 fallback
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        
        # Check if step was skipped (scale decreased)
        new_scale = self.scaler.get_scale() if self.use_fp16 else 1.0
        was_skipped = new_scale < old_scale
        
        # Update tracking
        self.skip_history.append(1 if was_skipped else 0)
        self.total_steps += 1
        if was_skipped:
            self.total_skips += 1
        
        # Check if we should fall back
        self._check_and_update_precision()
        
        return loss.item(), was_skipped
    
    def train_epoch(self, dataloader, optimizer):
        """
        Train for one epoch with adaptive precision.
        
        Returns:
            Dict with training metrics
        """
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        epoch_skips = 0
        
        start_time = time.time()
        
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            loss, skipped = self.train_step(inputs, targets, optimizer)
            
            running_loss += loss
            if skipped:
                epoch_skips += 1
            
            # Get predictions for accuracy
            with torch.no_grad():
                if self.use_fp16:
                    with autocast(device_type='cuda', dtype=torch.float16):
                        outputs = self.model(inputs)
                else:
                    outputs = self.model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        return {
            'loss': running_loss / len(dataloader),
            'accuracy': 100. * correct / total,
            'time': time.time() - start_time,
            'skipped_steps': epoch_skips,
            'precision': 'fp16' if self.use_fp16 else 'fp32',
            'fallback_triggered': self.fallback_triggered,
        }
    
    def get_stats(self):
        """Get training statistics."""
        return {
            'total_steps': self.total_steps,
            'total_skips': self.total_skips,
            'skip_ratio': self.total_skips / max(1, self.total_steps),
            'current_precision': 'fp16' if self.use_fp16 else 'fp32',
            'fallback_triggered': self.fallback_triggered,
        }

In [None]:
# Test the adaptive trainer
import torchvision
import torchvision.transforms as transforms

# Create a simple model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(32 * 32 * 3, 512),
    nn.ReLU(),
    nn.Linear(512, 10),
).to(device)

# Load data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

# Create trainer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

trainer = AdaptivePrecisionTrainer(
    model, criterion,
    max_skip_ratio=0.1,
    window_size=50
)

# Train for a few epochs
print("=== Adaptive Precision Training ===")
for epoch in range(3):
    metrics = trainer.train_epoch(trainloader, optimizer)
    print(f"Epoch {epoch+1}: Loss={metrics['loss']:.4f}, "
          f"Acc={metrics['accuracy']:.2f}%, "
          f"Precision={metrics['precision']}, "
          f"Skipped={metrics['skipped_steps']}")

# Print final stats
stats = trainer.get_stats()
print(f"\nFinal Stats:")
print(f"  Total steps: {stats['total_steps']}")
print(f"  Total skips: {stats['total_skips']}")
print(f"  Skip ratio: {stats['skip_ratio']:.2%}")
print(f"  Final precision: {stats['current_precision']}")
print(f"  Fallback triggered: {stats['fallback_triggered']}")

---

## Alternative: Dynamic Loss Scaling

Instead of falling back completely, we can be smarter about scaling.

In [None]:
class SmartGradScaler:
    """
    Enhanced gradient scaler with better overflow handling.
    
    Features:
    - More aggressive backoff on repeated overflows
    - Slower growth to avoid oscillation
    - Statistics tracking
    """
    
    def __init__(
        self,
        init_scale: float = 65536.0,
        growth_factor: float = 1.5,  # Slower growth than default 2.0
        backoff_factor: float = 0.25,  # More aggressive backoff than 0.5
        growth_interval: int = 1000,  # Longer interval
        max_scale: float = 2**24,
        min_scale: float = 1.0,
    ):
        self.scale = init_scale
        self.growth_factor = growth_factor
        self.backoff_factor = backoff_factor
        self.growth_interval = growth_interval
        self.max_scale = max_scale
        self.min_scale = min_scale
        
        self.steps_since_growth = 0
        self.consecutive_overflows = 0
        self.total_overflows = 0
        self.total_steps = 0
    
    def scale_loss(self, loss):
        """Scale the loss for backward pass."""
        return loss * self.scale
    
    def unscale_grads(self, optimizer):
        """Unscale gradients and check for overflow."""
        found_inf = False
        
        for group in optimizer.param_groups:
            for param in group['params']:
                if param.grad is not None:
                    # Unscale
                    param.grad.div_(self.scale)
                    
                    # Check for inf/nan
                    if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
                        found_inf = True
                        # Zero out to prevent optimizer step from using bad grads
                        param.grad.zero_()
        
        return found_inf
    
    def step(self, optimizer, found_inf: bool):
        """
        Update optimizer and adjust scale.
        
        Args:
            optimizer: The optimizer
            found_inf: Whether overflow was detected
            
        Returns:
            True if step was taken, False if skipped
        """
        self.total_steps += 1
        
        if found_inf:
            self.total_overflows += 1
            self.consecutive_overflows += 1
            
            # More aggressive backoff with consecutive overflows
            backoff = self.backoff_factor ** min(self.consecutive_overflows, 3)
            self.scale = max(self.scale * backoff, self.min_scale)
            self.steps_since_growth = 0
            
            return False
        else:
            self.consecutive_overflows = 0
            optimizer.step()
            
            # Consider growing scale
            self.steps_since_growth += 1
            if self.steps_since_growth >= self.growth_interval:
                self.scale = min(self.scale * self.growth_factor, self.max_scale)
                self.steps_since_growth = 0
            
            return True
    
    def get_stats(self):
        """Get scaler statistics."""
        return {
            'current_scale': self.scale,
            'total_steps': self.total_steps,
            'total_overflows': self.total_overflows,
            'overflow_ratio': self.total_overflows / max(1, self.total_steps),
        }

# Example usage
print("SmartGradScaler example:")
scaler = SmartGradScaler()
print(f"Initial scale: {scaler.scale}")

# Simulate some steps
for i in range(10):
    # Simulate overflow on step 3
    found_inf = (i == 3)
    scaler.step(None if found_inf else torch.optim.SGD([torch.zeros(1, requires_grad=True)], lr=0.01), found_inf)

stats = scaler.get_stats()
print(f"After 10 steps: scale={stats['current_scale']:.2f}, overflows={stats['total_overflows']}")

---

## Best Practices Summary

1. **Start with BF16** - Same range as FP32, no scaling needed
2. **Use FP16 with GradScaler** - Required for FP16 stability
3. **Monitor skip ratio** - If >5%, consider adjustments
4. **Gradient clipping helps** - Use with mixed precision
5. **Test accuracy** - Verify no degradation from FP32

In [None]:
# Cleanup
import gc
torch.cuda.empty_cache()
gc.collect()
print("Cleanup complete!")