# Mixed Precision Training and AMP

**File Location:** `notebooks/04_performance_and_scaling/08_mixed_precision_amp.ipynb`

## Introduction

This notebook covers mixed precision training using Automatic Mixed Precision (AMP) in PyTorch Lightning. Learn to accelerate training, reduce memory usage, and maintain model accuracy while leveraging modern GPU capabilities.

## Mixed Precision Fundamentals

### Understanding Mixed Precision

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import Callback
import time
import numpy as np

class MixedPrecisionDemo(pl.LightningModule):
    """Model to demonstrate mixed precision training"""
    
    def __init__(self, input_size=512, hidden_size=1024, num_classes=10, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        # Larger model to see memory benefits
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size), 
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden_size, hidden_size // 2),
            nn.BatchNorm1d(hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden_size // 2, num_classes)
        )
        
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        
        # Track precision info
        self.precision_info = []
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        # Log precision information
        if batch_idx % 50 == 0:
            self._log_precision_info(x, logits, loss)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def _log_precision_info(self, inputs, outputs, loss):
        """Log information about tensor precisions"""
        info = {
            'step': self.global_step,
            'input_dtype': str(inputs.dtype),
            'output_dtype': str(outputs.dtype),
            'loss_dtype': str(loss.dtype),
            'input_range': [inputs.min().item(), inputs.max().item()],
            'output_range': [outputs.min().item(), outputs.max().item()],
            'loss_value': loss.item()
        }
        self.precision_info.append(info)
        
        # Log to Lightning
        self.log('input_is_fp16', inputs.dtype == torch.float16)
        self.log('output_is_fp16', outputs.dtype == torch.float16)
        self.log('loss_is_fp16', loss.dtype == torch.float16)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
        return [optimizer], [scheduler]

# Create larger synthetic dataset to see memory benefits
def create_large_dataset(num_samples=10000, input_size=512, num_classes=10):
    torch.manual_seed(42)
    X = torch.randn(num_samples, input_size)
    weights = torch.randn(input_size)
    logits = X @ weights
    y = torch.div(logits - logits.min(), (logits.max() - logits.min()) / (num_classes - 1), rounding_mode='floor').long()
    y = torch.clamp(y, 0, num_classes - 1)
    return X, y

X, y = create_large_dataset()
dataset = TensorDataset(X, y)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

print("✓ Large model and dataset created")
print(f"Model parameters: {sum(p.numel() for p in MixedPrecisionDemo().parameters()):,}")
```

### Training without Mixed Precision (Baseline)

```python
# Memory monitoring callback
class MemoryMonitorCallback(Callback):
    """Monitor GPU memory usage during training"""
    
    def __init__(self):
        self.memory_stats = []
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if hasattr(self, 'batch_start_time'):
            batch_time = time.time() - self.batch_start_time
            self.batch_times.append(batch_time)
            if batch_idx % 50 == 0:
                pl_module.log('batch_time_ms', batch_time * 1000)
    
    def get_average_times(self):
        avg_epoch = np.mean(self.epoch_times) if self.epoch_times else 0
        avg_batch = np.mean(self.batch_times) if self.batch_times else 0
        return avg_epoch, avg_batch

print("=== Baseline Training (FP32) ===")

# Train without mixed precision
model_fp32 = MixedPrecisionDemo()
memory_callback = MemoryMonitorCallback()
speed_callback = SpeedBenchmarkCallback()

trainer_fp32 = pl.Trainer(
    max_epochs=5,
    precision="32-true",  # Explicit FP32
    callbacks=[memory_callback, speed_callback],
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False
)

print("Training with FP32 precision...")
start_time = time.time()
trainer_fp32.fit(model_fp32, train_loader, val_loader)
fp32_total_time = time.time() - start_time

# Get baseline metrics
fp32_peak_alloc, fp32_peak_reserved = memory_callback.get_peak_memory()
fp32_avg_epoch, fp32_avg_batch = speed_callback.get_average_times()
fp32_final_acc = trainer_fp32.callback_metrics.get('val_acc', 0)

print(f"✓ FP32 Training completed in {fp32_total_time:.2f}s")
print(f"Peak memory: {fp32_peak_alloc:.2f}GB allocated, {fp32_peak_reserved:.2f}GB reserved")
print(f"Final accuracy: {fp32_final_acc:.4f}")
```

### Training with Mixed Precision (AMP)

```python
print("\n=== Mixed Precision Training (AMP) ===")

# Train with automatic mixed precision
model_amp = MixedPrecisionDemo()
memory_callback_amp = MemoryMonitorCallback()
speed_callback_amp = SpeedBenchmarkCallback()

trainer_amp = pl.Trainer(
    max_epochs=5,
    precision="16-mixed",  # Enable AMP
    callbacks=[memory_callback_amp, speed_callback_amp],
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False
)

print("Training with mixed precision (AMP)...")
start_time = time.time()
trainer_amp.fit(model_amp, train_loader, val_loader)
amp_total_time = time.time() - start_time

