# Lesson 95: Distributed Training and Gradient Synchronization

**Week 19 - Day 5: Federated & Distributed Learning**  
**Difficulty Level:** Expert

---

## Introduction

As machine learning models grow larger and datasets expand to billions of samples, training on a single device becomes impractical or impossible. **Distributed training** enables us to leverage multiple GPUs or machines to train models faster and handle larger-scale problems.

In this lesson, we'll explore:
- **Data parallelism** vs **model parallelism**
- **Gradient synchronization** strategies (synchronous vs asynchronous)
- **PyTorch DistributedDataParallel (DDP)** for efficient multi-GPU training
- Real-world considerations for distributed systems

### Why Distributed Training Matters

1. **Speed**: Training time reduced from weeks to hours
2. **Scale**: Handle models too large for single GPU memory
3. **Cost efficiency**: Optimize compute resource utilization
4. **Research**: Enable experiments previously computationally infeasible

### Learning Objectives

By the end of this lesson, you will:
- ‚úÖ Understand data parallelism and model parallelism strategies
- ‚úÖ Implement distributed training with PyTorch DDP
- ‚úÖ Compare synchronous vs asynchronous gradient synchronization
- ‚úÖ Recognize communication bottlenecks and optimization techniques
- ‚úÖ Apply distributed training to real-world scenarios

---

## Core Concepts: Distributed Training Fundamentals

### 1. Data Parallelism vs Model Parallelism

#### Data Parallelism
- **Strategy**: Split the dataset across multiple devices, replicate the model on each device
- **Process**: Each device processes a different batch, computes gradients, then synchronizes
- **Best for**: Models that fit on a single device but need faster training
- **Example**: Training ResNet-50 on ImageNet across 8 GPUs

```
GPU 0: Model Copy ‚Üí Batch 0 ‚Üí Gradients ‚Üí 
GPU 1: Model Copy ‚Üí Batch 1 ‚Üí Gradients ‚Üí  ‚Üí Average Gradients ‚Üí Update All Models
GPU 2: Model Copy ‚Üí Batch 2 ‚Üí Gradients ‚Üí 
GPU 3: Model Copy ‚Üí Batch 3 ‚Üí Gradients ‚Üí 
```

#### Model Parallelism
- **Strategy**: Split the model itself across devices (e.g., different layers on different GPUs)
- **Process**: Forward/backward pass flows through devices sequentially
- **Best for**: Models too large to fit on a single device
- **Example**: GPT-3 with 175B parameters

```
Input ‚Üí GPU 0 (Layers 1-5) ‚Üí GPU 1 (Layers 6-10) ‚Üí GPU 2 (Layers 11-15) ‚Üí Output
```

### 2. Gradient Synchronization Strategies

#### Synchronous Training (All-Reduce)
- All workers compute gradients simultaneously
- Wait for all workers to finish before averaging gradients
- Update all model copies with synchronized gradients
- **Pros**: Deterministic, better convergence
- **Cons**: Speed limited by slowest worker (straggler problem)

#### Asynchronous Training (Parameter Server)
- Workers compute gradients independently
- Send gradients to parameter server immediately when ready
- Parameter server updates central model asynchronously
- **Pros**: No waiting, higher throughput
- **Cons**: Stale gradients, potential instability

### 3. Communication Patterns

#### Ring All-Reduce
- Most efficient for data parallelism
- Workers arranged in a ring topology
- Gradients passed around ring, accumulated incrementally
- Bandwidth optimal: each worker sends/receives exactly once

#### Parameter Server Architecture
- Central server(s) store model parameters
- Workers push gradients and pull updated parameters
- Can become bottleneck with many workers

### Mathematical Foundation

For data parallelism with $N$ workers:

**Local gradient on worker $i$:**
$$g_i = \nabla_{\theta} L(x_i, y_i; \theta)$$

