# PyTorch Tutorial: Distributed Training (DDP & FSDP)

When your model is too big for one GPU, or your data is so large that training takes weeks, you need **Distributed Training**. This is a standard requirement for FAANG AI roles working on Foundation Models.

## Learning Objectives
- Understand **Data Parallelism** vs **Model Parallelism**
- Learn the structure of **DDP (Distributed Data Parallel)**
- Introduction to **FSDP (Fully Sharded Data Parallel)** for massive models

## 1. Vocabulary First

- **Node**: A physical machine (server). A node can have multiple GPUs.
- **Rank**: The ID of a process. If you have 4 GPUs, ranks are 0, 1, 2, 3.
- **World Size**: Total number of processes (GPUs) in the training job.
- **Master Address/Port**: Where the processes coordinate (usually Rank 0).
- **Scatter/Gather**: Sending data to GPUs / Collecting results back.
- **All-Reduce**: A synchronization step where all GPUs share their gradients and calculate the average.

## 2. Distributed Data Parallel (DDP)

**How it works:**
1. Copy the model to every GPU.
2. Split the dataset (each GPU gets a different chunk).
3. Forward pass runs independently on each GPU.
4. Backward pass computes gradients.
5. **All-Reduce**: Gradients are averaged across all GPUs.
6. Optimizer updates weights (identical on all GPUs).

### The Code Structure
DDP requires a script, not a notebook, because it spawns multiple processes. Here is the template you would use:

In [None]:
# ddp_script.py (Template)

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    # Initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)
    
    # 1. Create Model and move to GPU (Rank)
    model = torch.nn.Linear(10, 10).to(rank)
    
    # 2. Wrap with DDP
    ddp_model = DDP(model, device_ids=[rank])
    
    # 3. Loss and Optimizer
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
    
    # 4. Training Loop
    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(rank))
    loss = loss_fn(outputs, torch.randn(20, 10).to(rank))
    loss.backward()
    optimizer.step()
    
    cleanup()

def main():
    world_size = 2 # Number of GPUs
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

# if __name__ == "__main__":
#     main()

## 3. Fully Sharded Data Parallel (FSDP)

DDP replicates the *entire model* on every GPU. If your model is 100GB and your GPU has 80GB, DDP fails.

**FSDP** solves this by **sharding** (splitting) the model parameters, gradients, and optimizer states across GPUs. Each GPU only holds a piece of the model.

### When to use FSDP?
- Training LLMs (Llama 3, GPT-4 class models).
- When model size > GPU memory.

```python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# Wrap your model
fsdp_model = FSDP(model)
```

## Key Takeaways

1. **DDP** is fast and standard for most models.
2. **FSDP** is necessary for giant models (LLMs).
3. **`dist.init_process_group`** is the handshake that starts distributed training.