# DeepSpeed ZeRO: A Complete Guide to Sharded Data Parallelism

## Fine-tuning a 7B Language Model with ZeRO Stages 0–3

---

This notebook is a deep-dive into **DeepSpeed ZeRO** (Zero Redundancy Optimizer), the most widely used
technique for training large language models when they don't fit in a single GPU's memory.

We will:
1. Understand **why** ZeRO exists — the memory problem in distributed training
2. Learn **how** each ZeRO stage works — what gets sharded and why
3. Train a **real 7B model** (Pythia-6.9B) on instruction-following data (Alpaca)
4. Compare ZeRO stages 0→3 on **memory, throughput, and convergence**
5. Run **inference** with the fine-tuned model to verify quality

### Hardware
- **Recommended:** 2× NVIDIA B200 (192 GB each) or 2× H100/A100 80GB
- **Minimum:** 2× A100 40GB (use Pythia-2.8B instead of 6.9B)

### RunPod Configuration
```
GPU Pod: 2× B200 192GB  (or 2× H100 80GB, or 2× A100 80GB)
Template: RunPod PyTorch 2.1+
Disk: 100 GB (model weights + checkpoints)
```


---
# Part 1: The Memory Problem

## Where Does GPU Memory Go During Training?

When you train a model with Adam optimizer in mixed precision (bf16), each parameter consumes:

| Component | Bytes per Parameter | For 7B Model | For 70B Model |
|-----------|-------------------|-------------|---------------|
| Model parameters (bf16) | 2 | 14 GB | 140 GB |
| Gradients (bf16) | 2 | 14 GB | 140 GB |
| Adam momentum `m` (fp32) | 4 | 28 GB | 280 GB |
| Adam variance `v` (fp32) | 4 | 28 GB | 280 GB |
| Adam master weights (fp32) | 4 | 28 GB | 280 GB |
| **Total per GPU (no sharding)** | **16** | **112 GB** | **1,120 GB** |

Plus **activation memory** (depends on batch size and sequence length).

### The Insight

In standard DDP (DistributedDataParallel), **every GPU holds a complete copy** of all this state.
With 8 GPUs, you have 8 identical copies of the optimizer states — that's **7 copies wasted**.

ZeRO eliminates this redundancy by **partitioning** (sharding) state across GPUs.

## The 4 Stages of ZeRO

```
┌─────────────────────────────────────────────────────────────────┐
│                     What Each GPU Stores                         │
├──────────┬────────────┬────────────┬────────────────────────────┤
│          │ Parameters │ Gradients  │ Optimizer (m, v, master)   │
├──────────┼────────────┼────────────┼────────────────────────────┤
│ ZeRO-0   │ FULL copy  │ FULL copy  │ FULL copy                  │
│ (= DDP)  │            │            │                            │
├──────────┼────────────┼────────────┼────────────────────────────┤
│ ZeRO-1   │ FULL copy  │ FULL copy  │ SHARDED (1/N)              │
│          │            │            │ Saves ~4× on optimizer     │
├──────────┼────────────┼────────────┼────────────────────────────┤
│ ZeRO-2   │ FULL copy  │ SHARDED    │ SHARDED (1/N)              │
│          │            │ (1/N)      │ Saves ~8× on opt+grad      │
├──────────┼────────────┼────────────┼────────────────────────────┤
│ ZeRO-3   │ SHARDED    │ SHARDED    │ SHARDED (1/N)              │
│          │ (1/N)      │ (1/N)      │ Saves ~16× on everything   │
└──────────┴────────────┴────────────┴────────────────────────────┘
```

### Memory per GPU (7B model, 2 GPUs, bf16)

```
ZeRO-0:  14 + 14 + 84 = 112 GB  (everything replicated)
ZeRO-1:  14 + 14 + 42 =  70 GB  (optimizer sharded → 84/2 = 42)
ZeRO-2:  14 +  7 + 42 =  63 GB  (+ grads sharded → 14/2 = 7)
ZeRO-3:   7 +  7 + 42 =  56 GB  (+ params sharded → 14/2 = 7)
```

## Communication Cost at Each Stage

There's no free lunch — less memory means more communication:

| Stage | Communication Pattern | Volume (per step) |
|-------|----------------------|-------------------|
| ZeRO-0 | All-reduce on gradients (during backward) | 2Φ |
| ZeRO-1 | All-reduce on gradients + all-gather optimizer results | 2Φ |
| ZeRO-2 | Reduce-scatter gradients + all-gather updated params | 2Φ |
| ZeRO-3 | All-gather params before fwd+bwd + reduce-scatter grads | 3Φ |

