In [7]:
!pip install -q transformers accelerate

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.profiler import profile, record_function, ProfilerActivity
import torch.distributed as dist
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.to(device)
labels = input_ids.clone()

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=0, world_size=1)

fsdp_model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
fsdp_model = FSDP(fsdp_model)
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=5e-5)

print("Profiling FSDP")
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
             record_shapes=True,
             with_stack=True) as prof:
    with record_function("FSDP_Train_Step"):
        outputs = fsdp_model(input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
dist.destroy_process_group()


Profiling FSDP
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_         0.17%     275.088us         0.33%     540.959us      20.036us      53.734ms        54.09%      53.734ms       1.990ms            27  
                                        FSDP_Train_Step        10.55%      17.446ms        42.80%      70.756ms      70.756ms       0.000us         0.00%      33.573ms      33.573ms           