# Get AMP metrics
amp_peak_alloc, amp_peak_reserved = memory_callback_amp.get_peak_memory()
amp_avg_epoch, amp_avg_batch = speed_callback_amp.get_average_times()
amp_final_acc = trainer_amp.callback_metrics.get('val_acc', 0)

print(f"✓ AMP Training completed in {amp_total_time:.2f}s")
print(f"Peak memory: {amp_peak_alloc:.2f}GB allocated, {amp_peak_reserved:.2f}GB reserved")
print(f"Final accuracy: {amp_final_acc:.4f}")

# Compare results
print(f"\n=== Performance Comparison ===")
print(f"Training time - FP32: {fp32_total_time:.2f}s, AMP: {amp_total_time:.2f}s")
print(f"Speedup: {fp32_total_time/amp_total_time:.2f}x")
print(f"Memory reduction: {(fp32_peak_alloc - amp_peak_alloc)/fp32_peak_alloc*100:.1f}%")
print(f"Accuracy - FP32: {fp32_final_acc:.4f}, AMP: {amp_final_acc:.4f}")
print(f"Accuracy difference: {amp_final_acc - fp32_final_acc:.4f}")
```

## Advanced Mixed Precision Techniques

### Custom AMP Configuration

```python
class AdvancedAMPModel(pl.LightningModule):
    """Model with advanced AMP configuration"""
    
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        
        # Model with different layer types
        self.feature_extractor = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LayerNorm(1024),  # LayerNorm works well with AMP
            nn.GELU(),
            nn.Dropout(0.1),
        )
        
        self.transformer_block = nn.TransformerEncoderLayer(
            d_model=1024,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 10)
        )
        
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
        # Track gradient scaling
        self.grad_scale_history = []
        
    def forward(self, x):
        # Reshape for transformer (add sequence dimension)
        x = self.feature_extractor(x)
        x = x.unsqueeze(1)  # [batch, 1, features]
        x = self.transformer_block(x)
        x = x.squeeze(1)  # [batch, features]
        x = self.classifier(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        # Log gradient scaling information
        if batch_idx % 50 == 0:
            self._log_gradient_info()
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def _log_gradient_info(self):
        """Log gradient scaling information"""
        if hasattr(self.trainer, 'precision_plugin'):
            scaler = getattr(self.trainer.precision_plugin, 'scaler', None)
            if scaler is not None:
                scale = scaler.get_scale()
                self.grad_scale_history.append(scale)
                self.log('gradient_scale', scale)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, 
            max_lr=1e-3, 
            total_steps=self.trainer.estimated_stepping_batches
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step'
            }
        }

# Train advanced AMP model
print("\n=== Advanced AMP Configuration ===")

model_advanced = AdvancedAMPModel()

trainer_advanced = pl.Trainer(
    max_epochs=5,
    precision="16-mixed",
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False
)

print("Training advanced model with AMP...")
trainer_advanced.fit(model_advanced, train_loader, val_loader)

print(f"✓ Advanced AMP training completed")
print(f"Gradient scale history: {len(model_advanced.grad_scale_history)} updates")
if model_advanced.grad_scale_history:
    print(f"Final gradient scale: {model_advanced.grad_scale_history[-1]:.0f}")
```

### Manual Mixed Precision Control

```python
class ManualAMPModel(pl.LightningModule):
    """Model with manual AMP control for specific operations"""
    
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
        
        from torchmetrics import Accuracy
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        # Manual precision control for specific operations
        with torch.cuda.amp.autocast(enabled=False):  # Force FP32 for this block
            # Compute some operations that need high precision
            x_normalized = F.normalize(x, p=2, dim=1)
        
        # Regular forward pass with AMP
        logits = self(x_normalized)
        
        # Loss computation (usually handled by AMP automatically)
        loss = F.cross_entropy(logits, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        # Ensure validation is consistent
        with torch.cuda.amp.autocast():
            logits = self(x)
            loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_acc', self.val_acc, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

print("\n=== Manual AMP Control ===")

model_manual = ManualAMPModel()

trainer_manual = pl.Trainer(
    max_epochs=3,
    precision="16-mixed",
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False
)

print("Training with manual AMP control...")
trainer_manual.fit(model_manual, train_loader, val_loader)
print("✓ Manual AMP training completed")
```

## Debugging Mixed Precision Issues

### Loss Scaling and Gradient Issues

```python
class DebuggingAMPModel(pl.LightningModule):
    """Model for debugging AMP issues"""
    
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(), 
            nn.Linear(1024, 10)
        )
        
        # Debug tracking
        self.loss_history = []
        self.gradient_norms = []
        self.nan_count = 0
        self.inf_count = 0
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Debug loss values
        self.loss_history.append(loss.item())
        
        # Check for NaN/Inf
        if torch.isnan(loss):
            self.nan_count += 1
            print(f"⚠️ NaN loss detected at step {self.global_step}")
        
        if torch.isinf(loss):
            self.inf_count += 1
            print(f"⚠️ Inf loss detected at step {self.global_step}")
        
        # Log debugging info
        self.log('train_loss', loss, on_step=True)
        self.log('nan_count', float(self.nan_count))
        self.log('inf_count', float(self.inf_count))
        
        return loss
    
    def on_before_optimizer_step(self, optimizer, optimizer_idx):
        # Check gradient norms before optimization step
        total_norm = 0
        nan_grads = 0
        inf_grads = 0
        
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
                
                if torch.isnan(p.grad).any():
                    nan_grads += 1
                if torch.isinf(p.grad).any():
                    inf_grads += 1
        
        total_norm = total_norm ** (1. / 2)
        self.gradient_norms.append(total_norm)
        
        # Log gradient statistics
        self.log('grad_norm', total_norm)
        self.log('nan_grads', float(nan_grads))
        self.log('inf_grads', float(inf_grads))
        
        # Warning for problematic gradients
        if total_norm > 100:
            print(f"⚠️ Large gradient norm: {total_norm:.2f}")
        
        if nan_grads > 0 or inf_grads > 0:
            print(f"⚠️ Problematic gradients: {nan_grads} NaN, {inf_grads} Inf")
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_acc', acc, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        # Use different optimizers to test stability
        return torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-4)

