In [5]:
!pip install -q transformers accelerate

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import time
import os

from torch.cuda import OutOfMemoryError

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
torch.distributed.init_process_group("gloo", rank=0, world_size=1)

print("Batch Size Scaling (FSDP)")

batch_size = 2
max_batch = 64
successes = []

while batch_size <= max_batch:
    input_ids = tokenizer(
        ["The future of AI is very bright."] * batch_size,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).input_ids.to(device)
    labels = input_ids.clone()

    model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
    model = FSDP(model)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    try:
        model.train()
        torch.cuda.reset_peak_memory_stats()
        start = time.time()
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        end = time.time()

        gpu_mem = torch.cuda.max_memory_allocated() / 1e6
        successes.append((batch_size, round(loss.item(), 4), round(end-start, 2), round(gpu_mem, 2)))
        print(f"Batch {batch_size} | Loss: {loss.item():.4f} | Time: {end-start:.2f}s | GPU: {gpu_mem:.2f} MB")
        batch_size *= 2

    except RuntimeError as e:
        print(f"Batch {batch_size} FAILED: {str(e).splitlines()[0]}")
        break

torch.distributed.destroy_process_group()
print("FSDP Batch Scaling Complete")


Batch Size Scaling (FSDP)
Batch 2 | Loss: 4.3525 | Time: 0.05s | GPU: 3341.88 MB
Batch 4 | Loss: 4.0875 | Time: 0.05s | GPU: 3340.76 MB
Batch 8 | Loss: 4.1603 | Time: 0.06s | GPU: 3029.83 MB
Batch 16 | Loss: 4.2236 | Time: 0.06s | GPU: 4036.94 MB
Batch 32 | Loss: 4.1540 | Time: 0.07s | GPU: 3766.24 MB
Batch 64 | Loss: 4.2348 | Time: 0.10s | GPU: 4860.31 MB
FSDP Batch Scaling Complete
