<a href="https://colab.research.google.com/github/amankiitg/5DParallel/blob/main/DataParallelization_GradientSynchronization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data Parallelism: Gradient Synchronization Strategies

## Naive DP (No Overlap) vs DDP (With Overlap)

This notebook demonstrates the **key difference** between:

1. **Naive DP** — All gradients are synchronized **AFTER** the entire backward pass completes (sequential: compute → communicate)
2. **DDP with Overlap** — Gradients are synchronized **DURING** the backward pass as buckets become ready (overlapped: compute + communicate simultaneously)

---

### ⚠️ Critical Design Note

A common mistake (and what the original code had) is implementing "naive" DP using `register_hook()` on parameters. **Hooks fire during backward** as each gradient is computed — which means all-reduce already overlaps with computation. That's essentially the same as DDP!

**The fix:** For truly naive (no-overlap) DP, we must:
- Do `loss.backward()` with **no hooks** and **no DDP wrapper**
- **After** backward completes, manually loop over all parameters and call `all_reduce` on each gradient

This ensures communication is fully sequential after computation.

## Step 1: Install Dependencies

In [None]:
# Uncomment and run once if needed
# !pip install torch torchvision tensorboard torch-tb-profiler -q

## Step 2: Create the Naive DP Script (NO Overlap)

**Key idea:** We wrap the model in a simple `nn.Module` with **NO hooks**. After `loss.backward()` finishes completely, we manually call `dist.all_reduce()` on every gradient in a separate loop. This creates a clear **sequential** pattern:

```
backward (all compute) ──────────────> | all_reduce (all communication) ──────────────>
```

We also use a **deeper and wider model** (40 layers × 4096 width) so the backward pass is long enough to clearly see the separation in the trace.

In [None]:
naive_dp_code = '''
import os, argparse
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.profiler import profile, ProfilerActivity, record_function


def build_model(depth=40, width=4096):
    """Build a deep MLP to make backward pass long enough to see overlap differences.

    Using 40 layers x 4096 width gives us:
    - A long backward pass with many sequential gradient computations
    - Large gradients (4096x4096 = 16M params per layer) that make all_reduce visible
    """
    layers = []
    for _ in range(depth):
        layers += [nn.Linear(width, width, bias=False), nn.ReLU()]
    return nn.Sequential(*layers)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--logdir", type=str, required=True)
    p.add_argument("--steps", type=int, default=10)
    p.add_argument("--batch", type=int, default=64)
    args = p.parse_args()

    # Initialize distributed training
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend="nccl")

    # =========================================================================
    # NAIVE DP: Plain model with NO hooks, NO DDP wrapper
    # We will manually all-reduce gradients AFTER backward completes
    # =========================================================================
    model = build_model().cuda()
    opt = torch.optim.SGD(model.parameters(), lr=0.01)

    # Broadcast initial parameters so all ranks start with the same weights
    for p_tensor in model.parameters():
        dist.broadcast(p_tensor.data, src=0)

    # Synthetic data
    x = torch.randn(args.batch, 4096, device="cuda")
    target = torch.randn(args.batch, 4096, device="cuda")

    # Warmup (3 steps to stabilize CUDA kernels and NCCL)
    for _ in range(3):
        y = model(x)
        loss = (y - target).pow(2).mean()
        opt.zero_grad()
        loss.backward()
        # Manual all-reduce after backward
        for p_tensor in model.parameters():
            if p_tensor.grad is not None:
                dist.all_reduce(p_tensor.grad, op=dist.ReduceOp.SUM)
                p_tensor.grad /= world_size
        opt.step()
        torch.cuda.synchronize()

    # Profile
    worker_name = f"rank{rank}"
    prof_schedule = torch.profiler.schedule(wait=2, warmup=1, active=3, repeat=1)

    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        schedule=prof_schedule,
        on_trace_ready=torch.profiler.tensorboard_trace_handler(
            args.logdir, worker_name=worker_name
        ),
        record_shapes=True,
        profile_memory=True,
        with_stack=False,
    ) as prof:
        for step in range(args.steps):
            dist.barrier()
            torch.cuda.synchronize()

            with record_function(f"step_{step}"):

                with record_function("forward"):
                    y = model(x)

                with record_function("loss"):
                    loss = (y - target).pow(2).mean()

                opt.zero_grad()

                # ---------------------------------------------------------
                # NAIVE BACKWARD: Pure computation, NO communication here
                # ---------------------------------------------------------
                with record_function("backward_COMPUTE_ONLY"):
                    loss.backward()

                # Force backward to fully complete on GPU before communication
                torch.cuda.synchronize()

                # ---------------------------------------------------------
                # NAIVE ALL-REDUCE: ALL communication happens here, AFTER
                # backward is 100% done. This is the "no overlap" part.
                # ---------------------------------------------------------
                with record_function("allreduce_ALL_GRADS_SEQUENTIAL"):
                    for p_tensor in model.parameters():
                        if p_tensor.grad is not None:
                            dist.all_reduce(p_tensor.grad, op=dist.ReduceOp.SUM)
                            p_tensor.grad /= world_size

                torch.cuda.synchronize()

                with record_function("optimizer"):
                    opt.step()

                torch.cuda.synchronize()

            prof.step()

    dist.destroy_process_group()


if __name__ == "__main__":
    main()
'''

