# ⚡ Chapter 5: Mixed Precision Training Mastery

## 🧮 Theoretical Foundations of Mixed Precision Computing

### Understanding Numerical Precision in Deep Learning

Mixed precision training represents one of the most significant advances in deep learning optimization, enabling 1.5-2x speedup while reducing memory usage by up to 50%. This chapter provides comprehensive theoretical understanding and practical implementation of FP16, BF16, and emerging FP8 training.

### Precision Format Analysis

#### **FP32 (Single Precision)**
```
Sign | Exponent (8 bits) | Mantissa (23 bits)
 1   |    8 bits         |    23 bits
```
- **Range**: ±3.4 × 10^38
- **Precision**: ~7 decimal digits
- **Memory**: 4 bytes per parameter

#### **FP16 (Half Precision)**
```
Sign | Exponent (5 bits) | Mantissa (10 bits)
 1   |    5 bits         |    10 bits
```
- **Range**: ±65,504 (limited!)
- **Precision**: ~3 decimal digits
- **Memory**: 2 bytes per parameter
- **Challenge**: Gradient underflow

#### **BF16 (Brain Float 16)**
```
Sign | Exponent (8 bits) | Mantissa (7 bits)
 1   |    8 bits         |    7 bits
```
- **Range**: Same as FP32 (±3.4 × 10^38)
- **Precision**: ~2-3 decimal digits
- **Memory**: 2 bytes per parameter
- **Advantage**: No gradient scaling needed

#### **FP8 (8-bit Floating Point)**
```
E4M3: Sign | Exponent (4 bits) | Mantissa (3 bits)
E5M2: Sign | Exponent (5 bits) | Mantissa (2 bits)
```
- **Range**: E4M3 ±448, E5M2 ±57,344
- **Memory**: 1 byte per parameter
- **Status**: Experimental, H100+ support

### Mathematical Framework for Gradient Scaling

**Loss Scaling Formula:**
```
scaled_loss = loss × scale_factor
scaled_gradients = ∇(scaled_loss) = scale_factor × ∇(loss)
true_gradients = scaled_gradients / scale_factor
```

**Dynamic Scaling Algorithm:**
```
if has_inf_or_nan(gradients):
    scale_factor = scale_factor / backoff_factor
    skip_update()
else:
    if consecutive_successful_steps > growth_interval:
        scale_factor = scale_factor × growth_factor
    apply_gradients(true_gradients)
```

### Tensor Core Optimization

Modern GPUs (V100, A100, H100) include specialized **Tensor Cores** that provide massive acceleration for mixed-precision operations:

- **V100**: 125 TFLOPS (FP16)
- **A100**: 312 TFLOPS (BF16), 624 TFLOPS (FP16)
- **H100**: 989 TFLOPS (BF16), 1978 TFLOPS (FP8)

**Tensor Core Requirements:**
1. Matrix dimensions must be multiples of 8 (FP16/BF16) or 16 (FP8)
2. Contiguous memory layout required
3. Specific CUDA operations (GEMM, Conv2D, etc.)

---

## 🔬 Hands-On Implementation

In [None]:
# Core dependencies for mixed precision training implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional, Tuple, Any, Union
from dataclasses import dataclass
import time
import json
import gc
from collections import defaultdict
import warnings
import struct
from enum import Enum

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure matplotlib for better visualization
plt.style.use('default')
sns.set_palette("husl")

print("⚡ Mixed Precision Training Mastery Environment Ready!")
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device_props = torch.cuda.get_device_properties(0)
    print(f"GPU: {device_props.name}")
    print(f"CUDA Compute Capability: {device_props.major}.{device_props.minor}")
    print(f"Tensor Core Support: {'✅' if device_props.major >= 7 else '❌'}")
    print(f"BF16 Support: {'✅' if device_props.major >= 8 else '❌'}")
    
    # Check for specific precision support
    print(f"\nPrecision Format Support:")
    print(f"  • FP32: ✅ (Always available)")
    print(f"  • FP16: {'✅' if torch.cuda.is_available() else '❌'}")
    print(f"  • BF16: {'✅' if torch.cuda.is_bf16_supported() else '❌'}")
    
    # Memory info
    print(f"\nGPU Memory: {device_props.total_memory / 1e9:.1f} GB")
else:
    print("🔸 CUDA not available - using CPU for demonstrations")

## 🧪 Precision Format Analysis and Comparison

### Understanding Numerical Behavior Across Formats

This section implements comprehensive analysis tools to understand how different precision formats behave with real neural network operations, gradients, and numerical stability considerations.

In [None]:
class PrecisionFormat(Enum):
    """Enumeration of supported precision formats."""
    FP32 = "float32"
    FP16 = "float16" 
    BF16 = "bfloat16"
    FP64 = "float64"  # For reference