Where Φ = total parameter bytes. ZeRO-3 communicates **50% more** than ZeRO-0/1/2.
This is why ZeRO-3 is slower — the extra communication is the price of lower memory.


---
# Part 2: Environment Setup


In [None]:
%%bash
# Install dependencies
pip install -q deepspeed transformers datasets tokenizers accelerate matplotlib

# Optional: Flash Attention 2 (strongly recommended for B200/H100)
pip install -q flash-attn --no-build-isolation 2>/dev/null || echo 'Flash Attention not installed (non-critical)'

echo ''
echo '=== Environment Check ==='
python -c '
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    name = torch.cuda.get_device_name(i)
    mem = torch.cuda.get_device_properties(i).total_mem / 1e9
    print(f"  GPU {i}: {name} ({mem:.0f} GB)")
'
python -c 'import deepspeed; print(f"DeepSpeed: {deepspeed.__version__}")'
echo '=== Ready! ==='


### Pre-download model and data (so timing measurements are clean)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

MODEL_NAME = 'EleutherAI/pythia-6.9b-deduped'
# For smaller GPUs (A100 40GB), use:
# MODEL_NAME = 'EleutherAI/pythia-2.8b-deduped'

print(f'Downloading tokenizer for {MODEL_NAME}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f'Downloading model weights for {MODEL_NAME}...')
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype='auto')
n_params = sum(p.numel() for p in model.parameters())
print(f'  Parameters: {n_params:,} ({n_params/1e9:.2f}B)')
del model  # free memory

print(f'Downloading Alpaca dataset...')
ds = load_dataset('tatsu-lab/alpaca')
print(f'  Training examples: {len(ds["train"]):,}')

import torch; torch.cuda.empty_cache()
print('\n✓ Everything cached and ready!')


---
# Part 3: Understanding the Training Script

Before running experiments, let's walk through the key pieces of `deepspeed_train.py`.

## 3.1 — DeepSpeed Initialization

The core of DeepSpeed is `deepspeed.initialize()`. This single call:
1. Reads your JSON config to determine the ZeRO stage
2. Wraps your model for distributed training
3. Creates the optimizer with the appropriate sharding
4. Sets up gradient communication hooks

```python
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,                    # Your vanilla PyTorch model
    model_parameters=model.parameters(),
    config=args.deepspeed_config,   # JSON file controls EVERYTHING
)
```

The returned `model_engine` replaces your model. You call:
- `model_engine(input_ids=...)` for forward pass
- `model_engine.backward(loss)` instead of `loss.backward()`
- `model_engine.step()` instead of `optimizer.step()`

## 3.2 — The Config is King

Here's what changes between ZeRO stages — it's ONLY the JSON config:

**ZeRO-0** (baseline DDP):
```json
{ "zero_optimization": { "stage": 0 } }
```

**ZeRO-1** (shard optimizer):
```json
{ "zero_optimization": { "stage": 1 } }
```

**ZeRO-2** (shard optimizer + gradients):
```json
{
  "zero_optimization": {
    "stage": 2,
    "overlap_comm": true,        // Overlap comm with compute
    "reduce_scatter": true,      // Use reduce-scatter (efficient)
    "contiguous_gradients": true  // Coalesce grads in memory
  }
}
```

**ZeRO-3** (shard everything):
```json
{
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "stage3_prefetch_bucket_size": 2e8,          // Prefetch next layer's params
    "stage3_param_persistence_threshold": 1e5,   // Keep tiny params local
    "stage3_gather_16bit_weights_on_model_save": true  // Consolidate for saving
  }
}
```

**The training script is IDENTICAL across all stages.** This is DeepSpeed's design philosophy:
separate the parallelism strategy from the training code.

## 3.3 — Saving with ZeRO-3

With ZeRO-3, parameters are distributed across GPUs — no single GPU has the full model.
To save, DeepSpeed must **consolidate** parameters from all ranks:

```python
if zero_stage == 3:
    model_engine.save_16bit_model(output_path)  # Gathers from all GPUs
else:
    model_engine.module.save_pretrained(output_path)  # Local save
```


---
# Part 4: Running the Experiments

We'll train Pythia-6.9B on Alpaca with each ZeRO stage and compare:
- **Peak GPU memory** — how much does sharding save?
- **Throughput** — what's the communication overhead?
- **Loss convergence** — do all stages produce the same result?