with open("naive_dp_profile.py", "w") as f:
    f.write(naive_dp_code)

print("✓ Created naive_dp_profile.py")

## Step 3: Create the DDP Script (WITH Overlap)

**Key idea:** PyTorch's `DistributedDataParallel` (DDP) groups parameters into **buckets** and fires `all_reduce` on each bucket as soon as all gradients in that bucket are ready — while backward is still running for earlier layers. This creates an **overlapped** pattern:

```
backward: [layer40 grad]──[layer39 grad]──[layer38 grad]──[layer37 grad]──...
comms:                     [allreduce bucket1]──────────[allreduce bucket2]──────...
```

We use a **small bucket size** (`bucket_cap_mb=5`) to create more frequent, smaller all-reduce calls — making the overlap more visible in the trace.

In [None]:
ddp_overlap_code = '''
import os, argparse
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.profiler import profile, ProfilerActivity, record_function


def build_model(depth=40, width=4096):
    """Same model architecture as naive DP for fair comparison."""
    layers = []
    for _ in range(depth):
        layers += [nn.Linear(width, width, bias=False), nn.ReLU()]
    return nn.Sequential(*layers)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--logdir", type=str, required=True)
    p.add_argument("--steps", type=int, default=10)
    p.add_argument("--batch", type=int, default=64)
    p.add_argument("--bucket_cap_mb", type=int, default=5)
    args = p.parse_args()

    # Initialize distributed training
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend="nccl")

    # =========================================================================
    # DDP: Wraps model with automatic gradient bucketing + overlapped all-reduce
    # - bucket_cap_mb=5: Small buckets → more frequent all-reduce calls
    #   (makes overlap pattern more visible in profiler trace)
    # - gradient_as_bucket_view=True: Avoid extra gradient copy
    # =========================================================================
    model = build_model().cuda()
    ddp = DDP(
        model,
        device_ids=[local_rank],
        broadcast_buffers=False,
        bucket_cap_mb=args.bucket_cap_mb,
        gradient_as_bucket_view=True,
    )
    opt = torch.optim.SGD(ddp.parameters(), lr=0.01)

    # Synthetic data
    x = torch.randn(args.batch, 4096, device="cuda")
    target = torch.randn(args.batch, 4096, device="cuda")

    # Warmup
    for _ in range(3):
        y = ddp(x)
        loss = (y - target).pow(2).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        torch.cuda.synchronize()

    # Profile
    worker_name = f"rank{rank}"
    prof_schedule = torch.profiler.schedule(wait=2, warmup=1, active=3, repeat=1)

    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        schedule=prof_schedule,
        on_trace_ready=torch.profiler.tensorboard_trace_handler(
            args.logdir, worker_name=worker_name
        ),
        record_shapes=True,
        profile_memory=True,
        with_stack=False,
    ) as prof:
        for step in range(args.steps):
            dist.barrier()
            torch.cuda.synchronize()

            with record_function(f"step_{step}"):

                with record_function("forward"):
                    y = ddp(x)

                with record_function("loss"):
                    loss = (y - target).pow(2).mean()

                opt.zero_grad()

                # ---------------------------------------------------------
                # DDP BACKWARD: Computation and communication are OVERLAPPED
                # As each bucket of gradients is computed, DDP automatically
                # fires all_reduce on that bucket while backward continues
                # computing gradients for earlier layers.
                # ---------------------------------------------------------
                with record_function("backward_WITH_OVERLAP"):
                    loss.backward()

                torch.cuda.synchronize()

                with record_function("optimizer"):
                    opt.step()

                torch.cuda.synchronize()

            prof.step()

    dist.destroy_process_group()


if __name__ == "__main__":
    main()
'''

