# Implementing Data Parallel Training in PyTorch

In [1]:
import os, math, argparse, time
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler

In [2]:
@dataclass
class TrainCfg:
    vocab_size: int = 512
    seq_len: int = 64
    d_model: int = 256
    n_heads: int = 8
    n_layers: int = 4
    d_ff: int = 1024
    batch_size: int = 32
    steps: int = 50
    lr: float = 3e-4
    seed: int = 42
    dataset_size: int = 8192

In [4]:
def setup_distributed_notebook(rank, world_size, backend="nccl", master_addr="localhost", master_port="12355"):
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = master_port
    
    # Initialize the process group
    dist.init_process_group(
        backend=backend,
        init_method=f'env://',
        world_size=world_size,
        rank=rank
    )
    
    # Set the device for this process
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    
    return rank, rank, world_size, device  # rank, local_rank, world_size, device

def cleanup_distributed():
    if dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()

def set_seed(seed: int, rank: int = 0):
    # Make RNG streams reproducible but distinct per-rank if desired
    torch.manual_seed(seed + rank)
    torch.cuda.manual_seed_all(seed + rank)

def broadcast_parameters_and_buffers(model, src=0):
    """Ensure identical model state across ranks without relying on identical RNG."""
    for p in model.parameters():
        dist.broadcast(p.data, src=src)
    for b in model.buffers():
        dist.broadcast(b.data, src=src)

In [5]:
class ToySequenceDataset(Dataset):
    """
    Generate random sequences of integers
    """
    def __init__(self, num_samples=8192, seq_len=64, vocab_size=512, seed=1234, device="cpu"):
        g = torch.Generator(device=device).manual_seed(seed)
        self.data = torch.randint(0, vocab_size, (num_samples, seq_len), generator=g, device=device).cpu()
        self.vocab_size = vocab_size
        self.seq_len = seq_len

    def __len__(self): return self.data.size(0)
    def __getitem__(self, idx): return self.data[idx]

def sinusoidal_positional_embedding(max_seq, d_model, device="cpu"):
    pe = torch.zeros(max_seq, d_model, device=device)
    pos = torch.arange(0, max_seq, device=device).unsqueeze(1)
    div = torch.exp(torch.arange(0, d_model, 2, device=device) * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(pos * div)
    pe[:, 1::2] = torch.cos(pos * div)
    return pe  # [T, D]

class SelfAttention(nn.Module):
    def __init__(self, d_model=256, n_heads=8):
        super().__init__()
        assert d_model % n_heads == 0
        self.d = d_model; self.h = n_heads; self.dh = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        B, T, D = x.shape
        qkv = self.qkv(x)                               # [B, T, 3D]
        q, k, v = qkv.chunk(3, dim=-1)
        # reshape to [B, H, T, Dh]
        q = q.view(B, T, self.h, self.dh).transpose(1, 2)
        k = k.view(B, T, self.h, self.dh).transpose(1, 2)
        v = v.view(B, T, self.h, self.dh).transpose(1, 2)
        # Scaled dot-product attention (causal LM)
        attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
        attn = attn.transpose(1, 2).contiguous().view(B, T, D)
        return self.proj(attn)

class MLP(nn.Module):
    def __init__(self, d_model=256, d_ff=1024):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

class Block(nn.Module):
    def __init__(self, d_model=256, n_heads=8, d_ff=1024):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = SelfAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, d_ff)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class TinyTransformerLM(nn.Module):
    def __init__(self, vocab_size=512, d_model=256, n_heads=8, n_layers=4, d_ff=1024, max_seq=64):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        pe = sinusoidal_positional_embedding(max_seq, d_model)
        self.register_buffer("pos_emb", pe, persistent=False)  # buffer, not trained
        self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):  # idx: [B, T]
        B, T = idx.shape
        x = self.tok_emb(idx) + self.pos_emb[:T, :].unsqueeze(0)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        return self.head(x)  # [B, T, V]


In [6]:
def sync_grads_allreduce(model, world_size):
    for p in model.parameters():
        if p.grad is None: continue
        dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
        p.grad.div_(world_size)

