# 🚀 Distributed Training in PyTorch: **DDP & FSDP** — Complete, Working Notebook

This notebook gives you **end-to-end, working code** for both **DDP (DistributedDataParallel)** and **FSDP (FullyShardedDataParallel)** using a tiny GPT-style model and a synthetic token-stream dataset.

It is designed to **actually run** across 1+ GPUs (or CPU) by writing runnable training scripts and launching them via `torchrun`.

## What you get
- Minimal yet realistic **MiniGPT** model (decoder-only)
- **Synthetic token-stream** dataset (no external data required)
- **AMP**, **Grad Accum**, **Cosine LR warmup/decay**, gradient **clipping**
- **DDP trainer** and **FSDP trainer** scripts generated to disk
- **Checkpoint save/load** (correct unwrapping), rank-0 logging
- **GPU/CPU** compatible (NCCL on CUDA, GLOO otherwise) — no `torch.compile`
- Lots of comments + troubleshooting tips

### How to use
1. Run each code cell **in order**.
2. Use the provided **launcher cells** to run DDP or FSDP with `torchrun`.
3. If you're on CPU only, you can still run with `--nproc_per_node=1` to validate end-to-end behavior.

> ⚠️ Multi-process distributed training is best launched from the shell. This notebook writes training scripts that you can run from here or from your terminal.

In [1]:
# 0) Environment sanity + global guards (disable PT2 compile just in case)
!pip -q install torch tokenizers tqdm
import os, torch
os.environ['TORCH_COMPILE_DISABLE'] = '1'
try:
    torch._dynamo.reset()
except Exception:
    pass
print('PyTorch:', torch.__version__, '| CUDA available:', torch.cuda.is_available(), '| GPUs:', torch.cuda.device_count())


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
PyTorch: 2.9.0 | CUDA available: False | GPUs: 0


1) Shared utilities: model + dataset (written to `lib_minigpt.py`): create a tiny GPT-style model and a simple token-stream dataset. Both trainers import this file

In [14]:
# Run this once to (re)generate the 3 files and verify they exist.
from pathlib import Path

lib_code = r"""import math, torch
from torch import nn
from torch.nn import functional as F

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embed, n_head, block_size, dropout=0.1):
        super().__init__()
        assert n_embed % n_head == 0
        self.n_head = n_head
        self.qkv = nn.Linear(n_embed, 3 * n_embed, bias=False)
        self.proj = nn.Linear(n_embed, n_embed, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)
        mask = torch.tril(torch.ones(block_size, block_size))
        self.register_buffer('mask', mask.view(1,1,block_size,block_size))
    def forward(self, x):
        B,T,C = x.shape
        qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1)
        nh = self.n_head
        q = q.view(B,T,nh,-1).transpose(1,2)
        k = k.view(B,T,nh,-1).transpose(1,2)
        v = v.view(B,T,nh,-1).transpose(1,2)
        att = (q @ k.transpose(-2,-1)) / math.sqrt(k.size(-1))
        att = att.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
        att = self.attn_drop(att.softmax(dim=-1))
        y = att @ v
        y = y.transpose(1,2).contiguous().view(B,T,-1)
        return self.resid_drop(self.proj(y))

class Block(nn.Module):
    def __init__(self, n_embed, n_head, block_size, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embed)
        self.attn = CausalSelfAttention(n_embed, n_head, block_size, dropout)
        self.ln2 = nn.LayerNorm(n_embed)
        self.mlp = nn.Sequential(nn.Linear(n_embed, 4*n_embed), nn.GELU(),
                                 nn.Linear(4*n_embed, n_embed), nn.Dropout(dropout))
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class MiniGPT(nn.Module):
    def __init__(self, vocab_size, n_embed=384, n_head=6, n_layer=6, block_size=256, dropout=0.1):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embed)
        self.pos_emb = nn.Embedding(block_size, n_embed)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(n_embed, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.head = nn.Linear(n_embed, vocab_size, bias=False)
        self.apply(self._init)
    def _init(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
    def forward(self, idx, targets=None):
        B,T = idx.shape
        pos = torch.arange(0, T, device=idx.device)
        x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)[None,:,:])
        for blk in self.blocks: x = blk(x)
        logits = self.head(self.ln_f(x))
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(B*T, -1), targets.view(B*T))
        return logits, loss

def make_stream(n_tokens:int, vocab_size:int):
    return torch.randint(0, vocab_size, (n_tokens,), dtype=torch.long)

class TokenStreamDataset(torch.utils.data.Dataset):
    def __init__(self, tensor_data, block_size):
        self.data = tensor_data
        self.block = block_size
    def __len__(self): return max(0, len(self.data) - self.block - 1)
    def __getitem__(self, i):
        x = self.data[i:i+self.block]
        y = self.data[i+1:i+1+self.block]
        return x, y
"""