**Synchronized gradient:**
$$g = \frac{1}{N} \sum_{i=1}^{N} g_i$$

**Model update:**
$$\theta_{t+1} = \theta_t - \eta \cdot g$$

**Effective batch size:**
$$B_{\text{effective}} = N \times B_{\text{local}}$$

Where $B_{\text{local}}$ is the per-worker batch size.

---

## Practical Implementation

Let's implement distributed training using PyTorch's DistributedDataParallel (DDP).

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import time
from typing import List, Tuple

# For visualization
import seaborn as sns
sns.set_style('whitegrid')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")

### Step 1: Define a Simple Neural Network

We'll create a simple CNN for demonstration purposes.

In [None]:
class SimpleCNN(nn.Module):
    """Simple CNN for distributed training demonstration."""
    
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Create model instance
model = SimpleCNN(num_classes=10)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

### Step 2: Simulated Distributed Training

Since we may not have multiple GPUs available, we'll simulate distributed training to demonstrate the concepts.

In [None]:
class DistributedTrainingSimulator:
    """Simulate distributed training with multiple workers."""
    
    def __init__(self, model, num_workers=4, sync_mode='synchronous'):
        """
        Args:
            model: PyTorch model to train
            num_workers: Number of simulated workers
            sync_mode: 'synchronous' or 'asynchronous'
        """
        self.num_workers = num_workers
        self.sync_mode = sync_mode
        
        # Create model copies for each worker
        self.workers = []
        for i in range(num_workers):
            worker_model = SimpleCNN(num_classes=10)
            worker_model.load_state_dict(model.state_dict())  # Initialize with same weights
            self.workers.append({
                'id': i,
                'model': worker_model,
                'optimizer': optim.SGD(worker_model.parameters(), lr=0.01),
                'gradients': None,
                'compute_time': 0.0
            })
        
        self.global_model = model
        self.iteration = 0
        self.history = {'sync_times': [], 'worker_times': []}
    
    def simulate_forward_backward(self, worker_id, batch_data, batch_labels):
        """Simulate forward and backward pass for a worker."""
        worker = self.workers[worker_id]
        
        # Simulate variable computation time (some workers slower than others)
        compute_time = np.random.uniform(0.8, 1.2)  # Relative time
        
        # Forward pass
        outputs = worker['model'](batch_data)
        loss = nn.CrossEntropyLoss()(outputs, batch_labels)
        
        # Backward pass
        worker['optimizer'].zero_grad()
        loss.backward()
        
        # Store gradients
        worker['gradients'] = [p.grad.clone() for p in worker['model'].parameters() if p.grad is not None]
        worker['compute_time'] = compute_time
        
        return loss.item(), compute_time
    
    def synchronous_update(self):
        """Synchronous gradient averaging (All-Reduce)."""
        # Wait for all workers (slowest worker determines sync time)
        max_compute_time = max(w['compute_time'] for w in self.workers)
        
        # Average gradients across all workers
        avg_gradients = []
        num_params = len(self.workers[0]['gradients'])
        
        for param_idx in range(num_params):
            grads = [w['gradients'][param_idx] for w in self.workers]
            avg_grad = torch.stack(grads).mean(dim=0)
            avg_gradients.append(avg_grad)
        
        # Update all workers with averaged gradients
        for worker in self.workers:
            for param, avg_grad in zip(worker['model'].parameters(), avg_gradients):
                if param.grad is not None:
                    param.grad.copy_(avg_grad)
            worker['optimizer'].step()
        
        # Simulate communication overhead (10% of compute time)
        comm_time = max_compute_time * 0.1
        total_time = max_compute_time + comm_time
        
        self.history['sync_times'].append(total_time)
        self.history['worker_times'].append([w['compute_time'] for w in self.workers])
        
        return total_time
    
    def asynchronous_update(self):
        """Asynchronous gradient updates (Parameter Server)."""
        # Each worker updates independently as soon as ready
        worker_times = []
        
        for worker in self.workers:
            # Apply worker's gradients immediately
            for param, grad in zip(worker['model'].parameters(), worker['gradients']):
                if param.grad is not None:
                    param.grad.copy_(grad)
            worker['optimizer'].step()
            
            # Simulate communication (can happen in parallel)
            comm_time = worker['compute_time'] * 0.05
            total_time = worker['compute_time'] + comm_time
            worker_times.append(total_time)
        
        # Average time is mean across all workers (no waiting)
        avg_time = np.mean(worker_times)
        
        self.history['sync_times'].append(avg_time)
        self.history['worker_times'].append([w['compute_time'] for w in self.workers])
        
        return avg_time
    
    def train_step(self, batches):
        """Perform one training step across all workers."""
        losses = []
        
        # Each worker processes its batch
        for worker_id, (batch_data, batch_labels) in enumerate(batches):
            loss, _ = self.simulate_forward_backward(worker_id, batch_data, batch_labels)
            losses.append(loss)
        
        # Synchronize gradients based on mode
        if self.sync_mode == 'synchronous':
            sync_time = self.synchronous_update()
        else:
            sync_time = self.asynchronous_update()
        
        self.iteration += 1
        return np.mean(losses), sync_time

