In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Autoregressive Generation and the KV Cache

*Part 1 of the Vizuara series on Inference & Scaling*
*Estimated time: 55 minutes*

## 1. Why Does This Matter?

You have trained a language model. The weights are learned, the loss has converged. Now you need to actually *use* it -- generate text, one token at a time.

But here is the problem: naive autoregressive generation is shockingly wasteful. Every time you generate a new token, the model re-processes the *entire* sequence from scratch. For a 1000-token prompt generating 100 new tokens, you end up processing over 100,000 tokens total -- for just 100 outputs.

The **KV cache** is the single most important optimization in LLM inference. It eliminates this redundancy entirely, and every production system (ChatGPT, Claude, Gemini -- all of them) relies on it.

In this notebook, we will:
1. Build naive autoregressive generation and **measure** the waste
2. Derive **why** the KV cache works from the attention equations
3. Implement a KV cache from scratch and see the speedup firsthand
4. Analyze the **memory tradeoff** -- because nothing is free

By the end, you will have a working KV cache implementation and a deep intuition for why inference is memory-bound, not compute-bound.

In [None]:
# Setup -- run this cell first
!pip install -q torch matplotlib numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time

%matplotlib inline

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Building Intuition

Let us start with a concrete analogy. Imagine you are writing a book review by hand. You write the first sentence. Then, to write the second sentence, you re-read the entire first sentence from the beginning. To write the third sentence, you re-read sentences one and two from the beginning. By paragraph ten, you are re-reading nine paragraphs just to write one new sentence.

That is insane. No human does this. You *remember* what you wrote and just continue from where you left off.

Autoregressive generation without a KV cache is exactly this insane re-reading process. The KV cache is the model's "memory" of what it has already processed.

Let us see this concretely. Here is what autoregressive generation looks like step by step:

**Step 1:** Input = "The capital of" --> Model processes 3 tokens --> Predicts "France"

**Step 2:** Input = "The capital of France" --> Model processes 4 tokens --> Predicts "is"

**Step 3:** Input = "The capital of France is" --> Model processes 5 tokens --> Predicts "Paris"

At step 3, the model is re-computing the representations for "The", "capital", and "of" -- even though those computations are *identical* to what it did at step 1. Under causal attention, a token's Key and Value depend only on itself and earlier tokens, never on future tokens. So these values cannot change.

### Think About This

If we generate $T$ tokens from a prompt of length $n$, how many total tokens get processed without caching? We process $n+1$, then $n+2$, ..., up to $n+T$. The total is:

$$\text{Total} = \sum_{t=1}^{T}(n+t) = nT + \frac{T(T+1)}{2}$$

For $n=512$ and $T=256$: that is $512 \times 256 + \frac{256 \times 257}{2} = 131{,}072 + 32{,}896 = 163{,}968$ tokens processed for just 256 outputs. Over 640 tokens processed per token generated.

## 3. The Mathematics

### Self-Attention Recap

In a single attention head, each token $i$ produces three vectors from its hidden state $h_i$:

$$Q_i = W_Q h_i, \quad K_i = W_K h_i, \quad V_i = W_V h_i$$

The attention output for token $i$ is:

$$\text{Attn}(i) = \sum_{j \leq i} \frac{\exp(Q_i \cdot K_j / \sqrt{d_k})}{\sum_{m \leq i} \exp(Q_i \cdot K_m / \sqrt{d_k})} \cdot V_j$$

The critical observation: under causal attention ($j \leq i$), the Key $K_j$ and Value $V_j$ for token $j$ depend only on $h_j$ and the weight matrices $W_K, W_V$. They do not depend on any future token. So once computed, they never change.

### What the KV Cache Stores

At generation step $t$, we have already computed $K_1, K_2, \ldots, K_{t-1}$ and $V_1, V_2, \ldots, V_{t-1}$ from all previous steps. The KV cache stores these vectors.

For the new token at position $t$, we compute:
1. $Q_t, K_t, V_t$ from the new token's hidden state
2. Append $K_t$ to the cached keys: $[K_1, \ldots, K_{t-1}, K_t]$
3. Append $V_t$ to the cached values: $[V_1, \ldots, V_{t-1}, V_t]$
4. Compute attention: $Q_t$ attends to all cached keys

