# Lesson 06: KV Cache for Fast Generation

When a GPT-style model generates one token at a time, it repeatedly re-computes attention over the whole context. A KV cache saves the **keys** and **values** from previous steps so each new token only needs attention against the cached history.

You will learn:
- how a decoder Transformer produces queries, keys, values
- how to reuse cached K/V tensors during generation
- why caching turns O(T^2) per step into O(T) per step

We'll build a small model, train briefly, and compare naive vs cached generation speed.


In [None]:
# Setup and config
import math
import os
import random
import json
import time
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F

# Top-level configuration (small defaults)
BLOCK_SIZE = 128  # context length (tokens)
BATCH_SIZE = 32  # batch size
D_MODEL = 256  # model width
N_HEADS = 4  # attention heads
N_LAYERS = 4  # transformer layers
FFN_MULT = 4  # FFN expansion multiplier
DROPOUT = 0.1  # dropout prob

LR = 3e-4  # learning rate
MAX_STEPS = 400  # training steps
EVAL_EVERY = 100  # eval interval (steps)
EVAL_ITERS = 30  # eval batches

TOKENIZER_VOCAB_SIZE = 2000  # target tokenizer vocab size
BPE_TRAIN_CHARS = 200_000  # max chars for BPE training

GEN_TOKENS = 120  # generated tokens

# Reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

## Dataset: Wikitext-2 (with fallback)

We prefer `wikitext-2-raw-v1` from Hugging Face. If that fails (offline), we fall back to wikitext-2-raw-v1, and then a tiny local string so the notebook always runs.


In [None]:
from datasets import load_dataset


def load_text_dataset():
    from datasets import load_dataset
    ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    name = "wikitext-2-raw-v1"
    train_text = "\n".join(ds["train"]["text"])
    val_text = "\n".join(ds["validation"]["text"]) if "validation" in ds else ""
    return train_text, val_text, name

train_text, val_text, dataset_name = load_text_dataset()
print("Dataset:", dataset_name)
print("Train chars:", len(train_text), "Val chars:", len(val_text))

## Tokenizer (small BPE)

We either load `tokenizer_bpe.json` or train a tiny BPE tokenizer on the dataset. This keeps the notebook self-contained.


In [None]:
SPECIAL_TOKENS = ["<pad>", "<unk>"]  # special token strings
TOKENIZER_PATH = "tokenizer_bpe.json"  # tokenizer file path


def _get_stats(vocab):
    # Count symbol pairs inside the current vocabulary
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return pairs


def _merge_vocab(pair, vocab):
    bigram = " ".join(pair)
    replacement = "".join(pair)
    new_vocab = {}
    for word, freq in vocab.items():
        new_word = word.replace(bigram, replacement)
        new_vocab[new_word] = freq
    return new_vocab


def _build_simple_bpe_helpers(model):
    merges = [tuple(p) for p in model["merges"]]
    token_to_id = model["token_to_id"]
    id_to_token = {i: t for t, i in token_to_id.items()}

    def bpe_encode_word(word):
        # Start from characters, then apply merges greedily
        symbols = list(word) + ["</w>"]
        for a, b in merges:
            i = 0
            new_symbols = []
            while i < len(symbols):
                if i < len(symbols) - 1 and symbols[i] == a and symbols[i + 1] == b:
                    new_symbols.append(a + b)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols
        return symbols

    def encode_fn(text_in):
        ids = []
        for w in text_in.split():
            for sym in bpe_encode_word(w):
                ids.append(token_to_id.get(sym, token_to_id["<unk>"]))
        return ids

    def decode_fn(ids_in):
        tokens = [id_to_token[i] for i in ids_in]
        text_out = "".join([" " if t == "</w>" else t for t in tokens])
        return " ".join(text_out.split())

    return encode_fn, decode_fn


def _train_simple_bpe(text, vocab_size=2000, max_merges=200):
    # Build initial vocab of character sequences
    words = [w for w in text.split() if w]
    vocab = defaultdict(int)
    for w in words:
        vocab[" ".join(list(w)) + " </w>"] += 1

    symbols = set()
    for word in vocab:
        symbols.update(word.split())
    base_vocab = len(symbols) + len(SPECIAL_TOKENS)
    num_merges = max(0, min(max_merges, vocab_size - base_vocab))

    merges = []
    for _ in range(num_merges):
        pairs = _get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = _merge_vocab(best, vocab)
        merges.append(best)

    symbols = set()
    for word in vocab:
        symbols.update(word.split())

    tokens = list(SPECIAL_TOKENS) + sorted(symbols)
    token_to_id = {t: i for i, t in enumerate(tokens)}

    model = {
        "merges": merges,
        "token_to_id": token_to_id,
    }

    encode_fn, decode_fn = _build_simple_bpe_helpers(model)
    pad_id = token_to_id["<pad>"]
    return model, encode_fn, decode_fn, len(tokens), pad_id