print("DistributedTrainingSimulator class defined successfully")

### Step 3: Generate Synthetic Data and Run Simulation

In [None]:
# Generate synthetic data
def generate_batch(batch_size=32, num_classes=10):
    """Generate synthetic image data and labels."""
    data = torch.randn(batch_size, 3, 32, 32)
    labels = torch.randint(0, num_classes, (batch_size,))
    return data, labels

# Run simulations
num_workers = 4
num_iterations = 50
batch_size_per_worker = 32

print("=" * 60)
print("SYNCHRONOUS TRAINING SIMULATION")
print("=" * 60)

# Synchronous simulation
model_sync = SimpleCNN(num_classes=10)
simulator_sync = DistributedTrainingSimulator(model_sync, num_workers=num_workers, sync_mode='synchronous')

sync_losses = []
sync_times = []

for iteration in range(num_iterations):
    # Generate batches for each worker
    batches = [generate_batch(batch_size_per_worker) for _ in range(num_workers)]
    loss, step_time = simulator_sync.train_step(batches)
    sync_losses.append(loss)
    sync_times.append(step_time)
    
    if (iteration + 1) % 10 == 0:
        print(f"Iteration {iteration+1}/{num_iterations} | Loss: {loss:.4f} | Time: {step_time:.3f}s")

print("\n" + "=" * 60)
print("ASYNCHRONOUS TRAINING SIMULATION")
print("=" * 60)

# Asynchronous simulation
model_async = SimpleCNN(num_classes=10)
model_async.load_state_dict(model_sync.state_dict())  # Start from same initialization
simulator_async = DistributedTrainingSimulator(model_async, num_workers=num_workers, sync_mode='asynchronous')

async_losses = []
async_times = []

for iteration in range(num_iterations):
    batches = [generate_batch(batch_size_per_worker) for _ in range(num_workers)]
    loss, step_time = simulator_async.train_step(batches)
    async_losses.append(loss)
    async_times.append(step_time)
    
    if (iteration + 1) % 10 == 0:
        print(f"Iteration {iteration+1}/{num_iterations} | Loss: {loss:.4f} | Time: {step_time:.3f}s")

print("\n" + "=" * 60)
print("SIMULATION SUMMARY")
print("=" * 60)
print(f"Synchronous - Total time: {sum(sync_times):.2f}s | Avg time/iteration: {np.mean(sync_times):.3f}s")
print(f"Asynchronous - Total time: {sum(async_times):.2f}s | Avg time/iteration: {np.mean(async_times):.3f}s")
print(f"Speedup (Async vs Sync): {sum(sync_times) / sum(async_times):.2f}x")

