# Lesson 05: Modern GPT Blocks (RMSNorm, SwiGLU, RoPE)

In this lesson we build a decoder-only Transformer that is closer to modern LLMs:

- RMSNorm instead of LayerNorm
- SwiGLU feedforward instead of GELU
- Rotary positional embeddings (RoPE) instead of learned absolute positions

We train on wikitext-2-raw-v1 (with a small fallback), and compare to Lesson 04.

Comparison to Lesson 04:
- RMSNorm is simpler and can be a bit faster, but it does not center activations.
- SwiGLU adds a gated path and often improves quality, but it adds parameters and compute.
- RoPE encodes relative positions and often extrapolates better, but it is a bit more complex.


In [None]:
# Setup and config
import math
import os
import random
import json
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F

# Top-level configuration (toy defaults)
BLOCK_SIZE = 128  # context length (tokens)
BATCH_SIZE = 32  # batch size
D_MODEL = 256  # model width
N_HEADS = 4  # attention heads
N_LAYERS = 4  # transformer layers
FFN_MULT = 4  # FFN expansion multiplier
DROPOUT = 0.1  # dropout prob

LR = 3e-4  # learning rate
MAX_STEPS = 3000  # training steps
EVAL_EVERY = 200  # eval interval (steps)
EVAL_ITERS = 50  # eval batches
SAMPLE_EVERY = 500  # sample interval (steps)

TOKENIZER_VOCAB_SIZE = 2000  # target tokenizer vocab size
BPE_TRAIN_CHARS = 200_000  # max chars for BPE training

ROPE_BASE = 10000  # RoPE base
NORM_EPS = 1e-8  # norm epsilon
CLIP_GRAD_NORM = 1.0  # set to None to disable

# Larger, slower settings (commented)
# BLOCK_SIZE = 1024
# BATCH_SIZE = 64
# D_MODEL = 768
# N_HEADS = 12
# N_LAYERS = 12
# FFN_MULT = 4
# LR = 2e-4
# MAX_STEPS = 20000

seed = 42  # random seed
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # device for tensors
print("Device:", device)

## Dataset + tokenizer

In [None]:
from datasets import load_dataset


def load_text_dataset():
    from datasets import load_dataset
    ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    name = "wikitext-2-raw-v1"
    train_text = "\n".join(ds["train"]["text"])
    val_text = "\n".join(ds["validation"]["text"]) if "validation" in ds else ""
    return train_text, val_text, name

train_text, val_text, dataset_name = load_text_dataset()
print("Dataset:", dataset_name)
print("Train chars:", len(train_text), "Val chars:", len(val_text))

## Tokenizer: load existing BPE or train a small one

In [None]:
SPECIAL_TOKENS = ["<pad>", "<unk>"]  # special token strings
TOKENIZER_PATH = "tokenizer_bpe.json"  # tokenizer file path


def _get_stats(vocab):
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return pairs


def _merge_vocab(pair, vocab):
    bigram = " ".join(pair)
    replacement = "".join(pair)
    new_vocab = {}
    for word, freq in vocab.items():
        new_word = word.replace(bigram, replacement)
        new_vocab[new_word] = freq
    return new_vocab


def _build_simple_bpe_helpers(model):
    merges = [tuple(p) for p in model["merges"]]
    token_to_id = model["token_to_id"]
    id_to_token = {i: t for t, i in token_to_id.items()}

    def bpe_encode_word(word):
        # Start from characters, then apply merges greedily
        symbols = list(word) + ["</w>"]
        for a, b in merges:
            i = 0
            new_symbols = []
            while i < len(symbols):
                if i < len(symbols) - 1 and symbols[i] == a and symbols[i + 1] == b:
                    new_symbols.append(a + b)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols
        return symbols

    def encode_fn(text_in):
        ids = []
        for w in text_in.split():
            for sym in bpe_encode_word(w):
                ids.append(token_to_id.get(sym, token_to_id["<unk>"]))
        return ids

    def decode_fn(ids_in):
        tokens = [id_to_token[i] for i in ids_in]
        text_out = "".join([" " if t == "</w>" else t for t in tokens])
        return " ".join(text_out.split())

    return encode_fn, decode_fn


def _train_simple_bpe(text, vocab_size=2000, max_merges=200):
    words = [w for w in text.split() if w]
    vocab = defaultdict(int)
    for w in words:
        vocab[" ".join(list(w)) + " </w>"] += 1

    symbols = set()
    for word in vocab:
        symbols.update(word.split())
    base_vocab = len(symbols) + len(SPECIAL_TOKENS)
    num_merges = max(0, min(max_merges, vocab_size - base_vocab))

    merges = []
    for _ in range(num_merges):
        pairs = _get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = _merge_vocab(best, vocab)
        merges.append(best)

    symbols = set()
    for word in vocab:
        symbols.update(word.split())

    tokens = list(SPECIAL_TOKENS) + sorted(symbols)
    token_to_id = {t: i for i, t in enumerate(tokens)}

    model = {
        "merges": merges,
        "token_to_id": token_to_id,
    }

    encode_fn, decode_fn = _build_simple_bpe_helpers(model)
    pad_id = token_to_id["<pad>"]
    return model, encode_fn, decode_fn, len(tokens), pad_id


