# Distributed PyTorch Tutorial

This tutorial covers the lower-level distributed PyTorch APIs, including:
- `torch.distributed` basics
- Process groups and initialization  
- Collective operations (AllReduce, Broadcast, etc.)
- Distributed Data Parallel (DDP)
- Custom distributed training loops

## Prerequisites
- PyTorch with CUDA support
- Multiple GPUs (or multiple processes for CPU-only)
- Basic understanding of PyTorch tensors and models

## How to Run This Notebook
```bash
# For 2 GPUs
torchrun --nproc_per_node=2 distributed_pytorch_tutorial.ipynb

# For 4 GPUs  
torchrun --nproc_per_node=4 distributed_pytorch_tutorial.ipynb
```


## 1. Basic Distributed Setup

First, let's understand the key concepts and environment variables.


In [2]:
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import time
import numpy as np

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")

# Check environment variables
env_vars = ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']
for var in env_vars:
    print(f"{var}: {os.environ.get(var, 'Not set')}")


CUDA available: True
CUDA device count: 4
RANK: Not set
LOCAL_RANK: Not set
WORLD_SIZE: Not set


## 2. Process Group Initialization

The foundation of distributed PyTorch is the process group.


In [3]:
def setup_distributed():
    """Initialize distributed training"""
    rank = int(os.environ.get('RANK', 0))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    
    if world_size > 1:
        dist.init_process_group(
            backend='nccl',  # Use NCCL for GPU, 'gloo' for CPU
            init_method='env://',
            world_size=world_size,
            rank=rank
        )
        
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
    
    return rank, local_rank, world_size

# Initialize distributed
rank, local_rank, world_size = setup_distributed()
print(f"Process {rank}/{world_size} (local rank: {local_rank})")
print(f"Distributed initialized: {dist.is_initialized()}")


Process 0/1 (local rank: 0)
Distributed initialized: False


## 3. Collective Operations

Collective operations are the building blocks of distributed training.


In [4]:
def demonstrate_collectives():
    """Demonstrate various collective operations"""
    if not dist.is_initialized():
        print("Distributed not initialized, skipping collectives")
        return
    
    device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
    
    # Create a tensor on each process
    tensor = torch.tensor([rank + 1], device=device, dtype=torch.float32)
    print(f"Process {rank}: Initial tensor = {tensor}")
    
    # 1. AllReduce - Sum all tensors and distribute result to all processes
    allreduce_tensor = tensor.clone()
    dist.all_reduce(allreduce_tensor, op=dist.ReduceOp.SUM)
    print(f"Process {rank}: After AllReduce = {allreduce_tensor}")
    
    # 2. Broadcast - Send tensor from root to all processes
    broadcast_tensor = torch.zeros_like(tensor)
    if rank == 0:
        broadcast_tensor = tensor.clone()
    dist.broadcast(broadcast_tensor, src=0)
    print(f"Process {rank}: After Broadcast = {broadcast_tensor}")

demonstrate_collectives()


Distributed not initialized, skipping collectives


In [None]:
def benchmark_allreduce():
    """Benchmark AllReduce performance"""
    if not dist.is_initialized():
        print("Distributed not initialized, skipping benchmark")
        return
    
    device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
    sizes_mb = [1, 10, 100, 500]
    
    if rank == 0:
        print(f"\\nAllReduce Benchmark (World Size: {world_size})")
        print("Size (MB) | Time (ms) | Bandwidth (GB/s)")
        print("-" * 40)
    
    for size_mb in sizes_mb:
        elements = int(size_mb * 1024 * 1024 / 4)  # float32 = 4 bytes
        tensor = torch.randn(elements, device=device, dtype=torch.float32)
        
        # Warmup
        for _ in range(3):
            test_tensor = tensor.clone()
            dist.all_reduce(test_tensor, op=dist.ReduceOp.SUM)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
        
        # Benchmark
        times = []
        for _ in range(10):
            test_tensor = tensor.clone()
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.perf_counter()
            
            dist.all_reduce(test_tensor, op=dist.ReduceOp.SUM)
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            end = time.perf_counter()
            
            times.append(end - start)
        
        avg_time = np.mean(times)
        bandwidth = (size_mb / 1024) / avg_time  # GB/s
        
        if rank == 0:
            print(f"{size_mb:8d} | {avg_time*1000:8.1f} | {bandwidth:12.1f}")

benchmark_allreduce()


In [None]:
class SimpleModel(nn.Module):
    """A simple model for demonstration"""
    def __init__(self, input_size=1000, hidden_size=500, output_size=10):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def ddp_training():
    """Distributed training using DDP"""
    if not dist.is_initialized():
        print("Distributed not initialized, skipping DDP training")
        return
    
    device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
    
    # Create model and wrap with DDP
    model = SimpleModel().to(device)
    model = DDP(model, device_ids=[device] if torch.cuda.is_available() else None)
    
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # Create dummy data
    batch_size = 32
    x = torch.randn(batch_size, 1000, device=device)
    y = torch.randint(0, 10, (batch_size,), device=device)
    
    if rank == 0:
        print(f"\\nDDP Training (World Size: {world_size})")
        print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Training loop
    for epoch in range(3):
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        
        # DDP automatically handles gradient synchronization!
        optimizer.step()
        
        if rank == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

ddp_training()


In [None]:
def manual_distributed_training():
    """Manual distributed training without DDP"""
    if not dist.is_initialized():
        print("Distributed not initialized, skipping training")
        return
    
    device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
    
    # Create model and move to device
    model = SimpleModel().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # Create dummy data
    batch_size = 32
    x = torch.randn(batch_size, 1000, device=device)
    y = torch.randint(0, 10, (batch_size,), device=device)
    
    if rank == 0:
        print(f"\\nManual Distributed Training (World Size: {world_size})")
    
    for epoch in range(3):
        # Forward pass
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        
        # Backward pass
        loss.backward()
        
        # Manual gradient synchronization
        for param in model.parameters():
            if param.grad is not None:
                dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
                param.grad /= world_size  # Average gradients
        
        # Update parameters
        optimizer.step()
        
        if rank == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

manual_distributed_training()


In [None]:
def cleanup():
    """Clean up distributed resources"""
    if dist.is_initialized():
        dist.destroy_process_group()
        print(f"Process {rank}: Distributed resources cleaned up")

cleanup()


In [1]:
print('hi')

hi