### Step 4: Visualize Results

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Training loss comparison
axes[0, 0].plot(sync_losses, label='Synchronous', linewidth=2, alpha=0.8)
axes[0, 0].plot(async_losses, label='Asynchronous', linewidth=2, alpha=0.8)
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss: Sync vs Async')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Time per iteration
axes[0, 1].plot(sync_times, label='Synchronous', linewidth=2, alpha=0.8)
axes[0, 1].plot(async_times, label='Asynchronous', linewidth=2, alpha=0.8)
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Time (seconds)')
axes[0, 1].set_title('Time per Iteration')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Cumulative time
cumulative_sync = np.cumsum(sync_times)
cumulative_async = np.cumsum(async_times)
axes[1, 0].plot(cumulative_sync, label='Synchronous', linewidth=2, alpha=0.8)
axes[1, 0].plot(cumulative_async, label='Asynchronous', linewidth=2, alpha=0.8)
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('Cumulative Time (seconds)')
axes[1, 0].set_title('Cumulative Training Time')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Worker time distribution (last 10 iterations)
last_sync_times = simulator_sync.history['worker_times'][-10:]
last_async_times = simulator_async.history['worker_times'][-10:]

worker_data = []
for i in range(num_workers):
    sync_worker_times = [times[i] for times in last_sync_times]
    async_worker_times = [times[i] for times in last_async_times]
    worker_data.extend([('Sync', f'Worker {i}', t) for t in sync_worker_times])
    worker_data.extend([('Async', f'Worker {i}', t) for t in async_worker_times])

import pandas as pd
df_workers = pd.DataFrame(worker_data, columns=['Mode', 'Worker', 'Time'])

sync_avg = [np.mean([times[i] for times in last_sync_times]) for i in range(num_workers)]
async_avg = [np.mean([times[i] for times in last_async_times]) for i in range(num_workers)]

x = np.arange(num_workers)
width = 0.35
axes[1, 1].bar(x - width/2, sync_avg, width, label='Synchronous', alpha=0.8)
axes[1, 1].bar(x + width/2, async_avg, width, label='Asynchronous', alpha=0.8)
axes[1, 1].set_xlabel('Worker ID')
axes[1, 1].set_ylabel('Avg Compute Time (seconds)')
axes[1, 1].set_title('Worker Compute Time (Last 10 Iterations)')
axes[1, 1].set_xticks(x)
axes[1, 1].set_xticklabels([f'W{i}' for i in range(num_workers)])
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nüìä Key Observations:")
print("  1. Synchronous training waits for the slowest worker (straggler problem)")
print("  2. Asynchronous training achieves higher throughput but may have noisier convergence")
print("  3. Effective batch size = num_workers √ó batch_size_per_worker = "
      f"{num_workers} √ó {batch_size_per_worker} = {num_workers * batch_size_per_worker}")

---

## Real-World Distributed Training with PyTorch DDP

Here's how you would set up actual multi-GPU training with PyTorch DistributedDataParallel:

In [None]:
# Example: PyTorch DDP setup (conceptual - requires multi-GPU environment)

ddp_example = '''
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    """Initialize the distributed environment."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    """Clean up the distributed environment."""
    dist.destroy_process_group()

def train(rank, world_size):
    """Training function for each process."""
    setup(rank, world_size)
    
    # Create model and move to GPU
    model = SimpleCNN().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    
    # Create distributed sampler
    train_dataset = YourDataset()
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
    
    # Training loop
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(num_epochs):
        train_sampler.set_epoch(epoch)  # Shuffle data differently each epoch
        
        for batch_data, batch_labels in train_loader:
            batch_data, batch_labels = batch_data.to(rank), batch_labels.to(rank)
            
            optimizer.zero_grad()
            outputs = ddp_model(batch_data)
            loss = criterion(outputs, batch_labels)
            loss.backward()  # Gradients automatically synchronized across GPUs
            optimizer.step()
    
    cleanup()

# Launch training
if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
'''

