# GRU Language Model over BPE Tokens

This notebook teaches the historical importance of **RNNs** for language modeling and implements a small **GRU-based** autoregressive model on BPE tokens.

You will:
1. Load a text dataset (Wikitext-2 if available, otherwise tiny Shakespeare).
2. Load a BPE tokenizer from Lesson 02 if it exists, otherwise train a small one here.
3. Build fixed-length training sequences with **teacher forcing**.
4. Train a GRU language model with gradient clipping.
5. Generate text with temperature and top-k sampling.

The focus is clarity and interpretability rather than speed.


In [None]:
# Setup and config
import math
import os
import random
import json
from collections import defaultdict
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Configuration (toy settings)
SEQ_LEN = 64  # sequence length (tokens)
BATCH_SIZE = 32  # batch size
EMBED_DIM = 256  # token embedding dimension
HIDDEN_DIM = 256  # hidden layer size
LR = 3e-4  # learning rate
MAX_STEPS = 800  # training steps
EVAL_EVERY = 100  # eval interval (steps)

# Production-ish example values (much more compute required):
# SEQ_LEN = 256
# BATCH_SIZE = 128
# EMBED_DIM = 512
# HIDDEN_DIM = 512
# MAX_STEPS = 20_000

seed = 42
random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:
# Load dataset (load wikitext-2-raw-v1)
from datasets import load_dataset

ds = load_dataset("wikitext", "wikitext-2-raw-v1")
dataset_name = "wikitext-2-raw-v1"

print("Dataset:", dataset_name)
print(ds)

# Concatenate text with explicit newlines to preserve some structure
train_text = "\n".join(ds["train"]["text"])
val_text = "\n".join(ds["validation"]["text"]) if "validation" in ds else ""

full_text = train_text + "\n" + val_text
print("Characters in full_text:", len(full_text))
print("Sample snippet:\n", full_text[:400])

## Tokenization: load or train BPE

If `./tokenizer_bpe.json` exists (from Lesson 02), we load it. Otherwise we train a small BPE tokenizer in this notebook.

Why BPE? It balances word-level and character-level tradeoffs by learning frequent subword units.


In [None]:
SPECIAL_TOKENS = ["<pad>", "<unk>"]  # special token strings
TOKENIZER_PATH = Path("tokenizer_bpe.json")  # tokenizer file path
TOKENIZER_VOCAB_SIZE = 2000  # target tokenizer vocab size
BPE_TRAIN_CHARS = 200_000  # max chars for BPE training

# --- Simple BPE fallback ---

def _get_stats(vocab):
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return pairs


def _merge_vocab(pair, vocab):
    bigram = " ".join(pair)
    replacement = "".join(pair)
    new_vocab = {}
    for word, freq in vocab.items():
        new_word = word.replace(bigram, replacement)
        new_vocab[new_word] = freq
    return new_vocab


def _train_simple_bpe(text, vocab_size=2000, max_merges=200):
    words = [w for w in text.split() if w]
    vocab = defaultdict(int)
    for w in words:
        vocab[" ".join(list(w)) + " </w>"] += 1

    symbols = set()
    for word in vocab:
        symbols.update(word.split())
    base_vocab = len(symbols) + len(SPECIAL_TOKENS)
    num_merges = max(0, min(max_merges, vocab_size - base_vocab))

    merges = []
    for _ in range(num_merges):
        pairs = _get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = _merge_vocab(best, vocab)
        merges.append(best)

    symbols = set()
    for word in vocab:
        symbols.update(word.split())

    tokens = list(SPECIAL_TOKENS) + sorted(symbols)
    token_to_id = {t: i for i, t in enumerate(tokens)}
    id_to_token = {i: t for t, i in token_to_id.items()}

    def bpe_encode_word(word):
        symbols = list(word) + ["</w>"]
        for a, b in merges:
            i = 0
            new_symbols = []
            while i < len(symbols):
                if i < len(symbols) - 1 and symbols[i] == a and symbols[i + 1] == b:
                    new_symbols.append(a + b)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols
        return symbols

    def encode_fn(text_in):
        ids = []
        for w in text_in.split():
            for sym in bpe_encode_word(w):
                ids.append(token_to_id.get(sym, token_to_id["<unk>"]))
        return ids

    def decode_fn(ids_in):
        tokens = [id_to_token[i] for i in ids_in]
        text_out = "".join([" " if t == "</w>" else t for t in tokens])
        return " ".join(text_out.split())

    model = {
        "merges": merges,
        "token_to_id": token_to_id,
    }
    return model, encode_fn, decode_fn, len(tokens)


