# Lesson 04: Transformer Decoder from Scratch (GPT-style)

In this notebook you will build a minimal GPT-style decoder Transformer using PyTorch, step by step.

You will learn:
- token and learned positional embeddings
- causal self-attention and why masking matters
- multi-head attention and head mixing
- feedforward blocks with residual connections and LayerNorm
- how to train and sample from an autoregressive language model


In [None]:
# Setup and config
import math
import os
import random
import json
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F

# Top-level configuration (toy defaults)
BLOCK_SIZE = 128  # context length (tokens)
BATCH_SIZE = 32  # batch size
D_MODEL = 256  # model width
N_HEADS = 4  # attention heads
N_LAYERS = 4  # transformer layers
FFN_MULT = 4  # FFN expansion multiplier
DROPOUT = 0.1  # dropout prob

LR = 3e-4  # learning rate
MAX_STEPS = 3000  # training steps
EVAL_EVERY = 200  # eval interval (steps)
EVAL_ITERS = 50  # eval batches

TOKENIZER_VOCAB_SIZE = 2000  # target tokenizer vocab size
BPE_TRAIN_CHARS = 200_000  # max chars for BPE training

# Larger, slower settings (commented)
# BLOCK_SIZE = 1024
# BATCH_SIZE = 64
# D_MODEL = 768
# N_HEADS = 12
# N_LAYERS = 12
# FFN_MULT = 4
# LR = 2e-4
# MAX_STEPS = 20000

seed = 42  # random seed
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # device for tensors
print("Device:", device)

## Load dataset (wikitext-2-raw-v1 preferred)

We try Wikitext-2 via Hugging Face `datasets`. If that fails (offline or missing),
we fall back to wikitext-2-raw-v1, and finally a local fallback so the notebook still runs.


In [None]:
from datasets import load_dataset


def load_text_dataset():
    from datasets import load_dataset
    ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    name = "wikitext-2-raw-v1"
    train_text = "\n".join(ds["train"]["text"])
    val_text = "\n".join(ds["validation"]["text"]) if "validation" in ds else ""
    return train_text, val_text, name

train_text, val_text, dataset_name = load_text_dataset()
print("Dataset:", dataset_name)
print("Train chars:", len(train_text), "Val chars:", len(val_text))

## Tokenizer: load existing BPE or train a small one

If `./tokenizer_bpe.json` exists, we load it. Otherwise we train a small BPE tokenizer and save it.


In [None]:
SPECIAL_TOKENS = ["<pad>", "<unk>"]  # special token strings
TOKENIZER_PATH = "tokenizer_bpe.json"  # tokenizer file path


def _get_stats(vocab):
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return pairs


def _merge_vocab(pair, vocab):
    bigram = " ".join(pair)
    replacement = "".join(pair)
    new_vocab = {}
    for word, freq in vocab.items():
        new_word = word.replace(bigram, replacement)
        new_vocab[new_word] = freq
    return new_vocab


def _build_simple_bpe_helpers(model):
    merges = [tuple(p) for p in model["merges"]]
    token_to_id = model["token_to_id"]
    id_to_token = {i: t for t, i in token_to_id.items()}

    def bpe_encode_word(word):
        # Start from characters, then apply merges greedily
        symbols = list(word) + ["</w>"]
        for a, b in merges:
            i = 0
            new_symbols = []
            while i < len(symbols):
                if i < len(symbols) - 1 and symbols[i] == a and symbols[i + 1] == b:
                    new_symbols.append(a + b)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols
        return symbols

    def encode_fn(text_in):
        ids = []
        for w in text_in.split():
            for sym in bpe_encode_word(w):
                ids.append(token_to_id.get(sym, token_to_id["<unk>"]))
        return ids

    def decode_fn(ids_in):
        tokens = [id_to_token[i] for i in ids_in]
        text_out = "".join([" " if t == "</w>" else t for t in tokens])
        return " ".join(text_out.split())

    return encode_fn, decode_fn