ddp_code = r"""import os, math, torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import GradScaler, autocast
from lib_minigpt import MiniGPT, make_stream, TokenStreamDataset

def setup_dist():
    backend = 'nccl' if torch.cuda.is_available() else 'gloo'
    dist.init_process_group(backend=backend)
    if torch.cuda.is_available():
        torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', 0)))

def is_main():
    return int(os.environ.get('RANK', '0')) == 0

def cosine_factor(step, max_steps, warmup=200, min_factor=0.1):
    if step < warmup:
        return max(1e-8, (step+1)/max(1, warmup))
    progress = (step - warmup) / max(1, max_steps - warmup)
    return min_factor + 0.5*(1-min_factor)*(1 + math.cos(math.pi*progress))

def main():
    setup_dist()
    device = torch.device(f"cuda:{os.environ.get('LOCAL_RANK','0')}" if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        try: torch.set_float32_matmul_precision('medium')
        except Exception: pass

    BLOCK = int(os.environ.get('BLOCK_SIZE', '256'))
    VOCAB = int(os.environ.get('VOCAB_SIZE', '8000'))
    TRAIN_TOK = int(os.environ.get('TRAIN_TOKENS', '600000'))
    VAL_TOK   = int(os.environ.get('VAL_TOKENS',   '60000'))
    BATCH = int(os.environ.get('BATCH_SIZE', '32'))
    ACCUM = int(os.environ.get('GRAD_ACCUM', '2'))
    LR    = float(os.environ.get('LR', '3e-4'))
    MAX_STEPS = int(os.environ.get('MAX_STEPS', '300'))
    WARMUP = int(os.environ.get('WARMUP_STEPS', '100'))
    CLIP = float(os.environ.get('CLIP_NORM', '1.0'))
    USE_AMP = torch.cuda.is_available()

    train_stream = make_stream(TRAIN_TOK, VOCAB)
    val_stream   = make_stream(VAL_TOK,   VOCAB)
    train_ds = TokenStreamDataset(train_stream, BLOCK)
    val_ds   = TokenStreamDataset(val_stream,   BLOCK)
    train_samp = DistributedSampler(train_ds, shuffle=True, drop_last=True)
    val_samp   = DistributedSampler(val_ds,   shuffle=False, drop_last=False)
    train_dl = DataLoader(train_ds, batch_size=BATCH, sampler=train_samp, num_workers=4, pin_memory=True)
    val_dl   = DataLoader(val_ds,   batch_size=BATCH, sampler=val_samp,   num_workers=2, pin_memory=True)

    model = MiniGPT(VOCAB, n_embed=384, n_head=6, n_layer=6, block_size=BLOCK, dropout=0.1).to(device)
    model = DDP(model, device_ids=[device] if device.type == 'cuda' else None)
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    scaler = GradScaler(enabled=USE_AMP)

    def run_val():
        model.eval()
        total, count = torch.tensor(0.0, device=device), torch.tensor(0, device=device)
        with torch.no_grad():
            for xb, yb in val_dl:
                xb, yb = xb.to(device), yb.to(device)
                _, loss = model(xb, yb)
                total += loss.detach(); count += 1
        dist.all_reduce(total, op=dist.ReduceOp.SUM)
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        model.train()
        return (total / count).item() if int(count.item())>0 else float('nan')

    # Tiny warmup
    for _ in range(2):
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            with autocast(enabled=USE_AMP, dtype=torch.bfloat16 if USE_AMP else None):
                _, loss = model(xb, yb)
            loss.backward(); opt.zero_grad(set_to_none=True); break

    step = 0
    while step < MAX_STEPS:
        train_samp.set_epoch(step)
        for xb, yb in train_dl:
            fac = cosine_factor(step, MAX_STEPS, warmup=WARMUP, min_factor=0.1)
            for pg in opt.param_groups: pg['lr'] = LR * fac
            opt.zero_grad(set_to_none=True)
            with autocast(enabled=USE_AMP, dtype=torch.bfloat16 if USE_AMP else None):
                _, loss = model(xb.to(device, non_blocking=True), yb.to(device, non_blocking=True))
                loss = loss / max(1, ACCUM)
            scaler.scale(loss).backward()
            if (step+1) % ACCUM == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
                scaler.step(opt); scaler.update()
            step += 1
            if step % 100 == 0 or step == MAX_STEPS:
                if is_main():
                    vl = run_val()
                    print(f"step {step}/{MAX_STEPS} | val_loss {vl:.4f}")
            if step >= MAX_STEPS: break

    if is_main():
        torch.save(model.module.state_dict(), 'minigpt_ddp.pt')
    dist.barrier(); dist.destroy_process_group()

if __name__ == '__main__':
    main()
"""