print("\n=== Debugging AMP Issues ===")

model_debug = DebuggingAMPModel()

# Train with debugging enabled
trainer_debug = pl.Trainer(
    max_epochs=3,
    precision="16-mixed",
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False,
    gradient_clip_val=1.0,  # Gradient clipping for stability
    gradient_clip_algorithm="norm"
)

print("Training with AMP debugging...")
trainer_debug.fit(model_debug, train_loader, val_loader)

print(f"✓ Debug training completed")
print(f"NaN losses: {model_debug.nan_count}")
print(f"Inf losses: {model_debug.inf_count}")
print(f"Max gradient norm: {max(model_debug.gradient_norms) if model_debug.gradient_norms else 0:.4f}")
print(f"Min gradient norm: {min(model_debug.gradient_norms) if model_debug.gradient_norms else 0:.4f}")
```

## Best Practices and Recommendations

### Production AMP Setup

```python
class ProductionAMPModel(pl.LightningModule):
    """Production-ready model with AMP best practices"""
    
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        
        # Model architecture optimized for AMP
        self.backbone = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LayerNorm(1024),  # Preferred over BatchNorm for AMP
            nn.GELU(),           # GELU works well with mixed precision
            nn.Dropout(0.1),
        )
        
        self.head = nn.Sequential(
            nn.Linear(1024, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 10)
        )
        
        from torchmetrics import Accuracy, MetricCollection
        metrics = MetricCollection({
            'accuracy': Accuracy(task="multiclass", num_classes=10),
        })
        
        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        
    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.train_metrics(preds, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log_dict(self.train_metrics, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_metrics(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log_dict(self.val_metrics, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        # Use AdamW with weight decay for better regularization
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=1e-3, 
            weight_decay=1e-4,
            eps=1e-4  # Slightly larger eps for numerical stability
        )
        
        # Warmup + cosine annealing scheduler
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=1e-3,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1  # 10% warmup
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step'
            }
        }

def create_production_trainer():
    """Create production-ready trainer with AMP"""
    from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
    
    return pl.Trainer(
        # AMP configuration
        precision="16-mixed",
        
        # Training configuration  
        max_epochs=20,
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
        
        # Callbacks
        callbacks=[
            ModelCheckpoint(
                monitor='val_accuracy',
                mode='max',
                save_top_k=3,
                filename='{epoch:02d}-{val_accuracy:.4f}'
            ),
            EarlyStopping(
                monitor='val_loss',
                patience=5,
                mode='min'
            )
        ],
        
        # Logging
        log_every_n_steps=50,
        enable_checkpointing=False,  # For demo
        logger=False,  # For demo
        enable_progress_bar=True
    )

print("\n=== Production AMP Setup ===")

model_production = ProductionAMPModel()
trainer_production = create_production_trainer()

print("Training production model with AMP...")
trainer_production.fit(model_production, train_loader, val_loader)

final_acc = trainer_production.callback_metrics.get('val_accuracy', 0)
print(f"✓ Production training completed with final accuracy: {final_acc:.4f}")

print("\n=== AMP Best Practices Summary ===")
print("✅ Use LayerNorm instead of BatchNorm when possible")
print("✅ Use GELU activation for better numerical stability") 
print("✅ Set larger eps values in optimizers (1e-4 instead of 1e-8)")
print("✅ Enable gradient clipping to prevent overflow")
print("✅ Monitor gradient scaling and adjust if needed")
print("✅ Use appropriate learning rate warmup")
print("✅ Prefer AdamW over SGD for mixed precision training")
print("✅ Test thoroughly and compare with FP32 baseline")
```

## Platform-Specific Optimizations

### Different Precision Modes

```python
# Test different precision modes available in PyTorch Lightning
precision_modes = [
    ("32-true", "Full FP32 precision"),
    ("16-mixed", "Automatic Mixed Precision with FP16"),
    ("bf16-mixed", "Mixed precision with BFloat16"),
    ("64-true", "Double precision FP64")
]

class PrecisionTestModel(pl.LightningModule):
    """Simple model to test different precision modes"""
    
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
        
        from torchmetrics import Accuracy
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss)
        self.log('val_acc', self.val_acc)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Create smaller test dataset