@dataclass
class PrecisionAnalysisConfig:
    """Configuration for precision analysis experiments."""
    formats: List[PrecisionFormat] = None
    test_ranges: List[Tuple[float, float]] = None
    num_samples: int = 10000
    gradient_scaling: bool = True
    tensor_core_aligned: bool = True
    
    def __post_init__(self):
        if self.formats is None:
            self.formats = [PrecisionFormat.FP32, PrecisionFormat.FP16]
            if torch.cuda.is_bf16_supported():
                self.formats.append(PrecisionFormat.BF16)
        
        if self.test_ranges is None:
            self.test_ranges = [
                (1e-8, 1e-6),   # Very small gradients
                (1e-6, 1e-4),   # Small gradients  
                (1e-4, 1e-2),   # Normal gradients
                (1e-2, 1.0),    # Large gradients
                (1.0, 100.0),   # Very large values
            ]

class PrecisionAnalyzer:
    """Comprehensive analyzer for mixed precision training behavior."""
    
    def __init__(self, config: PrecisionAnalysisConfig):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Results storage
        self.analysis_results = {}
        
    def analyze_numerical_precision(self) -> Dict[str, Any]:
        """Analyze numerical precision characteristics of different formats."""
        
        results = {}
        
        for precision_format in self.config.formats:
            format_results = {
                'representable_range': self._analyze_representable_range(precision_format),
                'precision_loss': self._analyze_precision_loss(precision_format),
                'gradient_underflow': self._analyze_gradient_underflow(precision_format),
                'overflow_behavior': self._analyze_overflow_behavior(precision_format),
                'tensor_core_efficiency': self._analyze_tensor_core_efficiency(precision_format)
            }
            
            results[precision_format.value] = format_results
        
        self.analysis_results['numerical_precision'] = results
        return results
    
    def _analyze_representable_range(self, precision_format: PrecisionFormat) -> Dict[str, float]:
        """Analyze the representable range of a precision format."""
        
        dtype = getattr(torch, precision_format.value)
        
        # Create test tensor
        test_values = torch.logspace(-10, 10, 1000, device=self.device, dtype=torch.float32)
        
        # Convert to target precision and back
        converted = test_values.to(dtype).to(torch.float32)
        
        # Find representable range
        is_finite = torch.isfinite(converted)
        is_nonzero = (converted != 0.0)
        
        valid_values = converted[is_finite & is_nonzero]
        
        if len(valid_values) > 0:
            min_representable = valid_values.min().item()
            max_representable = valid_values.max().item()
        else:
            min_representable = float('nan')
            max_representable = float('nan')
        
        # Calculate precision (smallest representable difference)
        eps_test = torch.tensor([1.0], device=self.device, dtype=dtype)
        eps = torch.finfo(eps_test.dtype).eps if hasattr(torch.finfo(eps_test.dtype), 'eps') else 1e-7
        
        return {
            'min_representable': min_representable,
            'max_representable': max_representable,
            'machine_epsilon': float(eps),
            'dynamic_range_db': 20 * np.log10(max_representable / min_representable) if min_representable > 0 else float('inf')
        }
    
    def _analyze_precision_loss(self, precision_format: PrecisionFormat) -> Dict[str, float]:
        """Analyze precision loss through format conversion."""
        
        dtype = getattr(torch, precision_format.value)
        
        # Generate test data across different ranges
        precision_errors = []
        
        for min_val, max_val in self.config.test_ranges:
            # Generate random values in range
            original = torch.rand(1000, device=self.device) * (max_val - min_val) + min_val
            
            # Convert to target precision and back
            converted = original.to(dtype).to(torch.float32)
            
            # Calculate relative error
            valid_mask = (original != 0) & torch.isfinite(converted)
            if valid_mask.sum() > 0:
                relative_error = torch.abs((converted - original) / original)[valid_mask]
                precision_errors.extend(relative_error.cpu().numpy())
        
        if precision_errors:
            precision_errors = np.array(precision_errors)
            return {
                'mean_relative_error': float(np.mean(precision_errors)),
                'max_relative_error': float(np.max(precision_errors)),
                'std_relative_error': float(np.std(precision_errors)),
                'p95_relative_error': float(np.percentile(precision_errors, 95)),
                'p99_relative_error': float(np.percentile(precision_errors, 99))
            }
        else:
            return {'error': 'No valid precision measurements'}
    
    def _analyze_gradient_underflow(self, precision_format: PrecisionFormat) -> Dict[str, Any]:
        """Analyze gradient underflow behavior."""
        
        dtype = getattr(torch, precision_format.value)
        
        # Simulate typical gradient magnitudes
        gradient_magnitudes = torch.logspace(-10, -1, 1000, device=self.device)
        
        # Convert gradients to target precision
        converted_gradients = gradient_magnitudes.to(dtype)
        
        # Analyze underflow
        underflow_mask = (converted_gradients == 0.0) & (gradient_magnitudes != 0.0)
        finite_mask = torch.isfinite(converted_gradients) & (converted_gradients != 0.0)
        
        underflow_threshold = None
        if underflow_mask.sum() > 0:
            # Find the largest gradient that underflows
            underflow_gradients = gradient_magnitudes[underflow_mask]
            if len(underflow_gradients) > 0:
                underflow_threshold = underflow_gradients.max().item()
        
        return {
            'underflow_rate': float(underflow_mask.sum() / len(gradient_magnitudes)),
            'finite_gradient_rate': float(finite_mask.sum() / len(gradient_magnitudes)),
            'underflow_threshold': underflow_threshold,
            'recommended_loss_scaling': self._calculate_recommended_scaling(precision_format, underflow_threshold)
        }
    
    def _analyze_overflow_behavior(self, precision_format: PrecisionFormat) -> Dict[str, Any]:
        """Analyze overflow behavior with different value ranges."""
        
        dtype = getattr(torch, precision_format.value)
        
        # Test large values
        large_values = torch.logspace(1, 6, 1000, device=self.device)
        converted_values = large_values.to(dtype)
        
        # Analyze overflow
        overflow_mask = ~torch.isfinite(converted_values)
        
        overflow_threshold = None
        if overflow_mask.sum() > 0:
            # Find the smallest value that overflows
            overflow_values = large_values[overflow_mask]
            if len(overflow_values) > 0:
                overflow_threshold = overflow_values.min().item()
        
        return {
            'overflow_rate': float(overflow_mask.sum() / len(large_values)),
            'overflow_threshold': overflow_threshold,
            'max_safe_value': float(large_values[~overflow_mask].max()) if (~overflow_mask).sum() > 0 else None
        }
    
    def _analyze_tensor_core_efficiency(self, precision_format: PrecisionFormat) -> Dict[str, Any]:
        """Analyze Tensor Core efficiency for different precision formats."""
        
        if not torch.cuda.is_available():
            return {'error': 'CUDA not available for Tensor Core analysis'}
        
        dtype = getattr(torch, precision_format.value)
        
        # Test matrix sizes (Tensor Core optimized vs non-optimized)
        test_sizes = [
            (768, 768),    # Non-aligned
            (768, 3072),   # Semi-aligned
            (1024, 4096),  # Tensor Core aligned
            (2048, 8192),  # Large aligned
        ]
        
        performance_results = []
        
        for m, n in test_sizes:
            # Create test matrices
            if self.config.tensor_core_aligned:
                # Ensure alignment for optimal Tensor Core usage
                m = ((m + 7) // 8) * 8
                n = ((n + 7) // 8) * 8
            
            a = torch.randn(m, n, device=self.device, dtype=dtype)
            b = torch.randn(n, m, device=self.device, dtype=dtype)
            
            # Warm up
            for _ in range(10):
                _ = torch.mm(a, b)
            torch.cuda.synchronize()
            
            # Benchmark
            start_time = time.time()
            for _ in range(100):
                result = torch.mm(a, b)
            torch.cuda.synchronize()
            end_time = time.time()
            
            avg_time = (end_time - start_time) / 100
            
            # Calculate FLOPS
            flops = 2 * m * n * n  # Matrix multiplication FLOPS
            tflops = (flops / avg_time) / 1e12
            
            performance_results.append({
                'matrix_size': f'{m}x{n}',
                'time_ms': avg_time * 1000,
                'tflops': tflops,
                'tensor_core_aligned': (m % 8 == 0) and (n % 8 == 0)
            })
        
        return {
            'performance_results': performance_results,
            'max_tflops': max(r['tflops'] for r in performance_results),
            'aligned_vs_unaligned_speedup': self._calculate_alignment_speedup(performance_results)
        }
    
    def _calculate_recommended_scaling(self, precision_format: PrecisionFormat, underflow_threshold: Optional[float]) -> Optional[float]:
        """Calculate recommended loss scaling factor."""
        
        if precision_format == PrecisionFormat.BF16:
            # BF16 typically doesn't need gradient scaling
            return 1.0
        elif precision_format == PrecisionFormat.FP16 and underflow_threshold is not None:
            # Scale to bring gradients into representable range
            # Target: gradients around 1e-4 to 1e-2
            target_gradient_magnitude = 1e-3
            recommended_scaling = target_gradient_magnitude / underflow_threshold
            
            # Clamp to reasonable range
            return min(max(recommended_scaling, 1.0), 65536.0)
        
        return None
    
    def _calculate_alignment_speedup(self, performance_results: List[Dict]) -> Optional[float]:
        """Calculate speedup from Tensor Core alignment."""
        
        aligned_results = [r for r in performance_results if r['tensor_core_aligned']]
        unaligned_results = [r for r in performance_results if not r['tensor_core_aligned']]
        
        if aligned_results and unaligned_results:
            avg_aligned_tflops = np.mean([r['tflops'] for r in aligned_results])
            avg_unaligned_tflops = np.mean([r['tflops'] for r in unaligned_results])
            
            return avg_aligned_tflops / avg_unaligned_tflops if avg_unaligned_tflops > 0 else None
        
        return None

# Initialize and run precision analysis
print("🧪 Starting Comprehensive Precision Analysis...")

# Configure analysis
config = PrecisionAnalysisConfig(
    num_samples=5000,  # Reduce for faster execution in demo
    gradient_scaling=True,
    tensor_core_aligned=True
)

analyzer = PrecisionAnalyzer(config)

# Run numerical precision analysis
print("📊 Analyzing numerical precision characteristics...")
precision_results = analyzer.analyze_numerical_precision()

print("\n✅ Precision Analysis Complete!")
print("\n📈 Summary Results:")

for format_name, results in precision_results.items():
    print(f"\n{format_name.upper()}:")
    
    # Representable range
    range_info = results['representable_range']
    print(f"  • Range: {range_info['min_representable']:.2e} to {range_info['max_representable']:.2e}")
    print(f"  • Dynamic Range: {range_info['dynamic_range_db']:.1f} dB")
    
    # Precision loss
    if 'mean_relative_error' in results['precision_loss']:
        precision_info = results['precision_loss']
        print(f"  • Mean Precision Error: {precision_info['mean_relative_error']:.2e}")
        print(f"  • P95 Precision Error: {precision_info['p95_relative_error']:.2e}")
    
    # Gradient underflow
    underflow_info = results['gradient_underflow']
    print(f"  • Gradient Underflow Rate: {underflow_info['underflow_rate']:.1%}")
    if underflow_info['recommended_loss_scaling']:
        print(f"  • Recommended Loss Scaling: {underflow_info['recommended_loss_scaling']:.0f}x")
    
    # Tensor Core performance
    if 'max_tflops' in results['tensor_core_efficiency']:
        tc_info = results['tensor_core_efficiency']
        print(f"  • Peak Performance: {tc_info['max_tflops']:.1f} TFLOPS")
        if tc_info['aligned_vs_unaligned_speedup']:
            print(f"  • Alignment Speedup: {tc_info['aligned_vs_unaligned_speedup']:.1f}x")

print(f"\n🎯 Analysis complete for {len(config.formats)} precision formats!")

## 🔥 Advanced Mixed Precision Training Implementation

### Production-Grade Mixed Precision System

This section implements a comprehensive mixed precision training system with automatic loss scaling, gradient clipping, and performance optimization. The implementation demonstrates advanced techniques used in production LLM training.

In [None]:
class AdvancedMixedPrecisionTrainer:
    """Production-grade mixed precision trainer with advanced optimizations."""
    
    def __init__(self, 
                 model: nn.Module,
                 optimizer: torch.optim.Optimizer,
                 precision_format: str = 'fp16',
                 loss_scaling: str = 'dynamic',
                 max_grad_norm: float = 1.0,
                 device: Optional[torch.device] = None):
        
        self.model = model
        self.optimizer = optimizer
        self.precision_format = precision_format
        self.max_grad_norm = max_grad_norm
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Move model to device
        self.model = self.model.to(self.device)
        
        # Initialize gradient scaler for FP16
        self.scaler = None
        if precision_format == 'fp16':
            if loss_scaling == 'dynamic':
                self.scaler = GradScaler(
                    init_scale=2**16,  # Initial scale
                    growth_factor=2.0,  # Scale growth factor
                    backoff_factor=0.5,  # Scale reduction factor
                    growth_interval=2000,  # Steps before scale increase
                    enabled=True
                )
            elif isinstance(loss_scaling, (int, float)):
                self.scaler = GradScaler(init_scale=loss_scaling, growth_factor=1.0, 
                                       backoff_factor=1.0, growth_interval=float('inf'))
        
        # Training statistics
        self.training_stats = {
            'total_steps': 0,
            'successful_steps': 0,
            'overflow_steps': 0,
            'scale_updates': [],
            'gradient_norms': [],
            'loss_values': [],
            'step_times': [],
            'memory_usage': []
        }
        
    def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """Perform a single training step with mixed precision."""
        
        step_start_time = time.time()
        
        # Clear gradients
        self.optimizer.zero_grad()
        
        # Mixed precision forward pass
        if self.precision_format in ['fp16', 'bf16']:
            with autocast(dtype=torch.float16 if self.precision_format == 'fp16' else torch.bfloat16):
                outputs = self.model(**batch)
                loss = outputs['loss'] if isinstance(outputs, dict) else outputs
        else:
            outputs = self.model(**batch)
            loss = outputs['loss'] if isinstance(outputs, dict) else outputs
        
        # Backward pass with scaling
        if self.scaler is not None:
            scaled_loss = self.scaler.scale(loss)
            scaled_loss.backward()
        else:
            loss.backward()
        
        # Gradient processing and optimization
        step_successful = self._optimize_step()
        
        # Update statistics
        step_time = time.time() - step_start_time
        self._update_statistics(loss.item(), step_time, step_successful)
        
        return {
            'loss': loss.item(),
            'step_time': step_time,
            'successful': step_successful,
            'current_scale': self.scaler.get_scale() if self.scaler else 1.0,
            'gradient_norm': self.training_stats['gradient_norms'][-1] if self.training_stats['gradient_norms'] else 0.0
        }
    
    def _optimize_step(self) -> bool:
        """Perform optimization step with gradient scaling and clipping."""
        
        if self.scaler is not None:
            # Unscale gradients for gradient clipping
            self.scaler.unscale_(self.optimizer)
            
            # Check for inf/nan gradients
            if self._has_inf_or_nan_gradients():
                self.scaler.update()  # Skip step and update scale
                return False
        
        # Calculate gradient norm before clipping
        grad_norm = self._calculate_gradient_norm()
        self.training_stats['gradient_norms'].append(grad_norm)
        
        # Gradient clipping
        if self.max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        
        # Optimizer step
        if self.scaler is not None:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()
        
        return True
    
    def _has_inf_or_nan_gradients(self) -> bool:
        """Check if gradients contain inf or nan values."""
        
        for param in self.model.parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                    return True
        return False
    
    def _calculate_gradient_norm(self) -> float:
        """Calculate the norm of gradients."""
        
        total_norm = 0.0
        for param in self.model.parameters():
            if param.grad is not None:
                param_norm = param.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        
        return total_norm ** 0.5
    
    def _update_statistics(self, loss_value: float, step_time: float, successful: bool):
        """Update training statistics."""
        
        self.training_stats['total_steps'] += 1
        self.training_stats['loss_values'].append(loss_value)
        self.training_stats['step_times'].append(step_time)
        
        if successful:
            self.training_stats['successful_steps'] += 1
        else:
            self.training_stats['overflow_steps'] += 1
        
        # Track scale updates
        if self.scaler is not None:
            current_scale = self.scaler.get_scale()
            if not self.training_stats['scale_updates'] or self.training_stats['scale_updates'][-1] != current_scale:
                self.training_stats['scale_updates'].append(current_scale)
        
        # Track memory usage
        if torch.cuda.is_available():
            memory_mb = torch.cuda.memory_allocated() / 1e6
            self.training_stats['memory_usage'].append(memory_mb)
    
    def get_training_statistics(self) -> Dict[str, Any]:
        """Get comprehensive training statistics."""
        
        stats = self.training_stats.copy()
        
        if stats['total_steps'] > 0:
            stats['success_rate'] = stats['successful_steps'] / stats['total_steps']
            stats['overflow_rate'] = stats['overflow_steps'] / stats['total_steps']
            stats['avg_step_time'] = np.mean(stats['step_times']) if stats['step_times'] else 0
            stats['avg_loss'] = np.mean(stats['loss_values']) if stats['loss_values'] else 0
            stats['avg_gradient_norm'] = np.mean(stats['gradient_norms']) if stats['gradient_norms'] else 0
            stats['max_memory_mb'] = max(stats['memory_usage']) if stats['memory_usage'] else 0
        
        return stats

# Create a simple transformer model for testing
class SimpleTransformerLayer(nn.Module):
    """Simple transformer layer for mixed precision testing."""
    
    def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 4096):
        super().__init__()
        
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Self-attention with residual connection
        attn_out, _ = self.self_attn(x, x, x)
        x = self.norm1(x + attn_out)
        
        # Feedforward with residual connection
        ff_out = self.feedforward(x)
        x = self.norm2(x + ff_out)
        
        return x

class TestTransformerModel(nn.Module):
    """Test transformer model for mixed precision experiments."""
    
    def __init__(self, vocab_size: int = 32000, d_model: int = 1024, num_layers: int = 6):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([SimpleTransformerLayer(d_model) for _ in range(num_layers)])
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        # Embedding
        x = self.embedding(input_ids)
        
        # Transformer layers
        for layer in self.layers:
            x = layer(x)
        
        # Output projection
        logits = self.output_projection(x)
        
        output = {'logits': logits}
        
        if labels is not None:
            # Flatten for loss calculation
            loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
            output['loss'] = loss
        
        return output

# Test mixed precision training
print("🔥 Testing Advanced Mixed Precision Training...")

# Create test model and data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TestTransformerModel(vocab_size=1000, d_model=512, num_layers=4)

print(f"📊 Model Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

# Test different precision formats
precision_formats = ['fp32', 'fp16']
if torch.cuda.is_bf16_supported():
    precision_formats.append('bf16')

training_results = {}

for precision_format in precision_formats:
    print(f"\n🧪 Testing {precision_format.upper()} Training...")
    
    # Create fresh model and optimizer for each test
    test_model = TestTransformerModel(vocab_size=1000, d_model=512, num_layers=4)
    optimizer = torch.optim.AdamW(test_model.parameters(), lr=1e-4, weight_decay=0.01)
    
    # Create trainer
    trainer = AdvancedMixedPrecisionTrainer(
        model=test_model,
        optimizer=optimizer,
        precision_format=precision_format,
        loss_scaling='dynamic' if precision_format == 'fp16' else None,
        max_grad_norm=1.0,
        device=device
    )
    
    # Generate test batch
    batch_size = 8
    seq_length = 128
    
    # Run training steps
    num_steps = 50
    for step in range(num_steps):
        # Generate random batch
        input_ids = torch.randint(0, 1000, (batch_size, seq_length), device=device)
        labels = torch.randint(0, 1000, (batch_size, seq_length), device=device)
        
        batch = {'input_ids': input_ids, 'labels': labels}
        
        # Training step
        step_result = trainer.train_step(batch)
        
        # Print progress every 10 steps
        if (step + 1) % 10 == 0:
            print(f"  Step {step + 1}: Loss = {step_result['loss']:.4f}, "
                  f"Time = {step_result['step_time']*1000:.1f}ms, "
                  f"Scale = {step_result['current_scale']:.0f}")
    
    # Get final statistics
    final_stats = trainer.get_training_statistics()
    training_results[precision_format] = final_stats
    
    print(f"\n📈 {precision_format.upper()} Results:")
    print(f"  • Success Rate: {final_stats['success_rate']:.1%}")
    print(f"  • Average Step Time: {final_stats['avg_step_time']*1000:.1f}ms")
    print(f"  • Average Loss: {final_stats['avg_loss']:.4f}")
    print(f"  • Average Gradient Norm: {final_stats['avg_gradient_norm']:.4f}")
    if torch.cuda.is_available():
        print(f"  • Peak Memory: {final_stats['max_memory_mb']:.1f}MB")
    
    # Clear memory
    del test_model, optimizer, trainer
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

print("\n✅ Mixed Precision Training Tests Complete!")

## 📊 Comprehensive Performance Analysis and Visualization

### Mixed Precision Training Comparison

This section creates comprehensive visualizations comparing the performance, memory usage, and numerical behavior of different precision formats in realistic training scenarios.

In [None]:
def create_mixed_precision_visualizations(precision_results: Dict, training_results: Dict):
    """Create comprehensive visualizations for mixed precision analysis."""
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    fig.suptitle('⚡ Mixed Precision Training Comprehensive Analysis', fontsize=16, y=0.98)
    
    # 1. Numerical Precision Comparison
    ax1 = axes[0, 0]
    
    formats = list(precision_results.keys())
    precision_errors = []
    dynamic_ranges = []
    
    for fmt in formats:
        if 'mean_relative_error' in precision_results[fmt]['precision_loss']:
            precision_errors.append(precision_results[fmt]['precision_loss']['mean_relative_error'])
        else:
            precision_errors.append(0)
        
        dynamic_ranges.append(precision_results[fmt]['representable_range']['dynamic_range_db'])
    
    x = np.arange(len(formats))
    ax1.bar(x, precision_errors, alpha=0.7, color=['blue', 'orange', 'green'][:len(formats)])
    ax1.set_xlabel('Precision Format')
    ax1.set_ylabel('Mean Relative Error')
    ax1.set_title('Numerical Precision Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels([fmt.upper() for fmt in formats])
    ax1.set_yscale('log')
    ax1.grid(True, alpha=0.3)
    
    # 2. Training Performance Comparison
    ax2 = axes[0, 1]
    
    training_formats = list(training_results.keys())
    step_times = [training_results[fmt]['avg_step_time'] * 1000 for fmt in training_formats]  # Convert to ms
    memory_usage = [training_results[fmt].get('max_memory_mb', 0) for fmt in training_formats]
    
    x = np.arange(len(training_formats))
    width = 0.35
    
    bars1 = ax2.bar(x - width/2, step_times, width, label='Step Time (ms)', alpha=0.7)
    ax2_twin = ax2.twinx()
    bars2 = ax2_twin.bar(x + width/2, memory_usage, width, label='Memory (MB)', alpha=0.7, color='orange')
    
    ax2.set_xlabel('Precision Format')
    ax2.set_ylabel('Step Time (ms)', color='blue')
    ax2_twin.set_ylabel('Memory Usage (MB)', color='orange')
    ax2.set_title('Training Performance Comparison')
    ax2.set_xticks(x)
    ax2.set_xticklabels([fmt.upper() for fmt in training_formats])
    ax2.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, value in zip(bars1, step_times):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                f'{value:.1f}', ha='center', va='bottom')
    
    # 3. Gradient Underflow Analysis
    ax3 = axes[0, 2]
    
    underflow_rates = []
    recommended_scalings = []
    
    for fmt in formats:
        underflow_info = precision_results[fmt]['gradient_underflow']
        underflow_rates.append(underflow_info['underflow_rate'])
        scaling = underflow_info.get('recommended_loss_scaling')
        recommended_scalings.append(scaling if scaling else 1.0)
    
    x = np.arange(len(formats))
    ax3.bar(x, underflow_rates, alpha=0.7, color=['red', 'orange', 'green'][:len(formats)])
    ax3.set_xlabel('Precision Format')
    ax3.set_ylabel('Gradient Underflow Rate')
    ax3.set_title('Gradient Underflow Analysis')
    ax3.set_xticks(x)
    ax3.set_xticklabels([fmt.upper() for fmt in formats])
    ax3.grid(True, alpha=0.3)
    
    # Add recommended scaling as text
    for i, (rate, scaling) in enumerate(zip(underflow_rates, recommended_scalings)):
        ax3.text(i, rate + 0.01, f'Scale: {scaling:.0f}x', ha='center', va='bottom', fontsize=8)
    
    # 4. Training Success Rate and Convergence
    ax4 = axes[1, 0]
    
    success_rates = [training_results[fmt]['success_rate'] for fmt in training_formats]
    avg_losses = [training_results[fmt]['avg_loss'] for fmt in training_formats]
    
    x = np.arange(len(training_formats))
    bars = ax4.bar(x, success_rates, alpha=0.7, color=['blue', 'orange', 'green'][:len(training_formats)])
    ax4.set_xlabel('Precision Format')
    ax4.set_ylabel('Training Success Rate')
    ax4.set_title('Training Stability Comparison')
    ax4.set_xticks(x)
    ax4.set_xticklabels([fmt.upper() for fmt in training_formats])
    ax4.set_ylim(0, 1.1)
    ax4.grid(True, alpha=0.3)
    
    # Add percentage labels
    for bar, rate in zip(bars, success_rates):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                f'{rate:.1%}', ha='center', va='bottom')
    
    # 5. Memory Efficiency Analysis
    ax5 = axes[1, 1]
    
    # Calculate memory savings relative to FP32
    fp32_memory = memory_usage[0] if training_formats[0] == 'fp32' else max(memory_usage)
    memory_savings = [(fp32_memory - mem) / fp32_memory * 100 for mem in memory_usage]
    
    bars = ax5.bar(range(len(training_formats)), memory_savings, 
                   alpha=0.7, color=['blue', 'orange', 'green'][:len(training_formats)])
    ax5.set_xlabel('Precision Format')
    ax5.set_ylabel('Memory Savings (%)')
    ax5.set_title('Memory Efficiency Gains')
    ax5.set_xticks(range(len(training_formats)))
    ax5.set_xticklabels([fmt.upper() for fmt in training_formats])
    ax5.grid(True, alpha=0.3)
    
    # Add percentage labels
    for i, saving in enumerate(memory_savings):
        ax5.text(i, saving + 1, f'{saving:.1f}%', ha='center', va='bottom')
    
    # 6. Tensor Core Performance Analysis
    ax6 = axes[1, 2]
    
    tflops_data = []
    format_labels = []
    
    for fmt in formats:
        if 'max_tflops' in precision_results[fmt]['tensor_core_efficiency']:
            tflops = precision_results[fmt]['tensor_core_efficiency']['max_tflops']
            tflops_data.append(tflops)
            format_labels.append(fmt.upper())
    
    if tflops_data:
        bars = ax6.bar(range(len(format_labels)), tflops_data, alpha=0.7,
                      color=['blue', 'orange', 'green'][:len(format_labels)])
        ax6.set_xlabel('Precision Format')
        ax6.set_ylabel('Peak Performance (TFLOPS)')
        ax6.set_title('Tensor Core Performance')
        ax6.set_xticks(range(len(format_labels)))
        ax6.set_xticklabels(format_labels)
        ax6.grid(True, alpha=0.3)
        
        # Add TFLOPS labels
        for i, tflops in enumerate(tflops_data):
            ax6.text(i, tflops + max(tflops_data) * 0.02, f'{tflops:.1f}', 
                    ha='center', va='bottom')
    else:
        ax6.text(0.5, 0.5, 'Tensor Core\nAnalysis\nNot Available', 
                ha='center', va='center', transform=ax6.transAxes, fontsize=12)
        ax6.set_title('Tensor Core Performance')
    
    plt.tight_layout()
    plt.show()
    
    return fig

def generate_mixed_precision_recommendations(precision_results: Dict, training_results: Dict) -> Dict[str, Any]:
    """Generate comprehensive recommendations for mixed precision training."""
    
    recommendations = {
        'format_recommendations': {},
        'optimization_strategies': {},
        'hardware_considerations': {},
        'production_guidelines': {}
    }
    
    # Analyze each precision format
    for fmt in precision_results.keys():
        precision_info = precision_results[fmt]
        training_info = training_results.get(fmt, {})
        
        format_rec = {
            'memory_savings': '50%' if fmt in ['float16', 'bfloat16'] else '0%',
            'speed_improvement': 'Up to 2x' if fmt in ['float16', 'bfloat16'] else 'Baseline',
            'numerical_stability': 'High' if fmt == 'bfloat16' else 'Medium' if fmt == 'float16' else 'Highest',
            'gradient_scaling_required': fmt == 'float16',
            'tensor_core_support': fmt in ['float16', 'bfloat16'],
            'recommended_use_cases': []
        }
        
        # Use case recommendations
        if fmt == 'float32':
            format_rec['recommended_use_cases'] = [
                'Debugging and development',
                'Small models where memory is not a constraint',
                'Research requiring highest numerical precision'
            ]
        elif fmt == 'float16':
            format_rec['recommended_use_cases'] = [
                'Large model training with V100/A100 GPUs',
                'Memory-constrained environments',
                'Production training with careful gradient scaling'
            ]
        elif fmt == 'bfloat16':
            format_rec['recommended_use_cases'] = [
                'Large model training with A100/H100 GPUs',
                'Production training requiring stability',
                'Training without gradient scaling complexity'
            ]
        
        # Performance analysis
        if training_info:
            format_rec['training_stability'] = training_info.get('success_rate', 0)
            format_rec['average_step_time_ms'] = training_info.get('avg_step_time', 0) * 1000
        
        recommendations['format_recommendations'][fmt] = format_rec
    
    # Optimization strategies
    recommendations['optimization_strategies'] = {
        'gradient_scaling': {
            'fp16': {
                'strategy': 'Dynamic scaling',
                'initial_scale': 2**16,
                'growth_factor': 2.0,
                'backoff_factor': 0.5,
                'growth_interval': 2000
            },
            'bf16': {
                'strategy': 'No scaling required',
                'reason': 'Same exponent range as FP32'
            }
        },
        'gradient_clipping': {
            'recommended_norm': 1.0,
            'adaptive_clipping': True,
            'clip_before_scaling': True
        },
        'tensor_core_optimization': {
            'matrix_alignment': 'Multiple of 8 for FP16/BF16',
            'memory_layout': 'Contiguous tensors required',
            'operation_coverage': 'GEMM, Conv2D, BatchNorm'
        }
    }
    
    # Hardware considerations
    recommendations['hardware_considerations'] = {
        'V100': {
            'fp16_support': 'Excellent',
            'bf16_support': 'Not available',
            'tensor_cores': 'First generation',
            'recommended_format': 'FP16'
        },
        'A100': {
            'fp16_support': 'Excellent',
            'bf16_support': 'Excellent',
            'tensor_cores': 'Third generation',
            'recommended_format': 'BF16'
        },
        'H100': {
            'fp16_support': 'Excellent',
            'bf16_support': 'Excellent',
            'fp8_support': 'Experimental',
            'tensor_cores': 'Fourth generation',
            'recommended_format': 'BF16 (FP8 for research)'
        },
        'T4': {
            'fp16_support': 'Good',
            'bf16_support': 'Not available',
            'tensor_cores': 'Second generation',
            'recommended_format': 'FP16 (with caution)'
        }
    }
    
    # Production guidelines
    recommendations['production_guidelines'] = {
        'monitoring': [
            'Track gradient overflow rates',
            'Monitor loss scaling updates',
            'Watch for training instability',
            'Profile memory usage patterns'
        ],
        'best_practices': [
            'Start with FP32 for debugging',
            'Use BF16 for A100+ hardware when available',
            'Implement gradual precision reduction',
            'Test extensively before production deployment'
        ],
        'fallback_strategies': [
            'Automatic fallback to FP32 on overflow',
            'Layer-specific precision selection',
            'Dynamic precision adjustment during training'
        ]
    }
    
    return recommendations

# Create comprehensive visualizations
print("📊 Creating Comprehensive Mixed Precision Visualizations...")
fig = create_mixed_precision_visualizations(precision_results, training_results)

# Generate production recommendations
print("\n🎯 Generating Production Recommendations...")
recommendations = generate_mixed_precision_recommendations(precision_results, training_results)

print("\n" + "=" * 60)
print("⚡ MIXED PRECISION TRAINING RECOMMENDATIONS")
print("=" * 60)

# Format-specific recommendations
print("\n🎯 PRECISION FORMAT RECOMMENDATIONS:")
print("-" * 40)

for fmt, rec in recommendations['format_recommendations'].items():
    print(f"\n{fmt.upper()}:")
    print(f"  • Memory Savings: {rec['memory_savings']}")
    print(f"  • Speed Improvement: {rec['speed_improvement']}")
    print(f"  • Numerical Stability: {rec['numerical_stability']}")
    print(f"  • Gradient Scaling Required: {rec['gradient_scaling_required']}")
    print(f"  • Tensor Core Support: {rec['tensor_core_support']}")
    print(f"  • Use Cases:")
    for use_case in rec['recommended_use_cases']:
        print(f"    - {use_case}")

# Hardware recommendations
print("\n🖥️ HARDWARE-SPECIFIC RECOMMENDATIONS:")
print("-" * 40)

for gpu, specs in recommendations['hardware_considerations'].items():
    print(f"\n{gpu}:")
    print(f"  • Recommended Format: {specs['recommended_format']}")
    print(f"  • FP16 Support: {specs['fp16_support']}")
    print(f"  • BF16 Support: {specs['bf16_support']}")
    print(f"  • Tensor Cores: {specs['tensor_cores']}")

# Production guidelines
print("\n🏭 PRODUCTION DEPLOYMENT GUIDELINES:")
print("-" * 40)

print("\nMonitoring:")
for guideline in recommendations['production_guidelines']['monitoring']:
    print(f"  • {guideline}")

print("\nBest Practices:")
for practice in recommendations['production_guidelines']['best_practices']:
    print(f"  • {practice}")

print("\n" + "=" * 60)
print("✅ Chapter 5: Mixed Precision Training Mastery Complete!")

print("\n📚 Key Learning Outcomes:")
print("  • Deep understanding of FP16, BF16, and emerging FP8 formats")
print("  • Advanced gradient scaling and numerical stability techniques")
print("  • Tensor Core optimization strategies")
print("  • Production-grade mixed precision training implementation")
print("  • Comprehensive performance analysis and monitoring")

print("\n🎓 Next Chapter: Advanced Inference Optimization")
print("Continue to Chapter 6 for deep dive into vLLM, continuous batching, and inference optimization!")