- Manual training loop (no
Trainer) - Deterministic seed setup
- Fixed tokenizer + max sequence length
- Gradient accumulation
- AMP (bf16 / fp16)
- Gradient clipping
- LR scheduler
- Checkpoint + resume
- Log loss, grad norm, param norm, LR, GPU memory
- Manual training loop (no
Trainer) - Deterministic seed setup
- Fixed tokenizer + max sequence length
- Gradient accumulation
- AMP (bf16 / fp16)
- Gradient clipping
- LR scheduler
- Checkpoint + resume
- Log loss, grad norm, param norm, LR, GPU memory
- PyTorch DDP reference run saved for comparison
- Deterministic dataset sharding per rank
- Same number of steps per rank
- Per-epoch seed control
- Verify no sample overlap across ranks
- Broadcast model parameters from rank 0
- Backward pass on each rank
- All-reduce gradients
- Divide gradients by
world_size - Optimizer step after synchronization
- Weights match PyTorch DDP within tolerance (rtol=1e-5, atol=1e-6)
- param_norm diff < 1e-6
- Register backward hooks per parameter
- Gradient bucketing
- Async all-reduce per bucket
- Overlap backward compute with communication
- Single sync point before optimizer step
-
no_sync()context manager for gradient accumulation
- Partition optimizer state (e.g. Adam momentum, variance) across ranks
- Each rank holds only
1 / world_sizeof optimizer state per parameter - Gather optimizer state on demand when applying updates
- Reduce memory per rank by ~4× for Adam-style optimizers
- Verify training matches DDP baseline (same loss trajectory, checkpoint parity)
- Partition gradients so each rank owns
1 / world_sizeof gradients - All-reduce only the gradient slice owned by each rank (reduce-scatter style)
- Combine with ZeRO-1: partition both gradients and optimizer state
- Further reduce memory (gradient buffers no longer replicated)
- Verify correctness vs DDP / ZeRO-1
- Partition model parameters across ranks; each rank holds
1 / world_size - All-gather parameters (or submodules) on demand before forward
- Free gathered parameters after backward for that layer (or use stream/cache)
- Combine with ZeRO-1 + ZeRO-2 for full redundancy removal
- Enable training models that do not fit on a single GPU
- Verify correctness and memory scaling with world size