def create_test_dataset(num_samples=1000, input_size=128):
    torch.manual_seed(42)
    X = torch.randn(num_samples, input_size)
    y = torch.randint(0, 10, (num_samples,))
    return TensorDataset(X, y)

test_dataset = create_test_dataset()
train_size = int(0.8 * len(test_dataset))
val_size = len(test_dataset) - train_size
train_test, val_test = torch.utils.data.random_split(test_dataset, [train_size, val_size])

train_test_loader = DataLoader(train_test, batch_size=64, shuffle=True)
val_test_loader = DataLoader(val_test, batch_size=64, shuffle=False)

print("=== Testing Different Precision Modes ===")

precision_results = {}

for precision_mode, description in precision_modes:
    print(f"\nTesting {precision_mode}: {description}")
    
    try:
        model = PrecisionTestModel()
        
        trainer = pl.Trainer(
            max_epochs=2,
            precision=precision_mode,
            enable_checkpointing=False,
            logger=False,
            enable_progress_bar=False
        )
        
        start_time = time.time()
        trainer.fit(model, train_test_loader, val_test_loader)
        end_time = time.time()
        
        # Record results
        final_acc = trainer.callback_metrics.get('val_acc', 0)
        training_time = end_time - start_time
        
        precision_results[precision_mode] = {
            'accuracy': float(final_acc),
            'training_time': training_time,
            'status': 'SUCCESS'
        }
        
        print(f"✓ {precision_mode}: Accuracy={final_acc:.4f}, Time={training_time:.2f}s")
        
    except Exception as e:
        precision_results[precision_mode] = {
            'accuracy': 0.0,
            'training_time': 0.0,
            'status': f'FAILED: {str(e)[:50]}'
        }
        print(f"❌ {precision_mode}: Failed - {e}")

# Summary of precision mode comparison
print(f"\n=== Precision Mode Comparison ===")
for mode, results in precision_results.items():
    if results['status'] == 'SUCCESS':
        print(f"{mode:12s}: Acc={results['accuracy']:.4f}, Time={results['training_time']:.2f}s")
    else:
        print(f"{mode:12s}: {results['status']}")
```

### Hardware Detection and Optimization

```python
def detect_hardware_capabilities():
    """Detect hardware capabilities for mixed precision"""
    info = {
        'cuda_available': torch.cuda.is_available(),
        'cuda_version': torch.version.cuda if torch.cuda.is_available() else None,
        'gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
        'tensor_cores': False,
        'bf16_support': False,
        'amp_supported': False
    }
    
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        info['gpu_name'] = gpu_name
        
        # Check for Tensor Core support (V100, T4, RTX series, A100, etc.)
        tensor_core_gpus = ['V100', 'T4', 'RTX', 'A100', 'A6000', 'A5000', 'A4000']
        info['tensor_cores'] = any(gpu in gpu_name for gpu in tensor_core_gpus)
        
        # Check compute capability
        capability = torch.cuda.get_device_capability(0)
        info['compute_capability'] = f"{capability[0]}.{capability[1]}"
        
        # Tensor Cores available from compute capability 7.0+
        if capability[0] >= 7:
            info['tensor_cores'] = True
            info['amp_supported'] = True
        
        # BF16 support (A100 and newer, compute capability 8.0+)
        if capability[0] >= 8:
            info['bf16_support'] = True
    
    return info

# Hardware detection
print("=== Hardware Capability Detection ===")
hw_info = detect_hardware_capabilities()

print(f"CUDA Available: {hw_info['cuda_available']}")
if hw_info['cuda_available']:
    print(f"CUDA Version: {hw_info['cuda_version']}")
    print(f"GPU Count: {hw_info['gpu_count']}")
    print(f"GPU Name: {hw_info.get('gpu_name', 'Unknown')}")
    print(f"Compute Capability: {hw_info.get('compute_capability', 'Unknown')}")
    print(f"Tensor Cores: {hw_info['tensor_cores']}")
    print(f"AMP Supported: {hw_info['amp_supported']}")
    print(f"BF16 Supported: {hw_info['bf16_support']}")
    
    # Recommendations based on hardware
    print(f"\n=== Recommendations ===")
    if hw_info['tensor_cores']:
        print("✓ Use precision='16-mixed' for optimal performance")
        if hw_info['bf16_support']:
            print("✓ Consider precision='bf16-mixed' for numerical stability")
    else:
        print("⚠ Mixed precision may not provide significant speedup")
        print("  Consider using precision='32-true' for compatibility")
else:
    print("⚠ CUDA not available - mixed precision training not supported")
