# ShallowSpeed Advanced Benchmark: 8 GPUs, Mixed Precision, Gradient Accumulation & Profiling

An extended version of the ShallowSpeed GPU benchmark with:

| Feature | What it adds |
|---------|-------------|
| **8 GPU scaling** | Benchmarks across 1, 2, 4, and 8 GPUs to see where communication overhead dominates |
| **Mixed Precision (AMP)** | FP16/BF16 training — halves gradient size, so AllReduce transfers half the data |
| **Gradient Accumulation** | Simulate larger effective batch sizes without more memory — fewer AllReduce calls per epoch |
| **torch.profiler** | Actual GPU timeline traces showing compute vs NCCL overlap |
| **Fixed loss logging** | Properly averages loss across all ranks (fixes the local-loss artifact from the original notebook) |

## Hardware Requirements
- **RunPod**: 2, 4, or 8 GPUs (e.g., 8x A100, 8x H100)
- Works on any multi-GPU machine with NCCL support
- For NVLink vs PCIe comparison: run this notebook on both SXM and PCIe instances, then compare the saved JSON results

---
## 1. Environment Check

In [None]:
import torch
import os

NUM_GPUS = torch.cuda.device_count()
print(f"PyTorch version : {torch.__version__}")
print(f"CUDA available  : {torch.cuda.is_available()}")
print(f"NCCL available  : {torch.distributed.is_nccl_available()}")
print(f"BF16 supported  : {torch.cuda.is_bf16_supported() if torch.cuda.is_available() else 'N/A'}")
print(f"GPUs found      : {NUM_GPUS}")

for i in range(NUM_GPUS):
    props = torch.cuda.get_device_properties(i)
    mem = getattr(props, 'total_memory', getattr(props, 'total_mem', 0)) / 1024**3
    print(f"  GPU {i}: {props.name} ({mem:.1f} GB)")

# Detect interconnect type
if NUM_GPUS >= 2:
    try:
        result = !nvidia-smi topo -m 2>/dev/null | head -20
        for line in result:
            print(line)
        if any('NV' in str(line) for line in result):
            print("\n>>> NVLink detected — expect fast AllReduce")
        else:
            print("\n>>> PCIe interconnect — AllReduce will be slower")
    except:
        pass

if NUM_GPUS < 2:
    print("\n⚠️  Only 1 GPU detected. Data-parallel benchmarks need >= 2 GPUs.")
    print("    Single GPU baseline + profiler will still run.")

# Store GPU name for results tagging
GPU_NAME = torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'unknown'
print(f"\nGPU tag for results: {GPU_NAME}")

---
## 2. Training Scripts

All scripts now support:
- `use_amp`: Enable mixed precision (FP16/BF16)
- `grad_accum_steps`: Accumulate gradients over N micro-steps before AllReduce + optimizer step
- **Fixed loss logging**: `dist.all_reduce` on the loss scalar so all ranks report the true global average

### 2a. Shared Model & Dataset

In [None]:
%%writefile model_common.py
"""
Shared model definition and dataset — used by all training scripts.

Model sizes are designed to stress H100/A100 GPUs:
  - base:   ~64M params  — minimum viable for multi-GPU benefit
  - large:  ~250M params — clear scaling benefits
  - xlarge: ~730M params — approaches billion-scale, strong scaling

Deeper + wider = more compute per step = more room for parallelism to help.
"""
import torch
import torch.nn as nn


def build_model(size='large'):
    configs = {
        'base':   [784, 4096, 4096, 4096, 4096, 2048, 1024, 10],              # ~64M params
        'large':  [784, 8192, 8192, 8192, 8192, 4096, 2048, 10],              # ~250M params
        'xlarge': [784, 16384, 16384, 16384, 8192, 4096, 2048, 10],           # ~730M params
    }
    sizes = configs[size]
    layers = []
    for i in range(len(sizes) - 1):
        layers.append(nn.Linear(sizes[i], sizes[i + 1]))
        if i < len(sizes) - 2:
            layers.append(nn.ReLU())
    model = nn.Sequential(*layers)
    n_params = sum(p.numel() for p in model.parameters())
    return model, n_params


def make_dataset(n_samples=65536, n_features=784, n_classes=10):
    """Larger dataset to increase compute per epoch."""
    torch.manual_seed(42)
    X = torch.randn(n_samples, n_features)
    y = torch.randint(0, n_classes, (n_samples,))
    return X, y


print("model_common.py loaded OK")
for sz in ['base', 'large', 'xlarge']:
    _, n = build_model(sz)
    print(f"  {sz:8s}: {n:>12,} params")

### 2b. Single GPU Baseline

Now supports mixed precision and gradient accumulation.

In [None]:
%%writefile train_single_gpu.py
"""
SINGLE GPU BASELINE — with AMP + gradient accumulation support.
"""
import torch, torch.nn as nn, time, json, sys, os
from model_common import build_model, make_dataset

def main():
    config = json.loads(sys.argv[1])
    device = torch.device('cuda:0')
    use_amp = config.get('use_amp', False)
    grad_accum_steps = config.get('grad_accum_steps', 1)
    amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16

    model, n_params = build_model(config['model_size'])
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'])
    loss_fn = nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler('cuda', enabled=(use_amp and amp_dtype == torch.float16))

    X, y = make_dataset(config['n_samples'])
    X, y = X.to(device), y.to(device)
    bs = config['batch_size']

    # Warmup
    for _ in range(3):
        idx = torch.randint(0, len(X), (bs,))
        with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
            out = model(X[idx])
            loss_fn(out, y[idx]).backward()
        optimizer.zero_grad()
    torch.cuda.synchronize()

    epoch_times, losses = [], []
    for epoch in range(config['n_epochs']):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        epoch_loss, nb = 0.0, 0
        optimizer.zero_grad()

        # Build list of batches
        batch_starts = list(range(0, len(X), bs))

        for step_idx, start in enumerate(batch_starts):
            xb = X[start:start+bs]
            yb = y[start:start+bs]

            with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
                out = model(xb)
                loss = loss_fn(out, yb) / grad_accum_steps

            scaler.scale(loss).backward()
            epoch_loss += loss.item() * grad_accum_steps
            nb += 1

            # Step every grad_accum_steps or at end of epoch
            if (step_idx + 1) % grad_accum_steps == 0 or (step_idx + 1) == len(batch_starts):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        torch.cuda.synchronize()
        epoch_times.append(time.perf_counter() - t0)
        losses.append(epoch_loss / nb)

    # Accuracy
    model.eval()
    with torch.no_grad():
        correct = 0
        for start in range(0, len(X), bs):
            out = model(X[start:start+bs])
            correct += (out.argmax(1) == y[start:start+bs]).sum().item()
    acc = correct / len(X)

    print('RESULTS_JSON:' + json.dumps({
        'mode': 'single_gpu', 'n_gpus': 1, 'n_params': n_params,
        'model_size': config['model_size'],
        'use_amp': use_amp, 'grad_accum_steps': grad_accum_steps,
        'epoch_times': epoch_times, 'losses': losses,
        'final_accuracy': acc,
        'avg_epoch_time': sum(epoch_times) / len(epoch_times),
        'total_time': sum(epoch_times),
        'comm_time': 0.0,
    }))