@torch.no_grad()
def sync_grads_rs_ag(model, world_size, device):
    """
    Ring all-reduce = reduce_scatter + all_gather.
    We implement it explicitly over flattened grads to demonstrate both collectives.
    """
    grads = []
    shapes = []
    for p in model.parameters():
        if p.grad is None:
            grads.append(torch.zeros(p.numel(), device=device, dtype=p.dtype))
        else:
            g = p.grad.detach().contiguous().view(-1)
            grads.append(g)
        shapes.append(p.shape)

    flat = torch.cat(grads)  # [N]
    n = flat.numel()
    pad = (world_size - (n % world_size)) % world_size
    if pad:
        flat = F.pad(flat, (0, pad))

    chunk_size = flat.numel() // world_size
    chunks = list(torch.split(flat, chunk_size))
    # reduce_scatter: sum corresponding chunks across ranks -> each rank gets its assigned chunk
    out_chunk = torch.empty_like(chunks[0])
    dist.reduce_scatter(out_chunk, chunks, op=dist.ReduceOp.SUM)  # sum over ranks, scatter by chunk index
    out_chunk.div_(world_size)  # average

    # all_gather the averaged chunks back to all ranks
    gathered = [torch.empty_like(out_chunk) for _ in range(world_size)]
    dist.all_gather(gathered, out_chunk)
    flat_avg = torch.cat(gathered, dim=0)
    if pad:
        flat_avg = flat_avg[:n]

    # write back into p.grad
    offset = 0
    for p in model.parameters():
        numel = p.numel()
        if p.grad is None:
            p.grad = flat_avg[offset:offset+numel].view_as(p).clone()
        else:
            p.grad.copy_(flat_avg[offset:offset+numel].view_as(p.grad))
        offset += numel


In [7]:
def make_model(cfg: TrainCfg):
    return TinyTransformerLM(
        vocab_size=cfg.vocab_size,
        d_model=cfg.d_model, n_heads=cfg.n_heads, n_layers=cfg.n_layers, d_ff=cfg.d_ff,
        max_seq=cfg.seq_len,
    )

def iterate_batches(dataset, batch_size, sampler, pin_memory=False, num_workers=0):
    return DataLoader(dataset, batch_size=batch_size, sampler=sampler,
                      pin_memory=pin_memory, num_workers=num_workers, drop_last=True)

def lm_loss(logits, tokens):
    # next-token prediction: targets are tokens[:, 1:]
    targets = tokens[:, 1:].contiguous()                      # [B, T-1]
    logits = logits[:, :-1, :].contiguous()                   # [B, T-1, V]
    B, Tm1, V = logits.shape
    return F.cross_entropy(logits.view(B*Tm1, V), targets.view(B*Tm1))