```

## Memory Optimization Strategies

### Memory-Efficient Training Patterns

```python
class MemoryEfficientModel(pl.LightningModule):
    """Model with memory optimization techniques"""
    
    def __init__(self, use_checkpointing=True, use_efficient_attention=True):
        super().__init__()
        self.save_hyperparameters()
        
        # Large model to demonstrate memory usage
        hidden_size = 2048
        
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size if i > 0 else 512, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.GELU(),
                nn.Dropout(0.1),
            ) for i in range(6)
        ])
        
        self.classifier = nn.Linear(hidden_size, 10)
        
        from torchmetrics import Accuracy
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
    def forward(self, x):
        # Gradient checkpointing for memory efficiency
        if self.hparams.use_checkpointing and self.training:
            from torch.utils.checkpoint import checkpoint
            
            for layer in self.layers:
                x = checkpoint(layer, x, use_reentrant=False)
        else:
            for layer in self.layers:
                x = layer(x)
        
        return self.classifier(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        # Memory monitoring
        if torch.cuda.is_available() and batch_idx % 100 == 0:
            memory_before = torch.cuda.memory_allocated() / 1024**3
            
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        if torch.cuda.is_available() and batch_idx % 100 == 0:
            memory_after = torch.cuda.memory_allocated() / 1024**3
            self.log('memory_usage_gb', memory_after)
            self.log('memory_increase_mb', (memory_after - memory_before) * 1024)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_acc', self.val_acc, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        # Use more memory-efficient optimizer settings
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=1e-4,  # Lower LR for stability
            weight_decay=1e-4,
            eps=1e-4,  # Larger epsilon for mixed precision
            betas=(0.9, 0.95)  # Slightly different betas
        )
        return optimizer

print("\n=== Memory Optimization Strategies ===")

# Test with and without gradient checkpointing
memory_configs = [
    ("Standard", False, "32-true"),
    ("Standard + AMP", False, "16-mixed"),
    ("Checkpointing + AMP", True, "16-mixed")
]

memory_results = {}

for config_name, use_checkpointing, precision in memory_configs:
    print(f"\nTesting: {config_name}")
    
    try:
        model = MemoryEfficientModel(use_checkpointing=use_checkpointing)
        memory_callback = MemoryMonitorCallback()
        
        trainer = pl.Trainer(
            max_epochs=2,
            precision=precision,
            callbacks=[memory_callback],
            enable_checkpointing=False,
            logger=False,
            enable_progress_bar=False
        )
        
        trainer.fit(model, train_test_loader, val_test_loader)
        
        peak_alloc, peak_reserved = memory_callback.get_peak_memory()
        memory_results[config_name] = {
            'peak_allocated': peak_alloc,
            'peak_reserved': peak_reserved,
            'success': True
        }
        
        print(f"✓ Peak memory: {peak_alloc:.2f}GB allocated, {peak_reserved:.2f}GB reserved")
        
    except Exception as e:
        memory_results[config_name] = {
            'peak_allocated': 0,
            'peak_reserved': 0,
            'success': False,
            'error': str(e)[:100]
        }
        print(f"❌ Failed: {e}")

# Memory optimization summary
print(f"\n=== Memory Usage Comparison ===")
baseline_memory = None
for config, results in memory_results.items():
    if results['success']:
        allocated = results['peak_allocated']
        if baseline_memory is None:
            baseline_memory = allocated
            reduction = 0
        else:
            reduction = (baseline_memory - allocated) / baseline_memory * 100
        
        print(f"{config:20s}: {allocated:.2f}GB ({reduction:+.1f}% vs baseline)")
```

## Gradient Accumulation with Mixed Precision

### Accumulation Strategies

```python
class GradientAccumulationModel(pl.LightningModule):
    """Model demonstrating gradient accumulation with mixed precision"""
    
    def __init__(self, effective_batch_size=512, actual_batch_size=64):
        super().__init__()
        self.save_hyperparameters()
        
        # Calculate accumulation steps
        self.accumulation_steps = effective_batch_size // actual_batch_size
        
        self.model = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LayerNorm(1024),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(512, 10)
        )
        
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
        # Track gradient statistics
        self.grad_norms = []
        self.accumulated_losses = []
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Scale loss by accumulation steps for proper gradient averaging
        scaled_loss = loss / self.accumulation_steps
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        # Track accumulated loss
        self.accumulated_losses.append(loss.item())
        
        # Log every accumulation cycle
        if (batch_idx + 1) % self.accumulation_steps == 0:
            avg_loss = sum(self.accumulated_losses[-self.accumulation_steps:]) / self.accumulation_steps
            self.log('train_loss', avg_loss, on_step=True, prog_bar=True)
            self.accumulated_losses.clear()
        
        self.log('train_acc', self.train_acc, on_epoch=True)
        
        return scaled_loss
    
    def on_before_optimizer_step(self, optimizer, optimizer_idx):
        # Monitor gradient norms
        total_norm = 0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
        
        self.grad_norms.append(total_norm)
        self.log('grad_norm', total_norm, on_step=True)
        
        # Check for gradient scaling issues
        if hasattr(self.trainer, 'precision_plugin'):
            scaler = getattr(self.trainer.precision_plugin, 'scaler', None)
            if scaler is not None:
                self.log('grad_scale', scaler.get_scale(), on_step=True)
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        # Adjust learning rate for effective batch size
        base_lr = 1e-3
        scaled_lr = base_lr * (self.hparams.effective_batch_size / 64)  # Scale with batch size
        
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=scaled_lr,
            weight_decay=1e-4
        )
        
        return optimizer

