# BPE Tokenization + N-gram MLP Language Model

This notebook teaches two ideas that sit at the core of modern language modeling:

- **Why tokenization exists** and how **Byte Pair Encoding (BPE)** creates a compact vocabulary.
- **How a tiny neural n-gram language model works**: use the previous *N* tokens to predict the next token with an embedding table and a small MLP.

You will:
1. Load a text dataset (Wikitext-2 if available, otherwise tiny Shakespeare).
2. Train a tiny BPE tokenizer locally in the notebook.
3. Build an n-gram dataset (context, target).
4. Train a small neural language model and inspect predictions.
5. Generate text with temperature sampling.

This is a teaching notebook: it prioritizes clarity and explicit prints over speed.


In [None]:
# Setup and config
import math
import os
import random
import json
from collections import Counter, defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Configuration (toy settings)
CONTEXT_SIZE = 16  # context length (tokens)
EMBED_DIM = 128  # token embedding dimension
HIDDEN_DIM = 256  # hidden layer size
BATCH_SIZE = 128  # batch size
LR = 3e-4  # learning rate
MAX_STEPS = 600  # training steps
EVAL_EVERY = 100  # eval interval (steps)

# Tokenizer settings
TOKENIZER_VOCAB_SIZE = 2000  # target tokenizer vocab size
BPE_TRAIN_CHARS = 200_000  # limit for quick demo training

# Production-ish example values (much more compute required):
# CONTEXT_SIZE = 256
# EMBED_DIM = 512
# HIDDEN_DIM = 2048
# BATCH_SIZE = 512
# MAX_STEPS = 20_000

seed = 42  # random seed
random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # device for tensors
print("Device:", device)

In [None]:
# Load dataset (load wikitext-2-raw-v1)
from datasets import load_dataset

ds = load_dataset("wikitext", "wikitext-2-raw-v1")  # dataset object
dataset_name = "wikitext-2-raw-v1"  # dataset name

print("Dataset:", dataset_name)
print(ds)

# Concatenate text with explicit newlines to preserve some structure
train_text = "\n".join(ds["train"]["text"])  # training text
val_text = "\n".join(ds["validation"]["text"]) if "validation" in ds else ""  # validation text

# Keep a single large text for tokenizer training and LM training
full_text = train_text + "\n" + val_text  # concatenated dataset text
print("Characters in full_text:", len(full_text))
print("Sample snippet:\n", full_text[:400])

## Tokenization and BPE (intuition)

Why tokenization?
- Models operate on **discrete symbols**, but raw text is a stream of characters.
- Word-level tokens create huge vocabularies and break on new words.
- Character-level tokens are robust but long and slow for models.

**BPE** is a compromise:
1. Start with a small vocabulary (often characters).
2. Count the most common adjacent pairs.
3. Merge the most common pair into a new token.
4. Repeat to build subword units that capture frequent patterns.

Result: a compact vocabulary that can represent **any word** while still compressing frequent word pieces.


In [None]:
# Train a tiny BPE tokenizer
# Uses the `tokenizers` library if available, otherwise a minimal pure-Python BPE fallback.

SPECIAL_TOKENS = ["<pad>", "<unk>"]  # special token strings

# --- Simple BPE fallback ---

def _get_stats(vocab):
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return pairs


def _merge_vocab(pair, vocab):
    bigram = " ".join(pair)
    replacement = "".join(pair)
    new_vocab = {}
    for word, freq in vocab.items():
        new_word = word.replace(bigram, replacement)
        new_vocab[new_word] = freq
    return new_vocab


def _train_simple_bpe(text, vocab_size=2000, max_merges=200):
    # Build word frequency table
    words = [w for w in text.split() if w]
    vocab = defaultdict(int)
    for w in words:
        vocab[" ".join(list(w)) + " </w>"] += 1

    # Estimate how many merges we can do
    symbols = set()
    for word in vocab:
        symbols.update(word.split())
    base_vocab = len(symbols) + len(SPECIAL_TOKENS)
    num_merges = max(0, min(max_merges, vocab_size - base_vocab))

    merges = []
    for _ in range(num_merges):
        pairs = _get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = _merge_vocab(best, vocab)
        merges.append(best)

    # Build final symbol set
    symbols = set()
    for word in vocab:
        symbols.update(word.split())

    tokens = list(SPECIAL_TOKENS) + sorted(symbols)
    token_to_id = {t: i for i, t in enumerate(tokens)}
    id_to_token = {i: t for t, i in token_to_id.items()}

    def bpe_encode_word(word):
        symbols = list(word) + ["</w>"]
        for a, b in merges:
            i = 0
            new_symbols = []
            while i < len(symbols):
                if i < len(symbols) - 1 and symbols[i] == a and symbols[i + 1] == b:
                    new_symbols.append(a + b)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols
        return symbols

    def encode_fn(text_in):
        ids = []
        for w in text_in.split():
            for sym in bpe_encode_word(w):
                ids.append(token_to_id.get(sym, token_to_id["<unk>"]))
        return ids

    def decode_fn(ids_in):
        tokens = [id_to_token[i] for i in ids_in]
        text_out = "".join([" " if t == "</w>" else t for t in tokens])
        # Normalize whitespace a bit
        return " ".join(text_out.split())

    model = {
        "merges": merges,
        "token_to_id": token_to_id,
    }
    return model, encode_fn, decode_fn, len(tokens)


