In [None]:
# distributed_training.ipynb

# -------------------------------
# 1. Distributed Setup
# -------------------------------
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
from datasets import load_dataset
from torch.utils.data import DataLoader, DistributedSampler
import time
import warnings

warnings.filterwarnings("ignore")

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

# -------------------------------
# 2. Load Dataset & Tokenizer
# -------------------------------
def get_dataloader(batch_size, rank, world_size):
    raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
    tokenizer.pad_token = tokenizer.eos_token  # for causal LM

    def tokenize_function(example):
        return tokenizer(example["text"], padding="max_length", truncation=True, max_length=128)

    tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
    sampler = DistributedSampler(tokenized_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, sampler=sampler)
    return dataloader

# -------------------------------
# 3. Training Loop
# -------------------------------
def train(rank, world_size):
    setup(rank, world_size)

    device = torch.device(f"cuda:{rank}")
    model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m").to(device)
    model = DDP(model, device_ids=[rank])

    dataloader = get_dataloader(batch_size=8, rank=rank, world_size=world_size)
    optimizer = AdamW(model.parameters(), lr=5e-5)

    model.train()
    for epoch in range(1):  # Increase for real training
        total_loss = 0
        for step, batch in enumerate(dataloader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if step % 100 == 0 and rank == 0:
                print(f"[GPU {rank}] Step {step}, Loss: {loss.item():.4f}")

        if rank == 0:
            print(f"✅ Epoch complete. Rank {rank} Avg Loss: {total_loss/len(dataloader):.4f}")

    cleanup()

# -------------------------------
# 4. Multiprocessing Launch
# -------------------------------
def main():
    world_size = torch.cuda.device_count()
    if world_size < 2:
        print("❌ Requires at least 2 GPUs for DDP.")
        return
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()
