# Part 2.1: Introduction to Distributed Computing in PyTorch

**PyTorch Translation of JAX Tutorial**

**Author:** Translated from Phillip Lippe's JAX tutorial


Recent success in deep learning has been driven by the availability of large datasets and the ability to train large models on these datasets. However, training large models on large datasets is computationally expensive and usually goes beyond the capability of a single accelerator like a GPU. To speed up training, we can use parallelism to distribute the computation across multiple devices.

This notebook introduces the basic concepts of distributed, multi-device processing in PyTorch. Unlike JAX's approach with explicit sharding and `shard_map`, PyTorch uses a different paradigm with process groups, collective operations, and abstractions like DistributedDataParallel (DDP).

PyTorch's distributed computing is built around the concept of **processes** rather than **devices**. Each process typically manages one GPU and communicates with other processes through collective operations. This is different from JAX's device-centric approach where you explicitly shard arrays across devices.

For this tutorial, we'll simulate multiple processes on a single machine to demonstrate the concepts. In practice, you would run separate Python processes, each managing one GPU.


In [1]:
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
from typing import Any, Dict, Tuple, Optional
import functools
import warnings

# For tutorial purposes, we'll use CPU and simulate multiple processes
USE_CPU_ONLY = False
WORLD_SIZE = 4  # Equivalent to 8 devices in JAX

if USE_CPU_ONLY:
    device = torch.device("cpu")
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")
print(f"Simulating {WORLD_SIZE} processes")


Using device: cuda
Simulating 4 processes


## Distributed Computing in PyTorch

This section introduces the basic concepts of distributed computing in PyTorch. Unlike JAX's explicit array sharding, PyTorch uses a process-based approach where each process typically owns one GPU and communicates through collective operations.

### Key Differences from JAX:

1. **Process-based vs Device-based**: PyTorch uses separate processes, each typically managing one GPU
2. **Implicit vs Explicit sharding**: PyTorch often handles data distribution automatically (e.g., in DDP)
3. **Collective operations**: Similar communication patterns but different API
4. **Process groups**: Logical grouping of processes for communication

### Basic Setup

In PyTorch distributed computing, we need to:
1. Initialize the process group
2. Set up communication backend (NCCL for GPU, Gloo for CPU)
3. Define rank (process ID) and world size (total processes)


In [2]:
def setup_distributed(rank: int, world_size: int, backend: str = "gloo"):
    """Initialize distributed training setup.
    
    Args:
        rank: Process rank (0 to world_size-1)
        world_size: Total number of processes
        backend: Communication backend ('nccl' for GPU, 'gloo' for CPU)
    """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize process group
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    
    print(f"Process {rank}/{world_size} initialized")
    return rank, world_size

def cleanup_distributed():
    """Clean up distributed training."""
    if dist.is_initialized():
        dist.destroy_process_group()

# For demonstration, we'll create utility functions that simulate distributed behavior
class DistributedSimulator:
    """Simulates distributed operations for tutorial purposes."""
    
    def __init__(self, world_size: int = 8):
        self.world_size = world_size
        self.rank = 0  # Current simulated rank
    
    def simulate_tensor_across_ranks(self, tensor: torch.Tensor) -> Dict[int, torch.Tensor]:
        """Simulate how a tensor would be distributed across ranks."""
        if tensor.dim() == 1:
            # Shard along first dimension
            chunk_size = tensor.size(0) // self.world_size
            chunks = {}
            for rank in range(self.world_size):
                start_idx = rank * chunk_size
                end_idx = start_idx + chunk_size
                chunks[rank] = tensor[start_idx:end_idx].clone()
            return chunks
        elif tensor.dim() == 2:
            # Shard along first dimension (batch dimension)
            chunk_size = tensor.size(0) // self.world_size
            chunks = {}
            for rank in range(self.world_size):
                start_idx = rank * chunk_size
                end_idx = start_idx + chunk_size
                chunks[rank] = tensor[start_idx:end_idx].clone()
            return chunks
        else:
            raise NotImplementedError("Only 1D and 2D tensors supported in this demo")
    
    def visualize_distribution(self, tensor_dict: Dict[int, torch.Tensor], name: str = "tensor"):
        """Visualize how tensor is distributed across ranks."""
        print(f"\n{name} distribution across {self.world_size} processes:")
        for rank, chunk in tensor_dict.items():
            print(f"  Rank {rank}: shape {chunk.shape}, data: {chunk.flatten()[:5].tolist()}{'...' if chunk.numel() > 5 else ''}")