def load_tokenizer(path):
    if not os.path.exists(path):
        return None
    # Try Hugging Face tokenizers first
    try:
        from tokenizers import Tokenizer

        tokenizer = Tokenizer.from_file(path)

        def encode_fn(s):
            return tokenizer.encode(s).ids

        def decode_fn(ids):
            return tokenizer.decode(ids)

        vocab_size = len(tokenizer.get_vocab())
        pad_id = tokenizer.token_to_id("<pad>")
        if pad_id is None:
            pad_id = 0
        return {
            "type": "tokenizers",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size,
            "pad_id": pad_id,
        }
    except Exception:
        pass

    # Fallback to simple BPE json
    try:
        with open(path, "r", encoding="utf-8") as f:
            model = json.load(f)
        if "merges" in model and "token_to_id" in model:
            encode_fn, decode_fn = _build_simple_bpe_helpers(model)
            pad_id = model["token_to_id"].get("<pad>", 0)
            return {
                "type": "simple_bpe",
                "encode": encode_fn,
                "decode": decode_fn,
                "vocab_size": len(model["token_to_id"]),
                "pad_id": pad_id,
            }
    except Exception:
        pass

    return None


def train_tokenizer(text, vocab_size=2000, save_path="tokenizer_bpe.json"):
    text = text[:BPE_TRAIN_CHARS]
    try:
        from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers

        tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
        tokenizer.normalizer = normalizers.Sequence([normalizers.NFKC()])
        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        trainer = trainers.BpeTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)

        tokenizer.train_from_iterator([text], trainer=trainer)
        tokenizer.save(save_path)

        def encode_fn(s):
            return tokenizer.encode(s).ids

        def decode_fn(ids):
            return tokenizer.decode(ids)

        vocab_size_out = len(tokenizer.get_vocab())
        pad_id = tokenizer.token_to_id("<pad>")
        if pad_id is None:
            pad_id = 0
        return {
            "type": "tokenizers",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
        }
    except Exception as e:
        print("tokenizers not available or failed. Using simple BPE fallback.")
        print("Reason:", repr(e))
        model, encode_fn, decode_fn, vocab_size_out, pad_id = _train_simple_bpe(
            text,
            vocab_size=vocab_size,
        )
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(model, f)
        return {
            "type": "simple_bpe",
            "encode": encode_fn,
            "decode": decode_fn,
            "vocab_size": vocab_size_out,
            "pad_id": pad_id,
        }

In [None]:
tok = load_tokenizer(TOKENIZER_PATH)
if tok is None:
    print("No tokenizer found. Training a small BPE tokenizer...")
    tok = train_tokenizer(train_text + "\n" + val_text, vocab_size=TOKENIZER_VOCAB_SIZE)
else:
    print("Loaded tokenizer from", TOKENIZER_PATH)

encode = tok["encode"]
decode = tok["decode"]

print("Tokenizer type:", tok["type"])
print("Vocab size:", tok["vocab_size"])

## Numericalize the dataset

We convert text to token ids once, then sample contiguous blocks for training.


In [None]:
train_ids = torch.tensor(encode(train_text), dtype=torch.long)
val_ids = torch.tensor(encode(val_text), dtype=torch.long)

print("Train tokens:", train_ids.numel())
print("Val tokens:", val_ids.numel())

## Batch sampling for autoregressive language modeling

We sample random contiguous windows of length `BLOCK_SIZE` and predict the next token.


In [None]:
def get_batch(split):
    data = train_ids if split == "train" else val_ids
    # Random starting indices for each sequence in the batch
    idx = torch.randint(0, len(data) - BLOCK_SIZE - 1, (BATCH_SIZE,))
    x = torch.stack([data[i : i + BLOCK_SIZE] for i in idx])
    y = torch.stack([data[i + 1 : i + BLOCK_SIZE + 1] for i in idx])
    return x.to(device), y.to(device)

## GPT-style decoder with KV cache

Below is a minimal decoder-only Transformer. The key change is in the attention module, which can accept a `past_kv` cache and return updated keys/values.