fsdp_code = r"""import os, math, torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import GradScaler, autocast
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
from lib_minigpt import MiniGPT, make_stream, TokenStreamDataset, Block

def setup_dist():
    backend = 'nccl' if torch.cuda.is_available() else 'gloo'
    dist.init_process_group(backend=backend)
    if torch.cuda.is_available():
        torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', 0)))

def is_main():
    return int(os.environ.get('RANK', '0')) == 0

def cosine_factor(step, max_steps, warmup=200, min_factor=0.1):
    if step < warmup:
        return max(1e-8, (step+1)/max(1, warmup))
    progress = (step - warmup) / max(1, max_steps - warmup)
    return min_factor + 0.5*(1-min_factor)*(1 + math.cos(math.pi*progress))

def main():
    setup_dist()
    device = torch.device(f"cuda:{os.environ.get('LOCAL_RANK','0')}" if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        try: torch.set_float32_matmul_precision('medium')
        except Exception: pass

    BLOCK = int(os.environ.get('BLOCK_SIZE', '256'))
    VOCAB = int(os.environ.get('VOCAB_SIZE', '8000'))
    TRAIN_TOK = int(os.environ.get('TRAIN_TOKENS', '600000'))
    VAL_TOK   = int(os.environ.get('VAL_TOKENS',   '60000'))
    BATCH = int(os.environ.get('BATCH_SIZE', '32'))
    LR    = float(os.environ.get('LR', '3e-4'))
    MAX_STEPS = int(os.environ.get('MAX_STEPS', '300'))
    WARMUP = int(os.environ.get('WARMUP_STEPS', '100'))
    USE_AMP = torch.cuda.is_available()

    train_stream = make_stream(TRAIN_TOK, VOCAB)
    val_stream   = make_stream(VAL_TOK,   VOCAB)
    train_ds = TokenStreamDataset(train_stream, BLOCK)
    val_ds   = TokenStreamDataset(val_stream,   BLOCK)
    train_samp = DistributedSampler(train_ds, shuffle=True, drop_last=True)
    val_samp   = DistributedSampler(val_ds,   shuffle=False, drop_last=False)
    train_dl = DataLoader(train_ds, batch_size=BATCH, sampler=train_samp, num_workers=4, pin_memory=True)
    val_dl   = DataLoader(val_ds,   batch_size=BATCH, sampler=val_samp,   num_workers=2, pin_memory=True)

    base_model = MiniGPT(VOCAB, n_embed=384, n_head=6, n_layer=6, block_size=BLOCK, dropout=0.1)

    # optional activation checkpointing per block
    for i, blk in enumerate(base_model.blocks):
        base_model.blocks[i] = checkpoint_wrapper(blk)

    auto_wrap = transformer_auto_wrap_policy({Block})

    mp_policy = None
    if torch.cuda.is_available():
        from torch.distributed.fsdp import MixedPrecision
        mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16)

    model = FSDP(base_model, auto_wrap_policy=auto_wrap, mixed_precision=mp_policy).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    scaler = GradScaler(enabled=USE_AMP)

    @torch.no_grad()
    def run_val():
        model.eval()
        total, count = torch.tensor(0.0, device=device), torch.tensor(0, device=device)
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            with autocast(enabled=USE_AMP, dtype=torch.bfloat16 if USE_AMP else None):
                _, loss = model(xb, yb)
            total += loss.detach(); count += 1
        dist.all_reduce(total, op=dist.ReduceOp.SUM)
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        model.train()
        return (total / count).item() if int(count.item())>0 else float('nan')

    # tiny warmup
    for _ in range(2):
        for xb, yb in train_dl:
            with autocast(enabled=USE_AMP, dtype=torch.bfloat16 if USE_AMP else None):
                _, loss = model(xb.to(device), yb.to(device))
            loss.backward(); opt.zero_grad(set_to_none=True); break

    step = 0
    ACCUM = 2
    CLIP = 1.0
    while step < MAX_STEPS:
        train_samp.set_epoch(step)
        for xb, yb in train_dl:
            fac = cosine_factor(step, MAX_STEPS, warmup=WARMUP, min_factor=0.1)
            for pg in opt.param_groups: pg['lr'] = float(LR) * fac
            opt.zero_grad(set_to_none=True)
            with autocast(enabled=USE_AMP, dtype=torch.bfloat16 if USE_AMP else None):
                _, loss = model(xb.to(device, non_blocking=True), yb.to(device, non_blocking=True))
                loss = loss / ACCUM
            scaler.scale(loss).backward()
            if (step+1) % ACCUM == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
                scaler.step(opt); scaler.update()
            step += 1
            if step % 100 == 0 or step == MAX_STEPS:
                if is_main():
                    vl = run_val()
                    print(f"step {step}/{MAX_STEPS} | val_loss {vl:.4f}")
            if step >= MAX_STEPS: break

    if is_main():
        from torch.distributed.fsdp import StateDictType, FullStateDictConfig
        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
                                  FullStateDictConfig(offload_to_cpu=True, rank0_only=True)):
            sd = model.state_dict()
        torch.save(sd, 'minigpt_fsdp_full.pt')

    dist.barrier(); dist.destroy_process_group()

if __name__ == '__main__':
    main()
"""