def _train_simple_bpe(text, vocab_size=2000, max_merges=200):
    words = [w for w in text.split() if w]
    vocab = defaultdict(int)
    for w in words:
        vocab[" ".join(list(w)) + " </w>"] += 1

    symbols = set()
    for word in vocab:
        symbols.update(word.split())
    base_vocab = len(symbols) + len(SPECIAL_TOKENS)
    num_merges = max(0, min(max_merges, vocab_size - base_vocab))

    merges = []
    for _ in range(num_merges):
        pairs = _get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = _merge_vocab(best, vocab)
        merges.append(best)

    symbols = set()
    for word in vocab:
        symbols.update(word.split())

    tokens = list(SPECIAL_TOKENS) + sorted(symbols)
    token_to_id = {t: i for i, t in enumerate(tokens)}

    model = {
        "merges": merges,
        "token_to_id": token_to_id,
    }

    encode_fn, decode_fn = _build_simple_bpe_helpers(model)
    pad_id = token_to_id["<pad>"]
    return model, encode_fn, decode_fn, len(tokens), pad_id


def load_tokenizer(path):
    if not os.path.exists(path):
        return None
    # Try Hugging Face tokenizers first
    try:
        from tokenizers import Tokenizer

        tokenizer = Tokenizer.from_file(path)

        def encode_fn(s):
            return tokenizer.encode(s).ids

        def decode_fn(ids):
            return tokenizer.decode(ids)

        vocab_size = len(tokenizer.get_vocab())
        pad_id = tokenizer.token_to_id("<pad>")
        if pad_id is None:
            pad_id = 0
        return {
            "type": "tokenizers",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size,
            "pad_id": pad_id,
        }
    except Exception:
        pass

    # Fallback to simple BPE json
    try:
        with open(path, "r", encoding="utf-8") as f:
            model = json.load(f)
        if "merges" in model and "token_to_id" in model:
            encode_fn, decode_fn = _build_simple_bpe_helpers(model)
            pad_id = model["token_to_id"].get("<pad>", 0)
            return {
                "type": "simple_bpe",
                "encode": encode_fn,
                "decode": decode_fn,
                "vocab_size": len(model["token_to_id"]),
                "pad_id": pad_id,
            }
    except Exception:
        pass

    return None


def train_tokenizer(text, vocab_size=2000, save_path="tokenizer_bpe.json"):
    text = text[:BPE_TRAIN_CHARS]
    try:
        from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers

        tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
        tokenizer.normalizer = normalizers.Sequence([normalizers.NFKC()])
        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        trainer = trainers.BpeTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)

        tokenizer.train_from_iterator([text], trainer=trainer)
        tokenizer.save(save_path)

        def encode_fn(s):
            return tokenizer.encode(s).ids

        def decode_fn(ids):
            return tokenizer.decode(ids)

        vocab_size_out = len(tokenizer.get_vocab())
        pad_id = tokenizer.token_to_id("<pad>")
        if pad_id is None:
            pad_id = 0
        return {
            "type": "tokenizers",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
        }
    except Exception as e:
        print("tokenizers not available or failed. Using simple BPE fallback.")
        print("Reason:", repr(e))
        model, encode_fn, decode_fn, vocab_size_out, pad_id = _train_simple_bpe(
            text,
            vocab_size=vocab_size,
        )
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(model, f)
        return {
            "type": "simple_bpe",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
        }


tok = load_tokenizer(TOKENIZER_PATH)  # tokenizer handle
if tok is None:
    print("No tokenizer found. Training a small BPE tokenizer...")
    tok = train_tokenizer(train_text + "\n" + val_text, vocab_size=TOKENIZER_VOCAB_SIZE)
else:
    print("Loaded tokenizer from", TOKENIZER_PATH)

encode = tok["encode"]  # text-to-ids function
decode = tok["decode"]  # ids-to-text function

print("Tokenizer type:", tok["type"])
print("Vocab size:", tok["vocab_size"])

## Numericalize the dataset

We convert text to token ids once, then sample contiguous blocks for training.


