# The LLMDataModule

## Install the required packages

In [1]:
%pip install torch lightning datasets --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.4/42.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m827.9/827.9 kB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m43.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m831.6/831.6 kB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
[?25h

## Step 1: The Tokenizer & Collate Function

In [None]:
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader

In [3]:
# 1. Load the Tokenizer
# We use GP
# T-2, a classic standard for Causal LLMs.
tokenizer = AutoTokenizer.from_pretrained("gpt2")

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [4]:
# CRITICAL FIX for GPT-2:
# GPT-2 was trained without a "pad" token.
# If we don't manually assign one, the code will crash when we try to pad unequal sentences.
# We tell it: "Use the End-Of-Sentence token as the Pad token."
tokenizer.pad_token = tokenizer.eos_token

In [5]:
# 2. The Magic Component: DataCollator
# This function runs EVERY time we fetch a batch.
# It checks the longest sentence in the batch and pads the others to match it.
# mlm=False means "Masked Language Modeling = False".
# We are doing Causal LM (Next Token Prediction), not BERT-style masking.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

## Step 2: The Lightning Data Module

In [None]:
import lightning as L
from datasets import load_dataset

In [8]:
class LLMDataModule(L.LightningDataModule):
    def __init__(self, model_name="gpt2", batch_size=32, max_length=128):
        super().__init__()
        self.model_name = model_name
        self.batch_size = batch_size
        self.max_length = max_length

        # Performance Tip: Set num_workers to your CPU count to load data faster.
        self.num_workers = 4

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # The fix we discussed earlier
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def prepare_data(self):
        self.dataset = load_dataset("wikitext", "wikitext-103-raw-v1")

    def setup(self, stage=None):
        # 1. Load raw data
        dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

        # 2. Define the tokenizer logic
        def tokenize_function(examples):
            # We truncate here to ensure no sequence exceeds our max memory
            return self.tokenizer(
                examples["text"],
                truncation=True,
                max_length=self.max_length
            )

        # 3. Apply tokenization (Map)
        # We remove the 'text' column because the model only needs numbers (input_ids).
        tokenized_datasets = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"]
        )

        # 4. Split for training phases
        if stage == 'fit' or stage is None:
            self.train_dataset = tokenized_datasets["train"]
            self.val_dataset = tokenized_datasets["validation"]

        if stage == 'test':
            self.test_dataset = tokenized_datasets["test"]

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True, # Always shuffle training data!
            num_workers=self.num_workers,
            # This is where Dynamic Padding happens:
            collate_fn=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            pin_memory=True # Speed boost for data transfer to GPU
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            pin_memory=True
        )


## Testing the Pipeline

In [9]:
def debug_datamodule():
    # Initialize the module
    dm = LLMDataModule(batch_size=4)

    # Manually run the steps usually handled by Trainer
    dm.prepare_data()
    dm.setup()

    # Get a single batch from the loader
    dataloader = dm.train_dataloader()
    batch = next(iter(dataloader))

    print("Keys available:", batch.keys())
    # Expected: dict_keys(['input_ids', 'attention_mask', 'labels'])

    print("Input Shape:", batch['input_ids'].shape)
    # Expected: torch.Size([4, <dynamic_length>])

    # Verify the data makes sense (Decode back to text)
    decoded = dm.tokenizer.decode(batch['input_ids'][0])
    print(f"\n--- Sample Text (Decoded) ---\n{decoded[:100]}...")

    # Check for Labels
    # In Causal LM, the 'labels' are usually just the 'input_ids' shifted by one.
    # The DataCollator creates this 'labels' key for us automatically!
    print("\nLabels included?", 'labels' in batch)

In [10]:
debug_datamodule()

README.md: 0.00B [00:00, ?B/s]

