# Devices, Precision, and Training Strategies

**File Location:** `notebooks/05_strategies_and_ddp/11_devices_precision_strategies.ipynb`

## Introduction

This notebook covers Lightning's device management, precision strategies, and training approaches. Learn to efficiently use CPUs, GPUs, and handle different precision modes for optimal training performance.

## Device Management

### Device Detection and Configuration

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

class DeviceAwareModel(pl.LightningModule):
    """Model that demonstrates device awareness"""
    
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        # Log device information occasionally
        if batch_idx == 0:
            self.log('device_type', float(x.device.type == 'cuda'))
            self.log('device_index', float(x.device.index) if x.device.index is not None else -1)
        
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Create dataset
def create_device_dataset(num_samples=2000, input_size=128, num_classes=10):
    torch.manual_seed(42)
    X = torch.randn(num_samples, input_size)
    weights = torch.randn(input_size)
    logits = X @ weights
    y = torch.div(logits - logits.min(), (logits.max() - logits.min()) / (num_classes - 1), rounding_mode='floor').long()
    y = torch.clamp(y, 0, num_classes - 1)
    return X, y

X, y = create_device_dataset()
dataset = TensorDataset(X, y)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Check available devices
def check_available_devices():
    """Check and report available computing devices"""
    
    print("=== Device Availability ===")
    print(f"CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"CUDA devices: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            print(f"  GPU {i}: {props.name}")
            print(f"    Memory: {props.total_memory / 1024**3:.1f} GB")
            print(f"    Compute Capability: {props.major}.{props.minor}")
    
    # CPU info
    print(f"CPU cores: {torch.get_num_threads()}")
    
    # MPS (Apple Silicon) support
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        print("Apple MPS available: Yes")
    else:
        print("Apple MPS available: No")

check_available_devices()
print("✓ Device setup completed")
```

### Training on Different Devices

```python
def train_on_different_devices():
    """Compare training on different available devices"""
    
    # Define device configurations to test
    device_configs = []
    
    # CPU training
    device_configs.append({
        'name': 'CPU',
        'accelerator': 'cpu',
        'devices': 1,
        'precision': '32-true'
    })
    
    # GPU training if available
    if torch.cuda.is_available():
        device_configs.append({
            'name': 'Single GPU',
            'accelerator': 'gpu',
            'devices': 1,
            'precision': '16-mixed'
        })
        
        if torch.cuda.device_count() > 1:
            device_configs.append({
                'name': 'Multi GPU',
                'accelerator': 'gpu', 
                'devices': 2,
                'precision': '16-mixed'
            })
    
    # MPS (Apple Silicon) if available
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device_configs.append({
            'name': 'Apple MPS',
            'accelerator': 'mps',
            'devices': 1,
            'precision': '32-true'
        })
    
    results = {}
    
    for config in device_configs:
        print(f"\n=== Training on {config['name']} ===")
        
        model = DeviceAwareModel()
        
        try:
            trainer = pl.Trainer(
                max_epochs=2,
                accelerator=config['accelerator'],
                devices=config['devices'],
                precision=config['precision'],
                enable_checkpointing=False,
                logger=False,
                enable_progress_bar=False,
                limit_train_batches=30,
                limit_val_batches=10
            )
            
            import time
            start_time = time.time()
            trainer.fit(model, train_loader, val_loader)
            training_time = time.time() - start_time
            
            final_acc = trainer.callback_metrics.get('val_acc', 0)
            
            results[config['name']] = {
                'time': training_time,
                'accuracy': final_acc,
                'config': config
            }
            
            print(f"Training time: {training_time:.2f}s")
            print(f"Final accuracy: {final_acc:.4f}")
            
        except Exception as e:
            print(f"Failed to train on {config['name']}: {e}")
            results[config['name']] = {'error': str(e)}
    
    return results

device_results = train_on_different_devices()

# Display comparison
print(f"\n📊 Device Performance Comparison:")
successful_runs = {k: v for k, v in device_results.items() if 'error' not in v}

if successful_runs:
    fastest_time = min(result['time'] for result in successful_runs.values())
    
    for device, result in successful_runs.items():
        speedup = fastest_time / result['time']
        print(f"{device:12} | Time: {result['time']:5.2f}s | Speedup: {speedup:.2f}x | Acc: {result['accuracy']:.4f}")
```

## Precision Strategies

### Comparing Different Precision Modes

```python
class PrecisionTestModel(pl.LightningModule):
    """Model to test different precision strategies"""
    
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        
        # Model with operations sensitive to precision
        self.model = nn.Sequential(
            nn.Linear(128, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(256, 10)
        )
        
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
        # Track precision-related metrics
        self.gradient_scales = []
        self.loss_values = []
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Track loss values for precision analysis
        self.loss_values.append(loss.item())
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        # Log precision-related info
        if batch_idx % 25 == 0:
            self.log('loss_magnitude', torch.log10(loss.abs() + 1e-8))
            
            # Log gradient scale if using AMP
            if hasattr(self.trainer, 'precision_plugin') and hasattr(self.trainer.precision_plugin, 'scaler'):
                scaler = self.trainer.precision_plugin.scaler
                if scaler is not None:
                    scale = scaler.get_scale()
                    self.gradient_scales.append(scale)
                    self.log('gradient_scale', scale)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        # Use optimizer settings that work well with different precisions
        return torch.optim.AdamW(
            self.parameters(),
            lr=1e-3,
            weight_decay=1e-4,
            eps=1e-4  # Larger eps for numerical stability
        )

def compare_precision_strategies():
    """Compare different precision strategies"""
    
    if not torch.cuda.is_available():
        print("⚠️ CUDA not available, skipping precision comparison")
        return {}
    
    precision_configs = [
        {'name': 'FP32', 'precision': '32-true'},
        {'name': 'Mixed Precision', 'precision': '16-mixed'},
        {'name': 'BFloat16', 'precision': 'bf16-mixed'} if torch.cuda.is_bf16_supported() else None
    ]
    
    # Remove None entries
    precision_configs = [config for config in precision_configs if config is not None]
    
    results = {}
    
    for config in precision_configs:
        print(f"\n=== Testing {config['name']} Precision ===")
        
        model = PrecisionTestModel()
        
        try:
            trainer = pl.Trainer(
                max_epochs=3,
                accelerator='gpu',
                devices=1,
                precision=config['precision'],
                enable_checkpointing=False,
                logger=False,
                enable_progress_bar=False,
                limit_train_batches=50,
                limit_val_batches=15
            )
            
            import time
            start_time = time.time()
            trainer.fit(model, train_loader, val_loader)
            training_time = time.time() - start_time
            
            final_acc = trainer.callback_metrics.get('val_acc', 0)
            final_loss = trainer.callback_metrics.get('val_loss', float('inf'))
            
            # Get precision-specific metrics
            avg_loss = np.mean(model.loss_values) if model.loss_values else 0
            loss_std = np.std(model.loss_values) if model.loss_values else 0
            
            results[config['name']] = {
                'time': training_time,
                'accuracy': final_acc,
                'final_loss': final_loss,
                'avg_loss': avg_loss,
                'loss_std': loss_std,
                'gradient_scales': len(model.gradient_scales)
            }
            
            print(f"Training time: {training_time:.2f}s")
            print(f"Final accuracy: {final_acc:.4f}")
            print(f"Loss stability (std): {loss_std:.6f}")
            
        except Exception as e:
            print(f"Failed with {config['name']}: {e}")
            results[config['name']] = {'error': str(e)}
    
    return results

precision_results = compare_precision_strategies()

# Display precision comparison
print(f"\n📊 Precision Strategy Comparison:")
successful_runs = {k: v for k, v in precision_results.items() if 'error' not in v}

if successful_runs:
    for precision, result in successful_runs.items():
        print(f"{precision:15} | Time: {result['time']:5.2f}s | Acc: {result['accuracy']:.4f} | Loss Std: {result['loss_std']:.6f}")
```

## Training Strategies

### Strategy Selection and Configuration

```python
class StrategyTestModel(pl.LightningModule):
    """Model to test different training strategies"""
    
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        
        # Slightly larger model to benefit from parallelization
        self.model = nn.Sequential(
            nn.Linear(128, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
        
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        
        # Strategy-specific tracking
        self.strategy_info = {}
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        # Log strategy-specific information
        if batch_idx == 0:
            strategy = str(type(self.trainer.strategy).__name__)
            self.strategy_info['strategy_name'] = strategy
            
            # Log device information
            device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
            self.log('device_count', float(device_count))
            self.log('strategy_hash', float(hash(strategy) % 1000))  # Simple identifier
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

def test_training_strategies():
    """Test different training strategies"""
    
    strategy_configs = [
        {'name': 'Single Device', 'strategy': 'auto', 'devices': 1},
    ]
    
    # Add multi-GPU strategies if available
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        strategy_configs.extend([
            {'name': 'Data Parallel', 'strategy': 'dp', 'devices': 2},
            {'name': 'DDP', 'strategy': 'ddp', 'devices': 2},
        ])
    
    results = {}
    
    for config in strategy_configs:
        print(f"\n=== Testing {config['name']} Strategy ===")
        
        model = StrategyTestModel()
        
        try:
            # Handle different device configurations
            if torch.cuda.is_available():
                accelerator = 'gpu'
                devices = min(config['devices'], torch.cuda.device_count())
            else:
                accelerator = 'cpu'
                devices = 1
                config['strategy'] = 'auto'  # Only auto strategy works with CPU
            
            trainer = pl.Trainer(
                max_epochs=2,
                accelerator=accelerator,
                devices=devices,
                strategy=config['strategy'],
                enable_checkpointing=False,
                logger=False,
                enable_progress_bar=False,
                limit_train_batches=25,
                limit_val_batches=10
            )
            
            import time
            start_time = time.time()
            trainer.fit(model, train_loader, val_loader)
            training_time = time.time() - start_time
            
            final_acc = trainer.callback_metrics.get('val_acc', 0)
            
            results[config['name']] = {
                'time': training_time,
                'accuracy': final_acc,
                'strategy': config['strategy'],
                'devices_used': devices
            }
            
            print(f"Strategy: {config['strategy']}")
            print(f"Devices used: {devices}")
            print(f"Training time: {training_time:.2f}s")
            print(f"Final accuracy: {final_acc:.4f}")
            
        except Exception as e:
            print(f"Failed with {config['name']} strategy: {e}")
            results[config['name']] = {'error': str(e)}
    
    return results

strategy_results = test_training_strategies()

# Display strategy comparison
print(f"\n📊 Training Strategy Comparison:")
for strategy, result in strategy_results.items():
    if 'error' not in result:
        efficiency = result['accuracy'] / result['time'] if result['time'] > 0 else 0
        print(f"{strategy:15} | Time: {result['time']:5.2f}s | Acc: {result['accuracy']:.4f} | Efficiency: {efficiency:.3f}")
    else:
        print(f"{strategy:15} | ERROR: {result['error']}")
```

## Advanced Device and Strategy Configuration

### Custom Strategy Implementation

```python
class CustomTrainingStrategy:
    """Custom training strategy configuration"""
    
    @staticmethod
    def get_optimal_config(task_type='classification', model_size='medium'):
        """Get optimal configuration based on task and model size"""
        
        config = {
            'accelerator': 'auto',
            'devices': 'auto',
            'precision': '32-true',
            'strategy': 'auto'
        }
        
        # Device selection
        if torch.cuda.is_available():
            gpu_count = torch.cuda.device_count()
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
            
            config['accelerator'] = 'gpu'
            
            # Device count based on model size and available GPUs
            if model_size == 'small':
                config['devices'] = 1
            elif model_size == 'medium' and gpu_count >= 2:
                config['devices'] = min(2, gpu_count)
            elif model_size == 'large' and gpu_count >= 4:
                config['devices'] = min(4, gpu_count)
            else:
                config['devices'] = 1
            
            # Precision based on GPU capability
            if gpu_memory > 8:  # High-memory GPU
                config['precision'] = '16-mixed'
            else:  # Lower-memory GPU
                config['precision'] = '16-mixed'
            
            # Strategy selection
            if config['devices'] > 1:
                config['strategy'] = 'ddp'  # DDP for multi-GPU
            else:
                config['strategy'] = 'auto'
                
        elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            config['accelerator'] = 'mps'
            config['devices'] = 1
            config['precision'] = '32-true'  # MPS doesn't support mixed precision yet
            
        else:  # CPU fallback
            config['accelerator'] = 'cpu'
            config['devices'] = 1
            config['precision'] = '32-true'
        
        return config
    
    @staticmethod
    def create_trainer(config, **trainer_kwargs):
        """Create trainer with optimal configuration"""
        
        default_kwargs = {
            'max_epochs': 10,
            'enable_checkpointing': True,
            'log_every_n_steps': 50,
            'check_val_every_n_epoch': 1
        }
        
        # Merge configurations
        final_kwargs = {**default_kwargs, **trainer_kwargs}
        final_kwargs.update(config)
        
        return pl.Trainer(**final_kwargs)

# Test custom strategy configuration
def test_custom_strategy():
    """Test custom strategy configuration"""
    
    print("=== Custom Strategy Configuration ===")
    
    # Test different model sizes
    model_sizes = ['small', 'medium', 'large']
    
    for size in model_sizes:
        config = CustomTrainingStrategy.get_optimal_config(model_size=size)
        print(f"\n{size.upper()} Model Configuration:")
        for key, value in config.items():
            print(f"  {key}: {value}")
        
        # Test if configuration is valid
        try:
            model = DeviceAwareModel()
            trainer = CustomTrainingStrategy.create_trainer(
                config,
                max_epochs=1,
                limit_train_batches=5,
                limit_val_batches=2,
                enable_checkpointing=False,
                logger=False,
                enable_progress_bar=False
            )
            
            trainer.fit(model, train_loader, val_loader)
            print(f"  ✓ Configuration valid and tested")
            
        except Exception as e:
            print(f"  ❌ Configuration failed: {e}")

test_custom_strategy()
```

## Production Configuration Guide

### Complete Configuration Templates

```python
def get_production_configs():
    """Get production-ready configurations for different scenarios"""
    
    configs = {
        'development': {
            'description': 'Fast iteration for development',
            'config': {
                'accelerator': 'auto',
                'devices': 1,
                'precision': '32-true',
                'strategy': 'auto',
                'max_epochs': 5,
                'limit_train_batches': 100,
                'limit_val_batches': 50,
                'fast_dev_run': False,
                'enable_progress_bar': True,
                'log_every_n_steps': 10
            }
        },
        
        'single_gpu_production': {
            'description': 'Single GPU production training',
            'config': {
                'accelerator': 'gpu',
                'devices': 1,
                'precision': '16-mixed',
                'strategy': 'auto',
                'max_epochs': 100,
                'gradient_clip_val': 1.0,
                'accumulate_grad_batches': 1,
                'enable_progress_bar': True,
                'log_every_n_steps': 100,
                'check_val_every_n_epoch': 5
            }
        },
        
        'multi_gpu_production': {
            'description': 'Multi-GPU production training',
            'config': {
                'accelerator': 'gpu',
                'devices': 4,
                'precision': '16-mixed',
                'strategy': 'ddp',
                'max_epochs': 100,
                'gradient_clip_val': 1.0,
                'accumulate_grad_batches': 2,
                'enable_progress_bar': True,
                'log_every_n_steps': 50,
                'check_val_every_n_epoch': 2,
                'sync_batchnorm': True
            }
        },
        
        'cpu_production': {
            'description': 'CPU-only production training',
            'config': {
                'accelerator': 'cpu',
                'devices': 1,
                'precision': '32-true',
                'strategy': 'auto',
                'max_epochs': 50,
                'gradient_clip_val': 1.0,
                'accumulate_grad_batches': 4,
                'enable_progress_bar': True,
                'log_every_n_steps': 25,
                'check_val_every_n_epoch': 1
            }
        },
        
        'debugging': {
            'description': 'Configuration for debugging issues',
            'config': {
                'accelerator': 'auto',
                'devices': 1,
                'precision': '32-true',  # Full precision for stability
                'strategy': 'auto',
                'max_epochs': 5,
                'fast_dev_run': False,
                'overfit_batches': 10,  # Overfit small subset
                'limit_train_batches': 20,
                'limit_val_batches': 10,
                'enable_progress_bar': True,
                'log_every_n_steps': 1,
                'detect_anomaly': True,
                'track_grad_norm': 2
            }
        }
    }
    
    return configs

def print_configuration_guide():
    """Print comprehensive configuration guide"""
    
    configs = get_production_configs()
    
    print("=" * 80)
    print("🚀 LIGHTNING CONFIGURATION GUIDE")
    print("=" * 80)
    
    for scenario, info in configs.items():
        print(f"\n📋 {scenario.upper().replace('_', ' ')} CONFIGURATION")
        print(f"Description: {info['description']}")
        print("-" * 60)
        
        for key, value in info['config'].items():
            print(f"  {key:25} = {value}")
        
        print("\nUsage:")
        print(f"  trainer = pl.Trainer(**{info['config']})")
    
    print(f"\n" + "=" * 80)
    print("💡 Configuration Selection Guide:")
    print("  • Development: Fast iteration, debugging enabled")
    print("  • Single GPU: Production training on one GPU")  
    print("  • Multi GPU: Distributed training on multiple GPUs")
    print("  • CPU: Training without GPU acceleration")
    print("  • Debugging: Troubleshooting training issues")
    print("=" * 80)

# Display configuration guide
print_configuration_guide()

# Test one of the configurations
def test_production_config():
    """Test a production configuration"""
    
    configs = get_production_configs()
    
    # Test development config (safe for all systems)
    dev_config = configs['development']['config']
    
    print(f"\n=== Testing Development Configuration ===")
    
    model = DeviceAwareModel()
    trainer = pl.Trainer(
        **dev_config,
        enable_checkpointing=False,
        logger=False
    )
    
    trainer.fit(model, train_loader, val_loader)
    
    final_acc = trainer.callback_metrics.get('val_acc', 0)
    print(f"✓ Development configuration test completed")
    print(f"Final accuracy: {final_acc:.4f}")

test_production_config()
```

## Summary

This notebook covered devices, precision, and training strategies in PyTorch Lightning:

1. **Device Management**: Automatic detection and configuration of CPUs, GPUs, and MPS devices
2. **Precision Strategies**: Comparing FP32, mixed precision, and BFloat16 training
3. **Training Strategies**: Single device, data parallel, and distributed data parallel approaches  
4. **Custom Configuration**: Building optimal configurations based on hardware and model requirements
5. **Production Templates**: Ready-to-use configurations for different deployment scenarios

Key device considerations:
- **CPU**: Always available, good for small models and debugging
- **GPU**: Best performance for most deep learning tasks
- **MPS**: Apple Silicon acceleration, growing support
- **Multi-GPU**: Significant speedup for larger models and datasets

Precision trade-offs:
- **FP32**: Maximum accuracy, higher memory usage
- **Mixed Precision**: 1.5-2x speedup, minimal accuracy loss
- **BFloat16**: Better numerical stability than FP16

Strategy selection:
- **Single Device**: Simplest setup, good for most cases
- **Data Parallel**: Easy multi-GPU, but has limitations
- **DDP**: Best multi-GPU strategy, more complex setup

Production best practices:
- Choose precision based on model sensitivity and hardware
- Use DDP for multi-GPU production training
- Enable gradient clipping for training stability
- Monitor GPU utilization and memory usage
- Test configurations thoroughly before deployment
- Plan for debugging with appropriate fallback configs

Next notebook: We'll explore distributed data parallel (DDP) training in detail.