In [None]:
train_ids = torch.tensor(encode(train_text), dtype=torch.long)  # token ids for training split
val_ids = torch.tensor(encode(val_text), dtype=torch.long)  # token ids for validation split

print("Train tokens:", train_ids.numel())
print("Val tokens:", val_ids.numel())

## Batch sampling for autoregressive language modeling

We sample random contiguous windows of length `BLOCK_SIZE` and predict the next token.


In [None]:
def get_batch(split):
    data = train_ids if split == "train" else val_ids
    # Random starting indices for each sequence in the batch
    idx = torch.randint(0, len(data) - BLOCK_SIZE - 1, (BATCH_SIZE,))
    x = torch.stack([data[i : i + BLOCK_SIZE] for i in idx])
    y = torch.stack([data[i + 1 : i + BLOCK_SIZE + 1] for i in idx])
    return x.to(device), y.to(device)

## LayerNorm intuition

LayerNorm normalizes each token's feature vector (across channels), which stabilizes training
when stacked with residual connections.


In [None]:
class LayerNorm(nn.Module):
    def __init__(self, n_embd, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(n_embd))
        self.bias = nn.Parameter(torch.zeros(n_embd))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        return self.weight * x_hat + self.bias

## Self-attention: Q, K, V, masking, and weighted sum

Each token projects to queries, keys, and values. We compute attention scores between tokens,
mask future positions (causal mask), then take a weighted sum of values.


In [None]:
class SelfAttentionHead(nn.Module):
    def __init__(self, d_model, head_dim, block_size, dropout, print_shapes=False):
        super().__init__()
        self.key = nn.Linear(d_model, head_dim, bias=False)
        self.query = nn.Linear(d_model, head_dim, bias=False)
        self.value = nn.Linear(d_model, head_dim, bias=False)
        # Causal mask keeps attention from looking ahead
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        self.print_shapes = print_shapes
        self._printed = False

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        # Attention scores (B, T, T)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
        att = att.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)

        if self.print_shapes and not self._printed:
            print("B,T,C:", x.shape)
            print("Attention weights shape:", att.shape)
            self._printed = True

        # Weighted sum of values
        out = att @ v
        return out

## Multi-head attention

Multiple heads attend to different subspaces. We concatenate their outputs and project back.


In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, block_size, dropout, print_shapes=False):
        super().__init__()
        assert d_model % n_heads == 0
        head_dim = d_model // n_heads
        heads = []
        for i in range(n_heads):
            heads.append(
                SelfAttentionHead(
                    d_model,
                    head_dim,
                    block_size,
                    dropout,
                    print_shapes=print_shapes and i == 0,
                )
            )
        self.heads = nn.ModuleList(heads)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return self.dropout(out)

## Feedforward block (MLP)

After attention mixes information across tokens, the MLP mixes information within each token.


In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, ffn_mult=4, dropout=0.1):
        super().__init__()
        hidden = ffn_mult * d_model
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.GELU(),
            nn.Linear(hidden, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

## Transformer block

We use pre-norm: normalize before attention and MLP. Residual connections help gradients flow.


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, block_size, ffn_mult, dropout, print_shapes=False):
        super().__init__()
        self.ln1 = LayerNorm(d_model)
        self.ln2 = LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(
            d_model,
            n_heads,
            block_size,
            dropout,
            print_shapes=print_shapes,
        )
        self.ffn = FeedForward(d_model, ffn_mult, dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

## Full GPT-style decoder model

We combine token embeddings, learned positional embeddings, stacked blocks, and a final LM head.


In [None]:
class GPT(nn.Module):
    def __init__(
        self,
        vocab_size,
        block_size,
        d_model,
        n_heads,
        n_layers,
        ffn_mult,
        dropout,
    ):
        super().__init__()
        self.block_size = block_size
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(block_size, d_model)
        self.drop = nn.Dropout(dropout)

        blocks = []
        for i in range(n_layers):
            blocks.append(
                TransformerBlock(
                    d_model,
                    n_heads,
                    block_size,
                    ffn_mult,
                    dropout,
                    print_shapes=(i == 0),
                )
            )
        self.blocks = nn.ModuleList(blocks)
        self.ln_f = LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        if T > self.block_size:
            raise ValueError("Sequence length exceeds block size")

        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok + pos  # broadcast over batch
        x = self.drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            logits_flat = logits.view(-1, logits.size(-1))
            targets_flat = targets.view(-1)
            loss = F.cross_entropy(logits_flat, targets_flat)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        was_training = self.training
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size :]  # crop to model context window
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-6)

            if top_k is not None:
                k = min(top_k, logits.size(-1))
                v, _ = torch.topk(logits, k)
                cutoff = v[:, [-1]]
                logits = logits.masked_fill(logits < cutoff, float("-inf"))

            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)

        if was_training:
            self.train()
        return idx