wikitext-103-raw-v1/test-00000-of-00001.(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-103-raw-v1/train-00000-of-00002(…):   0%|          | 0.00/157M [00:00<?, ?B/s]

wikitext-103-raw-v1/train-00001-of-00002(…):   0%|          | 0.00/157M [00:00<?, ?B/s]

wikitext-103-raw-v1/validation-00000-of-(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]



Keys available: KeysView({'input_ids': tensor([[14489,   220,   198, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [  796, 15120,   357,  2321,  2008,   983,  1267,   796,   220,   198]]), 'attention_mask': tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[14489,   220,   198,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [  796, 15120,   357,  2321,  2008,   983,  1267,   796,   220,   198]])})
Input Shape: torch.Size([4, 10])

--- Sample Text (Decoded) ---
 1982 
<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|e

## Challenge

Research "Sequence Packing."

- *Current method:* `[Sentence A, Pad, Pad]` and `[Sentence B]`
- *Packed method:* `[Sentence A, Sentence B, Sentence C]` (All concatenated to fill the context window).
- This removes padding entirely and is how top-tier LLMs are trained.

In [11]:
from itertools import chain

class PackedDataModule(L.LightningDataModule):
    def __init__(self, model_name="gpt2", batch_size=32, block_size=128):
        super().__init__()
        self.model_name = model_name
        self.batch_size = batch_size
        self.block_size = block_size # This is the fixed window size (context length)
        self.num_workers = 4
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def prepare_data(self):
        load_dataset('wikitext', 'wikitext-2-raw-v1')

    def setup(self, stage=None):
        dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

        # 1. Basic Tokenization (No padding/truncation yet)
        def tokenize_function(examples):
            return self.tokenizer(examples["text"])

        tokenized_datasets = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"]
        )

        # 2. THE PACKING LOGIC
        # This function concatenates all texts and chops them into blocks
        def group_texts(examples):
            # Concatenate all texts in this batch
            concatenated = {k: list(chain(*examples[k])) for k in examples.keys()}
            total_length = len(concatenated[list(examples.keys())[0]])

            # We drop the small remainder at the end to keep shapes perfect
            if total_length >= self.block_size:
                total_length = (total_length // self.block_size) * self.block_size

            # Split by chunks of block_size
            result = {
                k: [t[i : i + self.block_size] for i in range(0, total_length, self.block_size)]
                for k, t in concatenated.items()
            }

            # Create labels (copies of input_ids) used for training
            result["labels"] = result["input_ids"].copy()
            return result

        # Apply the packing
        lm_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
        )

        if stage == 'fit' or stage is None:
            self.train_dataset = lm_datasets["train"]
            self.val_dataset = lm_datasets["validation"]

    def train_dataloader(self):
        # Note: We use default_data_collator now because everything is ALREADY
        # perfectly sized to 'block_size'. No dynamic padding needed!
        from transformers import default_data_collator
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=default_data_collator,
            pin_memory=True
        )

In [12]:
def debug_packeddatamodule():
    # Initialize the module
    dm = PackedDataModule(batch_size=4)

    # Manually run the steps usually handled by Trainer
    dm.prepare_data()
    dm.setup()

    # Get a single batch from the loader
    dataloader = dm.train_dataloader()
    batch = next(iter(dataloader))

    print("Keys available:", batch.keys())
    # Expected: dict_keys(['input_ids', 'attention_mask', 'labels'])

    print("Input Shape:", batch['input_ids'].shape)
    # Expected: torch.Size([4, <dynamic_length>])

    # Verify the data makes sense (Decode back to text)
    decoded = dm.tokenizer.decode(batch['input_ids'][0])
    print(f"\n--- Sample Text (Decoded) ---\n{decoded[:100]}...")

    # Check for Labels
    # In Causal LM, the 'labels' are usually just the 'input_ids' shifted by one.
    # The DataCollator creates this 'labels' key for us automatically!
    print("\nLabels included?", 'labels' in batch)

In [13]:
debug_packeddatamodule()

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]



Keys available: dict_keys(['input_ids', 'attention_mask', 'labels'])
Input Shape: torch.Size([4, 128])

--- Sample Text (Decoded) ---
 Valentin Alkan from the rear , as in some photographs we have seen . His intelligent and original p...

Labels included? True
