# Large-Scale Training Strategies

Training modern models involves multiple GPUs or even multiple nodes. This notebook surveys data, tensor, and pipeline parallelism, illustrates PyTorch's Distributed Data Parallel (DDP) API, and demonstrates utilities like activation checkpointing and throughput logging.

## Learning Objectives

- Distinguish between data, tensor/model, and pipeline parallelism.
- Understand the scaffolding required to launch DDP jobs with `torchrun`.
- Apply activation checkpointing to trade compute for memory.
- Log throughput across workers to monitor scaling efficiency.

## Parallelism Overview

| Strategy | Concept | Best for |
|----------|---------|----------|
| Data Parallel (DDP) | Replicate model, split batches | Most scenarios |
| Tensor/Model Parallel | Split weight matrices across devices | Huge models exceeding single GPU memory |
| Pipeline Parallel | Partition layers, stream microbatches | Deep sequential networks |

Modern systems combine these strategies (e.g., data + tensor parallel).

In [None]:
import torch
import torch.nn as nn

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(1024, 2048), nn.ReLU(),
            nn.Linear(2048, 2048), nn.ReLU(),
            nn.Linear(2048, 1024)
        )

    def forward(self, x):
        return self.layers(x)

print("Parameters:", sum(p.numel() for p in ToyModel().parameters()))


## Distributed Data Parallel Skeleton

DDP launches one process per GPU. Below is illustrative pseudo-code; run it via `torchrun` for practical training.

In [None]:
import torch.distributed as dist

def setup(rank, world_size):
    backend = "nccl" if torch.cuda.is_available() else "gloo"
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    if torch.cuda.is_available():
        torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train_ddp(rank, world_size):
    setup(rank, world_size)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    model = ToyModel().to(device)
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank] if torch.cuda.is_available() else None)
    optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=1e-3)
    dataset = torch.utils.data.TensorDataset(torch.randn(1024, 1024), torch.randn(1024, 1024))
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    loader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)
    for epoch in range(2):
        sampler.set_epoch(epoch)
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            preds = ddp_model(xb)
            loss = nn.functional.mse_loss(preds, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    cleanup()

print("DDP skeleton ready (launch with torchrun)")


### Launch Command Example

```
torchrun --nproc_per_node=4 --rdzv_backend=c10d --rdzv_endpoint=localhost:29500 train_ddp.py
```

Adjust rendezvous settings for multi-node setups.

## Activation Checkpointing

Checkpointing re-computes intermediate activations during backward pass to save memory.

In [None]:
from torch.utils.checkpoint import checkpoint

class CheckpointedBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim), nn.ReLU(),
            nn.Linear(dim, dim), nn.ReLU()
        )

    def forward(self, x):
        def inner(t):
            return self.block(t)
        return checkpoint(inner, x)

block = CheckpointedBlock(1024)
sample = torch.randn(4, 1024, requires_grad=True)
block(sample).sum().backward()
print("Checkpointed backward completed")


## Mini Task – Throughput Logger

Implement a utility that aggregates processed samples across workers (using `dist.all_reduce`) and prints global throughput.

In [None]:
def log_throughput(local_count, elapsed, world_size):
    # TODO: all_reduce local_count when distributed initialized, compute samples/sec
    raise NotImplementedError


In [None]:
def log_throughput(local_count, elapsed, world_size):
    if dist.is_available() and dist.is_initialized():
        total = torch.tensor([local_count], dtype=torch.float32)
        dist.all_reduce(total, op=dist.ReduceOp.SUM)
        total = total.item()
    else:
        total = float(local_count)
    throughput = total / max(elapsed, 1e-6)
    print(f"Throughput: {throughput:.2f} samples/sec across {world_size} workers")
    return throughput

_ = log_throughput(local_count=1024, elapsed=1.5, world_size=4)


## Comprehensive Exercise – Distributed Training Blueprint

Draft the skeleton for a distributed training script that supports:

- Argument parsing for world size, epochs, batch size, checkpoint directory.
- Process group initialization and cleanup.
- Optimizer, scheduler, GradScaler setup.
- Checkpoint saving and restoration for fault tolerance.

In [None]:
def main():
    # TODO: parse args, spawn processes, handle checkpoints
    raise NotImplementedError

def train(rank, args):
    # TODO: set device, wrap model with DDP, train with checkpointing and scaler
    raise NotImplementedError


In [None]:
import argparse
import torch.multiprocessing as mp

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints")
    parser.add_argument("--world_size", type=int, default=torch.cuda.device_count() or 1)
    return parser.parse_args([])  # empty list for notebook demo

def train(rank, args):
    setup(rank, args.world_size)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    model = ToyModel().to(device)
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank] if torch.cuda.is_available() else None)
    optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

    dataset = torch.utils.data.TensorDataset(torch.randn(2048, 1024), torch.randn(2048, 1024))
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler)

    for epoch in range(args.epochs):
        sampler.set_epoch(epoch)
        ddp_model.train()
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                preds = ddp_model(xb)
                loss = nn.functional.mse_loss(preds, yb)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        scheduler.step()
    cleanup()

def main():
    args = parse_args()
    if args.world_size > 1:
        mp.spawn(train, args=(args,), nprocs=args.world_size, join=True)
    else:
        train(0, args)

print("Blueprint ready; adapt for actual training script")


## Further Reading

- PyTorch Distributed Overview: https://pytorch.org/tutorials/beginner/dist_overview.html
- Megatron-LM and DeepSpeed for hybrid parallelism
- NVIDIA NCCL tuning guide for multi-node setups
- ZeRO optimizer (DeepSpeed) for sharding optimizer states