Each run does **500 steps** (enough to see convergence trends and measure throughput).
At the end, we'll do a **full training run** with the best config for inference.


## Experiment 1: ZeRO Stage 0 (No Sharding = DDP)

This is the **baseline**. Every GPU holds a complete copy of everything.
All gradient synchronization uses all-reduce during backward (overlapped).

**Expected:** Highest memory usage, fastest throughput (minimal communication).


In [None]:
import subprocess, time, os

os.makedirs('output', exist_ok=True)

MODEL = 'EleutherAI/pythia-6.9b-deduped'
# MODEL = 'EleutherAI/pythia-2.8b-deduped'  # Use this for smaller GPUs
STEPS = 500  # Comparison runs: 500 steps each

print('═'*72)
print('  ZeRO Stage 0: NO Sharding (DDP Baseline)')
print('═'*72)

t0 = time.time()
result = subprocess.run(
    f'deepspeed --num_gpus=2 deepspeed_train.py '
    f'--deepspeed_config ds_zero0.json '
    f'--model_name {MODEL} '
    f'--max_steps {STEPS} --batch_size 4 --grad_accum 2 '
    f'--output_dir ./output --run_name zero0',
    shell=True, capture_output=True, text=True, cwd='/workspace'
)
elapsed = time.time() - t0
print(result.stdout[-3000:])  # Last 3000 chars
if result.returncode != 0:
    print('STDERR:', result.stderr[-3000:])
else:
    print(f'\n  ✓ ZeRO-0 complete in {elapsed:.0f}s')


## Experiment 2: ZeRO Stage 1 (Shard Optimizer States)

Now the Adam optimizer states (momentum `m`, variance `v`, and fp32 master weights)
are **partitioned** across GPUs. Each GPU stores only 1/N of them.

After `optimizer.step()`, each GPU updates its partition, then the results are
all-gathered so every GPU has the updated parameters.

**Expected:** Lower memory (optimizer is the biggest component!), similar throughput.


In [None]:
print('═'*72)
print('  ZeRO Stage 1: Shard Optimizer States')
print('═'*72)

t0 = time.time()
result = subprocess.run(
    f'deepspeed --num_gpus=2 deepspeed_train.py '
    f'--deepspeed_config ds_zero1.json '
    f'--model_name {MODEL} '
    f'--max_steps {STEPS} --batch_size 4 --grad_accum 2 '
    f'--output_dir ./output --run_name zero1',
    shell=True, capture_output=True, text=True, cwd='/workspace'
)
elapsed = time.time() - t0
print(result.stdout[-3000:])
if result.returncode != 0:
    print('STDERR:', result.stderr[-3000:])
else:
    print(f'\n  ✓ ZeRO-1 complete in {elapsed:.0f}s')


## Experiment 3: ZeRO Stage 2 (Shard Optimizer + Gradients)

In addition to optimizer sharding, **gradients are now reduce-scattered** instead of
all-reduced. This means:
1. During backward, each GPU computes full gradients (same as before)
2. Instead of all-reduce (everyone gets all gradients), we use **reduce-scatter**:
   each GPU receives only its 1/N partition of the averaged gradients
3. Each GPU updates its 1/N partition of the optimizer
4. Updated parameters are all-gathered

**Key insight:** In ZeRO-2, `reduce_scatter + all_gather` replaces `all_reduce`.
The total communication volume is the same (2Φ), but gradients don't persist in full.

**This is the sweet spot for most use cases** — good memory savings, minimal throughput loss.


In [None]:
print('═'*72)
print('  ZeRO Stage 2: Shard Optimizer + Gradients')
print('═'*72)

t0 = time.time()
result = subprocess.run(
    f'deepspeed --num_gpus=2 deepspeed_train.py '
    f'--deepspeed_config ds_zero2.json '
    f'--model_name {MODEL} '
    f'--max_steps {STEPS} --batch_size 4 --grad_accum 2 '
    f'--output_dir ./output --run_name zero2',
    shell=True, capture_output=True, text=True, cwd='/workspace'
)
elapsed = time.time() - t0
print(result.stdout[-3000:])
if result.returncode != 0:
    print('STDERR:', result.stderr[-3000:])
else:
    print(f'\n  ✓ ZeRO-2 complete in {elapsed:.0f}s')


## Experiment 4: ZeRO Stage 3 (Shard Everything)

The most aggressive stage: **parameters themselves are distributed**. No GPU holds
the full model at any point during training.

The lifecycle of a parameter in ZeRO-3:

```
  ┌──────────────┐      ┌──────────────┐      ┌──────────────┐
  │  Parameters   │      │  Parameters   │      │  Parameters   │
  │  (sharded     │─────>│  (all-gather  │─────>│  (use in     │
  │   1/N each)   │      │   full layer) │      │   forward)    │
  └──────────────┘      └──────────────┘      └──────┬───────┘
                                                      │
                                                      v
  ┌──────────────┐      ┌──────────────┐      ┌──────────────┐
  │  Free full    │      │  Reduce-     │      │  Compute     │
  │  params,      │<─────│  scatter     │<─────│  gradients   │
  │  keep shard   │      │  gradients   │      │  (backward)  │
  └──────────────┘      └──────────────┘      └──────────────┘
```

This adds **1Φ extra communication** (all-gather before forward) on top of ZeRO-2.

**Use ZeRO-3 when the model is too large to hold full parameters on each GPU.**
With 2 GPUs, a 7B model fits without ZeRO-3. But a 70B model REQUIRES it.


In [None]:
print('═'*72)
print('  ZeRO Stage 3: Shard Everything')
print('═'*72)

t0 = time.time()
result = subprocess.run(
    f'deepspeed --num_gpus=2 deepspeed_train.py '
    f'--deepspeed_config ds_zero3.json '
    f'--model_name {MODEL} '
    f'--max_steps {STEPS} --batch_size 4 --grad_accum 2 '
    f'--output_dir ./output --run_name zero3',
    shell=True, capture_output=True, text=True, cwd='/workspace'
)
elapsed = time.time() - t0
print(result.stdout[-3000:])
if result.returncode != 0:
    print('STDERR:', result.stderr[-3000:])
else:
    print(f'\n  ✓ ZeRO-3 complete in {elapsed:.0f}s')


---
# Part 5: Results Comparison


In [None]:
import subprocess
result = subprocess.run('python compare_zero_stages.py ./output',
                        shell=True, capture_output=True, text=True, cwd='/workspace')
print(result.stdout)
if result.returncode != 0:
    print('STDERR:', result.stderr[-2000:])


### Visualization

In [None]:
import json, glob, os
import matplotlib.pyplot as plt
import numpy as np

results = []
for f in sorted(glob.glob('/workspace/output/results_zero*.json')):
    with open(f) as fh:
        results.append(json.load(fh))

if not results:
    print('No results found — run the experiments first!')