def train_tokenizer(text, vocab_size=2000, save_path="tokenizer_bpe.json"):
    text = text[:BPE_TRAIN_CHARS]
    try:
        from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders, normalizers

        tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
        tokenizer.normalizer = normalizers.Sequence([normalizers.NFKC()])
        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        trainer = trainers.BpeTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)

        tokenizer.train_from_iterator([text], trainer=trainer)
        tokenizer.save(save_path)
        tokenizer = Tokenizer.from_file(save_path)

        def encode_fn(s):
            return tokenizer.encode(s).ids

        def decode_fn(ids):
            return tokenizer.decode(ids)

        vocab_size_out = len(tokenizer.get_vocab())
        pad_id = tokenizer.token_to_id("<pad>")
        return {
            "type": "tokenizers",
            "tokenizer": tokenizer,
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
        }
    except Exception as e:
        print("tokenizers not available or failed. Using simple BPE fallback.")
        print("Reason:", repr(e))
        model, encode_fn, decode_fn, vocab_size_out = _train_simple_bpe(text, vocab_size=vocab_size)
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(model, f)
        with open(save_path, "r", encoding="utf-8") as f:
            _ = json.load(f)

        pad_id = model["token_to_id"]["<pad>"]
        return {
            "type": "simple_bpe",
            "tokenizer": model,
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
        }


bpe = train_tokenizer(full_text, vocab_size=TOKENIZER_VOCAB_SIZE)  # BPE tokenizer wrapper
encode = bpe["encode"]  # text-to-ids function
decode = bpe["decode"]  # ids-to-text function

print("Tokenizer type:", bpe["type"])
print("Vocab size:", bpe["vocab_size"])

sample_text = "Tokenization lets models read text like numbers."  # sample text string
encoded = encode(sample_text)  # encoded token ids
print("Sample text:", sample_text)
print("Encoded ids:", encoded[:30])
print("Decoded text:", decode(encoded[:30]))

In [None]:
# Encode full text and build (context, target) pairs

all_ids = encode(full_text)  # all token ids
print("Total tokens:", len(all_ids))

split = int(0.9 * len(all_ids))  # train/val split index
train_ids = all_ids[:split]  # token ids for training split
val_ids = all_ids[split:]  # token ids for validation split

class NGramDataset(Dataset):
    def __init__(self, ids, context_size):
        self.ids = torch.tensor(ids, dtype=torch.long)
        self.context_size = context_size

    def __len__(self):
        return len(self.ids) - self.context_size

    def __getitem__(self, idx):
        context = self.ids[idx:idx + self.context_size]
        target = self.ids[idx + self.context_size]
        return context, target

train_ds = NGramDataset(train_ids, CONTEXT_SIZE)  # training dataset
val_ds = NGramDataset(val_ids, CONTEXT_SIZE)  # validation dataset

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)  # training DataLoader
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)  # validation DataLoader

# Peek at a batch
batch = next(iter(train_loader))  # current batch
ctx, tgt = batch
print("Context batch shape:", ctx.shape)
print("Target batch shape:", tgt.shape)
print("Decoded context example:", decode(ctx[0].tolist()))
print("Decoded target example:", decode([tgt[0].item()]))

## Model: Embedding + MLP

We embed each token into a vector, **flatten the context**, and run it through a 2-layer MLP to predict the next token. This is a neural version of an n-gram model:

- Input: `B x CONTEXT_SIZE` token IDs
- Embedding: `B x CONTEXT_SIZE x EMBED_DIM`
- Flatten: `B x (CONTEXT_SIZE * EMBED_DIM)`
- MLP -> logits over vocabulary