print("PyTorch DDP Setup Example:")
print("=" * 60)
print(ddp_example)
print("=" * 60)
print("\nüîë Key DDP Components:")
print("  ‚Ä¢ dist.init_process_group(): Initialize distributed backend (NCCL for GPUs)")
print("  ‚Ä¢ DistributedDataParallel: Wraps model for automatic gradient synchronization")
print("  ‚Ä¢ DistributedSampler: Ensures each GPU gets different data")
print("  ‚Ä¢ Gradients synced automatically during backward() via All-Reduce")

---

## Hands-On Activity: Optimize Communication Overhead

### Challenge

You're training a large model across 8 GPUs. You notice that communication overhead is 30% of your total training time. Implement gradient compression to reduce communication cost.

**Tasks:**
1. Implement gradient quantization (reducing precision)
2. Compare compressed vs uncompressed gradient sizes
3. Analyze the trade-off between compression ratio and accuracy

In [None]:
class GradientCompressor:
    """Compress gradients to reduce communication overhead."""
    
    @staticmethod
    def quantize(tensor, num_bits=8):
        """
        Quantize tensor to reduce precision.
        
        Args:
            tensor: Input tensor (float32)
            num_bits: Number of bits for quantization (1-8)
        
        Returns:
            Quantized tensor, scale, zero_point
        """
        # Find min/max values
        min_val = tensor.min()
        max_val = tensor.max()
        
        # Calculate scale and zero point
        qmin = 0
        qmax = 2**num_bits - 1
        scale = (max_val - min_val) / (qmax - qmin)
        zero_point = qmin - min_val / scale
        
        # Quantize
        quantized = torch.clamp(torch.round(tensor / scale + zero_point), qmin, qmax)
        
        return quantized.to(torch.uint8), scale, zero_point
    
    @staticmethod
    def dequantize(quantized, scale, zero_point):
        """
        Dequantize tensor back to float32.
        
        Args:
            quantized: Quantized tensor (uint8)
            scale: Scale factor
            zero_point: Zero point
        
        Returns:
            Dequantized tensor (float32)
        """
        return scale * (quantized.float() - zero_point)
    
    @staticmethod
    def top_k_sparsification(tensor, k=0.1):
        """
        Keep only top k% gradients by magnitude (sparse communication).
        
        Args:
            tensor: Input tensor
            k: Fraction of gradients to keep (0-1)
        
        Returns:
            Sparse tensor (most values zeroed out)
        """
        flat = tensor.flatten()
        threshold_idx = int(len(flat) * (1 - k))
        threshold = torch.topk(flat.abs(), threshold_idx)[0][-1]
        
        mask = tensor.abs() >= threshold
        return tensor * mask

# Test compression techniques
print("Testing Gradient Compression Techniques")
print("=" * 60)

# Create sample gradient tensor
sample_gradient = torch.randn(1000, 1000) * 0.01  # Typical gradient magnitudes

# Original size
original_size = sample_gradient.element_size() * sample_gradient.nelement()
print(f"\nüì¶ Original gradient size: {original_size / 1024 / 1024:.2f} MB (float32)")

# Quantization compression
compressor = GradientCompressor()
quantized, scale, zp = compressor.quantize(sample_gradient, num_bits=8)
dequantized = compressor.dequantize(quantized, scale, zp)

quantized_size = quantized.element_size() * quantized.nelement() + 8  # +8 for scale/zp
compression_ratio = original_size / quantized_size
quantization_error = (sample_gradient - dequantized).abs().mean().item()

print(f"\nüóúÔ∏è  8-bit Quantization:")
print(f"  Compressed size: {quantized_size / 1024 / 1024:.2f} MB")
print(f"  Compression ratio: {compression_ratio:.2f}x")
print(f"  Mean absolute error: {quantization_error:.6f}")