else:
    results.sort(key=lambda r: r['zero_stage'])
    stages = [f"ZeRO-{r['zero_stage']}" for r in results]

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Memory
    mems = [r['peak_memory_gb'] for r in results]
    colors = ['#ef4444','#f59e0b','#22c55e','#4a9eed']
    axes[0].bar(stages, mems, color=colors[:len(stages)], edgecolor='white', linewidth=1.5)
    axes[0].set_ylabel('Peak Memory (GB)')
    axes[0].set_title('GPU Memory', fontweight='bold')
    for i, v in enumerate(mems):
        axes[0].text(i, v+0.5, f'{v:.1f}', ha='center', fontweight='bold')

    # Throughput
    tputs = [r['avg_throughput_tok_s'] for r in results]
    axes[1].bar(stages, tputs, color=colors[:len(stages)], edgecolor='white', linewidth=1.5)
    axes[1].set_ylabel('Tokens/sec')
    axes[1].set_title('Throughput', fontweight='bold')
    for i, v in enumerate(tputs):
        axes[1].text(i, v+100, f'{v:.0f}', ha='center', fontweight='bold')

    # Loss curves
    for r in results:
        hist = r.get('loss_history', [])
        if hist:
            axes[2].plot(range(1, len(hist)+1), hist, label=f"ZeRO-{r['zero_stage']}", linewidth=2)
    axes[2].set_xlabel('Step')
    axes[2].set_ylabel('Loss')
    axes[2].set_title('Training Loss (should overlap!)', fontweight='bold')
    axes[2].legend()
    axes[2].grid(alpha=0.3)

    plt.suptitle('DeepSpeed ZeRO Comparison', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()


### Interpreting the Results

**Memory:** You should see a clear staircase — ZeRO-0 highest, ZeRO-3 lowest.
For a 7B model on 2 GPUs, expect roughly:
- ZeRO-0: ~80-110 GB (full replication)
- ZeRO-1: ~60-70 GB (optimizer sharded)
- ZeRO-2: ~50-60 GB (+ gradient sharded)
- ZeRO-3: ~30-50 GB (+ params sharded)

**Throughput:** ZeRO-0/1/2 should be within ~5% of each other. ZeRO-3 will be
noticeably slower (20-40%) due to the extra parameter all-gather communication.

**Loss curves:** These should be IDENTICAL (or very close). The math is the same —
ZeRO only changes how memory is managed, not what's computed. Small differences
are due to floating-point ordering in reductions.


---
# Part 6: Full Training Run (for Inference Quality)

The 500-step comparison runs are too short for meaningful instruction-following.
Now let's do a proper training run with ZeRO-2 (best memory/speed trade-off).

We'll train for **2000 steps** (~half an epoch of Alpaca), which takes about 15-30 minutes
on 2× B200.


In [None]:
print('═'*72)
print('  FULL TRAINING: ZeRO-2, 2000 steps')
print('  This will take 15-30 minutes...')
print('═'*72)

import subprocess, time
t0 = time.time()

result = subprocess.run(
    f'deepspeed --num_gpus=2 deepspeed_train.py '
    f'--deepspeed_config ds_zero2.json '
    f'--model_name {MODEL} '
    f'--max_steps 2000 --batch_size 4 --grad_accum 2 '
    f'--warmup_steps 100 --eval_interval 200 '
    f'--output_dir ./output --run_name zero2_full '
    f'--save_model',
    shell=True, capture_output=True, text=True, cwd='/workspace'
)
elapsed = time.time() - t0
print(result.stdout[-5000:])
if result.returncode != 0:
    print('STDERR:', result.stderr[-3000:])
else:
    print(f'\n  ✓ Full training complete in {elapsed:.0f}s ({elapsed/60:.1f} min)')


---
# Part 7: Inference — Before vs After

Let's see if the fine-tuning actually worked! We'll compare the base model's
responses to the fine-tuned model's responses on the same prompts.

The base model (Pythia-6.9B) was pretrained on web text — it can generate coherent
English but doesn't follow instructions well. After Alpaca fine-tuning, it should
produce structured, helpful responses to instructions.


In [None]:
import subprocess

print('═'*72)
print('  INFERENCE: Base vs Fine-tuned Comparison')
print('═'*72)

result = subprocess.run(
    f'python run_inference.py '
    f'--model_path ./output/zero2_full '
    f'--base_model {MODEL}',
    shell=True, capture_output=True, text=True, cwd='/workspace'
)
print(result.stdout)
if result.returncode != 0:
    print('STDERR:', result.stderr[-3000:])


---
# Part 8: When to Use What — Decision Framework

```
                   Does model + optimizer fit on 1 GPU?
                              │
                    ┌─────────┴─────────┐
                    YES                  NO
                    │                    │
              Use ZeRO-0            Does model fit on 1 GPU
              (plain DDP)           but optimizer doesn't?
                                         │
                               ┌─────────┴─────────┐
                               YES                  NO
                               │                    │
                          Use ZeRO-1/2         Use ZeRO-3
                          (shard optimizer)    (shard everything)
                               │
                         Need max memory?
                               │
                     ┌─────────┴─────────┐
                     YES                  NO
                     │                    │
                Use ZeRO-2          Use ZeRO-1
                (most popular)      (simplest)
```

### Rules of Thumb

| Scenario | Recommendation | Reason |
|----------|---------------|--------|
| 7B model, 4× A100 80GB | ZeRO-1 or ZeRO-2 | Model fits, save optimizer memory for larger batches |
| 13B model, 2× A100 80GB | ZeRO-2 | Tight fit — need optimizer + gradient sharding |
| 70B model, 8× A100 80GB | ZeRO-3 | Model doesn't fit on one GPU |
| 7B model, 2× A100 40GB | ZeRO-2 or ZeRO-3 | Limited memory — need sharding |
| Any model, want max throughput | ZeRO-0 or ZeRO-1 | Less communication overhead |

### ZeRO-2 is the Default Choice for Most Teams

It gives you significant memory savings with almost no throughput penalty.
Only move to ZeRO-3 when you need the extra memory, and only stay at ZeRO-0
when your model is small enough that memory isn't a concern.

### Beyond ZeRO: When Data Parallelism Isn't Enough

If you have 100+ GPUs, data parallelism alone hits a batch size wall.
You need **3D parallelism**:
- **Tensor Parallelism**: Split matrix multiplies across GPUs (within a node)
- **Pipeline Parallelism**: Split layers across GPUs (across nodes)
- **Data Parallelism (ZeRO)**: Replicate the above across groups of GPUs

That's the next lecture.