# Create simulator instance
sim = DistributedSimulator(WORLD_SIZE)


### Basic Tensor Distribution

Let's start with a simple example of how tensors would be distributed across processes in PyTorch. Unlike JAX's explicit sharding, PyTorch typically handles this through data loaders and collective operations.


In [3]:
# Create a simple tensor (equivalent to JAX's jnp.arange(8))
a = torch.arange(8, dtype=torch.float32)
print("Original tensor:", a)
print("Device:", a.device)

# Simulate distribution across processes
distributed_a = sim.simulate_tensor_across_ranks(a)
sim.visualize_distribution(distributed_a, "tensor 'a'")

# Apply operation to each chunk (simulating distributed computation)
distributed_tanh = {}
for rank, chunk in distributed_a.items():
    distributed_tanh[rank] = torch.tanh(chunk)

sim.visualize_distribution(distributed_tanh, "tanh(a)")

# Gather results back (equivalent to collecting sharded results)
gathered_result = torch.cat([distributed_tanh[rank] for rank in range(WORLD_SIZE)])
print(f"\nGathered result: {gathered_result}")
print(f"Verification - matches torch.tanh(a): {torch.allclose(gathered_result, torch.tanh(a))}")


Original tensor: tensor([0., 1., 2., 3., 4., 5., 6., 7.])
Device: cpu

tensor 'a' distribution across 4 processes:
  Rank 0: shape torch.Size([2]), data: [0.0, 1.0]
  Rank 1: shape torch.Size([2]), data: [2.0, 3.0]
  Rank 2: shape torch.Size([2]), data: [4.0, 5.0]
  Rank 3: shape torch.Size([2]), data: [6.0, 7.0]

tanh(a) distribution across 4 processes:
  Rank 0: shape torch.Size([2]), data: [0.0, 0.7615941762924194]
  Rank 1: shape torch.Size([2]), data: [0.9640275835990906, 0.9950547814369202]
  Rank 2: shape torch.Size([2]), data: [0.9993293285369873, 0.9999092221260071]
  Rank 3: shape torch.Size([2]), data: [0.9999877214431763, 0.9999983310699463]

Gathered result: tensor([0.0000, 0.7616, 0.9640, 0.9951, 0.9993, 0.9999, 1.0000, 1.0000])
Verification - matches torch.tanh(a): True


## Collective Operations in PyTorch

PyTorch provides collective operations similar to JAX's communication primitives. Here are the main equivalents:

| JAX Operation | PyTorch Equivalent | Description |
|---------------|-------------------|-------------|
| `jax.lax.psum` | `dist.all_reduce` | Sum/average across processes |
| `jax.lax.all_gather` | `dist.all_gather` | Gather tensors from all processes |
| `jax.lax.psum_scatter` | `dist.reduce_scatter` | Reduce then scatter |
| `jax.lax.ppermute` | `dist.send`/`dist.recv` | Point-to-point communication |

Let's demonstrate these operations:


In [None]:
class CollectiveOperations:
    """Demonstrates PyTorch collective operations."""
    
    def __init__(self, world_size: int):
        self.world_size = world_size
    
    def simulate_all_reduce(self, tensors: Dict[int, torch.Tensor], op: str = "sum") -> Dict[int, torch.Tensor]:
        """Simulate dist.all_reduce operation."""
        # Compute the reduction
        if op == "sum":
            reduced = sum(tensors.values())
        elif op == "mean":
            reduced = sum(tensors.values()) / len(tensors)
        else:
            raise ValueError(f"Unsupported operation: {op}")
        
        # All processes get the same result
        return {rank: reduced.clone() for rank in range(self.world_size)}
    
    def simulate_all_gather(self, tensors: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]:
        """Simulate dist.all_gather operation."""
        # Gather all tensors
        gathered = torch.cat([tensors[rank] for rank in range(self.world_size)])
        
        # All processes get the full gathered tensor
        return {rank: gathered.clone() for rank in range(self.world_size)}
    
    def simulate_reduce_scatter(self, tensors: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]:
        """Simulate dist.reduce_scatter operation."""
        # First reduce (sum all tensors)
        reduced = sum(tensors.values())
        
        # Then scatter (each process gets a chunk)
        chunk_size = reduced.size(0) // self.world_size
        scattered = {}
        for rank in range(self.world_size):
            start_idx = rank * chunk_size
            end_idx = start_idx + chunk_size
            scattered[rank] = reduced[start_idx:end_idx].clone()
        
        return scattered