def load_tokenizer(path):
    if not os.path.exists(path):
        return None
    # Try Hugging Face tokenizers first
    try:
        from tokenizers import Tokenizer

        tokenizer = Tokenizer.from_file(path)

        def encode_fn(s):
            return tokenizer.encode(s).ids

        def decode_fn(ids):
            return tokenizer.decode(ids)

        vocab_size = len(tokenizer.get_vocab())
        pad_id = tokenizer.token_to_id("<pad>")
        if pad_id is None:
            pad_id = 0
        return {
            "type": "tokenizers",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size,
            "pad_id": pad_id,
        }
    except Exception:
        pass

    # Fallback to simple BPE json
    try:
        with open(path, "r", encoding="utf-8") as f:
            model = json.load(f)
        if "merges" in model and "token_to_id" in model:
            encode_fn, decode_fn = _build_simple_bpe_helpers(model)
            pad_id = model["token_to_id"].get("<pad>", 0)
            return {
                "type": "simple_bpe",
                "encode": encode_fn,
                "decode": decode_fn,
                "vocab_size": len(model["token_to_id"]),
                "pad_id": pad_id,
            }
    except Exception:
        pass

    return None


def train_tokenizer(text, vocab_size=2000, save_path="tokenizer_bpe.json"):
    text = text[:BPE_TRAIN_CHARS]
    try:
        from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers

        tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
        tokenizer.normalizer = normalizers.Sequence([normalizers.NFKC()])
        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        trainer = trainers.BpeTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)

        tokenizer.train_from_iterator([text], trainer=trainer)
        tokenizer.save(save_path)

        def encode_fn(s):
            return tokenizer.encode(s).ids

        def decode_fn(ids):
            return tokenizer.decode(ids)

        vocab_size_out = len(tokenizer.get_vocab())
        pad_id = tokenizer.token_to_id("<pad>")
        if pad_id is None:
            pad_id = 0
        return {
            "type": "tokenizers",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
        }
    except Exception as e:
        print("tokenizers not available or failed. Using simple BPE fallback.")
        print("Reason:", repr(e))
        model, encode_fn, decode_fn, vocab_size_out, pad_id = _train_simple_bpe(
            text,
            vocab_size=vocab_size,
        )
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(model, f)
        return {
            "type": "simple_bpe",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
        }


tok = load_tokenizer(TOKENIZER_PATH)  # tokenizer handle
if tok is None:
    print("No tokenizer found. Training a small BPE tokenizer...")
    tok = train_tokenizer(train_text + "\n" + val_text, vocab_size=TOKENIZER_VOCAB_SIZE)
else:
    print("Loaded tokenizer from", TOKENIZER_PATH)

encode = tok["encode"]  # text-to-ids function
decode = tok["decode"]  # ids-to-text function

print("Tokenizer type:", tok["type"])
print("Vocab size:", tok["vocab_size"])

## Numericalize the dataset

In [None]:
train_ids = encode(train_text)  # token ids for training split
val_ids = encode(val_text)  # token ids for validation split

train_data = torch.tensor(train_ids, dtype=torch.long)  # training split data
val_data = torch.tensor(val_ids, dtype=torch.long)  # validation split data

print("Train tokens:", len(train_data))
print("Val tokens:", len(val_data))

## Batch sampling for autoregressive language modeling