# Top-K sparsification
for k in [0.1, 0.01, 0.001]:
    sparse = compressor.top_k_sparsification(sample_gradient, k=k)
    sparsity = (sparse == 0).float().mean().item()
    
    # In practice, sparse tensors would be stored efficiently
    # For now, just calculate theoretical compression
    theoretical_size = original_size * k
    theoretical_ratio = original_size / theoretical_size
    sparsification_error = (sample_gradient - sparse).abs().mean().item()
    
    print(f"\n‚úÇÔ∏è  Top-{k*100:.1f}% Sparsification:")
    print(f"  Sparsity: {sparsity*100:.1f}% of values zeroed")
    print(f"  Theoretical compression: {theoretical_ratio:.2f}x")
    print(f"  Mean absolute error: {sparsification_error:.6f}")

print("\n" + "=" * 60)
print("\nüí° Key Insights:")
print("  ‚Ä¢ Quantization: 4x compression with minimal error")
print("  ‚Ä¢ Top-K: Higher compression but loses small gradients")
print("  ‚Ä¢ Trade-off: Communication cost vs convergence speed")
print("  ‚Ä¢ In practice: Combine techniques (quantized + sparse)")

### Activity: Experiment with Compression

**Your Turn!** Try the following:

1. **Vary quantization bits**: Test 4-bit, 2-bit, and 1-bit quantization. How does error change?
2. **Gradient accumulation**: Instead of syncing every iteration, accumulate gradients for N steps. How does this affect training?
3. **Mixed precision**: Research how FP16/BF16 training reduces memory and communication costs

In [None]:
# YOUR CODE HERE
# Experiment with different compression strategies

# Example starter code:
# for num_bits in [1, 2, 4, 8]:
#     quantized, scale, zp = compressor.quantize(sample_gradient, num_bits=num_bits)
#     # Analyze results...

print("Experiment with gradient compression techniques here!")

---

## Advanced Considerations

### 1. Scaling Laws

**Linear Scaling Rule** (Goyal et al., 2017):
- When increasing batch size by $k$, increase learning rate by $k$
- Example: batch 256 with lr=0.1 ‚Üí batch 1024 with lr=0.4
- Works well for moderate scaling (up to ~8k batch size)

**Gradual Warmup**:
- Start with small learning rate, gradually increase to target
- Prevents instability when using large batch sizes
- Typical: warmup for first 5-10 epochs

### 2. Communication Optimization

**Gradient Compression Techniques:**
- **Quantization**: Reduce gradient precision (8-bit, 4-bit)
- **Sparsification**: Send only top-K gradients by magnitude
- **Error Feedback**: Accumulate quantization errors, send in next iteration

**Overlapping Communication and Computation:**
- Start gradient sync while still computing later layers
- PyTorch DDP does this automatically with bucketing

### 3. Fault Tolerance

**Challenges:**
- What if one worker fails mid-training?
- How to handle stragglers (slow workers)?

**Solutions:**
- **Checkpointing**: Save model state frequently
- **Elastic training**: Dynamically add/remove workers (Torch Elastic)
- **Backup workers**: Keep redundant workers to replace failures

### 4. Beyond Data Parallelism

**Pipeline Parallelism:**
- Split model into stages across devices
- Process multiple micro-batches in pipeline
- Reduces idle time vs simple model parallelism

**Tensor Parallelism:**
- Split individual layers across devices
- Each device computes part of matrix multiplication
- Used in Megatron-LM for training massive language models

**3D Parallelism:**
- Combine data + pipeline + tensor parallelism
- Used to train models with 100B+ parameters
- Example: GPT-3 training at scale