In [8]:
def train_ddp_worker(rank, world_size, cfg, mode="ddp", sync_method="allreduce", work_dir="./runs", save_tag=""):
    """
    Worker function that runs on each GPU process.
    This is called by torch.multiprocessing.spawn()
    """
    try:
        # Setup distributed
        rank, local_rank, world_size, device = setup_distributed_notebook(rank, world_size)
        
        # Set seed
        set_seed(cfg.seed, rank)
        
        # Create model
        model = make_model(cfg).to(device)
        
        # Broadcast initial weights
        broadcast_parameters_and_buffers(model, src=0)
        dist.barrier()
        
        # Create dataset
        dataset = ToySequenceDataset(
            num_samples=cfg.dataset_size, 
            seq_len=cfg.seq_len,
            vocab_size=cfg.vocab_size, 
            seed=1234
        )
        sampler = DistributedSampler(
            dataset, 
            num_replicas=world_size, 
            rank=rank, 
            shuffle=False, 
            drop_last=True
        )
        loader = DataLoader(
            dataset, 
            batch_size=cfg.batch_size, 
            sampler=sampler,
            pin_memory=True, 
            num_workers=0, 
            drop_last=True
        )
        
        # Setup training based on mode
        if mode == "ddp":
            ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
            optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=cfg.lr)
            train_model = ddp_model
        else:  # manual
            optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
            train_model = model
        
        train_model.train()
        step = 0
        t0 = time.time()
        
        for batch in loader:
            batch = batch.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            logits = train_model(batch)
            loss = lm_loss(logits, batch)
            loss.backward()
            
            # Manual gradient sync if needed
            if mode == "manual":
                if sync_method == "allreduce":
                    sync_grads_allreduce(model, world_size)
                else:
                    sync_grads_rs_ag(model, world_size, device)
            
            optimizer.step()
            
            # Logging
            with torch.no_grad():
                metrics = torch.tensor([loss.item(), 1.0], device=device)
                dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
                if rank == 0 and (step % 10 == 0 or step == cfg.steps - 1):
                    mean_loss = metrics[0].item() / world_size
                    elapsed = time.time() - t0
                    print(f"[{mode}/{sync_method if mode=='manual' else 'default'}] "
                          f"step {step:04d}  loss {mean_loss:.4f}  time {elapsed:.2f}s")
            
            step += 1
            if step >= cfg.steps:
                break
        
        # Save checkpoint
        if rank == 0:
            os.makedirs(work_dir, exist_ok=True)
            tag = save_tag if save_tag else f"{mode}_{sync_method if mode=='manual' else 'default'}"
            checkpoint_path = os.path.join(work_dir, f"{tag}.pt")
            torch.save({
                "model": model.state_dict(), 
                "cfg": cfg.__dict__
            }, checkpoint_path)
            print(f"[{mode}] Saved checkpoint to {checkpoint_path}")
        
        dist.barrier()
        
    finally:
        cleanup_distributed()

In [12]:
def train_multi_gpu(
    num_gpus=None,
    mode="ddp",
    sync_method="allreduce", 
    steps=50,
    batch_size=32,
    seq_len=64,
    vocab_size=512,
    work_dir="./runs",
    save_tag="",
    seed=42
):
    """
    
    Args:
        num_gpus: Number of GPUs to use (default: all available)
        mode: "ddp" or "manual"
        sync_method: "allreduce" or "rs_ag" (only for manual mode)
        steps: Number of training steps
        batch_size: Batch size per GPU
        seq_len: Sequence length
        vocab_size: Vocabulary size
        work_dir: Directory to save checkpoints
        save_tag: Tag for checkpoint filename
        seed: Random seed
    
    Example:
        # In a Jupyter notebook cell:
        train_multi_gpu(num_gpus=2, mode="ddp", steps=100)
    """
    if num_gpus is None:
        num_gpus = torch.cuda.device_count()
    
    if num_gpus == 0:
        raise RuntimeError("No CUDA devices available!")
    
    print(f"🚀 Launching training on {num_gpus} GPUs...")
    print(f"   Mode: {mode}")
    if mode == "manual":
        print(f"   Sync method: {sync_method}")
    print(f"   Steps: {steps}")
    print(f"   Batch size per GPU: {batch_size}")
    print()
    
    cfg = TrainCfg(
        vocab_size=vocab_size,
        seq_len=seq_len,
        batch_size=batch_size,
        steps=steps,
        seed=seed
    )
    
    ctx = mp.get_context('fork')
    processes = []
    
    for rank in range(num_gpus):
        p = ctx.Process(
            target=train_ddp_worker,
            args=(rank, num_gpus, cfg, mode, sync_method, work_dir, save_tag)
        )
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()
        if p.exitcode != 0:
            raise RuntimeError(f"Process failed with exit code {p.exitcode}")    
    
    print(f"\n✅ Training complete!")



In [13]:
train_multi_gpu()

🚀 Launching training on 4 GPUs...
   Mode: ddp
   Steps: 50
   Batch size per GPU: 32





[ddp/default] step 0000  loss 6.4049  time 2.07s
[ddp/default] step 0010  loss 6.2795  time 2.21s
[ddp/default] step 0020  loss 6.2496  time 2.33s
[ddp/default] step 0030  loss 6.2507  time 2.45s
[ddp/default] step 0040  loss 6.2449  time 2.57s




[ddp/default] step 0049  loss 6.2457  time 2.69s




[ddp] Saved checkpoint to ./runs/ddp_default.pt





✅ Training complete!