print("\n=== Gradient Accumulation with Mixed Precision ===")

# Test different accumulation strategies
accumulation_configs = [
    (128, 64, 2),   # Small accumulation
    (256, 64, 4),   # Medium accumulation  
    (512, 64, 8),   # Large accumulation
]

for effective_batch, actual_batch, expected_steps in accumulation_configs:
    print(f"\nTesting accumulation: {effective_batch} effective, {actual_batch} actual ({expected_steps} steps)")
    
    model = GradientAccumulationModel(
        effective_batch_size=effective_batch,
        actual_batch_size=actual_batch
    )
    
    trainer = pl.Trainer(
        max_epochs=2,
        precision="16-mixed",
        accumulate_grad_batches=expected_steps,
        enable_checkpointing=False,
        logger=False,
        enable_progress_bar=False
    )
    
    trainer.fit(model, train_test_loader, val_test_loader)
    
    avg_grad_norm = np.mean(model.grad_norms) if model.grad_norms else 0
    print(f"✓ Average gradient norm: {avg_grad_norm:.4f}")
    print(f"✓ Accumulation steps configured: {model.accumulation_steps}")
```

## Troubleshooting Common Issues

### NaN/Inf Detection and Recovery

```python
class RobustAMPModel(pl.LightningModule):
    """Model with comprehensive NaN/Inf detection and recovery"""
    
    def __init__(self):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 10)
        )
        
        # Tracking for debugging
        self.loss_history = []
        self.gradient_issues = {
            'nan_losses': 0,
            'inf_losses': 0,
            'nan_gradients': 0,
            'inf_gradients': 0,
            'large_gradients': 0,
            'scale_adjustments': 0
        }
        
        from torchmetrics import Accuracy
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Comprehensive loss validation
        self._validate_loss(loss, batch_idx)
        
        self.loss_history.append(loss.item())
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        
        return loss
    
    def _validate_loss(self, loss, batch_idx):
        """Validate loss for NaN/Inf issues"""
        if torch.isnan(loss):
            self.gradient_issues['nan_losses'] += 1
            print(f"🚨 NaN loss at step {self.global_step}, batch {batch_idx}")
            self.log('nan_losses', float(self.gradient_issues['nan_losses']))
            
        if torch.isinf(loss):
            self.gradient_issues['inf_losses'] += 1
            print(f"🚨 Inf loss at step {self.global_step}, batch {batch_idx}")
            self.log('inf_losses', float(self.gradient_issues['inf_losses']))
            
        if loss.item() > 100:  # Unusually large loss
            print(f"⚠️ Large loss detected: {loss.item():.2f} at step {self.global_step}")
            
    def on_before_optimizer_step(self, optimizer, optimizer_idx):
        # Comprehensive gradient validation
        self._validate_gradients()
        
    def _validate_gradients(self):
        """Validate gradients for numerical issues"""
        total_norm = 0
        nan_params = 0
        inf_params = 0
        
        for name, param in self.named_parameters():
            if param.grad is not None:
                # Check for NaN/Inf in gradients
                if torch.isnan(param.grad).any():
                    nan_params += 1
                    if nan_params <= 3:  # Limit spam
                        print(f"🚨 NaN gradient in {name}")
                        
                if torch.isinf(param.grad).any():
                    inf_params += 1
                    if inf_params <= 3:  # Limit spam
                        print(f"🚨 Inf gradient in {name}")
                
                # Calculate gradient norm
                param_norm = param.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        
        total_norm = total_norm ** (1. / 2)
        
        # Update tracking
        if nan_params > 0:
            self.gradient_issues['nan_gradients'] += 1
        if inf_params > 0:
            self.gradient_issues['inf_gradients'] += 1
        if total_norm > 100:
            self.gradient_issues['large_gradients'] += 1
            print(f"⚠️ Large gradient norm: {total_norm:.2f}")
        
        # Log gradient statistics
        self.log('grad_norm', total_norm)
        self.log('nan_gradients', float(self.gradient_issues['nan_gradients']))
        self.log('inf_gradients', float(self.gradient_issues['inf_gradients']))
        
        # Check gradient scaler
        if hasattr(self.trainer, 'precision_plugin'):
            scaler = getattr(self.trainer.precision_plugin, 'scaler', None)
            if scaler is not None:
                current_scale = scaler.get_scale()
                self.log('gradient_scale', current_scale)
                
                # Detect scale adjustments (indicates gradient overflow)
                if hasattr(self, '_last_scale'):
                    if current_scale < self._last_scale:
                        self.gradient_issues['scale_adjustments'] += 1
                        print(f"📉 Gradient scale reduced: {self._last_scale} -> {current_scale}")
                        
                self._last_scale = current_scale
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_acc', self.val_acc, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        # Conservative optimizer settings for stability
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=1e-4,  # Lower learning rate
            weight_decay=1e-4,
            eps=1e-4,  # Larger epsilon for numerical stability
            betas=(0.9, 0.95)  # More conservative momentum
        )
        return optimizer
    
    def get_debugging_summary(self):
        """Get summary of numerical issues encountered"""
        return {
            'total_steps': self.global_step,
            'loss_history_length': len(self.loss_history),
            'numerical_issues': self.gradient_issues.copy(),
            'average_loss': np.mean(self.loss_history[-100:]) if self.loss_history else 0,
            'loss_std': np.std(self.loss_history[-100:]) if len(self.loss_history) > 1 else 0
        }