In [None]:
# Visualize different parallelism strategies
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Data Parallelism
axes[0].text(0.5, 0.85, 'Data Parallelism', ha='center', fontsize=14, weight='bold')
for i in range(4):
    y = 0.7 - i*0.15
    axes[0].add_patch(plt.Rectangle((0.15, y-0.05), 0.7, 0.08, 
                                     facecolor=f'C{i}', edgecolor='black', linewidth=2))
    axes[0].text(0.5, y, f'GPU {i}: Full Model\nBatch {i}', ha='center', va='center', fontsize=10)
axes[0].text(0.5, 0.05, 'All GPUs sync gradients\n(All-Reduce)', ha='center', fontsize=9, style='italic')
axes[0].set_xlim(0, 1)
axes[0].set_ylim(0, 1)
axes[0].axis('off')

# Model Parallelism
axes[1].text(0.5, 0.85, 'Model Parallelism', ha='center', fontsize=14, weight='bold')
layer_names = ['Input ‚Üí L1-3', 'Layers 4-6', 'Layers 7-9', 'L10-12 ‚Üí Out']
for i in range(4):
    y = 0.7 - i*0.15
    axes[1].add_patch(plt.Rectangle((0.15, y-0.05), 0.7, 0.08, 
                                     facecolor=f'C{i}', edgecolor='black', linewidth=2))
    axes[1].text(0.5, y, f'GPU {i}:\n{layer_names[i]}', ha='center', va='center', fontsize=9)
    if i < 3:
        axes[1].arrow(0.5, y-0.06, 0, -0.035, head_width=0.04, head_length=0.01, fc='black', ec='black')
axes[1].text(0.5, 0.05, 'Sequential forward/backward\nthrough GPUs', ha='center', fontsize=9, style='italic')
axes[1].set_xlim(0, 1)
axes[1].set_ylim(0, 1)
axes[1].axis('off')

# Pipeline Parallelism
axes[2].text(0.5, 0.85, 'Pipeline Parallelism', ha='center', fontsize=14, weight='bold')
pipeline_data = [
    ['B0', 'B1', 'B2', 'B3'],
    ['', 'B0', 'B1', 'B2'],
    ['', '', 'B0', 'B1'],
    ['', '', '', 'B0']
]
for i in range(4):
    y = 0.7 - i*0.15
    axes[2].add_patch(plt.Rectangle((0.05, y-0.05), 0.9, 0.08, 
                                     facecolor='lightgray', edgecolor='black', linewidth=1))
    axes[2].text(0.02, y, f'G{i}', ha='center', va='center', fontsize=9, weight='bold')
    for j, batch in enumerate(pipeline_data[i]):
        if batch:
            x = 0.15 + j*0.18
            axes[2].add_patch(plt.Rectangle((x, y-0.03), 0.15, 0.06, 
                                             facecolor=f'C{j}', edgecolor='black', linewidth=1))
            axes[2].text(x+0.075, y, batch, ha='center', va='center', fontsize=8)
axes[2].text(0.5, 0.05, 'Micro-batches flow through\npipeline stages', ha='center', fontsize=9, style='italic')
axes[2].set_xlim(0, 1)
axes[2].set_ylim(0, 1)
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("\nüéØ Parallelism Strategy Selection:")
print("  ‚Ä¢ Data Parallelism: Model fits on 1 GPU, want faster training")
print("  ‚Ä¢ Model Parallelism: Model too large for 1 GPU memory")
print("  ‚Ä¢ Pipeline Parallelism: Reduce bubble time in model parallelism")
print("  ‚Ä¢ Hybrid: Combine all three for maximum scale (e.g., GPT-3)")

---

## Key Takeaways

### üéØ Core Concepts

1. **Data Parallelism** is the most common distributed training strategy
   - Replicate model across devices, split data
   - Synchronize gradients via All-Reduce
   - Effective batch size = num_devices √ó local_batch_size

2. **Synchronous vs Asynchronous Training**
   - Synchronous: Better convergence, limited by stragglers
   - Asynchronous: Higher throughput, potential stability issues
   - Most production systems use synchronous (All-Reduce)