## Training loop

We use AdamW, dropout (already in the model), and evaluate perplexity periodically.


In [None]:
vocab_size = tok["vocab_size"]  # vocabulary size
model = GPT(  # model instance
    vocab_size=vocab_size,
    block_size=BLOCK_SIZE,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    ffn_mult=FFN_MULT,
    dropout=DROPOUT,
).to(device)

# Parameter count is a quick sanity check on model size
param_count = sum(p.numel() for p in model.parameters())  # parameter count
print(f"Model parameters: {param_count/1e6:.2f}M")

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)  # optimizer instance


@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ["train", "val"]:
        losses = []
        for _ in range(EVAL_ITERS):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        avg = sum(losses) / len(losses)
        # Perplexity is exp(cross-entropy)
        out[split] = {"loss": avg, "ppl": math.exp(avg)}
    model.train()
    return out


print("Initial eval:")
metrics = estimate_loss()  # evaluation metrics dict
print(
    f"train loss {metrics['train']['loss']:.3f}, train ppl {metrics['train']['ppl']:.2f} | "
    f"val loss {metrics['val']['loss']:.3f}, val ppl {metrics['val']['ppl']:.2f}"
)

# optimization steps
for step in range(1, MAX_STEPS + 1):
    xb, yb = get_batch("train")
    # Forward pass returns logits and loss
    _, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    # Backprop + parameter update
    loss.backward()
    optimizer.step()

    if step % EVAL_EVERY == 0:
        metrics = estimate_loss()
        print(
            f"step {step}: "
            f"train loss {metrics['train']['loss']:.3f}, train ppl {metrics['train']['ppl']:.2f} | "
            f"val loss {metrics['val']['loss']:.3f}, val ppl {metrics['val']['ppl']:.2f}"
        )

## Inference: generate text

We sample from the trained model using temperature and top-k sampling.


In [None]:
prompts = [  # list of prompts
    "The meaning of life is",
    "Once upon a time",
    "In the middle of the night",
]

for prompt in prompts:
    idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
    out = model.generate(idx, max_new_tokens=80, temperature=0.9, top_k=40)
    print("PROMPT:", prompt)
    print(decode(out[0].tolist()))
    print("-" * 60)

## Scaling notes

Toy settings (this notebook):
- D_MODEL=256, N_LAYERS=4, N_HEADS=4, BLOCK_SIZE=128, BATCH_SIZE=32

Production-ish example (much larger):
- D_MODEL=768 or 1024, N_LAYERS=12 to 24, N_HEADS=12 to 16, BLOCK_SIZE=1024 or 2048

Why bigger models need more data and compute:
- Attention is O(T^2) in sequence length, so longer context is expensive
- More parameters increase capacity, which needs more training tokens to avoid overfitting
- Larger models often require longer training and higher batch sizes to converge well


## Exercises

1. Increase `BLOCK_SIZE` to 256 and compare validation perplexity.
2. Try `N_LAYERS=6` and `D_MODEL=384` and see how training speed changes.
3. Add gradient clipping and see if it stabilizes training at higher learning rates.
4. Compare outputs with `temperature=0.7` vs `temperature=1.2`.
5. Implement tied embeddings (share token embedding and LM head weights).
