### 🧩 **PyTorch Profiling a Distributed Workload**

The goal of this tutorial is to understand the **core concepts behind profiling a distributed workload**.  
It directly builds on the previous tutorial that covered profiling a single-GPU run.

Here we focus on what changes when a workload runs under **PyTorch Distributed (DDP or otherwise)** — not on the mechanics of DDP itself, but on how profiling behaves when multiple processes participate.


### ⚙️ How Distributed Profiling Works

This example uses **`torchrun`** to launch the job. The key concept is that `torchrun` starts **one identical Python process for every GPU** you are using.

When you profile this distributed workload, you are really profiling **multiple identical programs** at the same time, each running on its own GPU.



#### What "Rank" Actually Means

To coordinate, each process needs a unique identity. `torchrun` provides this identity using environment variables, which we call **"rank."**

* **`WORLD_SIZE`**: The **total number** of processes in the entire job (e.g., 2 nodes $\times$ 8 GPUs/node = 16).
* **`RANK`**: The **global ID** for this specific process, from `0` to `WORLD_SIZE - 1`.
* **`LOCAL_RANK`**: The **local GPU index** *on this specific machine* (e.g., `0`, `1`, ... `7` for an 8-GPU node).

Every process runs the *same Python program*, but it uses these rank variables to figure out what to do:
* It uses `LOCAL_RANK` to select its GPU (e.g., `cuda:0`, `cuda:1`, …).
* It uses `RANK` to participate in collective operations (like gradient sync).

---

#### Why naive profiling breaks
If every rank writes to the same `trace.json`, the independent processes will **race and overwrite** each other’s output.

**To avoid file races (required):**
- **Use a unique filename per rank**, such as `rank{rank}_iter{start}_{end}.json`.  
  This ensures that concurrent processes never write to the same file.

**To reduce trace volume (optional):**
- **Enable profiling only for selected ranks** (e.g., `"0"` or `"0,3"`).  
  Every rank still executes the full workload, but only the chosen ranks record traces—keeping the data manageable and export overhead low.

The helper **`make_profiler_ctx()`** in the first cell implements both of these ideas:
- It filters which ranks actually profile.  
- It embeds the rank number into the trace filename so outputs are always unique.


> 🔗 Optional references:  
> • [DDP Tutorial — PyTorch](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)  
> • [`torchrun` launcher docs](https://pytorch.org/docs/stable/elastic/run.html)



In [None]:
%%writefile train_ddp.py
import os
import contextlib
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import torchvision.models as models
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.profiler import profile, ProfilerActivity, schedule as profiler_schedule

###############################################################################
# DDP init / cleanup
###############################################################################

def setup():
    # torchrun sets these:
    #   RANK, WORLD_SIZE, LOCAL_RANK, MASTER_ADDR, MASTER_PORT
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])

    torch.cuda.set_device(local_rank)
    dist.init_process_group(
        backend="nccl",
        rank=rank,
        world_size=world_size,
        device_id=local_rank,
    )

    return rank, world_size, local_rank

def cleanup():
    dist.destroy_process_group()


###############################################################################
# Core workload
###############################################################################

def make_model_and_data(rank):
    """
    Match the profiling guide:
      - model: torchvision.models.resnet18()
      - dtype: bfloat16 on GPU
      - dummy_input:  [5, 3, 224, 224]
      - dummy_target: [5, 1000]
    """
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    dtype = torch.bfloat16  # same as profiling guide
    model = models.resnet18().to(device).to(dtype)
    ddp_model = DDP(model, device_ids=[rank] if device.type == "cuda" else None)
    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1e-3)

    # fixed random batch (no dataloader), same as guide
    B, C, H, W = 10, 3, 224, 224
    num_classes = 1000

    dummy_input = torch.randn(B, C, H, W, device=device, dtype=dtype)
    dummy_target = torch.randn(B, num_classes, device=device, dtype=dtype)

    return ddp_model, optimizer, dummy_input, dummy_target, device


def train_step(ddp_model, optimizer, X, Y):
    """
    Single step:
      forward -> mse_loss -> backward -> step
    """
    optimizer.zero_grad(set_to_none=True)
    out = ddp_model(X)
    loss = F.mse_loss(out, Y)
    loss.backward()
    optimizer.step()
    return loss