def _load_simple_bpe(path):
    model = json.loads(Path(path).read_text())
    merges = [tuple(m) for m in model.get("merges", [])]
    token_to_id = model.get("token_to_id", {})
    id_to_token = {i: t for t, i in token_to_id.items()}

    def bpe_encode_word(word):
        symbols = list(word) + ["</w>"]
        for a, b in merges:
            i = 0
            new_symbols = []
            while i < len(symbols):
                if i < len(symbols) - 1 and symbols[i] == a and symbols[i + 1] == b:
                    new_symbols.append(a + b)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols
        return symbols

    def encode_fn(text_in):
        ids = []
        for w in text_in.split():
            for sym in bpe_encode_word(w):
                ids.append(token_to_id.get(sym, token_to_id["<unk>"]))
        return ids

    def decode_fn(ids_in):
        tokens = [id_to_token[i] for i in ids_in]
        text_out = "".join([" " if t == "</w>" else t for t in tokens])
        return " ".join(text_out.split())

    vocab_size_out = len(token_to_id)
    pad_id = token_to_id.get("<pad>", 0)
    return model, encode_fn, decode_fn, vocab_size_out, pad_id


def build_or_load_tokenizer(text, vocab_size=2000):
    text = text[:BPE_TRAIN_CHARS]
    if TOKENIZER_PATH.exists():
        # Try tokenizers JSON first
        try:
            from tokenizers import Tokenizer
            tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))

            def encode_fn(s):
                return tokenizer.encode(s).ids

            def decode_fn(ids):
                return tokenizer.decode(ids)

            vocab_size_out = len(tokenizer.get_vocab())
            pad_id = tokenizer.token_to_id("<pad>")
            return {
                "type": "tokenizers",
                "tokenizer": tokenizer,
                "encode": encode_fn,
                "decode": decode_fn,
                "vocab_size": vocab_size_out,
                "pad_id": pad_id,
                "path": str(TOKENIZER_PATH),
            }
        except Exception as e:
            print("Tokenizer file exists but tokenizers load failed. Falling back.")
            print("Reason:", repr(e))
            try:
                model, encode_fn, decode_fn, vocab_size_out, pad_id = _load_simple_bpe(TOKENIZER_PATH)
                return {
                    "type": "simple_bpe",
                    "tokenizer": model,
                    "encode": encode_fn,
                    "decode": decode_fn,
                    "vocab_size": vocab_size_out,
                    "pad_id": pad_id,
                    "path": str(TOKENIZER_PATH),
                }
            except Exception as e2:
                print("Simple BPE load failed. Training new tokenizer.")
                print("Reason:", repr(e2))

    # Train new tokenizer
    try:
        from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers

        tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
        tokenizer.normalizer = normalizers.Sequence([normalizers.NFKC()])
        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        trainer = trainers.BpeTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)

        tokenizer.train_from_iterator([text], trainer=trainer)
        tokenizer.save(str(TOKENIZER_PATH))

        def encode_fn(s):
            return tokenizer.encode(s).ids

        def decode_fn(ids):
            return tokenizer.decode(ids)

        vocab_size_out = len(tokenizer.get_vocab())
        pad_id = tokenizer.token_to_id("<pad>")
        return {
            "type": "tokenizers",
            "tokenizer": tokenizer,
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
            "path": str(TOKENIZER_PATH),
        }
    except Exception as e:
        print("tokenizers not available or failed. Using simple BPE.")
        print("Reason:", repr(e))
        model, encode_fn, decode_fn, vocab_size_out = _train_simple_bpe(text, vocab_size=vocab_size)
        TOKENIZER_PATH.write_text(json.dumps(model))
        pad_id = model["token_to_id"]["<pad>"]
        return {
            "type": "simple_bpe",
            "tokenizer": model,
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
            "path": str(TOKENIZER_PATH),
        }


bpe = build_or_load_tokenizer(full_text, vocab_size=TOKENIZER_VOCAB_SIZE)
encode = bpe["encode"]
decode = bpe["decode"]

print("Tokenizer type:", bpe["type"])
print("Tokenizer path:", bpe["path"])
print("Vocab size:", bpe["vocab_size"])

sample_text = "RNNs process tokens sequentially."
encoded = encode(sample_text)
print("Sample text:", sample_text)
print("Encoded ids:", encoded[:20])
print("Decoded text:", decode(encoded[:20]))

## Prepare training sequences (teacher forcing)

We build fixed-length sequences of token IDs. For each sequence, the model **sees** tokens `x` and is trained to **predict** the next tokens `y`.

This is called **teacher forcing**: the correct previous tokens are fed into the model while predicting the next step.


In [None]:
# Encode text and build (input, target) sequences
all_ids = encode(full_text)
print("Total tokens:", len(all_ids))

split = int(0.9 * len(all_ids))
train_ids = all_ids[:split]
val_ids = all_ids[split:]