if __name__ == '__main__':
    main()

### 2c. Naive Data-Parallel (Non-Interleaved AllReduce)

Full forward + full backward, THEN blocking AllReduce on all gradients.
Now with AMP, gradient accumulation, and **fixed global loss logging**.

In [None]:
%%writefile train_dp_naive.py
"""
DATA-PARALLEL: NAIVE (NON-INTERLEAVED)
With AMP + gradient accumulation + fixed global loss averaging.
"""
import torch, torch.nn as nn, torch.distributed as dist
import time, json, sys, os
from model_common import build_model, make_dataset

def main():
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f'cuda:{rank}')
    torch.cuda.set_device(device)

    config = json.loads(sys.argv[1])
    use_amp = config.get('use_amp', False)
    grad_accum_steps = config.get('grad_accum_steps', 1)
    amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16

    torch.manual_seed(42)
    model, n_params = build_model(config['model_size'])
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'])
    loss_fn = nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler('cuda', enabled=(use_amp and amp_dtype == torch.float16))

    X_all, y_all = make_dataset(config['n_samples'])
    X_all, y_all = X_all.to(device), y_all.to(device)
    bs = config['batch_size']

    # Warmup
    for _ in range(3):
        idx = torch.randint(0, len(X_all), (bs // world_size,))
        with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
            out = model(X_all[idx])
            loss_fn(out, y_all[idx]).backward()
        optimizer.zero_grad()
    torch.cuda.synchronize()
    dist.barrier()

    epoch_times, losses, comm_times = [], [], []
    for epoch in range(config['n_epochs']):
        torch.cuda.synchronize()
        dist.barrier()
        t0 = time.perf_counter()
        epoch_loss, nb, epoch_comm = 0.0, 0, 0.0
        optimizer.zero_grad()

        batch_starts = list(range(0, len(X_all), bs))

        for step_idx, start in enumerate(batch_starts):
            xb = X_all[start:start+bs]
            yb = y_all[start:start+bs]

            # Shard the batch across GPUs
            chunk = len(xb) // world_size
            s = rank * chunk
            e = s + chunk if rank < world_size - 1 else len(xb)
            x_local, y_local = xb[s:e], yb[s:e]

            with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
                out = model(x_local)
                loss = loss_fn(out, y_local) / grad_accum_steps

            scaler.scale(loss).backward()

            # --- Fixed global loss logging ---
            loss_val = loss.detach() * grad_accum_steps
            dist.all_reduce(loss_val, op=dist.ReduceOp.SUM)
            epoch_loss += (loss_val.item() / world_size)
            nb += 1

            # AllReduce + step every grad_accum_steps or at end
            if (step_idx + 1) % grad_accum_steps == 0 or (step_idx + 1) == len(batch_starts):
                # Unscale before manual allreduce
                scaler.unscale_(optimizer)

                # === COMMUNICATION: AllReduce ALL gradients (BLOCKING) ===
                torch.cuda.synchronize()
                tc0 = time.perf_counter()

                for param in model.parameters():
                    if param.grad is not None:
                        dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
                        param.grad /= world_size

                torch.cuda.synchronize()
                epoch_comm += time.perf_counter() - tc0

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        torch.cuda.synchronize()
        epoch_times.append(time.perf_counter() - t0)
        losses.append(epoch_loss / max(nb, 1))
        comm_times.append(epoch_comm)

    # Accuracy
    model.eval()
    with torch.no_grad():
        correct = 0
        for start in range(0, len(X_all), bs):
            out = model(X_all[start:start+bs])
            correct += (out.argmax(1) == y_all[start:start+bs]).sum().item()
    acc = correct / len(X_all)

    if rank == 0:
        print('RESULTS_JSON:' + json.dumps({
            'mode': 'dp_naive', 'n_gpus': world_size, 'n_params': n_params,
            'model_size': config['model_size'],
            'use_amp': use_amp, 'grad_accum_steps': grad_accum_steps,
            'epoch_times': epoch_times, 'losses': losses,
            'comm_times': comm_times,
            'final_accuracy': acc,
            'avg_epoch_time': sum(epoch_times) / len(epoch_times),
            'avg_comm_time': sum(comm_times) / len(comm_times),
            'total_time': sum(epoch_times),
            'comm_time': sum(comm_times),
        }))

    dist.destroy_process_group()

if __name__ == '__main__':
    main()

### 2d. Interleaved Data-Parallel (Non-blocking AllReduce During Backward)

Per-layer async AllReduce during backward pass. With AMP + gradient accumulation.

In [None]:
%%writefile train_dp_interleaved.py
"""
DATA-PARALLEL: INTERLEAVED (Non-blocking AllReduce during backward)
With AMP + gradient accumulation + fixed global loss averaging.

Gradient hooks fire async AllReduce per-layer ONLY on the accumulation
step (when we actually need to sync). During accumulation micro-steps,
hooks are disabled so gradients accumulate locally.
"""
import torch, torch.nn as nn, torch.distributed as dist
import time, json, sys, os
from model_common import build_model, make_dataset


class InterleavedDP:
    def __init__(self, model, world_size):
        self.model = model
        self.world_size = world_size
        self._handles = []
        self._sync_enabled = True  # Toggle for grad accumulation

        for param in self.model.parameters():
            param.register_post_accumulate_grad_hook(self._make_hook(param))

    def _make_hook(self, param):
        def hook(p):
            if self._sync_enabled:
                handle = dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, async_op=True)
                self._handles.append((handle, p))
        return hook

    def finish_allreduce(self):
        for handle, param in self._handles:
            handle.wait()
            param.grad /= self.world_size
        self._handles.clear()


def main():
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f'cuda:{rank}')
    torch.cuda.set_device(device)

    config = json.loads(sys.argv[1])
    use_amp = config.get('use_amp', False)
    grad_accum_steps = config.get('grad_accum_steps', 1)
    amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16

    torch.manual_seed(42)
    model, n_params = build_model(config['model_size'])
    model = model.to(device)
    dp = InterleavedDP(model, world_size)
    optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'])
    loss_fn = nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler('cuda', enabled=(use_amp and amp_dtype == torch.float16))

    X_all, y_all = make_dataset(config['n_samples'])
    X_all, y_all = X_all.to(device), y_all.to(device)
    bs = config['batch_size']

    # Warmup
    for _ in range(3):
        idx = torch.randint(0, len(X_all), (bs // world_size,))
        with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
            out = model(X_all[idx])
            loss_fn(out, y_all[idx]).backward()
        dp.finish_allreduce()
        optimizer.zero_grad()
    torch.cuda.synchronize()
    dist.barrier()

    epoch_times, losses = [], []
    for epoch in range(config['n_epochs']):
        torch.cuda.synchronize()
        dist.barrier()
        t0 = time.perf_counter()
        epoch_loss, nb = 0.0, 0
        optimizer.zero_grad()

        batch_starts = list(range(0, len(X_all), bs))

        for step_idx, start in enumerate(batch_starts):
            xb = X_all[start:start+bs]
            yb = y_all[start:start+bs]

            chunk = len(xb) // world_size
            s = rank * chunk
            e = s + chunk if rank < world_size - 1 else len(xb)
            x_local, y_local = xb[s:e], yb[s:e]

            is_sync_step = ((step_idx + 1) % grad_accum_steps == 0) or ((step_idx + 1) == len(batch_starts))

            # Only fire AllReduce hooks on the sync step
            dp._sync_enabled = is_sync_step

            with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
                out = model(x_local)
                loss = loss_fn(out, y_local) / grad_accum_steps

            scaler.scale(loss).backward()

            # --- Fixed global loss logging ---
            loss_val = loss.detach() * grad_accum_steps
            dist.all_reduce(loss_val, op=dist.ReduceOp.SUM)
            epoch_loss += (loss_val.item() / world_size)
            nb += 1

            if is_sync_step:
                scaler.unscale_(optimizer)
                dp.finish_allreduce()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        torch.cuda.synchronize()
        epoch_times.append(time.perf_counter() - t0)
        losses.append(epoch_loss / max(nb, 1))

    # Accuracy
    model.eval()
    with torch.no_grad():
        correct = 0
        for start in range(0, len(X_all), bs):
            out = model(X_all[start:start+bs])
            correct += (out.argmax(1) == y_all[start:start+bs]).sum().item()
    acc = correct / len(X_all)

    if rank == 0:
        print('RESULTS_JSON:' + json.dumps({
            'mode': 'dp_interleaved', 'n_gpus': world_size, 'n_params': n_params,
            'model_size': config['model_size'],
            'use_amp': use_amp, 'grad_accum_steps': grad_accum_steps,
            'epoch_times': epoch_times, 'losses': losses,
            'final_accuracy': acc,
            'avg_epoch_time': sum(epoch_times) / len(epoch_times),
            'total_time': sum(epoch_times),
        }))

    dist.destroy_process_group()

if __name__ == '__main__':
    main()

### 2e. PyTorch DDP (Reference)

Production-grade interleaved AllReduce with gradient bucketing.

In [None]:
%%writefile train_ddp_builtin.py
"""
PyTorch DDP — with AMP + gradient accumulation + fixed loss.
Uses no_sync() context manager to skip AllReduce during accumulation steps.
"""
import torch, torch.nn as nn, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from contextlib import nullcontext
import time, json, sys, os
from model_common import build_model, make_dataset

def main():
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f'cuda:{rank}')
    torch.cuda.set_device(device)

    config = json.loads(sys.argv[1])
    use_amp = config.get('use_amp', False)
    grad_accum_steps = config.get('grad_accum_steps', 1)
    amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16

    torch.manual_seed(42)
    model, n_params = build_model(config['model_size'])
    model = model.to(device)
    model = DDP(model, device_ids=[rank])
    optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'])
    loss_fn = nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler('cuda', enabled=(use_amp and amp_dtype == torch.float16))

    X_all, y_all = make_dataset(config['n_samples'])
    X_all, y_all = X_all.to(device), y_all.to(device)
    bs = config['batch_size']

    # Warmup
    for _ in range(3):
        idx = torch.randint(0, len(X_all), (bs // world_size,))
        with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
            out = model(X_all[idx])
            loss_fn(out, y_all[idx]).backward()
        optimizer.zero_grad()
    torch.cuda.synchronize()
    dist.barrier()

    epoch_times, losses = [], []
    for epoch in range(config['n_epochs']):
        torch.cuda.synchronize()
        dist.barrier()
        t0 = time.perf_counter()
        epoch_loss, nb = 0.0, 0
        optimizer.zero_grad()

        batch_starts = list(range(0, len(X_all), bs))

        for step_idx, start in enumerate(batch_starts):
            xb = X_all[start:start+bs]
            yb = y_all[start:start+bs]

            chunk = len(xb) // world_size
            s = rank * chunk
            e = s + chunk if rank < world_size - 1 else len(xb)
            x_local, y_local = xb[s:e], yb[s:e]

            is_sync_step = ((step_idx + 1) % grad_accum_steps == 0) or ((step_idx + 1) == len(batch_starts))

            # DDP's no_sync() skips AllReduce during accumulation steps
            sync_context = nullcontext() if is_sync_step else model.no_sync()

            with sync_context:
                with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
                    out = model(x_local)
                    loss = loss_fn(out, y_local) / grad_accum_steps
                scaler.scale(loss).backward()

            # --- Fixed global loss logging ---
            loss_val = loss.detach() * grad_accum_steps
            dist.all_reduce(loss_val, op=dist.ReduceOp.SUM)
            epoch_loss += (loss_val.item() / world_size)
            nb += 1

            if is_sync_step:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        torch.cuda.synchronize()
        epoch_times.append(time.perf_counter() - t0)
        losses.append(epoch_loss / max(nb, 1))

    # Accuracy
    model.eval()
    with torch.no_grad():
        correct = 0
        for start in range(0, len(X_all), bs):
            out = model(X_all[start:start+bs])
            correct += (out.argmax(1) == y_all[start:start+bs]).sum().item()
    acc = correct / len(X_all)

    if rank == 0:
        print('RESULTS_JSON:' + json.dumps({
            'mode': 'ddp_builtin', 'n_gpus': world_size, 'n_params': n_params,
            'model_size': config['model_size'],
            'use_amp': use_amp, 'grad_accum_steps': grad_accum_steps,
            'epoch_times': epoch_times, 'losses': losses,
            'final_accuracy': acc,
            'avg_epoch_time': sum(epoch_times) / len(epoch_times),
            'total_time': sum(epoch_times),
        }))

    dist.destroy_process_group()

if __name__ == '__main__':
    main()

---
## 3. Benchmark Runner

Trimmed matrix (~50 runs) that still covers the full range:
- **3 model sizes**: base (64M), large (250M), xlarge (730M) — sized to stress H100s
- **GPU counts**: 1, 2, 4 (up to 8 if available)
- **Precision**: FP32 and AMP for all configs
- **Gradient accumulation**: GA=1 for all, GA=4 added for `large` model only (to show the effect without doubling runtime)
- **65K samples, batch size 4096** — enough data to keep GPUs busy

In [None]:
%%writefile run_benchmarks.py
import subprocess, json, sys, os, torch

NUM_GPUS = torch.cuda.device_count()
GPU_NAME = torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'unknown'
print(f"Detected {NUM_GPUS} GPUs ({GPU_NAME})")

def run(cmd):
    print(f"  CMD: {' '.join(cmd[:6])}...")
    r = subprocess.run(cmd, capture_output=True, text=True, timeout=900)
    if r.returncode != 0:
        print(f"  ERROR (rc={r.returncode}): {r.stderr[:400]}")
        return None
    for line in r.stdout.split('\n'):
        if line.startswith('RESULTS_JSON:'):
            result = json.loads(line[len('RESULTS_JSON:'):])
            result['gpu_name'] = GPU_NAME
            return result
    print(f"  No RESULTS_JSON. stdout: {r.stdout[:200]}")
    return None


base = {
    'n_samples': 65536,     # 2x larger dataset
    'batch_size': 4096,     # bigger batches to saturate H100 compute
    'lr': 0.01,
    'n_epochs': 5,
}

# --- What to benchmark ---
# Models sized for H100: base (64M), large (250M), xlarge (730M)
model_sizes = ['base', 'large', 'xlarge']
gpu_counts = [g for g in [2, 4, 8] if g <= NUM_GPUS]
amp_modes = [False, True]

# Build the test matrix: ~50 runs
# All models: FP32 + AMP, GA=1
# Large model only: also GA=4 (to show gradient accumulation effect)
test_configs = []
for msz in model_sizes:
    for use_amp in amp_modes:
        test_configs.append((msz, use_amp, 1))
        # Add GA=4 only for 'large' model
        if msz == 'large':
            test_configs.append((msz, use_amp, 4))

all_results = []
port_counter = 29500
total_runs = len(test_configs) * (1 + len(gpu_counts) * 3)
run_num = 0

print(f"\nBenchmark plan: {len(test_configs)} configs × (1 single + {len(gpu_counts)} gpu_counts × 3 modes)")
print(f"Total estimated runs: ~{total_runs}\n")

for msz, use_amp, ga in test_configs:
    cfg = {**base, 'model_size': msz, 'use_amp': use_amp, 'grad_accum_steps': ga}
    cfg_str = json.dumps(cfg)
    amp_tag = 'AMP' if use_amp else 'FP32'
    ga_tag = f'GA={ga}'

    print(f"\n{'='*60}")
    print(f"Model: {msz} | {amp_tag} | {ga_tag}")
    print(f"{'='*60}")

    # --- Single GPU baseline ---
    run_num += 1
    print(f"\n  [{run_num}/{total_runs}] Single GPU baseline...")
    r = run(['python', 'train_single_gpu.py', cfg_str])
    if r:
        all_results.append(r)
        print(f"    avg_epoch={r['avg_epoch_time']:.4f}s")

    for ng in gpu_counts:
        for mode_idx, (script, mode_name) in enumerate([
            ('train_dp_naive.py', 'DP Naive'),
            ('train_dp_interleaved.py', 'DP Interleaved'),
            ('train_ddp_builtin.py', 'PyTorch DDP'),
        ]):
            port_counter += 1
            run_num += 1
            tr = ['torchrun', '--nproc_per_node', str(ng),
                  '--master_port', str(port_counter)]

            print(f"\n  [{run_num}/{total_runs}] {mode_name}, {ng} GPUs...")
            r = run(tr + [script, cfg_str])
            if r:
                all_results.append(r)
                comm_str = f", comm={r.get('avg_comm_time',0):.4f}s" if 'avg_comm_time' in r else ""
                print(f"    avg_epoch={r['avg_epoch_time']:.4f}s{comm_str}")

# Save with GPU tag in filename
gpu_tag = GPU_NAME.replace(' ', '_').replace('/', '_')
filename = f'benchmark_results_{gpu_tag}.json'
with open(filename, 'w') as f:
    json.dump(all_results, f, indent=2)
with open('benchmark_results.json', 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"\n\nDone! {len(all_results)} results saved to {filename}")

In [None]:
!python run_benchmarks.py

---
## 4. Profiler: GPU Timeline Traces

This cell runs a few training steps with `torch.profiler` and exports Chrome traces.
You can view them at `chrome://tracing` to see exactly when compute and NCCL kernels overlap.

We profile 3 configs:
1. **Naive DP** — you should see NCCL kernels AFTER compute kernels (no overlap)
2. **Interleaved DP** — NCCL and compute kernels should overlap
3. **PyTorch DDP** — similar overlap to interleaved, but with bucketed NCCL calls

In [None]:
%%writefile run_profiler.py
"""
Profile one training step for each DP mode.
Exports Chrome traces to ./profiles/
"""
import torch, torch.nn as nn, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler
import json, sys, os
from model_common import build_model, make_dataset

def main():
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f'cuda:{rank}')
    torch.cuda.set_device(device)

    config = json.loads(sys.argv[1])
    mode = config['mode']
    use_amp = config.get('use_amp', False)
    amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16

    os.makedirs('profiles', exist_ok=True)

    torch.manual_seed(42)
    model, _ = build_model(config.get('model_size', 'large'))
    model = model.to(device)

    # Setup based on mode
    handles_list = []
    if mode == 'ddp_builtin':
        model = DDP(model, device_ids=[rank])
    elif mode == 'dp_interleaved':
        # Register interleaved hooks
        for param in model.parameters():
            def make_hook(p):
                def hook(p_inner):
                    h = dist.all_reduce(p_inner.grad, op=dist.ReduceOp.SUM, async_op=True)
                    handles_list.append((h, p_inner))
                return hook
            param.register_post_accumulate_grad_hook(make_hook(param))

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()

    X_all, y_all = make_dataset(8192)
    X_all, y_all = X_all.to(device), y_all.to(device)
    bs = 1024

    # Warmup
    for _ in range(5):
        chunk = bs // world_size
        idx = torch.randint(0, len(X_all), (chunk,))
        with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
            out = model(X_all[idx])
            loss_fn(out, y_all[idx]).backward()
        if mode == 'dp_interleaved':
            for h, p in handles_list:
                h.wait()
                p.grad /= world_size
            handles_list.clear()
        optimizer.zero_grad()
    torch.cuda.synchronize()
    dist.barrier()

    amp_tag = 'amp' if use_amp else 'fp32'
    trace_name = f'profiles/trace_{mode}_{world_size}gpu_{amp_tag}'

    # Profile: 2 warmup steps, 3 active steps
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        schedule=schedule(wait=0, warmup=2, active=3, repeat=1),
        on_trace_ready=lambda p: p.export_chrome_trace(f'{trace_name}_rank{rank}.json'),
        record_shapes=True,
        with_stack=True,
    ) as prof:
        for step in range(5):
            xb = X_all[:bs]
            yb = y_all[:bs]
            chunk = len(xb) // world_size
            x_local = xb[rank*chunk:(rank+1)*chunk]
            y_local = yb[rank*chunk:(rank+1)*chunk]

            optimizer.zero_grad()

            with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
                out = model(x_local)
                loss = loss_fn(out, y_local)

            loss.backward()

            if mode == 'dp_naive':
                for param in (model.module.parameters() if hasattr(model, 'module') else model.parameters()):
                    if param.grad is not None:
                        dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
                        param.grad /= world_size
            elif mode == 'dp_interleaved':
                for h, p in handles_list:
                    h.wait()
                    p.grad /= world_size
                handles_list.clear()
            # DDP handles AllReduce automatically

            optimizer.step()
            torch.cuda.synchronize()
            prof.step()

    if rank == 0:
        print(f'TRACE_SAVED:{trace_name}_rank0.json')

    dist.destroy_process_group()

if __name__ == '__main__':
    main()

In [None]:
import subprocess, json, torch

NUM_GPUS = torch.cuda.device_count()
ng = min(NUM_GPUS, 4)  # Profile with up to 4 GPUs

print(f"Profiling with {ng} GPUs...")
print("Traces will be saved to ./profiles/\n")

port = 29600
for mode in ['dp_naive', 'dp_interleaved', 'ddp_builtin']:
    for use_amp in [False, True]:
        amp_tag = 'AMP' if use_amp else 'FP32'
        print(f"  Profiling {mode} ({amp_tag})...")
        cfg = json.dumps({'mode': mode, 'model_size': 'large', 'use_amp': use_amp})
        port += 1
        cmd = ['torchrun', '--nproc_per_node', str(ng),
               '--master_port', str(port), 'run_profiler.py', cfg]
        r = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
        if r.returncode == 0:
            for line in r.stdout.split('\n'):
                if line.startswith('TRACE_SAVED:'):
                    print(f"    Saved: {line.split(':',1)[1]}")
        else:
            print(f"    ERROR: {r.stderr[:200]}")

print("\n Profiling complete!")
print("To view traces: open chrome://tracing and load the JSON files from ./profiles/")
print("\nWhat to look for:")
print("  - dp_naive:       NCCL kernels appear AFTER compute (sequential, no overlap)")
print("  - dp_interleaved: NCCL kernels OVERLAP with backward compute")
print("  - ddp_builtin:    Similar overlap but fewer, larger NCCL calls (bucketing)")
print("  - AMP vs FP32:    AMP traces show smaller NCCL transfers (half the data)")

In [None]:
# List generated trace files
import glob
traces = sorted(glob.glob('profiles/*.json'))
print(f"Generated {len(traces)} trace files:\n")
for t in traces:
    size_mb = os.path.getsize(t) / 1024 / 1024
    print(f"  {t:60s} ({size_mb:.1f} MB)")

print(f"\nDownload these and open at chrome://tracing to see GPU timelines.")

---
## 5. Visualize Results

In [None]:
import json, numpy as np, matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.dpi'] = 130
matplotlib.rcParams['font.size'] = 10

with open('benchmark_results.json') as f:
    results = json.load(f)

N_SAMPLES = 65536  # matches our config
model_order = ['base', 'large', 'xlarge']

print(f"Loaded {len(results)} results")
print(f"GPU: {results[0].get('gpu_name', 'unknown')}")
print(f"\n{'Model':<8} {'Mode':<20} {'GPUs':>4} {'AMP':>4} {'GA':>3} {'Params':>10} {'Epoch':>9} {'Acc':>7}")
print('─' * 75)
for r in results:
    amp_str = 'Yes' if r.get('use_amp') else 'No'
    ga = r.get('grad_accum_steps', 1)
    print(f"{r['model_size']:<8} {r['mode']:<20} {r['n_gpus']:>4} {amp_str:>4} {ga:>3} "
          f"{r['n_params']/1e6:>8.1f}M {r['avg_epoch_time']:>8.4f}s {r['final_accuracy']:>6.4f}")

In [None]:
# ========================================================
# CHART 1: FP32 vs AMP Epoch Time — Grouped by Model Size
# Shows the speedup from mixed precision
# ========================================================

# Filter to grad_accum=1, DDP mode for clarity
ddp_results = [r for r in results if r['mode'] == 'ddp_builtin' and r.get('grad_accum_steps', 1) == 1]

model_order = ['small', 'medium', 'large', 'xlarge']
model_sizes_present = sorted(set(r['model_size'] for r in ddp_results),
                              key=lambda x: model_order.index(x) if x in model_order else 99)
gpu_counts_present = sorted(set(r['n_gpus'] for r in ddp_results))

fig, axes = plt.subplots(1, len(model_sizes_present), figsize=(5*len(model_sizes_present), 5), sharey=False)
if len(model_sizes_present) == 1:
    axes = [axes]

for idx, ms in enumerate(model_sizes_present):
    ax = axes[idx]
    x_labels, fp32_vals, amp_vals = [], [], []

    for ng in gpu_counts_present:
        fp32 = [r for r in ddp_results if r['model_size'] == ms and r['n_gpus'] == ng and not r.get('use_amp')]
        amp = [r for r in ddp_results if r['model_size'] == ms and r['n_gpus'] == ng and r.get('use_amp')]
        if fp32 and amp:
            x_labels.append(f'{ng} GPU')
            fp32_vals.append(fp32[0]['avg_epoch_time'])
            amp_vals.append(amp[0]['avg_epoch_time'])

    # Add single GPU
    sg_fp32 = [r for r in results if r['model_size'] == ms and r['mode'] == 'single_gpu'
               and not r.get('use_amp') and r.get('grad_accum_steps', 1) == 1]
    sg_amp = [r for r in results if r['model_size'] == ms and r['mode'] == 'single_gpu'
              and r.get('use_amp') and r.get('grad_accum_steps', 1) == 1]
    if sg_fp32 and sg_amp:
        x_labels.insert(0, '1 GPU')
        fp32_vals.insert(0, sg_fp32[0]['avg_epoch_time'])
        amp_vals.insert(0, sg_amp[0]['avg_epoch_time'])

    x = np.arange(len(x_labels))
    w = 0.35
    ax.bar(x - w/2, fp32_vals, w, label='FP32', color='#FF7043', edgecolor='white')
    ax.bar(x + w/2, amp_vals, w, label='AMP (FP16/BF16)', color='#42A5F5', edgecolor='white')

    # Annotate speedup
    for i in range(len(x_labels)):
        if fp32_vals[i] > 0 and amp_vals[i] > 0:
            speedup = fp32_vals[i] / amp_vals[i]
            ax.text(x[i], max(fp32_vals[i], amp_vals[i]) * 1.05,
                    f'{speedup:.2f}×', ha='center', fontsize=8, fontweight='bold', color='#1565C0')

    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel('Epoch Time (s)')
    n_params = [r for r in results if r['model_size'] == ms][0]['n_params']
    ax.set_title(f'{ms.capitalize()} ({n_params/1e6:.1f}M)', fontweight='bold')
    ax.legend(fontsize=8)
    ax.grid(axis='y', alpha=0.3)

fig.suptitle('FP32 vs Mixed Precision (AMP) — DDP Epoch Time', fontsize=13, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('chart_amp_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ========================================================
# CHART 2: Scaling Efficiency — 1 to 8 GPUs
# FP32 vs AMP side by side
# ========================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

for amp_idx, (use_amp, amp_label) in enumerate([(False, 'FP32'), (True, 'AMP')]):
    ax = axes[amp_idx]
    ga1_results = [r for r in results if r.get('use_amp') == use_amp and r.get('grad_accum_steps', 1) == 1]

    msizes = sorted(set(r['model_size'] for r in ga1_results),
                     key=lambda x: model_order.index(x) if x in model_order else 99)

    colors_map = {'small': '#FF7043', 'medium': '#42A5F5', 'large': '#66BB6A', 'xlarge': '#AB47BC'}

    for ms in msizes:
        baseline = [r for r in ga1_results if r['model_size'] == ms and r['mode'] == 'single_gpu']
        if not baseline:
            continue
        base_time = baseline[0]['avg_epoch_time']

        gpus, speedups = [1], [1.0]
        for ng in sorted(set(r['n_gpus'] for r in ga1_results if r['mode'] == 'ddp_builtin')):
            m = [r for r in ga1_results if r['model_size'] == ms and r['mode'] == 'ddp_builtin' and r['n_gpus'] == ng]
            if m:
                gpus.append(ng)
                speedups.append(base_time / m[0]['avg_epoch_time'])

        n_params = baseline[0]['n_params']
        ax.plot(gpus, speedups, '-o', color=colors_map.get(ms, 'gray'),
                label=f'{ms} ({n_params/1e6:.0f}M)', linewidth=2, markersize=7)

    max_g = max(r['n_gpus'] for r in ga1_results) if ga1_results else 8
    ax.plot([1, max_g], [1, max_g], ':k', alpha=0.3, label='Ideal linear')
    ax.set_xlabel('Number of GPUs')
    ax.set_ylabel('Speedup vs 1 GPU')
    ax.set_title(f'{amp_label} — DDP Scaling', fontweight='bold')
    ax.set_xticks([1, 2, 4, 8][:len(set(r['n_gpus'] for r in ga1_results))+1])
    ax.legend(fontsize=8)
    ax.grid(alpha=0.3)
    ax.set_ylim(bottom=0)

fig.suptitle('GPU Scaling: FP32 vs AMP — PyTorch DDP', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('chart_scaling_fp32_vs_amp.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ========================================================
# CHART 3: Gradient Accumulation Effect
# GA=1 vs GA=4 — fewer AllReduce calls per epoch
# ========================================================

# Use DDP + AMP for this comparison
ga_results = [r for r in results if r['mode'] == 'ddp_builtin' and r.get('use_amp', False)]

if len(set(r.get('grad_accum_steps', 1) for r in ga_results)) > 1:
    msizes = sorted(set(r['model_size'] for r in ga_results),
                     key=lambda x: model_order.index(x) if x in model_order else 99)

    fig, axes = plt.subplots(1, len(msizes), figsize=(5*len(msizes), 5), sharey=False)
    if len(msizes) == 1:
        axes = [axes]

    for idx, ms in enumerate(msizes):
        ax = axes[idx]
        for ga, color, style in [(1, '#FF7043', '-o'), (4, '#42A5F5', '-s')]:
            subset = [r for r in ga_results if r['model_size'] == ms and r.get('grad_accum_steps', 1) == ga]
            if not subset:
                continue
            gpus = sorted(set(r['n_gpus'] for r in subset))
            times = [next(r['avg_epoch_time'] for r in subset if r['n_gpus'] == ng) for ng in gpus]
            ax.plot(gpus, times, style, color=color, label=f'GA={ga}', linewidth=2, markersize=7)

        # Also add single GPU
        for ga, color, style in [(1, '#FF7043', '-o'), (4, '#42A5F5', '-s')]:
            sg = [r for r in results if r['model_size'] == ms and r['mode'] == 'single_gpu'
                  and r.get('use_amp', False) and r.get('grad_accum_steps', 1) == ga]
            if sg:
                ax.plot(1, sg[0]['avg_epoch_time'], 'D', color=color, markersize=8)

        ax.set_xlabel('Number of GPUs')
        ax.set_ylabel('Epoch Time (s)')
        n_params = [r for r in results if r['model_size'] == ms][0]['n_params']
        ax.set_title(f'{ms.capitalize()} ({n_params/1e6:.1f}M)', fontweight='bold')
        ax.legend(fontsize=8)
        ax.grid(alpha=0.3)
        ax.set_xticks([1, 2, 4, 8][:len(set(r['n_gpus'] for r in ga_results))+1])

    fig.suptitle('Gradient Accumulation: GA=1 vs GA=4 (DDP + AMP)\n'
                 'GA=4 means 4× fewer AllReduce calls per epoch',
                 fontsize=13, fontweight='bold', y=1.05)
    plt.tight_layout()
    plt.savefig('chart_grad_accum.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("Only one grad_accum setting found — skipping GA comparison chart.")

In [None]:
# ========================================================
# CHART 4: Communication Overhead — Naive DP
# FP32 vs AMP: AMP should show less comm time (half the gradients)
# ========================================================

naive_results = [r for r in results if r['mode'] == 'dp_naive' and 'comm_times' in r
                 and r.get('grad_accum_steps', 1) == 1]

if naive_results:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    for amp_idx, (use_amp, amp_label) in enumerate([(False, 'FP32'), (True, 'AMP')]):
        ax = axes[amp_idx]
        subset = [r for r in naive_results if r.get('use_amp') == use_amp]
        if not subset:
            ax.text(0.5, 0.5, f'No {amp_label} naive results', ha='center', va='center', transform=ax.transAxes)
            continue

        labels, comp, comm = [], [], []
        for r in sorted(subset, key=lambda x: (model_order.index(x['model_size'])
                         if x['model_size'] in model_order else 99, x['n_gpus'])):
            labels.append(f"{r['model_size']}\n{r['n_gpus']}G")
            avg_comm = r.get('avg_comm_time', 0)
            avg_total = r['avg_epoch_time']
            comp.append(avg_total - avg_comm)
            comm.append(avg_comm)

        x = np.arange(len(labels))
        ax.bar(x, comp, 0.6, label='Computation', color='#42A5F5')
        ax.bar(x, comm, 0.6, bottom=comp, label='Communication', color='#FF7043')

        for i, (c, cm) in enumerate(zip(comp, comm)):
            total = c + cm
            pct = cm / total * 100 if total > 0 else 0
            ax.text(i, total + 0.01, f'{pct:.0f}%', ha='center', fontsize=8, fontweight='bold')

        ax.set_xticks(x)
        ax.set_xticklabels(labels)
        ax.set_ylabel('Time (s)')
        ax.set_title(f'Naive DP — {amp_label}', fontweight='bold')
        ax.legend(fontsize=8)
        ax.grid(axis='y', alpha=0.3)

    fig.suptitle('Communication Overhead: FP32 vs AMP\n'
                 'AMP halves gradient size → less AllReduce data → lower comm %',
                 fontsize=12, fontweight='bold', y=1.05)
    plt.tight_layout()
    plt.savefig('chart_comm_fp32_vs_amp.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# ========================================================
# CHART 5: Loss Convergence — Fixed Global Average
# All modes should now overlap perfectly within same config
# ========================================================

# Pick largest model, AMP, GA=1
all_msizes = sorted(set(r['model_size'] for r in results),
                     key=lambda x: model_order.index(x) if x in model_order else 99)
target_model = all_msizes[-1]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
style_map = {
    'single_gpu': ('-', '#37474F', 3),
    'dp_naive': ('--', '#FF5722', 2),
    'dp_interleaved': ('-', '#1E88E5', 2),
    'ddp_builtin': ('-.', '#2E7D32', 2),
}

for amp_idx, (use_amp, amp_label) in enumerate([(False, 'FP32'), (True, 'AMP')]):
    ax = axes[amp_idx]
    subset = [r for r in results if r['model_size'] == target_model
              and r.get('use_amp') == use_amp and r.get('grad_accum_steps', 1) == 1
              and 'losses' in r]

    for r in subset:
        ls, c, lw = style_map.get(r['mode'], ('-', 'gray', 1))
        lbl = f"{r['mode']} ({r['n_gpus']}G)"
        ax.plot(range(1, len(r['losses'])+1), r['losses'], linestyle=ls, color=c,
                linewidth=lw, marker='o', markersize=4, label=lbl)

    ax.set_xlabel('Epoch')
    ax.set_ylabel('Training Loss (global average)')
    ax.set_title(f'{amp_label} — {target_model.capitalize()} Model', fontweight='bold')
    ax.legend(fontsize=7)
    ax.grid(alpha=0.3)

fig.suptitle('Loss Convergence — Fixed Global Average Logging\n'
             'All modes within same GPU count should now overlap (same math, same loss)',
             fontsize=12, fontweight='bold', y=1.05)
plt.tight_layout()
plt.savefig('chart_loss_convergence_fixed.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ========================================================
# CHART 6: Naive vs Interleaved vs DDP — All GPU Counts
# ========================================================

# Use AMP + GA=1 for this
compare_results = [r for r in results if r.get('use_amp', False) and r.get('grad_accum_steps', 1) == 1]

msizes = sorted(set(r['model_size'] for r in compare_results),
                 key=lambda x: model_order.index(x) if x in model_order else 99)

fig, ax = plt.subplots(figsize=(max(14, len(msizes)*5), 6))
pair_labels, naive_t, interl_t, ddp_t = [], [], [], []

for ms in msizes:
    for ng in sorted(set(r['n_gpus'] for r in compare_results if r['n_gpus'] > 1)):
        n = [r for r in compare_results if r['model_size'] == ms and r['mode'] == 'dp_naive' and r['n_gpus'] == ng]
        il = [r for r in compare_results if r['model_size'] == ms and r['mode'] == 'dp_interleaved' and r['n_gpus'] == ng]
        dd = [r for r in compare_results if r['model_size'] == ms and r['mode'] == 'ddp_builtin' and r['n_gpus'] == ng]
        if n and il and dd:
            pair_labels.append(f"{ms}\n{ng}G")
            naive_t.append(n[0]['avg_epoch_time'])
            interl_t.append(il[0]['avg_epoch_time'])
            ddp_t.append(dd[0]['avg_epoch_time'])

x = np.arange(len(pair_labels))
w = 0.25
ax.bar(x - w, naive_t, w, label='Naive DP', color='#FF7043', edgecolor='white')
ax.bar(x, interl_t, w, label='Interleaved DP', color='#42A5F5', edgecolor='white')
ax.bar(x + w, ddp_t, w, label='PyTorch DDP', color='#66BB6A', edgecolor='white')

for i in range(len(pair_labels)):
    if naive_t[i] > 0 and interl_t[i] > 0:
        saving = (1 - interl_t[i] / naive_t[i]) * 100
        y = max(naive_t[i], interl_t[i], ddp_t[i]) + 0.01
        if saving > 0:
            ax.annotate(f'{saving:.0f}% faster', xy=(i, y), fontsize=7,
                        ha='center', fontweight='bold', color='#1565C0')

ax.set_xticks(x)
ax.set_xticklabels(pair_labels)
ax.set_ylabel('Epoch Time (s)')
ax.set_title('Naive vs Interleaved vs DDP (AMP) — All GPU Counts', fontweight='bold')
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('chart_mode_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ========================================================
# CHART 7: Throughput (samples/sec) — The Real Metric
# ========================================================

# Use AMP + GA=1 for throughput comparison
tp_results = [r for r in results if r.get('use_amp', False) and r.get('grad_accum_steps', 1) == 1]

# Collect unique combos
combos = []
seen = set()
for r in tp_results:
    key = (r['mode'], r['n_gpus'])
    if key not in seen:
        seen.add(key)
        combos.append(key)

colors = {
    ('single_gpu', 1): '#37474F',
    ('dp_naive', 2): '#FF7043', ('dp_naive', 4): '#FF5722', ('dp_naive', 8): '#D84315',
    ('dp_interleaved', 2): '#42A5F5', ('dp_interleaved', 4): '#1E88E5', ('dp_interleaved', 8): '#1565C0',
    ('ddp_builtin', 2): '#66BB6A', ('ddp_builtin', 4): '#43A047', ('ddp_builtin', 8): '#2E7D32',
}

def combo_label(mode, ng):
    names = {'single_gpu': '1G', 'dp_naive': f'Naive {ng}G',
             'dp_interleaved': f'Interl {ng}G', 'ddp_builtin': f'DDP {ng}G'}
    return names.get(mode, f'{mode} ({ng}G)')

msizes = sorted(set(r['model_size'] for r in tp_results),
                 key=lambda x: model_order.index(x) if x in model_order else 99)

fig, ax = plt.subplots(figsize=(max(14, len(msizes)*4), 6))
tp_data = {}
for r in tp_results:
    key = (r['model_size'], r['mode'], r['n_gpus'])
    tp_data[key] = N_SAMPLES / r['avg_epoch_time']

for ms_idx, ms in enumerate(msizes):
    tp_vals, tp_colors_list = [], []
    for mode, ng in combos:
        key = (ms, mode, ng)
        if key in tp_data:
            tp_vals.append(tp_data[key])
            tp_colors_list.append(colors.get((mode, ng), f'C{len(tp_vals)}'))

    x = np.arange(len(tp_vals))
    offset = ms_idx * (len(tp_vals) + 1.5)
    bars = ax.bar(x + offset, tp_vals, 0.8, color=tp_colors_list, edgecolor='white')
    n_params = [r for r in results if r['model_size'] == ms][0]['n_params']
    ax.text(offset + len(tp_vals)/2 - 0.5, -max(tp_data.values())*0.08,
            f'{ms.capitalize()}\n({n_params/1e6:.0f}M)',
            ha='center', fontweight='bold', fontsize=10)
    for bar, v in zip(bars, tp_vals):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(tp_data.values())*0.01,
                f'{v:.0f}', ha='center', fontsize=6, fontweight='bold', rotation=90)

ax.set_ylabel('Throughput (samples/sec)')
ax.set_title('Training Throughput (AMP) — All GPU Counts', fontweight='bold')
ax.set_xticks([])

from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=colors.get((m,g), 'gray'), label=combo_label(m,g))
                   for m,g in combos if (m,g) in colors]
ax.legend(handles=legend_elements, fontsize=7, ncol=3)
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('chart_throughput_amp.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 6. Summary Table

In [None]:
print(f"GPU: {results[0].get('gpu_name', 'unknown')}")
print(f"\n{'Model':<8} {'Mode':<20} {'GPUs':>4} {'AMP':>4} {'GA':>3} {'Params':>10} {'Epoch':>9} "
      f"{'Comm':>8} {'Thruput':>10} {'Speedup':>8} {'Acc':>7}")
print('─' * 105)

for ms in sorted(set(r['model_size'] for r in results),
                  key=lambda x: model_order.index(x) if x in model_order else 99):
    # Baseline: single GPU FP32 GA=1
    baseline = [r for r in results if r['model_size'] == ms and r['mode'] == 'single_gpu'
                and not r.get('use_amp') and r.get('grad_accum_steps', 1) == 1]
    base_time = baseline[0]['avg_epoch_time'] if baseline else 1.0

    subset = sorted([r for r in results if r['model_size'] == ms],
                     key=lambda x: (x.get('use_amp', False), x.get('grad_accum_steps', 1),
                                    x['n_gpus'], x['mode']))
    for r in subset:
        throughput = N_SAMPLES / r['avg_epoch_time']
        speedup = base_time / r['avg_epoch_time']
        comm = r.get('avg_comm_time', r.get('comm_time', 0))
        if isinstance(comm, list):
            comm = sum(comm) / len(comm)
        amp_str = 'Yes' if r.get('use_amp') else 'No'
        ga = r.get('grad_accum_steps', 1)
        print(f"{r['model_size']:<8} {r['mode']:<20} {r['n_gpus']:>4} {amp_str:>4} {ga:>3} "
              f"{r['n_params']/1e6:>8.1f}M {r['avg_epoch_time']:>8.4f}s "
              f"{comm:>7.4f}s {throughput:>8.0f}/s {speedup:>7.2f}x {r['final_accuracy']:>6.4f}")
    print()

---
## 7. Understanding the New Features

### Mixed Precision (AMP)

When you enable `torch.amp.autocast`, the forward pass runs in FP16/BF16 instead of FP32:
- **Compute benefit**: FP16 matrix multiplications are 2× faster on Tensor Cores (A100/H100)
- **Communication benefit**: Gradients are half the size → AllReduce transfers half the data
- **Memory benefit**: Activations stored in FP16 → can fit larger batches

The `GradScaler` prevents underflow by dynamically scaling the loss before backward (only needed for FP16, not BF16).

### Gradient Accumulation

Instead of syncing gradients every micro-batch:
```
GA=1: [FWD+BWD → AllReduce → Step] [FWD+BWD → AllReduce → Step] ...
GA=4: [FWD+BWD] [FWD+BWD] [FWD+BWD] [FWD+BWD → AllReduce → Step]
```

With GA=4, you do 4× fewer AllReduce calls per epoch. The gradient from 4 micro-batches accumulates locally, then one AllReduce syncs everything. This:
- **Reduces communication overhead** (fewer AllReduce calls)
- **Simulates larger batch size** (effective_batch = batch_size × GA × n_gpus)
- **Uses DDP's `no_sync()`** context manager to skip AllReduce during accumulation steps

### torch.profiler Traces

The profiler captures actual GPU kernel timelines. When you open the traces in `chrome://tracing`:
- **Compute kernels** (gemm, relu, etc.) appear on the CUDA stream rows
- **NCCL kernels** (AllReduce) appear on separate NCCL stream rows
- **Naive DP**: NCCL blocks are sequential after compute — you see a gap
- **Interleaved/DDP**: NCCL blocks overlap with compute — the streams run in parallel

### Fixed Loss Logging

The original notebook logged `loss.item()` from rank 0's local micro-batch, creating artificial separation between GPU-count clusters. Now we do:
```python
loss_val = loss.detach()
dist.all_reduce(loss_val, op=dist.ReduceOp.SUM)  # sync across ranks
global_loss = loss_val.item() / world_size         # true average
```
All modes within the same config should now produce identical loss curves.

---
## 8. NVLink vs PCIe Comparison Guide

To compare interconnects:

1. **Run this notebook on a PCIe machine** (e.g., RunPod `4x A100 80GB PCIe`)
   - Results auto-save to `benchmark_results_NVIDIA_A100_80GB_PCIe.json`

2. **Run again on an NVLink/SXM machine** (e.g., RunPod `4x A100 80GB SXM`)
   - Results save to `benchmark_results_NVIDIA_A100_80GB_SXM.json`

3. **Compare** using the cell below (upload both JSON files):

### What to expect:
| Metric | PCIe | NVLink |
|--------|------|--------|
| AllReduce bandwidth | ~32 GB/s | ~600 GB/s |
| Comm overhead (naive) | **High** | Low |
| Interleaved benefit | **Large** (more to hide) | Smaller (less to hide) |
| Scaling efficiency | Drops at 8 GPUs | Near-linear at 8 GPUs |

In [None]:
# ========================================================
# Optional: Load two result files and compare
# Upload your PCIe and NVLink JSON files, then run this
# ========================================================

import glob

result_files = sorted(glob.glob('benchmark_results_*.json'))
print(f"Found {len(result_files)} result files:")
for f in result_files:
    print(f"  {f}")

if len(result_files) >= 2:
    all_hw_results = {}
    for rf in result_files:
        with open(rf) as f:
            data = json.load(f)
        gpu_name = data[0].get('gpu_name', rf)
        all_hw_results[gpu_name] = data

    # Compare DDP AMP GA=1 across hardware
    fig, ax = plt.subplots(figsize=(14, 6))
    hw_names = list(all_hw_results.keys())
    hw_colors = ['#FF7043', '#42A5F5', '#66BB6A', '#AB47BC']

    bar_groups = []
    for ms in model_order:
        for ng in [2, 4, 8]:
            has_data = False
            for hw in hw_names:
                match = [r for r in all_hw_results[hw]
                         if r['model_size'] == ms and r['mode'] == 'ddp_builtin'
                         and r['n_gpus'] == ng and r.get('use_amp') and r.get('grad_accum_steps', 1) == 1]
                if match:
                    has_data = True
            if has_data:
                bar_groups.append((ms, ng))

    x = np.arange(len(bar_groups))
    w = 0.8 / len(hw_names)

    for hw_idx, hw in enumerate(hw_names):
        vals = []
        for ms, ng in bar_groups:
            match = [r for r in all_hw_results[hw]
                     if r['model_size'] == ms and r['mode'] == 'ddp_builtin'
                     and r['n_gpus'] == ng and r.get('use_amp') and r.get('grad_accum_steps', 1) == 1]
            vals.append(match[0]['avg_epoch_time'] if match else 0)
        offset = (hw_idx - len(hw_names)/2 + 0.5) * w
        ax.bar(x + offset, vals, w, label=hw, color=hw_colors[hw_idx % len(hw_colors)], edgecolor='white')

    ax.set_xticks(x)
    ax.set_xticklabels([f"{ms}\n{ng}G" for ms, ng in bar_groups], fontsize=8)
    ax.set_ylabel('Epoch Time (s)')
    ax.set_title('Hardware Comparison — DDP + AMP', fontweight='bold')
    ax.legend(fontsize=8)
    ax.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig('chart_hardware_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("\nUpload a second benchmark_results_*.json file to enable comparison.")
    print("Run this notebook on different hardware and copy the result file here.")

---
## 9. RunPod Setup Guide

### Recommended configurations:

| Config | GPUs | Interconnect | What you'll learn |
|--------|------|-------------|------------------|
| **4x A100 80GB PCIe** | 4 | PCIe 4.0 (~32 GB/s) | Baseline scaling, comm overhead is visible |
| **4x A100 80GB SXM** | 4 | NVLink 3.0 (~600 GB/s) | NVLink vs PCIe comparison |
| **8x A100 80GB SXM** | 8 | NVLink 3.0 | Full 8-GPU scaling, diminishing returns |
| **8x H100 80GB SXM** | 8 | NVLink 4.0 (~900 GB/s) | State-of-the-art scaling |
| **4x RTX 4090** | 4 | PCIe 4.0 | Consumer GPU, high comm overhead |

### Steps:
1. Create a RunPod instance with your chosen GPU config
2. Upload this notebook
3. Run all cells
4. Download `benchmark_results_<GPU_NAME>.json`
5. Repeat on different hardware
6. Upload all JSON files to the comparison cell above