3. **Communication is the Bottleneck**
   - Gradient synchronization can be 30-50% of iteration time
   - Optimize via compression, overlapping compute/communication
   - Ring All-Reduce is bandwidth-optimal

4. **Scaling Considerations**
   - Linear scaling rule: scale learning rate with batch size
   - Warmup needed for large batches
   - Diminishing returns beyond certain scale

### üìö Practical Guidelines

**When to use distributed training:**
- ‚úÖ Training takes >24 hours on single GPU
- ‚úÖ Model barely fits in GPU memory
- ‚úÖ Need to experiment with larger batch sizes
- ‚úÖ Have access to multi-GPU infrastructure

**When NOT to use distributed training:**
- ‚ùå Model trains in <1 hour (overhead not worth it)
- ‚ùå Dataset is very small
- ‚ùå Debugging new model architecture

### ‚ö†Ô∏è Common Pitfalls

1. **Forgetting to scale learning rate** with batch size
2. **Not using DistributedSampler** ‚Üí all GPUs see same data
3. **Improper gradient accumulation** ‚Üí effective batch size confusion
4. **Ignoring communication overhead** ‚Üí poor scaling efficiency
5. **No batch norm adjustments** ‚Üí stats incorrect with small local batches

### üöÄ Next Steps

1. **Hands-on Practice**: Train a real model with PyTorch DDP
2. **Read Papers**: 
   - "Accurate, Large Minibatch SGD" (Goyal et al., 2017)
   - "Deep Gradient Compression" (Lin et al., 2018)
3. **Explore Tools**: Horovod, DeepSpeed, Megatron-LM
4. **Advanced Topics**: ZeRO optimizer, 3D parallelism, gradient checkpointing

---

## Further Resources

### üìñ Documentation
- [PyTorch Distributed Training Tutorial](https://pytorch.org/tutorials/beginner/dist_overview.html)
- [PyTorch DDP Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)
- [NVIDIA NCCL Documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html)

### üìù Research Papers
1. **Goyal et al. (2017)**: [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677)
2. **Lin et al. (2018)**: [Deep Gradient Compression](https://arxiv.org/abs/1712.01887)
3. **Rajbhandari et al. (2020)**: [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054)
4. **Narayanan et al. (2021)**: [Efficient Large-Scale Language Model Training](https://arxiv.org/abs/2104.04473)

### üõ†Ô∏è Tools and Frameworks
- **Horovod**: Easy distributed deep learning framework
- **DeepSpeed**: Microsoft's optimization library for large models
- **Megatron-LM**: NVIDIA's framework for training massive language models
- **Ray**: Distributed computing framework with ML support
- **Torch Elastic**: Fault-tolerant distributed training

### üé• Videos and Courses
- [Stanford CS231n: Distributed Training](https://cs231n.stanford.edu/)
- [PyTorch Distributed Training Webinar](https://www.youtube.com/watch?v=0fKT4WVq6AQ)
- [Distributed Deep Learning with Horovod](https://www.youtube.com/watch?v=D1By2hy4Ecw)

### üè¢ Cloud Platform Guides
- [AWS SageMaker Distributed Training](https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-training.html)
- [Google Cloud AI Platform Training](https://cloud.google.com/ai-platform/training/docs/distributed-training)
- [Azure Machine Learning Distributed Training](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-distributed-gpu)

---

## üéì Congratulations!

You've completed Lesson 95 on **Distributed Training and Gradient Synchronization**. You now understand:

‚úÖ The fundamentals of data and model parallelism  
‚úÖ How gradient synchronization works (synchronous vs asynchronous)  
‚úÖ Communication patterns and optimization techniques  
‚úÖ Practical implementation with PyTorch DDP  
‚úÖ Real-world considerations for scaling distributed training  

**You're now ready to train models at scale!** üöÄ

Continue to **Week 20** to explore more advanced machine learning topics!