# The Scalable Trainer (FSDP)

## Install Dependencies

In [1]:
%pip install lightning --quiet

[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m42.4/42.4 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m827.9/827.9 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m13.8/13.8 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m24.6/24.6 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[

## Model Code

In [2]:
%%writefile model.py

import lightning as L
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
from datasets import load_dataset
from torch.utils.data import DataLoader
from lightning.pytorch.strategies import FSDPStrategy
from transformers.models.gpt2.modeling_gpt2 import GPT2Block

class BabyGPT(L.LightningModule):
    def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", lr=2e-4):
        super().__init__()
        self.save_hyperparameters()

        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.model.gradient_checkpointing_enable()

    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        val_loss = outputs.loss

        # Calculate Perplexity (The "Confusion" Score)
        # Loss is Logarithmic (e.g., 2.5). Perplexity is Linear (e.g., 12.1).
        # We convert by using exponent (e^loss)
        perplexity = torch.exp(val_loss)

        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_perplexity", perplexity, prog_bar=True)

        return val_loss

    def test_step(self, batch, batch_idx):
        outputs = self(**batch)
        test_loss = outputs.loss
        perplexity = torch.exp(test_loss)

        self.log("test_loss", test_loss, prog_bar=True)
        self.log("test_perplexity", perplexity, prog_bar=True)

        return test_loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr)
        total_steps = self.trainer.estimated_stepping_batches
        warmup_steps = int(total_steps * 0.1)
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                # CRITICAL: We update the LR every 'step' (batch), not every 'epoch'.
                "interval": "step",
            }
        }

    def generate_text(self, tokenizer, prompt, max_new_tokens=50, temperature=0.7):
        inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_k=50,
                pad_token_id=tokenizer.eos_token_id
            )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

Writing model.py


## Data Code

In [3]:
%%writefile data.py

import lightning as L
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
from datasets import load_dataset
from torch.utils.data import DataLoader
from lightning.pytorch.strategies import FSDPStrategy
from transformers.models.gpt2.modeling_gpt2 import GPT2Block

class WikiDataModule(L.LightningDataModule):
    def __init__(self, model_name="gpt2", batch_size=32, max_length=128):
        super().__init__()
        self.model_name = model_name
        self.batch_size = batch_size
        self.max_length = max_length

        # Performance Tip: Set num_workers to your CPU count to load data faster.
        self.num_workers = 6

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # The fix we discussed earlier
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def prepare_data(self):
        self.dataset = load_dataset("wikitext", "wikitext-103-raw-v1")

    def setup(self, stage=None):
        # 1. Load raw data
        dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

        # 2. Define the tokenizer logic
        def tokenize_function(examples):
            # We truncate here to ensure no sequence exceeds our max memory
            return self.tokenizer(
                examples["text"],
                truncation=True,
                max_length=self.max_length
            )

        # 3. Apply tokenization (Map)
        # We remove the 'text' column because the model only needs numbers (input_ids).
        tokenized_datasets = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"]
        )

        # 4. Filter out empty sequences (this is the fix!)
        def filter_empty(example):
            return len(example["input_ids"]) > 0

        tokenized_datasets = tokenized_datasets.filter(filter_empty)

        # 5. Split for training phases
        if stage == 'fit' or stage is None:
            self.train_dataset = tokenized_datasets["train"]
            self.val_dataset = tokenized_datasets["validation"]

        if stage == 'test':
            self.test_dataset = tokenized_datasets["test"]

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True, # Always shuffle training data!
            num_workers=self.num_workers,
            # This is where Dynamic Padding happens:
            collate_fn=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            pin_memory=True # Speed boost for data transfer to GPU
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            pin_memory=True
        )

Writing data.py


## The Auto-Wrap Policy

In [4]:
%%writefile train.py

import lightning as L
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
from datasets import load_dataset
from torch.utils.data import DataLoader
from lightning.pytorch.strategies import FSDPStrategy
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from model import BabyGPT
from data import WikiDataModule


def train_fsdp():
    fsdp_strategy = FSDPStrategy(
        auto_wrap_policy={GPT2Block},
        sharding_strategy="FULL_SHARD",
        cpu_offload=False,
    )
    trainer = L.Trainer(
        accelerator="gpu",
        devices=2,
        strategy=fsdp_strategy,
        precision="bf16-mixed",
        max_epochs=3,
        log_every_n_steps=10,
    )

    dm = WikiDataModule(model_name="gpt2-medium", batch_size=4)
    model = BabyGPT(model_name="gpt2-medium")

    print("Starting FSDP Training... üå∂Ô∏è")
    trainer.fit(model, datamodule=dm)

if __name__ == "__main__":
    train_fsdp()

Writing train.py


In [5]:
!python train.py

2025-11-22 18:21:45.596552: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763835705.814784      55 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763835705.874001      55 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
üí° Tip: For seamless cloud uploads and versioning, try installing [