print("\n=== Robust AMP Training with Issue Detection ===")

# Train robust model with comprehensive debugging
model_robust = RobustAMPModel()

trainer_robust = pl.Trainer(
    max_epochs=5,
    precision="16-mixed",
    gradient_clip_val=1.0,  # Gradient clipping for stability
    gradient_clip_algorithm="norm",
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=False
)

print("Training with comprehensive debugging...")
trainer_robust.fit(model_robust, train_test_loader, val_test_loader)

# Get debugging summary
debug_summary = model_robust.get_debugging_summary()

print(f"\n=== Training Debug Summary ===")
print(f"Total training steps: {debug_summary['total_steps']}")
print(f"Average loss (last 100): {debug_summary['average_loss']:.4f}")
print(f"Loss std deviation: {debug_summary['loss_std']:.4f}")
print(f"\nNumerical Issues:")
for issue, count in debug_summary['numerical_issues'].items():
    if count > 0:
        print(f"  {issue}: {count}")
    
if sum(debug_summary['numerical_issues'].values()) == 0:
    print("✅ No numerical issues detected!")
else:
    print("⚠️ Some numerical issues detected - see logs above")
```

## Performance Profiling and Optimization

### Detailed Performance Analysis

```python
class ProfilingAMPModel(pl.LightningModule):
    """Model with detailed performance profiling"""
    
    def __init__(self):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(512, 2048),
            nn.LayerNorm(2048),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(2048, 1024),
            nn.LayerNorm(1024), 
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(1024, 10)
        )
        
        # Profiling data
        self.profiling_data = {
            'forward_times': [],
            'backward_times': [],
            'optimizer_times': [],
            'memory_usage': []
        }
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        # Profile forward pass
        if batch_idx % 20 == 0:
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            forward_start = time.time()
        
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        if batch_idx % 20 == 0:
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            forward_time = time.time() - forward_start
            self.profiling_data['forward_times'].append(forward_time)
            
            # Memory usage
            if torch.cuda.is_available():
                memory_mb = torch.cuda.memory_allocated() / 1024**2
                self.profiling_data['memory_usage'].append(memory_mb)
                self.log('memory_mb', memory_mb)
            
            self.log('forward_time_ms', forward_time * 1000)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss
    
    def on_before_optimizer_step(self, optimizer, optimizer_idx):
        # Profile optimizer step
        if self.global_step % 20 == 0:
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            self._optimizer_start_time = time.time()
    
    def on_after_optimizer_step(self, optimizer, optimizer_idx):
        # Record optimizer timing
        if self.global_step % 20 == 0 and hasattr(self, '_optimizer_start_time'):
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            optimizer_time = time.time() - self._optimizer_start_time
            self.profiling_data['optimizer_times'].append(optimizer_time)
            self.log('optimizer_time_ms', optimizer_time * 1000)
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_acc', acc, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)
    
    def get_profiling_summary(self):
        """Get comprehensive profiling summary"""
        summary = {}
        
        for key, times in self.profiling_data.items():
            if times:
                summary[key] = {
                    'mean_ms': np.mean(times) * 1000,
                    'std_ms': np.std(times) * 1000,
                    'min_ms': np.min(times) * 1000,
                    'max_ms': np.max(times) * 1000,
                    'count': len(times)
                }
        
        return summary

print("\n=== Performance Profiling ===")

# Profile different precision modes
profiling_configs = [
    ("FP32", "32-true"),
    ("AMP", "16-mixed")
]

profiling_results = {}

for config_name, precision in profiling_configs:
    print(f"\nProfiling {config_name} training...")
    
    model_profile = ProfilingAMPModel()
    
    trainer_profile = pl.Trainer(
        max_epochs=2,
        precision=precision,
        enable_checkpointing=False,
        logger=False,
        enable_progress_bar=False
    )
    
    # Train with profiling
    start_time = time.time()
    trainer_profile.fit(model_profile, train_test_loader, val_test_loader)
    total_time = time.time() - start_time
    
    # Get profiling results
    profile_summary = model_profile.get_profiling_summary()
    profile_summary['total_time_sec'] = total_time
    profiling_results[config_name] = profile_summary
    
    print(f"✓ {config_name} profiling completed in {total_time:.2f}s")

