# File Location: notebooks/12_ddp_single_node_walkthrough.ipynb

# Distributed Data Parallel (DDP) Single Node Walkthrough

This notebook demonstrates how to implement Distributed Data Parallel training on a single node with multiple GPUs using PyTorch Lightning. We'll cover the setup, configuration, and best practices for multi-GPU training.

## Learning Objectives
- Understand DDP concepts and benefits
- Implement multi-GPU training with Lightning
- Handle data loading and synchronization
- Monitor and optimize distributed training performance

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import os

# Check available GPUs
print(f"Available GPUs: {torch.cuda.device_count()}")
print(f"CUDA available: {torch.cuda.is_available()}")
```

## 1. Understanding DDP Basics

```python
# DDP Key Concepts Explanation
class DDPConcepts:
    """
    Key DDP concepts:
    1. Process Groups: Groups of processes that communicate
    2. Rank: Unique identifier for each process
    3. World Size: Total number of processes
    4. Local Rank: Process rank within a node
    5. All-Reduce: Synchronization operation across processes
    """
    
    @staticmethod
    def explain_ddp():
        concepts = {
            "Data Parallelism": "Each GPU processes a different batch subset",
            "Gradient Synchronization": "Gradients averaged across all GPUs",
            "Model Replication": "Same model on each GPU",
            "Communication Backend": "NCCL for GPU, Gloo for CPU"
        }
        
        for concept, explanation in concepts.items():
            print(f"{concept}: {explanation}")

DDPConcepts.explain_ddp()
```

## 2. Dataset Preparation for DDP

```python
class CIFAR10Dataset(Dataset):
    """Custom CIFAR10 dataset with DDP-friendly sampling"""
    
    def __init__(self, train=True, transform=None):
        self.dataset = torchvision.datasets.CIFAR10(
            root='./data', train=train, download=True, transform=transform
        )
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]

# Data transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Create datasets
train_dataset = CIFAR10Dataset(train=True, transform=train_transform)
val_dataset = CIFAR10Dataset(train=False, transform=val_transform)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
```

## 3. Lightning Module for DDP

```python
class DDPResNet(pl.LightningModule):
    """ResNet model optimized for DDP training"""
    
    def __init__(self, num_classes=10, learning_rate=0.1):
        super().__init__()
        self.save_hyperparameters()
        
        # Model architecture
        self.backbone = torchvision.models.resnet18(pretrained=False)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Metrics tracking
        self.train_acc = pl.metrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = pl.metrics.Accuracy(task="multiclass", num_classes=num_classes)
        
    def forward(self, x):
        return self.backbone(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        # Log metrics
        self.train_acc(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        # Log metrics
        self.val_acc(logits, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        
        return loss
    
    def configure_optimizers(self):
        # Scale learning rate by number of GPUs
        scaled_lr = self.hparams.learning_rate * torch.cuda.device_count()
        
        optimizer = torch.optim.SGD(
            self.parameters(), 
            lr=scaled_lr,
            momentum=0.9, 
            weight_decay=5e-4
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=200
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'
            }
        }

# Initialize model
model = DDPResNet(num_classes=10, learning_rate=0.1)
```

## 4. Data Module for DDP

```python
class CIFAR10DataModule(pl.LightningDataModule):
    """Data module with DDP-optimized data loading"""
    
    def __init__(self, batch_size=128, num_workers=4):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = CIFAR10Dataset(train=True, transform=train_transform)
            self.val_dataset = CIFAR10Dataset(train=False, transform=val_transform)
        
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,  # Lightning handles DDP sampling
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True
        )

# Initialize data module
data_module = CIFAR10DataModule(batch_size=128, num_workers=4)
```

## 5. DDP Strategy Configuration

```python
# Configure DDP strategy
ddp_strategy = DDPStrategy(
    process_group_backend="nccl",  # Use NCCL for GPU communication
    find_unused_parameters=False,  # Set to True if you have unused parameters
    gradient_as_bucket_view=True,  # Memory optimization
    static_graph=True  # Optimization for static computation graphs
)

# Alternative: Auto-detect strategy
# strategy = "auto" if torch.cuda.device_count() > 1 else None

print(f"Using DDP strategy: {ddp_strategy}")
```

## 6. Callbacks for DDP Training

```python
# Model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath='./checkpoints',
    filename='ddp-cifar10-{epoch:02d}-{val_acc:.2f}',
    save_top_k=3,
    monitor='val_acc',
    mode='max',
    save_last=True
)

# Early stopping callback
early_stop_callback = EarlyStopping(
    monitor='val_acc',
    min_delta=0.001,
    patience=10,
    mode='max'
)

callbacks = [checkpoint_callback, early_stop_callback]
```

## 7. Trainer Configuration for DDP

```python
# Configure trainer for DDP
trainer = pl.Trainer(
    max_epochs=50,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=torch.cuda.device_count() if torch.cuda.is_available() else 1,
    strategy=ddp_strategy if torch.cuda.device_count() > 1 else 'auto',
    callbacks=callbacks,
    log_every_n_steps=10,
    val_check_interval=1.0,
    precision=16,  # Mixed precision for faster training
    gradient_clip_val=1.0,
    enable_checkpointing=True,
    enable_progress_bar=True
)