In [None]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    if len(data) <= BLOCK_SIZE + 1:
        raise ValueError("Dataset too small for the chosen block size")
    # Random starting positions for each batch element
    ix = torch.randint(0, len(data) - BLOCK_SIZE - 1, (BATCH_SIZE,))
    x = torch.stack([data[i : i + BLOCK_SIZE] for i in ix])
    y = torch.stack([data[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
    return x.to(device), y.to(device)

## RMSNorm

RMSNorm normalizes by the root mean square (RMS) of the features. It does not
subtract the mean like LayerNorm. The formula is:

- rms = sqrt(mean(x^2) + eps)
- y = (x / rms) * g

where g is a learned scale vector. This is simpler and often faster while
keeping stable activations in deep networks.


In [None]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        # x: (batch, time, channels)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        x_norm = x / rms
        return x_norm * self.weight

## SwiGLU feedforward

SwiGLU uses a gated activation: one linear path provides values, another
provides gates, and the gates are squashed with SiLU (Swish). In symbols:

- a = W1 x
- b = W2 x
- y = W3 (silu(a) * b)

This often improves quality compared to ReLU/GELU, but adds some compute.


In [None]:
class SwiGLU(nn.Module):
    def __init__(self, d_model, hidden_dim, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, hidden_dim)
        self.w2 = nn.Linear(d_model, hidden_dim)
        self.w3 = nn.Linear(hidden_dim, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Gate and value paths, then project back
        gated = F.silu(self.w1(x)) * self.w2(x)
        return self.dropout(self.w3(gated))

## RoPE (rotary positional embeddings)

RoPE rotates query and key vectors in each attention head by a position-dependent
angle. This injects relative position information directly into attention.

We build a cosine/sine cache for all positions, then rotate the first half of
head_dim against the second half. The shapes below are the common multi-head
layout: (batch, heads, time, head_dim).


In [None]:
def build_rope_cache(seq_len, head_dim, device, base=10000):
    # head_dim must be even so we can split into two halves for rotation
    if head_dim % 2 != 0:
        raise ValueError("head_dim must be even for RoPE")

    half_dim = head_dim // 2
    inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, device=device).float() / half_dim))
    positions = torch.arange(seq_len, device=device).float()
    # Outer product -> (seq_len, half_dim)
    freqs = torch.einsum("i,j->ij", positions, inv_freq)

    # Duplicate for cos/sin to match head_dim
    emb = torch.cat([freqs, freqs], dim=-1)
    cos = emb.cos()[None, None, :, :]  # (1, 1, T, head_dim)
    sin = emb.sin()[None, None, :, :]
    return cos, sin


def apply_rope(q, k, cos, sin, debug=False):
    # q, k: (B, H, T, D) where D is head_dim
    if debug:
        print("q shape:", q.shape)
        print("k shape:", k.shape)
        print("cos shape:", cos.shape)
        print("sin shape:", sin.shape)

    D = q.size(-1)  # model width (alias)
    half = D // 2
    q1, q2 = q[..., :half], q[..., half:]
    k1, k2 = k[..., :half], k[..., half:]

    # Rotate: (x1, x2) -> (x1*cos - x2*sin, x1*sin + x2*cos)
    q_rot = torch.cat([q1 * cos[..., :half] - q2 * sin[..., :half],
                       q1 * sin[..., :half] + q2 * cos[..., :half]], dim=-1)
    k_rot = torch.cat([k1 * cos[..., :half] - k2 * sin[..., :half],
                       k1 * sin[..., :half] + k2 * cos[..., :half]], dim=-1)

    if debug:
        print("q_rot shape:", q_rot.shape)
        print("k_rot shape:", k_rot.shape)

    return q_rot, k_rot

In [None]:
# Quick shape sanity check for RoPE
B, H, T, D = 2, 4, 8, 64
dummy_q = torch.randn(B, H, T, D)  # dummy query tensor for shape check
dummy_k = torch.randn(B, H, T, D)  # dummy key tensor for shape check
cos, sin = build_rope_cache(T, D, device=dummy_q.device, base=ROPE_BASE)
_ = apply_rope(dummy_q, dummy_k, cos, sin, debug=True)  # unused placeholder


## Full model definition

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, block_size, dropout, rope_base=10000):
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError("d_model must be divisible by n_heads")
        head_dim = d_model // n_heads
        if head_dim % 2 != 0:
            raise ValueError("head_dim must be even for RoPE")

        self.n_heads = n_heads
        self.head_dim = head_dim
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

        # Causal mask keeps attention from looking ahead
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)))

        cos, sin = build_rope_cache(block_size, head_dim, device=torch.device("cpu"), base=rope_base)
        self.register_buffer("cos_cached", cos)
        self.register_buffer("sin_cached", sin)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x)
        # Split into q, k, v and then into heads
        qkv = qkv.view(B, T, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q = q.transpose(1, 2)  # (B, H, T, D)
        k = k.transpose(1, 2)  # (B, H, T, D)
        v = v.transpose(1, 2)  # (B, H, T, D)

        cos = self.cos_cached[:, :, :T, :]
        sin = self.sin_cached[:, :, :T, :]
        q, k = apply_rope(q, k, cos, sin)

        # Attention scores
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = att.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)

        # Weighted sum of values
        out = att @ v
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.proj(out)
        return self.dropout(out)


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, block_size, ffn_mult, dropout, rope_base=10000):
        super().__init__()
        self.norm1 = RMSNorm(d_model, eps=NORM_EPS)
        self.norm2 = RMSNorm(d_model, eps=NORM_EPS)
        self.attn = CausalSelfAttention(d_model, n_heads, block_size, dropout, rope_base=rope_base)
        self.ffn = SwiGLU(d_model, ffn_mult * d_model, dropout=dropout)

    def forward(self, x):
        # Pre-norm residual blocks are stable for deep Transformers
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class GPT(nn.Module):
    def __init__(
        self,
        vocab_size,
        block_size,
        d_model,
        n_heads,
        n_layers,
        ffn_mult,
        dropout,
        rope_base=10000,
    ):
        super().__init__()
        self.block_size = block_size
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.drop = nn.Dropout(dropout)

        blocks = []
        for _ in range(n_layers):
            blocks.append(
                TransformerBlock(
                    d_model,
                    n_heads,
                    block_size,
                    ffn_mult,
                    dropout,
                    rope_base=rope_base,
                )
            )
        self.blocks = nn.ModuleList(blocks)
        self.norm_f = RMSNorm(d_model, eps=NORM_EPS)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        if T > self.block_size:
            raise ValueError("Sequence length exceeds block size")

        tok = self.token_emb(idx)
        x = self.drop(tok)

        for block in self.blocks:
            x = block(x)

        x = self.norm_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            logits_flat = logits.view(-1, logits.size(-1))
            targets_flat = targets.view(-1)
            loss = F.cross_entropy(logits_flat, targets_flat)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        was_training = self.training
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size :]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-6)

            if top_k is not None:
                k = min(top_k, logits.size(-1))
                v, _ = torch.topk(logits, k)
                cutoff = v[:, [-1]]
                logits = logits.masked_fill(logits < cutoff, float("-inf"))

            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)

        if was_training:
            self.train()
        return idx