The computation per step goes from $O(t \cdot d^2)$ (recomputing all projections) to $O(d^2 + t \cdot d_\text{head})$ (one projection + attention dot products).

### Memory Cost

For a model with $L$ layers, $h$ heads, head dimension $d_k$, and sequence length $t$:

$$\text{KV cache memory} = 2 \times L \times t \times h \times d_k \times \text{bytes}$$

The factor of 2 accounts for both K and V. This is the tradeoff: we save enormous compute but consume GPU memory proportional to the sequence length.

## 4. Let's Build It -- Component by Component

### 4.1 A Minimal Transformer Layer

We will build a single-layer transformer with multi-head attention, then add the KV cache.

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention with optional KV cache support."""

    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        """
        Args:
            x: (batch, seq_len, d_model) -- full sequence or single new token
            kv_cache: tuple of (cached_K, cached_V) or None
        Returns:
            output: (batch, seq_len, d_model)
            new_kv_cache: tuple of (K, V) including new tokens
        """
        B, T, _ = x.shape

        # Project Q, K, V for new tokens
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        # Q, K, V shape: (B, n_heads, T, d_k)

        # If we have a cache, concatenate previous K, V
        if kv_cache is not None:
            K_prev, V_prev = kv_cache
            K = torch.cat([K_prev, K], dim=2)  # (B, n_heads, T_prev + T, d_k)
            V = torch.cat([V_prev, V], dim=2)

        # Store the updated cache
        new_kv_cache = (K, V)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)

        # Causal mask: only attend to current and previous positions
        # Q has T positions, K has T_total positions
        T_total = K.shape[2]
        # Each query position can attend to keys at positions <= its absolute position
        # Absolute positions of queries: T_total - T, ..., T_total - 1
        # A query at absolute position p can attend to keys at positions 0..p
        causal_mask = torch.triu(
            torch.ones(T, T_total, device=x.device, dtype=torch.bool),
            diagonal=T_total - T + 1
        )
        scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        # Reshape and project
        output = output.transpose(1, 2).contiguous().view(B, T, self.d_model)
        output = self.W_o(output)

        return output, new_kv_cache


# Quick test
d_model, n_heads = 64, 4
attn = MultiHeadAttention(d_model, n_heads).to(device)

x = torch.randn(1, 5, d_model, device=device)
out, cache = attn(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Cache K shape: {cache[0].shape}")
print(f"Cache V shape: {cache[1].shape}")

### 4.2 Building the Full Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    """A single transformer block: attention + FFN + layer norms."""

    def __init__(self, d_model, n_heads, d_ff=None):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x, kv_cache=None):
        # Pre-norm architecture (used by GPT-2, LLaMA, etc.)
        normed = self.ln1(x)
        attn_out, new_cache = self.attn(normed, kv_cache=kv_cache)
        x = x + attn_out
        x = x + self.ffn(self.ln2(x))
        return x, new_cache


class MiniGPT(nn.Module):
    """A minimal GPT-style model for demonstrating KV cache."""

    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_seq_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, tokens, kv_caches=None, start_pos=0):
        """
        Args:
            tokens: (B, T) token indices
            kv_caches: list of KV caches, one per layer, or None
            start_pos: position offset for positional embeddings
        Returns:
            logits: (B, T, vocab_size)
            new_kv_caches: list of updated KV caches
        """
        B, T = tokens.shape
        positions = torch.arange(start_pos, start_pos + T, device=tokens.device)

        x = self.token_emb(tokens) + self.pos_emb(positions)

        new_caches = []
        for i, layer in enumerate(self.layers):
            cache = kv_caches[i] if kv_caches is not None else None
            x, new_cache = layer(x, kv_cache=cache)
            new_caches.append(new_cache)

        x = self.ln_f(x)
        logits = self.head(x)
        return logits, new_caches


# Create a small model
vocab_size = 256
d_model = 128
n_heads = 4
n_layers = 4

model = MiniGPT(vocab_size, d_model, n_heads, n_layers).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

### 4.3 Naive Generation (No KV Cache)

In [None]:
@torch.no_grad()
def generate_naive(model, prompt_tokens, max_new_tokens=50):
    """
    Generate tokens WITHOUT KV cache.
    At each step, re-process the entire sequence from scratch.
    """
    tokens = prompt_tokens.clone()
    for _ in range(max_new_tokens):
        logits, _ = model(tokens, kv_caches=None, start_pos=0)
        next_logit = logits[:, -1, :]  # Only care about the last position
        next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
        tokens = torch.cat([tokens, next_token], dim=1)
    return tokens


# Time the naive approach
prompt = torch.randint(0, vocab_size, (1, 32), device=device)

start = time.time()
output_naive = generate_naive(model, prompt, max_new_tokens=100)
naive_time = time.time() - start

print(f"Naive generation: {naive_time:.3f}s for 100 tokens")
print(f"Output length: {output_naive.shape[1]} tokens")

### 4.4 KV-Cached Generation

In [None]:
@torch.no_grad()
def generate_with_cache(model, prompt_tokens, max_new_tokens=50):
    """
    Generate tokens WITH KV cache.
    First pass: process full prompt, populate cache.
    Subsequent passes: process only the new token.
    """
    # Prefill: process the entire prompt
    logits, kv_caches = model(prompt_tokens, kv_caches=None, start_pos=0)
    next_logit = logits[:, -1, :]
    next_token = torch.argmax(next_logit, dim=-1, keepdim=True)

    all_tokens = [prompt_tokens, next_token]
    cur_pos = prompt_tokens.shape[1]

    for _ in range(max_new_tokens - 1):
        # Only feed the single new token
        logits, kv_caches = model(next_token, kv_caches=kv_caches, start_pos=cur_pos)
        cur_pos += 1
        next_logit = logits[:, -1, :]
        next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
        all_tokens.append(next_token)

    return torch.cat(all_tokens, dim=1)


# Time the cached approach
start = time.time()
output_cached = generate_with_cache(model, prompt, max_new_tokens=100)
cached_time = time.time() - start

print(f"Cached generation: {cached_time:.3f}s for 100 tokens")
print(f"Speedup: {naive_time / cached_time:.1f}x")

## 5. Your Turn

### TODO 1: Verify Output Equivalence

The KV cache is an *optimization*, not an approximation. The outputs should be **exactly identical** to naive generation. Your task: verify this.

In [None]:
# TODO: Compare the outputs from naive and cached generation.
# They should be identical (same tokens at every position).
# Hint: use torch.equal() or compare element-wise.

# YOUR CODE HERE
# are_equal = ...
# print(f"Outputs are identical: {are_equal}")

# If they are NOT identical, there is a bug in the implementation.
# Check the causal masking and position offset logic.

### TODO 2: Measure the Speedup Curve

How does the speedup change as we increase the number of generated tokens? Generate 10, 50, 100, 200, and 500 tokens with and without cache, and plot the speedup factor.

In [None]:
# TODO: Create a plot showing speedup vs. number of generated tokens.
# Expected behavior: speedup increases with more tokens because
# the naive approach wastes more and more computation.

# YOUR CODE HERE
# generation_lengths = [10, 50, 100, 200]
# naive_times = []
# cached_times = []
#
# for n_tokens in generation_lengths:
#     # Time naive generation
#     ...
#     # Time cached generation
#     ...
#
# speedups = [n / c for n, c in zip(naive_times, cached_times)]
#
# plt.figure(figsize=(8, 5))
# plt.plot(generation_lengths, speedups, 'bo-', linewidth=2, markersize=8)
# plt.xlabel('Number of generated tokens')
# plt.ylabel('Speedup (naive / cached)')
# plt.title('KV Cache Speedup vs. Generation Length')
# plt.grid(True, alpha=0.3)
# plt.tight_layout()
# plt.show()

## 6. Putting It All Together

Let us now visualize the compute and memory tradeoffs to build deeper intuition.

In [None]:
# Compute analysis: count FLOPs for naive vs cached
def count_operations(n_prompt, n_generate, d_model, n_layers):
    """Estimate relative computation for naive vs cached generation."""
    # Naive: at step t, process (n_prompt + t) tokens through all layers
    naive_ops = 0
    for t in range(1, n_generate + 1):
        seq_len = n_prompt + t
        # Projection cost: 3 * seq_len * d_model^2 per layer
        # Attention cost: seq_len^2 * d_model per layer
        naive_ops += n_layers * (3 * seq_len * d_model**2 + seq_len**2 * d_model)

    # Cached: prefill + per-token decode
    # Prefill: process n_prompt tokens once
    prefill_ops = n_layers * (3 * n_prompt * d_model**2 + n_prompt**2 * d_model)
    # Decode: process 1 token at each step
    decode_ops = 0
    for t in range(1, n_generate + 1):
        seq_len = n_prompt + t
        # Projection: 3 * 1 * d_model^2 per layer
        # Attention: 1 * seq_len * d_model per layer (Q dot all K's)
        decode_ops += n_layers * (3 * d_model**2 + seq_len * d_model)
    cached_ops = prefill_ops + decode_ops

    return naive_ops, cached_ops

# Calculate for different prompt lengths
prompt_lengths = [32, 64, 128, 256, 512]
n_generate = 100

naive_ops_list = []
cached_ops_list = []

for n_prompt in prompt_lengths:
    naive, cached = count_operations(n_prompt, n_generate, d_model=128, n_layers=4)
    naive_ops_list.append(naive)
    cached_ops_list.append(cached)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart: compute comparison
x = np.arange(len(prompt_lengths))
width = 0.35
bars1 = ax1.bar(x - width/2, [n/1e9 for n in naive_ops_list], width, label='Naive', color='#e74c3c', alpha=0.8)
bars2 = ax1.bar(x + width/2, [c/1e9 for c in cached_ops_list], width, label='KV Cache', color='#2ecc71', alpha=0.8)
ax1.set_xlabel('Prompt Length')
ax1.set_ylabel('Operations (billions)')
ax1.set_title('Compute: Naive vs KV Cache')
ax1.set_xticks(x)
ax1.set_xticklabels(prompt_lengths)
ax1.legend()
ax1.set_yscale('log')

# Memory analysis
# KV cache memory = 2 * n_layers * seq_len * d_model * bytes_per_element
d_model_real = 4096  # Realistic model
n_layers_real = 32
bytes_per_element = 2  # FP16

seq_lengths = np.arange(1, 8193)
kv_memory_gb = (2 * n_layers_real * seq_lengths * d_model_real * bytes_per_element) / (1024**3)

ax2.plot(seq_lengths, kv_memory_gb, 'b-', linewidth=2)
ax2.axhline(y=16, color='r', linestyle='--', alpha=0.7, label='16 GB (T4 GPU)')
ax2.axhline(y=80, color='orange', linestyle='--', alpha=0.7, label='80 GB (A100 GPU)')
ax2.set_xlabel('Sequence Length (tokens)')
ax2.set_ylabel('KV Cache Memory (GB)')
ax2.set_title('KV Cache Memory vs Sequence Length\n(32-layer, d=4096 model, FP16)')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nKV cache at 2048 tokens: {kv_memory_gb[2047]:.2f} GB")
print(f"KV cache at 4096 tokens: {kv_memory_gb[4095]:.2f} GB")
print(f"KV cache at 8192 tokens: {kv_memory_gb[8191]:.2f} GB")

## 7. Training and Results

Since we are focused on *inference* rather than training in this notebook, let us quantify the empirical performance of our KV cache implementation.

In [None]:
# Benchmark: measure tokens per second for different configurations
configs = [
    {"prompt_len": 32, "gen_len": 50},
    {"prompt_len": 128, "gen_len": 100},
    {"prompt_len": 256, "gen_len": 200},
]

print("=" * 70)
print(f"{'Config':<25} {'Naive (tok/s)':<15} {'Cached (tok/s)':<15} {'Speedup':<10}")
print("=" * 70)

for cfg in configs:
    prompt = torch.randint(0, vocab_size, (1, cfg["prompt_len"]), device=device)

    # Warm up
    _ = generate_naive(model, prompt, max_new_tokens=5)
    _ = generate_with_cache(model, prompt, max_new_tokens=5)

    # Benchmark naive
    start = time.time()
    for _ in range(3):
        _ = generate_naive(model, prompt, max_new_tokens=cfg["gen_len"])
    naive_t = (time.time() - start) / 3
    naive_tps = cfg["gen_len"] / naive_t

    # Benchmark cached
    start = time.time()
    for _ in range(3):
        _ = generate_with_cache(model, prompt, max_new_tokens=cfg["gen_len"])
    cached_t = (time.time() - start) / 3
    cached_tps = cfg["gen_len"] / cached_t

    label = f"prompt={cfg['prompt_len']}, gen={cfg['gen_len']}"
    print(f"{label:<25} {naive_tps:<15.1f} {cached_tps:<15.1f} {cached_tps/naive_tps:<10.1f}x")

print("=" * 70)

## 8. Final Output

Let us generate some actual text to see our complete system in action. We will use a character-level model on a tiny corpus so we can train it quickly.

In [None]:
# Train a tiny character-level model for demonstration
text = """The quick brown fox jumps over the lazy dog. The dog barked at the fox.
The fox ran away quickly. The lazy dog went back to sleep. The brown fox was clever.
The dog woke up and chased the fox again. The fox jumped over the fence."""

# Character-level tokenization
chars = sorted(set(text))
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for c, i in char_to_idx.items()}
v_size = len(chars)

# Create training data
data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long, device=device)

# Train the model
small_model = MiniGPT(v_size, d_model=64, n_heads=4, n_layers=2, max_seq_len=256).to(device)
optimizer = torch.optim.Adam(small_model.parameters(), lr=3e-3)

# Simple training loop
small_model.train()
seq_len = 32

for epoch in range(200):
    total_loss = 0
    n_batches = 0
    for i in range(0, len(data) - seq_len - 1, seq_len):
        x = data[i:i+seq_len].unsqueeze(0)
        y = data[i+1:i+seq_len+1].unsqueeze(0)
        logits, _ = small_model(x, start_pos=0)
        loss = F.cross_entropy(logits.view(-1, v_size), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        n_batches += 1
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}: loss = {total_loss/n_batches:.4f}")

# Generate text with KV cache
small_model.eval()
prompt_text = "The quick"
prompt_ids = torch.tensor([[char_to_idx[c] for c in prompt_text]], device=device)

generated = generate_with_cache(small_model, prompt_ids, max_new_tokens=100)
generated_text = ''.join([idx_to_char[t.item()] for t in generated[0]])
print(f"\nGenerated text (KV cached):\n{generated_text}")

## 9. Reflection and Next Steps

### What We Learned

1. **Autoregressive generation** is inherently sequential: each token depends on all previous tokens. Without optimization, this leads to massive redundant computation.

2. **The KV cache** exploits the fact that under causal attention, past tokens' Keys and Values never change. By caching them, we reduce per-step computation from processing the entire sequence to processing a single token.

3. **The tradeoff is memory**: the KV cache grows linearly with sequence length and model depth. For large models (70B+ parameters) serving many concurrent users, KV cache memory management becomes the primary bottleneck.

4. **The speedup is substantial**: for our small model, we saw significant speedups that increase with longer generation.

### Key Takeaways

| Metric | Without KV Cache | With KV Cache |
|--------|-----------------|---------------|
| Tokens processed per step | Entire sequence | 1 token |
| Computation scaling | $O(T^2)$ total | $O(T)$ total (for projections) |
| Memory overhead | Minimal | $O(L \times T \times d)$ |

### What is Next

In the next notebook, we will tackle the other half of the generation problem: **sampling strategies**. The KV cache tells us *how* to generate efficiently. Sampling strategies tell us *what* to generate -- how to choose the next token from the probability distribution in a way that produces coherent, diverse, and high-quality text.

### Further Reading

- Vaswani et al., "Attention Is All You Need" (2017) -- the original transformer and attention mechanism
- Pope et al., "Efficiently Scaling Transformer Inference" (2022) -- advanced KV cache management
- Kwon et al., "Efficient Memory Management for Large Language Model Serving with PagedAttention" (2023) -- vLLM and PagedAttention