print(f"Trainer configuration:")
print(f"  Accelerator: {trainer.accelerator}")
print(f"  Devices: {trainer.num_devices}")
print(f"  Strategy: {trainer.strategy}")
```

## 8. Training with DDP

```python
# Start training
print("Starting DDP training...")

# Fit the model
trainer.fit(model, data_module)

# Load best checkpoint for testing
best_model = DDPResNet.load_from_checkpoint(checkpoint_callback.best_model_path)
print(f"Best model loaded from: {checkpoint_callback.best_model_path}")
```

## 9. DDP Performance Monitoring

```python
class DDPProfiler:
    """Utility class for monitoring DDP performance"""
    
    @staticmethod
    def log_gpu_memory():
        """Log GPU memory usage across all devices"""
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                allocated = torch.cuda.memory_allocated(i) / 1024**3
                cached = torch.cuda.memory_reserved(i) / 1024**3
                print(f"GPU {i}: {allocated:.2f}GB allocated, {cached:.2f}GB cached")
    
    @staticmethod
    def estimate_training_time(dataloader, num_epochs):
        """Estimate total training time"""
        samples_per_epoch = len(dataloader.dataset)
        batch_size = dataloader.batch_size
        batches_per_epoch = len(dataloader)
        
        print(f"Training estimation:")
        print(f"  Samples per epoch: {samples_per_epoch}")
        print(f"  Batch size: {batch_size}")
        print(f"  Batches per epoch: {batches_per_epoch}")
        print(f"  Total epochs: {num_epochs}")

# Monitor performance
profiler = DDPProfiler()
profiler.log_gpu_memory()
profiler.estimate_training_time(data_module.train_dataloader(), 50)
```

## 10. DDP Troubleshooting and Best Practices

```python
class DDPTroubleshooting:
    """Common DDP issues and solutions"""
    
    @staticmethod
    def check_environment():
        """Check DDP environment setup"""
        env_vars = ['MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK']
        
        print("Environment variables:")
        for var in env_vars:
            value = os.environ.get(var, 'Not set')
            print(f"  {var}: {value}")
    
    @staticmethod
    def best_practices():
        """Print DDP best practices"""
        practices = [
            "Use sync_dist=True for logging metrics",
            "Scale learning rate by number of GPUs",
            "Use pin_memory=True in DataLoader",
            "Set find_unused_parameters=False if possible",
            "Use NCCL backend for GPU communication",
            "Ensure batch size is divisible by number of GPUs",
            "Use gradient clipping for stability"
        ]
        
        print("DDP Best Practices:")
        for i, practice in enumerate(practices, 1):
            print(f"{i}. {practice}")

troubleshooter = DDPTroubleshooting()
troubleshooter.check_environment()
troubleshooter.best_practices()
```

## 11. Performance Comparison: Single GPU vs DDP

```python
# Performance comparison utility
def compare_training_performance():
    """Compare single GPU vs multi-GPU training"""
    
    single_gpu_time = 100  # Example: 100 minutes
    multi_gpu_time = single_gpu_time / torch.cuda.device_count() * 1.1  # 10% overhead
    
    speedup = single_gpu_time / multi_gpu_time
    efficiency = speedup / torch.cuda.device_count() * 100
    
    print(f"Performance Comparison:")
    print(f"  Single GPU time: {single_gpu_time:.1f} minutes")
    print(f"  Multi-GPU time: {multi_gpu_time:.1f} minutes")
    print(f"  Speedup: {speedup:.2f}x")
    print(f"  Efficiency: {efficiency:.1f}%")
    
    return speedup, efficiency

if torch.cuda.device_count() > 1:
    speedup, efficiency = compare_training_performance()
```

## 12. Testing DDP Model

```python
# Test the trained model
def test_ddp_model(model, test_dataloader):
    """Test the DDP-trained model"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in test_dataloader:
            x, y = batch
            if torch.cuda.is_available():
                x, y = x.cuda(), y.cuda()
            
            outputs = model(x)
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

# Create test dataloader
test_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)

# Test the model
if torch.cuda.is_available():
    best_model = best_model.cuda()

test_accuracy = test_ddp_model(best_model, test_dataloader)
```

# Summary

This notebook demonstrated Distributed Data Parallel (DDP) training with PyTorch Lightning on a single node with multiple GPUs. Key takeaways:

## Key Concepts Covered
- **DDP Architecture**: Understanding process groups, ranks, and synchronization
- **Data Loading**: Proper dataset handling for distributed training
- **Model Implementation**: Lightning module optimized for DDP
- **Strategy Configuration**: Setting up DDP strategy with optimal parameters
- **Performance Monitoring**: Tracking GPU usage and training efficiency

## Best Practices Implemented
- Synchronized metric logging across processes
- Learning rate scaling for multiple GPUs
- Memory-efficient data loading with pin_memory
- Mixed precision training for speed
- Proper gradient clipping and checkpointing

## Performance Benefits
- Linear speedup with multiple GPUs (with minimal overhead)
- Efficient memory utilization across devices
- Automatic gradient synchronization
- Fault tolerance with checkpointing

## Next Steps
- Experiment with multi-node DDP setups
- Try different communication backends (NCCL vs Gloo)
- Implement custom strategies for specific use cases
- Explore advanced features like gradient compression

The DDP approach significantly reduces training time while maintaining model quality, making it essential for large-scale deep learning projects.