In [1]:
# !pip install transformers, AutoTokenizer, torch

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import torch

In [3]:
# 1. Load a text dataset (we use a small example dataset for demonstration)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")  # raw text WikiText-2
print(f"Number of lines in dataset: {len(dataset)}")

Number of lines in dataset: 36718


In [4]:
model_name = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Tokenizer loaded:", model_name)
print("Vocab size:", tokenizer.vocab_size)
print("Pad token:", tokenizer.pad_token, "eos token:", tokenizer.eos_token)


Tokenizer loaded: facebook/opt-125m
Vocab size: 50265
Pad token: <pad> eos token: </s>


In [5]:
from typing import List

# 3. Tokenize the dataset efficiently 
# Approach:
#  - take the list of raw text strings
#  - run tokenizer in Python in batches (tokenizer(texts, add_special_tokens=False))
#  - concatenate all token ids into a single list of ids
#  - split into non-overlapping blocks of block_size tokens
#  - create a PyTorch Dataset that yields these blocks

def batch_tokenize_texts(texts: List[str], tokenizer, batch_size: int = 512):
    """
    Tokenize a list of strings in batches and return a flat list of token ids lists.
    We use add_special_tokens=False so block splitting is straightforward.
    """
    all_input_ids = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        enc = tokenizer(batch, add_special_tokens=False)
        all_input_ids.extend(enc["input_ids"])
    return all_input_ids

# Parameters
block_size = 128
batch_tokenize_batch_size = 512

# Extract raw texts
text_column = "text" if "text" in dataset.column_names else dataset.column_names[0]
raw_texts = [x[text_column] for x in dataset]

# Tokenize all lines in batches
tokenized_lines = batch_tokenize_texts(raw_texts, tokenizer, batch_size=batch_tokenize_batch_size)

# Concatenate token ids + EOS
all_ids = []
for ids in tokenized_lines:
    all_ids.extend(ids + [tokenizer.eos_token_id])

# Split into blocks
num_full_blocks = len(all_ids) // block_size
examples = []
for i in range(num_full_blocks):
    start = i * block_size
    block_ids = all_ids[start : start + block_size]
    examples.append(block_ids)

print(f"Total tokens: {len(all_ids)}; Full {block_size}-token blocks: {len(examples)}")


Total tokens: 2428602; Full 128-token blocks: 18973


In [6]:
from torch.utils.data import Dataset, DataLoader

class BlockDataset(Dataset):
    def __init__(self, blocks):
        # blocks: list of list[int]
        self.blocks = blocks

    def __len__(self):
        return len(self.blocks)

    def __getitem__(self, idx):
        ids = torch.tensor(self.blocks[idx], dtype=torch.long)
        return {"input_ids": ids, "labels": ids.clone()}

lm_ds = BlockDataset(examples)

# 5. Create a DataLoader for the tokenized, grouped dataset
def collate_fn(batch):
    input_ids = torch.stack([item["input_ids"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {"input_ids": input_ids, "labels": labels}

train_loader = DataLoader(lm_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
print("DataLoader ready. Number of batches per epoch (approx):", len(train_loader))


DataLoader ready. Number of batches per epoch (approx): 2372


In [7]:
# 5. Create a DataLoader for the tokenized, grouped dataset
def collate_fn(batch):
    # Since our sequences are fixed length after grouping, we might just stack them.
    # If they weren't fixed, we could use tokenizer.pad to pad to max length in batch.
    input_ids = torch.tensor([example["input_ids"] for example in batch], dtype=torch.long)
    # For language modeling, labels are the input_ids shifted by one, but 
    # Transformers' CausalLM models usually handle that internally if we provide labels = input_ids.
    return {"input_ids": input_ids, "labels": input_ids.clone()}

train_loader = DataLoader(lm_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [8]:
import torch
from torch.utils.data import DataLoader

# Safe collate function that works whether __getitem__ returns lists or tensors
def collate_fn(batch):
    # batch is a list of dicts like {"input_ids": <list or tensor>, "labels": <list or tensor>}
    # convert each field to a tensor if needed, then stack
    input_ids_list = []
    labels_list = []
    for item in batch:
        inp = item["input_ids"]
        lab = item["labels"]
        if not isinstance(inp, torch.Tensor):
            inp = torch.tensor(inp, dtype=torch.long)
        if not isinstance(lab, torch.Tensor):
            lab = torch.tensor(lab, dtype=torch.long)
        input_ids_list.append(inp)
        labels_list.append(lab)
    input_ids = torch.stack(input_ids_list, dim=0)
    labels = torch.stack(labels_list, dim=0)
    return {"input_ids": input_ids, "labels": labels}

# Recreate the DataLoader (adjust batch_size if you want)
batch_size = 8
train_loader = DataLoader(lm_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# Sanity-check: iterate one batch
for batch in train_loader:
    print("Batch input_ids shape:", batch["input_ids"].shape)   # (batch_size, block_size)
    print("Batch labels shape:", batch["labels"].shape)
    print("First example token ids (first 20):", batch["input_ids"][0][:20].tolist())
    print("Decoded (first 200 chars):", tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)[:200])
    break


Batch input_ids shape: torch.Size([8, 128])
Batch labels shape: torch.Size([8, 128])
First example token ids (first 20): [12096, 7403, 1746, 1325, 1131, 1416, 16392, 11, 5, 315, 532, 479, 252, 4596, 11, 59, 155, 787, 6, 1039]
Decoded (first 200 chars):  000 burn injuries receive medical treatment yearly in the United States . They resulted in about 3 @,@ 300 deaths in 2008 . Most burns ( 70 % ) and deaths from burns occur in males . The highest inci