# write files
Path('lib_minigpt.py').write_text(lib_code)
Path('train_ddp_minigpt.py').write_text(ddp_code)
Path('train_fsdp_minigpt.py').write_text(fsdp_code)

# verify
missing = [p for p in ['lib_minigpt.py','train_ddp_minigpt.py','train_fsdp_minigpt.py'] if not Path(p).exists()]
print("Wrote files." if not missing else f"Missing: {missing}")


Wrote files.


3) Write FSDP trainer (`train_fsdp_minigpt.py`)
    - this script uses PyTorch FSDP for sharding params/grad/optim state across GPUs
        - Mixed precision (`bf16` if CUDA available)
        - Auto-wrap transformer blocks
        - Full-state-dict checkpointing (rank 0)
        NOTE: FSDP on CPU will run, but its benefit is when you have multiple GPUs.


In [15]:
fsdp_path = Path('train_fsdp_minigpt.py')
fsdp_code = r'''import os, math, time, torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import GradScaler, autocast
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
from lib_minigpt import MiniGPT, make_stream, TokenStreamDataset, Block

def setup_dist():
    backend = 'nccl' if torch.cuda.is_available() else 'gloo'
    dist.init_process_group(backend=backend)
    if torch.cuda.is_available():
        torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', 0)))

def is_main():
    return int(os.environ.get('RANK', '0')) == 0

def cosine_factor(step, max_steps, warmup=200, min_factor=0.1):
    if step < warmup:
        return max(1e-8, (step+1)/max(1, warmup))
    progress = (step - warmup) / max(1, max_steps - warmup)
    return min_factor + 0.5 * (1-min_factor) * (1 + math.cos(math.pi * progress))

def main():
    setup_dist()
    device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}" if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        try: torch.set_float32_matmul_precision('medium')
        except Exception: pass

    # Config
    BLOCK = int(os.environ.get('BLOCK_SIZE', '256'))
    VOCAB = int(os.environ.get('VOCAB_SIZE', '8000'))
    TRAIN_TOK = int(os.environ.get('TRAIN_TOKENS', '600000'))
    VAL_TOK   = int(os.environ.get('VAL_TOKENS',   '60000'))
    BATCH = int(os.environ.get('BATCH_SIZE', '32'))
    ACCUM = int(os.environ.get('GRAD_ACCUM', '2'))
    LR    = float(os.environ.get('LR', '3e-4'))
    MAX_STEPS = int(os.environ.get('MAX_STEPS', '500'))
    WARMUP = int(os.environ.get('WARMUP_STEPS', '200'))
    CLIP = float(os.environ.get('CLIP_NORM', '1.0'))
    USE_AMP = torch.cuda.is_available()

    # Data
    train_stream = make_stream(TRAIN_TOK, VOCAB)
    val_stream   = make_stream(VAL_TOK,   VOCAB)
    train_ds = TokenStreamDataset(train_stream, BLOCK)
    val_ds   = TokenStreamDataset(val_stream,   BLOCK)
    train_samp = DistributedSampler(train_ds, shuffle=True, drop_last=True)
    val_samp   = DistributedSampler(val_ds,   shuffle=False, drop_last=False)
    train_dl = DataLoader(train_ds, batch_size=BATCH, sampler=train_samp, num_workers=4, pin_memory=True)
    val_dl   = DataLoader(val_ds,   batch_size=BATCH, sampler=val_samp,   num_workers=2, pin_memory=True)

    # Base model on CPU first (FSDP will move shards to GPU)
    base_model = MiniGPT(VOCAB, n_embed=384, n_head=6, n_layer=6, block_size=BLOCK, dropout=0.1)

    # Optional activation checkpointing on each transformer Block for memory
    for i, blk in enumerate(base_model.blocks):
        base_model.blocks[i] = checkpoint_wrapper(blk)

    # Auto-wrap policy for transformer Blocks
    auto_wrap = transformer_auto_wrap_policy({Block})

    # Mixed precision config for FSDP (use bf16 on CUDA if available)
    mp_policy = None
    if torch.cuda.is_available():
        from torch.distributed.fsdp import MixedPrecision
        mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16)

    # Wrap with FSDP
    model = FSDP(base_model, auto_wrap_policy=auto_wrap, mixed_precision=mp_policy)
    model = model.to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    scaler = GradScaler(enabled=USE_AMP)

    def run_val():
        model.eval()
        total, count = torch.tensor(0.0, device=device), torch.tensor(0, device=device)
        with torch.no_grad():
            for xb, yb in val_dl:
                xb, yb = xb.to(device), yb.to(device)
                with autocast(enabled=USE_AMP, dtype=torch.bfloat16 if USE_AMP else None):
                    _, loss = model(xb, yb)
                total += loss.detach()
                count += 1
        dist.all_reduce(total, op=dist.ReduceOp.SUM)
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        model.train()
        return (total / count).item() if int(count.item())>0 else float('nan')

    # Warmup micro-steps
    for _ in range(3):
        for xb, yb in train_dl:
            with autocast(enabled=USE_AMP, dtype=torch.bfloat16 if USE_AMP else None):
                _, loss = model(xb.to(device), yb.to(device))
            loss.backward(); opt.zero_grad(set_to_none=True); break

    # Train
    step = 0
    while step < MAX_STEPS:
        train_samp.set_epoch(step)
        for xb, yb in train_dl:
            fac = cosine_factor(step, MAX_STEPS, warmup=WARMUP, min_factor=0.1)
            for pg in opt.param_groups: pg['lr'] = LR * fac
            opt.zero_grad(set_to_none=True)
            with autocast(enabled=USE_AMP, dtype=torch.bfloat16 if USE_AMP else None):
                _, loss = model(xb.to(device, non_blocking=True), yb.to(device, non_blocking=True))
                loss = loss / 2  # ACCUM fixed to 2 for clarity
            scaler.scale(loss).backward()
            if (step+1) % 2 == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt); scaler.update()
            step += 1
            if step % 100 == 0 or step == MAX_STEPS:
                if is_main():
                    vl = run_val()
                    print(f"step {step}/{MAX_STEPS} | val_loss {vl:.4f}")
            if step >= MAX_STEPS: break

    # Save full state dict on rank 0
    if is_main():
        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)):
            sd = model.state_dict()
        torch.save(sd, 'minigpt_fsdp_full.pt')

    dist.barrier(); dist.destroy_process_group()

if __name__ == '__main__':
    main()
'''
fsdp_path.write_text(fsdp_code)
print(f'Wrote {fsdp_path.resolve()}')