In [None]:
class NGramMLP(nn.Module):
    def __init__(self, vocab_size, context_size, embed_dim, hidden_dim):
        super().__init__()
        self.context_size = context_size
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(context_size * embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, vocab_size),
        )

    def forward(self, x):
        # x: (B, C)
        emb = self.embed(x)  # (B, C, D)
        flat = emb.view(x.size(0), -1)  # (B, C*D)
        logits = self.mlp(flat)  # (B, V)
        return logits

model = NGramMLP(  # model instance
    vocab_size=bpe["vocab_size"],
    context_size=CONTEXT_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
).to(device)

print(model)

## Training loop

We minimize cross-entropy loss between predicted logits and the true next token. We also report **perplexity**, a standard language-model metric:

`perplexity = exp(loss)`


In [None]:
def evaluate(model, loader):
    model.eval()
    losses = []
    with torch.no_grad():
        for ctx, tgt in loader:
            ctx = ctx.to(device)
            tgt = tgt.to(device)
            logits = model(ctx)
            loss = F.cross_entropy(logits, tgt)
            losses.append(loss.item())
    model.train()
    return sum(losses) / max(1, len(losses))


def show_predictions(model, dataset, num_examples=3):
    model.eval()
    for _ in range(num_examples):
        idx = random.randint(0, len(dataset) - 1)
        ctx, tgt = dataset[idx]
        ctx = ctx.unsqueeze(0).to(device)
        logits = model(ctx)
        pred_id = int(torch.argmax(logits, dim=-1).item())

        ctx_text = decode(ctx[0].tolist())
        pred_text = decode([pred_id])
        tgt_text = decode([int(tgt.item())])

        print("Context:", ctx_text)
        print("Pred next:", pred_text, "| Target:", tgt_text)
        print("-")
    model.train()


optimizer = torch.optim.AdamW(model.parameters(), lr=LR)  # optimizer instance

train_iter = iter(train_loader)  # training iterator
# optimization steps
for step in range(1, MAX_STEPS + 1):
    try:
        ctx, tgt = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        ctx, tgt = next(train_iter)

    ctx = ctx.to(device)
    tgt = tgt.to(device)

    logits = model(ctx)
    loss = F.cross_entropy(logits, tgt)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % EVAL_EVERY == 0 or step == 1:
        train_loss = loss.item()
        val_loss = evaluate(model, val_loader)
        print(f"Step {step:4d} | train loss {train_loss:.4f} | val loss {val_loss:.4f} | val ppl {math.exp(val_loss):.2f}")
        show_predictions(model, val_ds, num_examples=2)

## Inference: autoregressive generation

We generate one token at a time, feeding the last `CONTEXT_SIZE` tokens back into the model. Temperature controls randomness:

- **T < 1.0**: sharper, more deterministic
- **T = 1.0**: baseline
- **T > 1.0**: more diverse but noisier


In [None]:
def generate(model, prompt, max_new_tokens=80, temperature=1.0):
    model.eval()
    ids = encode(prompt)
    pad_id = bpe["pad_id"]

    for _ in range(max_new_tokens):
        context = ids[-CONTEXT_SIZE:]
        if len(context) < CONTEXT_SIZE:
            context = [pad_id] * (CONTEXT_SIZE - len(context)) + context

        x = torch.tensor([context], dtype=torch.long).to(device)
        logits = model(x)[0] / max(1e-6, temperature)
        probs = F.softmax(logits, dim=-1)
        next_id = int(torch.multinomial(probs, num_samples=1).item())
        ids.append(next_id)

    model.train()
    return decode(ids)

prompt = "The meaning of tokenization"  # input prompt string
print(generate(model, prompt, max_new_tokens=80, temperature=0.9))

## Scaling notes (toy vs production-ish)

Toy settings (this notebook):
- Short context (16 tokens)
- Small embeddings (128)
- 2-layer MLP
- Tiny vocab (~2k)

Production-ish knobs to turn:
- Longer context (128-1024+)
- Larger embeddings (512-2048)
- Deeper MLP or transformer layers
- Larger BPE vocab (30k-100k)

Each knob increases memory and compute requirements significantly.


## Exercises

1. Increase `CONTEXT_SIZE` and compare validation perplexity.
2. Change the number of BPE merges (or vocab size). How do the tokens change?
3. Add dropout to the MLP and observe overfitting behavior.
4. Replace the MLP with a single linear layer and compare performance.
5. Try different temperatures during generation and evaluate coherence.