# Compare profiling results
print(f"\n=== Performance Comparison ===")
for config_name, results in profiling_results.items():
    print(f"\n{config_name} Results:")
    print(f"  Total time: {results['total_time_sec']:.2f}s")
    
    if 'forward_times' in results:
        forward = results['forward_times']
        print(f"  Forward pass: {forward['mean_ms']:.2f}±{forward['std_ms']:.2f}ms")
        
    if 'optimizer_times' in results:
        optimizer = results['optimizer_times']
        print(f"  Optimizer step: {optimizer['mean_ms']:.2f}±{optimizer['std_ms']:.2f}ms")
        
    if 'memory_usage' in results:
        memory = results['memory_usage']
        print(f"  Memory usage: {memory['mean_ms']:.0f}MB (avg)")

# Calculate speedup
if 'FP32' in profiling_results and 'AMP' in profiling_results:
    fp32_time = profiling_results['FP32']['total_time_sec']
    amp_time = profiling_results['AMP']['total_time_sec']
    speedup = fp32_time / amp_time
    print(f"\n🚀 AMP Speedup: {speedup:.2f}x")
```

## Summary

### Key Takeaways and Best Practices

This comprehensive notebook covered mixed precision training and AMP optimization in PyTorch Lightning:

**Mixed Precision Fundamentals:**
- **Memory Efficiency**: ~50% reduction in GPU memory usage with FP16/BF16
- **Speed Improvements**: 1.5-2x training speedup on Tensor Core GPUs
- **Maintained Accuracy**: Proper loss scaling preserves model performance
- **Hardware Requirements**: Tensor Cores (V100+, RTX 20/30/40 series, A100+) for optimal benefits

**Implementation Strategies:**
- **Automatic Mixed Precision**: Lightning's `precision="16-mixed"` handles complexity automatically
- **BFloat16 Support**: `precision="bf16-mixed"` on newer hardware for better numerical stability
- **Manual Control**: Fine-grained precision control for specific operations when needed
- **Gradient Accumulation**: Proper scaling with mixed precision for large effective batch sizes

**Performance Optimization:**
- **Memory Strategies**: Gradient checkpointing, efficient attention, model parallelism
- **Hardware Detection**: Automatic capability detection and optimization recommendations
- **Profiling Tools**: Comprehensive performance analysis and bottleneck identification
- **Scaling Strategies**: Dynamic loss scaling and gradient clipping for numerical stability

**Common Issues and Solutions:**
- **NaN/Inf Values**: Comprehensive detection, logging, and recovery mechanisms
- **Gradient Underflow**: Automatic loss scaling handles small gradient magnitudes
- **Convergence Issues**: Conservative optimizer settings and learning rate schedules
- **Memory Overflow**: Gradient accumulation and checkpointing for large models

**Production Best Practices:**
- **Architecture Choices**: LayerNorm over BatchNorm, GELU activations, larger epsilon values
- **Training Stability**: Gradient clipping, warmup schedules, conservative learning rates
- **Monitoring**: Comprehensive logging of loss values, gradient norms, and memory usage
- **Validation**: Always compare with FP32 baseline for accuracy verification

**Platform Considerations:**
- **GPU Architecture**: Tensor Core utilization for maximum benefit
- **CUDA Versions**: Compatibility with mixed precision features
- **Memory Capacity**: Balance between model size and batch size
- **Multi-GPU**: Proper scaling across distributed training setups

**Development Workflow:**
1. **Baseline Establishment**: Train with FP32 to establish accuracy targets
2. **Hardware Verification**: Check Tensor Core availability and capabilities  
3. **Gradual Implementation**: Start with `precision="16-mixed"` and monitor closely
4. **Performance Profiling**: Measure speedup and memory improvements
5. **Stability Testing**: Extended training runs to verify convergence
6. **Production Deployment**: Comprehensive monitoring and fallback mechanisms

**Future Considerations:**
- **8-bit Training**: Emerging quantization techniques for further memory reduction
- **Model Compilation**: Integration with `torch.compile` for additional speedups
- **Custom Precision**: Task-specific precision policies for optimal performance
- **Hardware Evolution**: Adaptation to new GPU architectures and features

Best practices:
- Always compare with FP32 baseline for accuracy verification
- Use LayerNorm over BatchNorm for better AMP compatibility
- Enable gradient clipping for numerical stability
- Monitor gradient scaling and loss values during training
- Use appropriate learning rate schedules with warmup
- Test extensively before production deployment

Hardware requirements:
- Tensor Cores (V100, T4, RTX 20/30/40 series, A100)
- CUDA Compute Capability 7.0+
- Sufficient GPU memory for model and gradients

Mixed precision training is essential for modern deep learning workflows, providing significant performance benefits while maintaining model quality. The key is careful implementation with comprehensive monitoring and validation to ensure both speed and accuracy improvements.