In [5]:
# !pip install transformers, AutoTokenizer, torch

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import torch

In [None]:
# 1. Load a text dataset (IMDB movie reviews — raw text for LM)
dataset = load_dataset("imdb", split="train")
print(f"Number of examples in dataset: {len(dataset)}")

Train examples: 25000, Test examples: 25000


In [29]:
# 2. Initialize a tokenizer (DistilGPT-2 — smaller/faster, same style as GPT-2)
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token  # DistilGPT-2 doesn't have a pad token by default

In [None]:
# 3. Tokenize the dataset efficiently using `.map` with batched processing
def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=False)

tokenized_ds = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# The dataset now has columns like 'input_ids' and 'attention_mask'

print(tokenized_ds[0]["input_ids"][:20])  # print first 20 token IDs of first example for sanity check

Map: 100%|██████████| 25000/25000 [00:03<00:00, 7811.57 examples/s]
Map: 100%|██████████| 25000/25000 [00:02<00:00, 8694.28 examples/s]

[40, 26399, 314, 3001, 327, 47269, 20958, 12, 56, 23304, 3913, 422, 616, 2008, 3650, 780, 286, 477, 262, 10386]





In [None]:
# 4. Slice into training sequences of fixed length
# For language model training, often we concatenate all texts then split into blocks of e.g. 128 or 512 tokens.
block_size = 128

def _flatten(seqs):
    """Flatten to list of ints; works when elements are lists or numpy arrays."""
    out = []
    for s in seqs:
        out.extend(s.tolist() if hasattr(s, "tolist") else list(s))
    return out

def group_texts(examples):
    # Concatenate each field (flatten so we get token IDs, not list-of-arrays)
    concatenated_inputs = _flatten(examples["input_ids"])
    concatenated_masks = _flatten(examples["attention_mask"])

    total_len = (len(concatenated_inputs) // block_size) * block_size
    concatenated_inputs = concatenated_inputs[:total_len]
    concatenated_masks = concatenated_masks[:total_len]

    # Split into chunks
    result_input_ids = [concatenated_inputs[i:i+block_size] for i in range(0, total_len, block_size)]
    result_masks = [concatenated_masks[i:i+block_size] for i in range(0, total_len, block_size)]

    return {"input_ids": result_input_ids, "attention_mask": result_masks}


# remove_columns: drop "label" (and any other cols) so output has only chunk rows.
batch_size_map = 1000
lm_ds = tokenized_ds.map(
    group_texts,
    batched=True,
    batch_size=batch_size_map,
    remove_columns=tokenized_ds.column_names,
)
lm_ds_val = tokenized_ds_val.map(
    group_texts,
    batched=True,
    batch_size=batch_size_map,
    remove_columns=tokenized_ds_val.column_names,
)
print(f"Train LM sequences: {len(lm_ds)}, Val LM sequences: {len(lm_ds_val)}")

Map: 100%|██████████| 25000/25000 [00:02<00:00, 11098.53 examples/s]
Map: 100%|██████████| 25000/25000 [00:02<00:00, 11503.17 examples/s]

Train LM sequences: 51879, Val LM sequences: 51065





In [None]:
# 5. Create a DataLoader for the tokenized, grouped dataset
# We'll use a custom collate to dynamically pad sequences (though all are same length here by construction)
def collate_fn(batch):
    # Since our sequences are fixed length after grouping, we might just stack them.
    # If they weren't fixed, we could use tokenizer.pad to pad to max length in batch.
    input_ids = torch.tensor([example["input_ids"] for example in batch], dtype=torch.long)
    # For language modeling, labels are the input_ids shifted by one, but 
    # Transformers' CausalLM models usually handle that internally if we provide labels = input_ids.
    return {"input_ids": input_ids, "labels": input_ids.clone()}

train_loader = DataLoader(lm_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
# 6. Iterate through a couple of batches to see that it works
for batch in train_loader:
    print(batch["input_ids"].shape, batch["labels"].shape)
    break

Batch shapes: torch.Size([8, 128]) torch.Size([8, 128])