# Create collective operations simulator
collective_ops = CollectiveOperations(WORLD_SIZE)


### All-Reduce Operation (equivalent to JAX's psum)

All-reduce is one of the most important operations in distributed training. It sums (or averages) tensors across all processes and gives each process the result.


In [None]:
# Create different values on each "process"
values_per_rank = {rank: torch.tensor([float(rank + 1)]) for rank in range(WORLD_SIZE)}

print("Before all_reduce:")
for rank, value in values_per_rank.items():
    print(f"  Rank {rank}: {value.item()}")

# Simulate all_reduce sum
reduced_sum = collective_ops.simulate_all_reduce(values_per_rank, "sum")
print("\nAfter all_reduce (sum):")
for rank, value in reduced_sum.items():
    print(f"  Rank {rank}: {value.item()}")

# Simulate all_reduce mean
reduced_mean = collective_ops.simulate_all_reduce(values_per_rank, "mean")
print("\nAfter all_reduce (mean):")
for rank, value in reduced_mean.items():
    print(f"  Rank {rank}: {value.item()}")

# Practical example: averaging gradients in distributed training
print("\n--- Practical Example: Gradient Averaging ---")
torch.manual_seed(42)
gradients_per_rank = {rank: torch.randn(3) for rank in range(WORLD_SIZE)}

print("Gradients on each rank:")
for rank, grad in gradients_per_rank.items():
    print(f"  Rank {rank}: {grad.numpy()}")

averaged_gradients = collective_ops.simulate_all_reduce(gradients_per_rank, "mean")
print("\nAveraged gradients (same on all ranks):")
print(f"  All ranks: {averaged_gradients[0].numpy()}")


## Practical Example: Distributed Data Parallel Training

Let's create a practical example that shows how PyTorch's distributed training works in practice. This demonstrates the equivalent of JAX's shard_map approach but using PyTorch's process-based paradigm.


In [None]:
class SimpleModel(nn.Module):
    """Simple neural network for distributed training demonstration."""
    
    def __init__(self, input_dim: int = 64, hidden_dim: int = 128, output_dim: int = 10):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        return x

def simulate_distributed_training_step(model, data_per_rank, targets_per_rank, world_size):
    """Simulate one step of distributed training."""
    
    # Each rank computes loss and gradients on its data chunk
    losses_per_rank = {}
    gradients_per_rank = {}
    
    for rank in range(world_size):
        # Forward pass on local data
        local_data = data_per_rank[rank]
        local_targets = targets_per_rank[rank]
        
        outputs = model(local_data)
        loss = F.cross_entropy(outputs, local_targets)
        losses_per_rank[rank] = loss
        
        # Backward pass
        loss.backward()
        
        # Collect gradients
        local_gradients = []
        for param in model.parameters():
            if param.grad is not None:
                local_gradients.append(param.grad.clone())
        gradients_per_rank[rank] = local_gradients
        
        # Clear gradients for next rank simulation
        model.zero_grad()
    
    return losses_per_rank, gradients_per_rank

# Create model and data
torch.manual_seed(0)
model = SimpleModel(input_dim=64, hidden_dim=128, output_dim=10)
batch_size = 32
input_dim = 64

# Create synthetic data distributed across ranks
data = torch.randn(batch_size, input_dim)
targets = torch.randint(0, 10, (batch_size,))

# Distribute data across ranks
data_per_rank = sim.simulate_tensor_across_ranks(data)
targets_per_rank = sim.simulate_tensor_across_ranks(targets)

print("Data distribution:")
sim.visualize_distribution(data_per_rank, "training data")

# Simulate distributed training step
losses, gradients = simulate_distributed_training_step(model, data_per_rank, targets_per_rank, WORLD_SIZE)

print(f"\nLosses per rank:")
for rank, loss in losses.items():
    print(f"  Rank {rank}: {loss.item():.4f}")