with open("ddp_overlap_profile.py", "w") as f:
    f.write(ddp_overlap_code)

print("✓ Created ddp_overlap_profile.py")

## Step 4: Run Experiment 1 — Naive DP (No Overlap)

In [None]:
import subprocess
import time

print("=" * 70)
print("EXPERIMENT 1: Naive Data Parallelism (NO overlap)")
print("=" * 70)

# Clean old logs
subprocess.run("rm -rf logs/naive_dp", shell=True)

# Run naive DP
result = subprocess.run(
    "torchrun --standalone --nproc_per_node=2 naive_dp_profile.py "
    "--logdir logs/naive_dp --steps 10 --batch 64",
    shell=True,
    capture_output=True,
    text=True,
)
print(result.stdout)
if result.returncode != 0:
    print("STDERR:", result.stderr[-2000:])  # Last 2000 chars of error
else:
    print("✓ Naive DP profiling complete!")

## Step 5: Run Experiment 2 — DDP with Overlap

In [None]:
print("\n" + "=" * 70)
print("EXPERIMENT 2: DDP with Communication Overlap")
print("=" * 70)

time.sleep(2)

# Clean old logs
subprocess.run("rm -rf logs/ddp_overlap", shell=True)

# Run DDP with overlap (small bucket size to maximize visible overlap)
result = subprocess.run(
    "torchrun --standalone --nproc_per_node=2 ddp_overlap_profile.py "
    "--logdir logs/ddp_overlap --steps 10 --batch 64 --bucket_cap_mb 5",
    shell=True,
    capture_output=True,
    text=True,
)
print(result.stdout)
if result.returncode != 0:
    print("STDERR:", result.stderr[-2000:])
else:
    print("✓ DDP overlap profiling complete!")

print("\n" + "=" * 70)
print("✓ Both experiments completed!")
print("=" * 70)
print("\nGenerated trace logs:")
print("  - logs/naive_dp/      (NO overlap)")
print("  - logs/ddp_overlap/   (WITH overlap)")

## Step 6: View Results in TensorBoard

### Launch TensorBoard

```bash
# In a terminal on RunPod:
cd /workspace
pkill -f tensorboard
tensorboard --logdir logs --host 0.0.0.0 --port 6007
```

Then SSH tunnel from your local machine:
```bash
ssh root@<RUNPOD_IP> -p <PORT> -i ~/.ssh/id_ed25519 -L 6007:localhost:6007
```

Open `http://localhost:6007` → **PYTORCH_PROFILER** tab → **Trace** view.

---

### What to Look For in the Trace View

#### Naive DP (No Overlap) — `naive_dp`

You should see **two distinct, sequential blocks**:

```
CPU:  |── backward_COMPUTE_ONLY ──|── allreduce_ALL_GRADS_SEQUENTIAL ──|
GPU:  |███ compute kernels ███████|░░░░░░░░ idle ░░░░░░░░░░░░░░░░░░░░░░|
NCCL: |░░░░░░░░ idle ░░░░░░░░░░░░|███ nccl:all_reduce ████████████████|
```

- `backward_COMPUTE_ONLY` region: only compute kernels (matmul, ReLU backward), **zero NCCL calls**
- `allreduce_ALL_GRADS_SEQUENTIAL` region: a **burst of nccl:all_reduce** calls, one per parameter, all bunched together
- The GPU compute stream is **idle** during the all-reduce phase

#### DDP with Overlap — `ddp_overlap`

You should see **interleaved compute and communication**:

```
CPU:  |──────────── backward_WITH_OVERLAP ──────────────|
GPU:  |███ compute ██ compute ██ compute ██ compute ████|
NCCL: |░░░░░░░░░███ allreduce ████ allreduce ████ allreduce █████|
```

- Inside `backward_WITH_OVERLAP`: you see both compute kernels AND `nccl:all_reduce` calls **happening simultaneously**
- `c10d::allreduce_` and `record_param_comms` operations appear **interleaved** with backward compute
- The NCCL stream shows all-reduce calls **overlapping** with the GPU compute stream
- Overall step time should be **shorter** since communication is hidden behind computation