In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, ffn_mult=4, dropout=0.1):
        super().__init__()
        hidden = ffn_mult * d_model
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.GELU(),
            nn.Linear(hidden, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, block_size, dropout=0.1, print_shapes=False):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        # Causal mask keeps attention from looking ahead
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)))
        self.print_shapes = print_shapes
        self._printed = False

    def forward(self, x, past_kv=None, use_cache=False):
        B, T, C = x.shape

        # Project to queries, keys, values (still in full D_MODEL space)
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape to separate heads: (B, T, C) -> (B, n_heads, T, head_dim)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        if past_kv is not None:
            past_k, past_v = past_kv
            # Concatenate cached keys/values with the current step
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
            past_len = past_k.size(2)
        else:
            past_len = 0

        total_len = k.size(2)
        if total_len > self.mask.size(0):
            raise ValueError("KV cache is longer than block size.")

        # Attention scores: (B, n_heads, T, total_len)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if past_len == 0:
            # Standard causal mask for training on full sequences
            att = att.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        else:
            # Mask so new tokens cannot attend to "future" positions in the new chunk
            full_mask = self.mask[:total_len, :total_len]
            att = att.masked_fill(full_mask[past_len:total_len, :total_len] == 0, float("-inf"))

        att = F.softmax(att, dim=-1)
        att = self.dropout(att)

        # Weighted sum of values -> (B, n_heads, T, head_dim)
        out = att @ v
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.out_proj(out)
        out = self.dropout(out)

        # Print shapes once so you can see exactly what flows through attention
        if self.print_shapes and not self._printed:
            print("x:", x.shape)
            print("q:", q.shape, "k:", k.shape, "v:", v.shape)
            print("att:", att.shape, "out:", out.shape, "past_len:", past_len)
            self._printed = True

        new_kv = (k, v) if use_cache else None
        return out, new_kv

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, block_size, ffn_mult=4, dropout=0.1, print_shapes=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(
            d_model,
            n_heads,
            block_size,
            dropout=dropout,
            print_shapes=print_shapes,
        )
        self.ffn = FeedForward(d_model, ffn_mult=ffn_mult, dropout=dropout)

    def forward(self, x, past_kv=None, use_cache=False):
        attn_out, new_kv = self.attn(self.ln1(x), past_kv=past_kv, use_cache=use_cache)
        x = x + attn_out
        x = x + self.ffn(self.ln2(x))
        return x, new_kv


class GPT(nn.Module):
    def __init__(
        self,
        vocab_size,
        block_size,
        d_model,
        n_heads,
        n_layers,
        ffn_mult,
        dropout,
    ):
        super().__init__()
        self.block_size = block_size
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(block_size, d_model)
        self.drop = nn.Dropout(dropout)

        blocks = []
        for i in range(n_layers):
            blocks.append(
                TransformerBlock(
                    d_model,
                    n_heads,
                    block_size,
                    ffn_mult=ffn_mult,
                    dropout=dropout,
                    print_shapes=(i == 0),
                )
            )
        self.blocks = nn.ModuleList(blocks)
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx, targets=None, past_kv=None, use_cache=False):
        B, T = idx.shape
        if past_kv is None:
            past_kv = [None] * len(self.blocks)
            past_len = 0
        else:
            past_len = past_kv[0][0].size(2) if past_kv[0] is not None else 0

        if past_len + T > self.block_size:
            raise ValueError("Sequence length exceeds block size")

        # Token and positional embeddings
        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(past_len, past_len + T, device=idx.device))
        x = tok + pos  # broadcast over batch
        x = self.drop(x)

        new_kv = [] if use_cache else None
        for block, layer_past in zip(self.blocks, past_kv):
            x, layer_kv = block(x, past_kv=layer_past, use_cache=use_cache)
            if use_cache:
                new_kv.append(layer_kv)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            logits_flat = logits.view(-1, logits.size(-1))
            targets_flat = targets.view(-1)
            loss = F.cross_entropy(logits_flat, targets_flat)

        return logits, loss, new_kv

## Training loop

We keep training short (just enough to make generation non-random-ish).


In [None]:
vocab_size = tok["vocab_size"]
model = GPT(
    vocab_size=vocab_size,
    block_size=BLOCK_SIZE,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    ffn_mult=FFN_MULT,
    dropout=DROPOUT,
).to(device)

# Parameter count is a quick sanity check on model size
param_count = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {param_count/1e6:.2f}M")

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)


@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ["train", "val"]:
        losses = []
        for _ in range(EVAL_ITERS):
            xb, yb = get_batch(split)
            _, loss, _ = model(xb, yb)
            losses.append(loss.item())
        avg = sum(losses) / len(losses)
        # Perplexity is exp(cross-entropy)
        out[split] = {"loss": avg, "ppl": math.exp(avg)}
    model.train()
    return out


print("Initial eval:")
metrics = estimate_loss()
print(
    f"train loss {metrics['train']['loss']:.3f}, train ppl {metrics['train']['ppl']:.2f} | "
    f"val loss {metrics['val']['loss']:.3f}, val ppl {metrics['val']['ppl']:.2f}"
)