# Simulate gradient averaging (all_reduce)
print(f"\nGradient averaging across ranks:")
averaged_gradients = []
for param_idx in range(len(gradients[0])):
    param_grads = {rank: gradients[rank][param_idx] for rank in range(WORLD_SIZE)}
    avg_grad = collective_ops.simulate_all_reduce(param_grads, "mean")
    averaged_gradients.append(avg_grad[0])  # All ranks have same averaged gradient

print(f"Gradients averaged across {WORLD_SIZE} ranks")
print(f"Number of parameter groups: {len(averaged_gradients)}")
for i, grad in enumerate(averaged_gradients[:2]):  # Show first 2 parameter gradients
    print(f"  Param {i} gradient shape: {grad.shape}, norm: {grad.norm().item():.4f}")


## Real PyTorch Distributed Training

In practice, PyTorch distributed training is much simpler than our simulation. Here's how you would actually implement distributed training:


In [None]:
def real_distributed_training_example():
    """Shows how real PyTorch distributed training would work."""
    
    # This is the code you would actually write for distributed training
    code_example = '''
def train_distributed(rank, world_size):
    # 1. Setup
    setup_distributed(rank, world_size)
    
    # 2. Create model and wrap with DDP
    model = SimpleModel()
    if torch.cuda.is_available():
        model = model.cuda(rank)
    model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)
    
    # 3. Create distributed sampler for data
    from torch.utils.data import DataLoader, DistributedSampler
    dataset = YourDataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
    
    # 4. Training loop
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Important for proper shuffling
        
        for batch_idx, (data, targets) in enumerate(dataloader):
            if torch.cuda.is_available():
                data, targets = data.cuda(rank), targets.cuda(rank)
            
            # Forward pass
            outputs = model(data)
            loss = F.cross_entropy(outputs, targets)
            
            # Backward pass - DDP automatically handles gradient averaging
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    cleanup_distributed()

# Launch training with:
# mp.spawn(train_distributed, args=(world_size,), nprocs=world_size, join=True)
'''
    
    print("Real PyTorch Distributed Training Code:")
    print(code_example)
    
    print("\nKey differences from our simulation:")
    print("1. DistributedDataParallel (DDP) automatically handles gradient averaging")
    print("2. DistributedSampler ensures each process gets different data")
    print("3. You launch multiple processes with mp.spawn() or torchrun")
    print("4. Each process runs the same code but on different data")

# Show the example
real_distributed_training_example()


## Summary: JAX vs PyTorch Distributed Computing

Here's a comprehensive comparison of distributed computing concepts between JAX and PyTorch:

### Architecture Differences

| Aspect | JAX | PyTorch |
|--------|-----|---------|
| **Paradigm** | Device-centric sharding | Process-centric communication |
| **Explicit Control** | High (explicit sharding) | Medium (abstracted by DDP) |
| **Setup** | Mesh + PartitionSpec | Process groups + backends |
| **SPMD** | `shard_map` functions | Multiple processes with same code |

### Communication Operations

| Operation | JAX | PyTorch | Notes |
|-----------|-----|---------|-------|
| **Sum/Average** | `jax.lax.psum` | `dist.all_reduce` | Both sum across devices/processes |
| **Gather** | `jax.lax.all_gather` | `dist.all_gather` | Collect data from all devices/processes |
| **Scatter-Sum** | `jax.lax.psum_scatter` | `dist.reduce_scatter` | Reduce then distribute chunks |
| **Point-to-point** | `jax.lax.ppermute` | `dist.send`/`dist.recv` | Direct device-to-device communication |

### Practical Usage

**JAX Approach:**
- Explicit control over device placement
- Write per-device code with `shard_map`
- Manual handling of communication
- Great for research and custom parallelism strategies

**PyTorch Approach:**
- Higher-level abstractions (DDP, FSDP)
- Write single-process code, run multiple processes
- Automatic gradient synchronization
- Better for standard training workflows

### When to Choose What?

**Choose JAX if:**
- You need fine-grained control over device placement
- You're implementing custom parallelism strategies
- You prefer functional programming paradigms
- You're doing research requiring explicit control

**Choose PyTorch if:**
- You want easier setup for standard distributed training
- You prefer object-oriented programming
- You need extensive ecosystem support
- You're building production systems

Both frameworks are powerful for distributed computing, but they approach the problem differently. JAX gives you more explicit control, while PyTorch provides higher-level abstractions that are easier to use for common scenarios.