Wrote /Users/ankushraj/Desktop/Rising Sun Labs Resource/R-S-L-Repositories/SunForgeLLM/SunForgeLLM/Knowledge-Resource/L5_distributed_training/train_fsdp_minigpt.py


## 4) Launchers — How to run DDP / FSDP from here

### DDP (single machine)
Use all visible GPUs:
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=$(python - <<'PY'\nimport torch; print(torch.cuda.device_count() or 1)\nPY) \
  train_ddp_minigpt.py
```
Or pick specific GPUs:
```bash
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes=1 --nproc_per_node=2 train_ddp_minigpt.py
```

### FSDP (single machine)
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=$(python - <<'PY'\nimport torch; print(torch.cuda.device_count() or 1)\nPY) \
  train_fsdp_minigpt.py
```

> **CPU-only test:** set `--nproc_per_node=1`; backend switches to GLOO automatically.

You can also tweak env variables inline, for example:
```bash
BLOCK_SIZE=256 VOCAB_SIZE=8000 MAX_STEPS=200 BATCH_SIZE=32 \
torchrun --standalone --nproc_per_node=2 train_ddp_minigpt.py
```

## 5) (Optional) Run a tiny smoke test locally
This just verifies imports and that scripts exist. For true distributed runs, use the launcher commands above.

In [16]:
import importlib.util, sys
for p in ['lib_minigpt.py', 'train_ddp_minigpt.py', 'train_fsdp_minigpt.py']:
    assert Path(p).exists(), f'Missing {p}'
print('Files present ✅')

spec = importlib.util.spec_from_file_location('lib_minigpt', 'lib_minigpt.py')
m = importlib.util.module_from_spec(spec); spec.loader.exec_module(m)
model = m.MiniGPT(vocab_size=8000, block_size=256)
xb = torch.randint(0, 8000, (2,256)); yb = torch.randint(0,8000,(2,256))
with torch.no_grad():
    _, loss = model(xb, yb)
print('Model forward OK, loss=', float(loss))

Files present ✅
Model forward OK, loss= 9.060871124267578