for step in range(1, MAX_STEPS + 1):
    xb, yb = get_batch("train")
    # Forward pass returns logits and loss
    _, loss, _ = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    # Backprop + parameter update
    loss.backward()
    optimizer.step()

    if step % EVAL_EVERY == 0:
        metrics = estimate_loss()
        print(
            f"step {step}: "
            f"train loss {metrics['train']['loss']:.3f}, train ppl {metrics['train']['ppl']:.2f} | "
            f"val loss {metrics['val']['loss']:.3f}, val ppl {metrics['val']['ppl']:.2f}"
        )

## Inference: naive vs cached generation

We compare two approaches:
- **naive**: recompute attention over the full context each step
- **cached**: reuse keys/values from previous steps


In [None]:
@torch.no_grad()
def sample_from_logits(logits, temperature=1.0, top_k=None):
    # Apply temperature and (optional) top-k filtering
    logits = logits / max(temperature, 1e-6)
    if top_k is not None:
        k = min(top_k, logits.size(-1))
        v, _ = torch.topk(logits, k)
        cutoff = v[:, [-1]]
        logits = logits.masked_fill(logits < cutoff, float("-inf"))
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)


@torch.no_grad()
def generate_naive(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    was_training = model.training
    model.eval()

    for _ in range(max_new_tokens):
        # Crop to model context window
        idx_cond = idx[:, -model.block_size :]
        logits, _, _ = model(idx_cond)
        next_id = sample_from_logits(logits[:, -1, :], temperature=temperature, top_k=top_k)
        idx = torch.cat([idx, next_id], dim=1)

    if was_training:
        model.train()
    return idx


def crop_kv_cache(past_kv, max_len):
    if past_kv is None:
        return None
    cropped = []
    for k, v in past_kv:
        if k.size(2) > max_len:
            k = k[:, :, -max_len:, :]
            v = v[:, :, -max_len:, :]
        cropped.append((k, v))
    return cropped


@torch.no_grad()
def generate_cached(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    was_training = model.training
    model.eval()

    # Prime the cache on the prompt (crop to block size)
    idx_cond = idx[:, -model.block_size :]
    logits, _, past_kv = model(idx_cond, use_cache=True)

    for _ in range(max_new_tokens):
        next_id = sample_from_logits(logits[:, -1, :], temperature=temperature, top_k=top_k)
        idx = torch.cat([idx, next_id], dim=1)

        # Keep only the most recent context window in the cache
        past_kv = crop_kv_cache(past_kv, max_len=model.block_size - 1)
        logits, _, past_kv = model(next_id, past_kv=past_kv, use_cache=True)

    if was_training:
        model.train()
    return idx


def time_generation(label, fn, model, idx, max_new_tokens, **kwargs):
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    start = time.perf_counter()
    out = fn(model, idx, max_new_tokens, **kwargs)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    elapsed = time.perf_counter() - start
    tok_per_sec = max_new_tokens / elapsed
    print(f"{label}: {elapsed:.3f}s, {tok_per_sec:.1f} tokens/sec")
    return out

In [None]:
prompt = "The meaning of life is"
idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)

# Use the same seed so both methods sample the same tokens
sample_seed = 123

print("--- Timing comparison ---")
torch.manual_seed(sample_seed)
out_naive = time_generation(
    "naive",
    generate_naive,
    model,
    idx.clone(),
    GEN_TOKENS,
    temperature=0.9,
    top_k=40,
)

torch.manual_seed(sample_seed)
out_cached = time_generation(
    "cached",
    generate_cached,
    model,
    idx.clone(),
    GEN_TOKENS,
    temperature=0.9,
    top_k=40,
)

print("\n--- Naive output ---")
print(decode(out_naive[0].tolist()))
print("\n--- Cached output ---")
print(decode(out_cached[0].tolist()))

## Scaling notes for production inference

A few real-world considerations when serving large models:
- **Batching**: group multiple requests so attention uses larger matrix multiplies.
- **Paged KV cache**: store K/V in blocks so you can evict or swap old tokens efficiently.
- **Quantized KV**: store K/V in lower precision (e.g., FP8 or INT8) to save memory bandwidth.
- **Speculative decoding**: draft tokens with a small model, then verify with a large model.
- **Streaming + early exit**: send tokens to users as soon as they are ready.


## Exercises

1. Increase `BLOCK_SIZE` to 256 and see how the speed gap changes.
2. Try greedy decoding (`torch.argmax`) instead of sampling and compare output consistency.
3. Measure memory usage with and without KV cache (hint: `torch.cuda.memory_allocated`).
4. Add a maximum prompt length and show how cache cropping affects output.
5. Implement top-p (nucleus) sampling and compare with top-k.