class SequenceDataset(Dataset):
    def __init__(self, ids, seq_len):
        self.ids = torch.tensor(ids, dtype=torch.long)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.ids) - self.seq_len - 1

    def __getitem__(self, idx):
        chunk = self.ids[idx:idx + self.seq_len + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

train_ds = SequenceDataset(train_ids, SEQ_LEN)
val_ds = SequenceDataset(val_ids, SEQ_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

x0, y0 = next(iter(train_loader))
print("Input batch shape:", x0.shape)
print("Target batch shape:", y0.shape)
print("Decoded input sample:", decode(x0[0][:20].tolist()))
print("Decoded target sample:", decode(y0[0][:20].tolist()))

## Model: Embedding + GRU + Linear

An RNN processes sequences step-by-step, maintaining a **hidden state** that summarizes the past.
A **GRU** is a gated RNN that improves training stability compared to a vanilla RNN.

We use:
- Embedding layer
- GRU layer
- Linear projection to vocabulary logits

Output shape: `(B, T, V)`


In [None]:
class GRULM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # x: (B, T)
        emb = self.embed(x)  # (B, T, D)
        out, hidden = self.gru(emb, hidden)  # out: (B, T, H)
        logits = self.fc(out)  # (B, T, V)
        return logits, hidden

model = GRULM(
    vocab_size=bpe["vocab_size"],
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
).to(device)

print(model)

## Training loop with gradient clipping

We flatten `(B, T, V)` logits to `(B*T, V)` and compute cross-entropy against `(B*T)` targets.
Gradient clipping (e.g., `clip_norm=1.0`) helps prevent exploding gradients in RNNs.


In [None]:
def evaluate(model, loader):
    model.eval()
    losses = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits, _ = model(x)
            vocab_size = logits.size(-1)
            loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
            losses.append(loss.item())
    model.train()
    return sum(losses) / max(1, len(losses))


def show_predictions(model, batch_x, batch_y, num_tokens=12):
    model.eval()
    with torch.no_grad():
        logits, _ = model(batch_x.to(device))
        pred_ids = torch.argmax(logits, dim=-1).cpu()

        print("Input:", decode(batch_x[0][:num_tokens].tolist()))
        print("Pred :", decode(pred_ids[0][:num_tokens].tolist()))
        print("Tgt  :", decode(batch_y[0][:num_tokens].tolist()))
    model.train()


optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

train_iter = iter(train_loader)
for step in range(1, MAX_STEPS + 1):
    try:
        x, y = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        x, y = next(train_iter)

    x = x.to(device)
    y = y.to(device)

    logits, _ = model(x)
    vocab_size = logits.size(-1)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % EVAL_EVERY == 0 or step == 1:
        val_loss = evaluate(model, val_loader)
        print(f"Step {step:4d} | train loss {loss.item():.4f} | val loss {val_loss:.4f} | val ppl {math.exp(val_loss):.2f}")
        show_predictions(model, x.cpu(), y.cpu())

## Inference: autoregressive generation

We generate tokens one by one while carrying the GRU hidden state forward.
We can control randomness with **temperature** and **top-k sampling**.


In [None]:
def sample_next_token(logits, temperature=1.0, top_k=None):
    logits = logits / max(1e-6, temperature)
    if top_k is not None and top_k > 0:
        v, idx = torch.topk(logits, top_k)
        mask = torch.full_like(logits, float("-inf"))
        mask.scatter_(0, idx, v)
        logits = mask
    probs = F.softmax(logits, dim=-1)
    return int(torch.multinomial(probs, num_samples=1).item())


def generate(model, prompt, max_new_tokens=80, temperature=1.0, top_k=50):
    model.eval()
    ids = encode(prompt)
    pad_id = bpe["pad_id"]

    hidden = None
    if not ids:
        ids = [pad_id]

    # Warm up the hidden state with the prompt
    for token_id in ids[:-1]:
        x = torch.tensor([[token_id]], dtype=torch.long).to(device)
        _, hidden = model(x, hidden)

    cur = ids[-1]
    for _ in range(max_new_tokens):
        x = torch.tensor([[cur]], dtype=torch.long).to(device)
        logits, hidden = model(x, hidden)
        next_id = sample_next_token(logits[0, -1], temperature=temperature, top_k=top_k)
        ids.append(next_id)
        cur = next_id

    model.train()
    return decode(ids)

prompt = "The meaning of sequence"
print(generate(model, prompt, max_new_tokens=80, temperature=0.9, top_k=40))

## Scaling notes: GRU vs Transformer

RNNs (including GRUs) were historically important because they can model sequences with a hidden state and were efficient on small data.

However, **Transformers** dominate today because:
- They process sequences in parallel (faster on GPUs).
- Self-attention captures long-range dependencies better.
- Scaling tends to improve performance more predictably.

GRUs are still useful for small models, streaming tasks, or environments where memory is tight.


## Exercises

1. Increase `SEQ_LEN` and observe changes in validation perplexity.
2. Compare `temperature=0.7` vs `1.2` during generation.
3. Set `top_k=None` and see how generation changes.
4. Increase `HIDDEN_DIM` and compare training stability.
5. Try a 2-layer GRU and compare perplexity vs runtime.