vocab_size = tok["vocab_size"]  # vocabulary size
model = GPT(  # model instance
    vocab_size=vocab_size,
    block_size=BLOCK_SIZE,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    ffn_mult=FFN_MULT,
    dropout=DROPOUT,
    rope_base=ROPE_BASE,
).to(device)

param_count = sum(p.numel() for p in model.parameters())  # parameter count
print(f"Model parameters: {param_count/1e6:.2f}M")

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)  # optimizer instance


## Training loop

In [None]:
@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ["train", "val"]:
        losses = []
        for _ in range(EVAL_ITERS):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        avg = sum(losses) / len(losses)
        out[split] = {"loss": avg, "ppl": math.exp(avg)}
    model.train()
    return out


def sample_generation(prompt, max_new_tokens=80):
    idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
    out = model.generate(idx, max_new_tokens=max_new_tokens, temperature=0.9, top_k=40)
    return decode(out[0].tolist())


print("Initial eval:")
metrics = estimate_loss()  # evaluation metrics dict
print(
    f"train loss {metrics['train']['loss']:.3f}, train ppl {metrics['train']['ppl']:.2f} | "
    f"val loss {metrics['val']['loss']:.3f}, val ppl {metrics['val']['ppl']:.2f}"
)

# optimization steps
for step in range(1, MAX_STEPS + 1):
    xb, yb = get_batch("train")
    _, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    if CLIP_GRAD_NORM is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD_NORM)
    optimizer.step()

    if step % EVAL_EVERY == 0:
        metrics = estimate_loss()
        print(
            f"step {step}: "
            f"train loss {metrics['train']['loss']:.3f}, train ppl {metrics['train']['ppl']:.2f} | "
            f"val loss {metrics['val']['loss']:.3f}, val ppl {metrics['val']['ppl']:.2f}"
        )

    if step % SAMPLE_EVERY == 0:
        print("Sample generation:")
        print(sample_generation("The meaning of life is"))
        print("-" * 60)

## Inference: prompted generation

In [None]:
prompts = [  # list of prompts
    "The meaning of life is",
    "Once upon a time",
    "In the middle of the night",
]

for prompt in prompts:
    idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
    out = model.generate(idx, max_new_tokens=80, temperature=0.9, top_k=40)
    print("PROMPT:", prompt)
    print(decode(out[0].tolist()))
    print("-" * 60)

## Scaling notes

Toy defaults in this notebook:
- D_MODEL=256, N_HEADS=4, N_LAYERS=4, BLOCK_SIZE=128, BATCH_SIZE=32
- LR=3e-4, MAX_STEPS around 3000

Production-ish suggestions (much slower):
- D_MODEL=768 or 1024, N_HEADS=12 or 16, N_LAYERS=12 or 24
- BLOCK_SIZE=512 or 1024, BATCH_SIZE as large as memory allows
- More training steps and a learning rate schedule

When scaling up, consider gradient accumulation, mixed precision, and
periodic evaluation on a held-out validation set.


## Exercises

1) Replace RoPE with learned absolute positional embeddings and compare results.
2) Try GELU instead of SwiGLU and measure training speed and loss.
3) Vary the RoPE base (ROPE_BASE) and see how it affects long-range behavior.
4) Add dropout to attention weights and compare with/without.
5) Implement weight tying between token embeddings and the LM head.