###############################################################################
# Profiler setup (same schedule as the profiling guide)
###############################################################################

def make_profiler_ctx(rank, profile_ranks, traces_dir):
    """
    We use the same schedule pattern as in the notebook:
      wait=10, warmup=5, active=3, repeat=1
    Every time an 'active' window finishes, we dump a Chrome trace
    file for that rank using export_chrome_trace.
    """
    class NullProfiler:
        def __enter__(self):
            return self
        def __exit__(self, exc_type, exc_val, exc_tb):
            pass
        def step(self):
            pass

    # Decide if this rank should actually profile
    if profile_ranks != "all":
        profile_ranks_int = {int(x) for x in profile_ranks.split(",") if x}
        if rank not in profile_ranks_int:
            return NullProfiler()

    os.makedirs(traces_dir, exist_ok=True)

    sched = profiler_schedule(
        wait=10,
        warmup=5,
        active=3,
        repeat=1,
    )

    def trace_handler(p):
        end_iter = p.step_num
        start_iter = end_iter - 3 + 1  # active window length = 3
        trace_path = os.path.join(traces_dir, f"rank{rank}_iter{start_iter}_{end_iter}.json")
        p.export_chrome_trace(trace_path)

    return profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        schedule=sched,
        record_shapes=True,
        with_stack=True,
        on_trace_ready=trace_handler,
    )


###############################################################################
# Per-rank entry point
###############################################################################

def ddp_worker(rank, world_size, profile_ranks, traces_dir):
    rank, world_size, local_rank = setup()
    print(f"[rank {rank}] starting worker")

    ddp_model, optimizer, dummy_input, dummy_target, device = make_model_and_data(rank)

    total_steps = 100
    
    prof = make_profiler_ctx(rank, profile_ranks, traces_dir)
    with prof:
        for step_idx in range(total_steps):
            loss = train_step(ddp_model, optimizer, dummy_input, dummy_target)
            prof.step()
            if rank == 0 and step_idx%10 ==0:
                print(f"[rank {rank}] step {step_idx}/{total_steps} "
                        f"loss={loss.item():.4f}", flush=True)

    # sync before cleanup
    dist.barrier(device_ids=[rank])
    cleanup()
    print(f"[rank {rank}] finished worker")

def main():
    # torchrun sets these per process
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    # you can keep these as constants / CLI args if you want
    profile_ranks = "0, 2, 4"     # or "all", or whatever you were passing before
    traces_dir = "./traces"

    ddp_worker(rank, world_size, profile_ranks, traces_dir)


if __name__ == "__main__":
    main()

In [None]:
import sys
import subprocess

num_procs = 8  # how many GPUs / ranks you want
script_path = "train_ddp.py"  # path to your script

cmd = [
    sys.executable,
    "-m", "torch.distributed.run",   # this is effectively torchrun
    f"--nproc_per_node={num_procs}",
    script_path,
]

print("Launching:", " ".join(cmd))

result = subprocess.run(cmd, capture_output=False, check=False)
print("Return code:", result.returncode)


You should now be able to see the generated trace files in the `./traces` directory.  
Each profiled rank writes its own Chrome trace file (for example, `rank0_iter10_12.json`, `rank3_iter10_12.json`, etc.), depending on the `profile_ranks` selection.

---

### 🌐 Multi-node note

For multi-node runs, **nothing changes conceptually** from the profiler’s point of view.  
If you have four nodes with eight GPUs each (32 total ranks):

- The global ranks will simply span **0 → 31** across all nodes.  
- The same rank-filtering and unique filename logic still applies.  
- Each process writes its trace to the **storage path accessible to that process**.

If you’re writing to **local disk paths** (e.g., `./traces`), each node will contain only the trace files for its local ranks.  
For example:
- Node 0 → `rank0–7` traces  
- Node 1 → `rank8–15` traces  
- …and so on.  

If instead you write to a **shared NFS or network-mounted directory**, all ranks’ traces will appear in one place, since all nodes share the same storage.  
That’s the only difference—profiling logic itself remains identical.
