# Attention Mechanism from First Principles

**Inference Engineering Series - Notebook 4**

---

The attention mechanism is the beating heart of all transformer-based models. It's what allows a token to "look at" every other token in the sequence and decide what information to gather. Understanding attention deeply is critical for inference engineering because:

- Attention is the **most memory-intensive operation** during inference
- It's where the **KV cache** lives (covered in the next notebook)
- Its cost scales **quadratically** with sequence length
- Many inference optimizations target attention specifically (FlashAttention, PagedAttention, etc.)

In this notebook, we'll build attention from absolute scratch -- first in NumPy, then in PyTorch -- and visualize every step.

## What You'll Learn

1. **The intuition behind attention** - queries, keys, and values
2. **Scaled dot-product attention** - the complete formula
3. **Implementing attention in NumPy** - step by step
4. **Visualizing attention weights** - what the model "pays attention to"
5. **Multi-head attention** - why multiple heads matter
6. **Causal masking** - how autoregressive models prevent "cheating"
7. **Compute and memory costs** - arithmetic intensity of attention

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(42)
torch.manual_seed(42)
print(f"PyTorch version: {torch.__version__}")

## Part 1: The Intuition Behind Attention

Think of attention as a **database lookup**:

- **Query (Q)**: "What am I looking for?" - Each token generates a query vector describing what information it needs.
- **Key (K)**: "What do I contain?" - Each token generates a key vector advertising what information it has.
- **Value (V)**: "Here's my information." - Each token generates a value vector containing the actual information to share.

The attention mechanism:
1. Compares each query with all keys (dot product) to get a **similarity score**
2. Normalizes scores with **softmax** to get **attention weights** (probabilities)
3. Uses these weights to take a **weighted sum of values**

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
# Let's build a concrete example
# Imagine we have a 4-token sequence: "The cat sat down"
# Each token has a 3-dimensional embedding (tiny for visualization)

tokens = ["The", "cat", "sat", "down"]
seq_len = len(tokens)
d_model = 8  # Embedding dimension
d_k = 4      # Key/Query dimension (often d_model / num_heads)

# Simulated token embeddings (normally these come from the embedding layer)
X = np.random.randn(seq_len, d_model)

print("Input matrix X (each row is a token embedding):")
print(f"Shape: {X.shape} (seq_len={seq_len}, d_model={d_model})")
for i, token in enumerate(tokens):
    print(f"  {token:6s}: {X[i].round(2)}")

## Part 2: Computing Q, K, V Projections

The raw embeddings are projected into Q, K, V spaces using learned weight matrices:

$$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$

In [None]:
# Projection matrices (normally these are learned during training)
W_Q = np.random.randn(d_model, d_k) * 0.5
W_K = np.random.randn(d_model, d_k) * 0.5
W_V = np.random.randn(d_model, d_k) * 0.5

# Project into Q, K, V spaces
Q = X @ W_Q  # (seq_len, d_k)
K = X @ W_K  # (seq_len, d_k)
V = X @ W_V  # (seq_len, d_k)

print(f"Q (Queries) shape: {Q.shape}")
print(f"K (Keys) shape:    {K.shape}")
print(f"V (Values) shape:  {V.shape}")

print("\nQ matrix (each row is a query for one token):")
for i, token in enumerate(tokens):
    print(f"  {token:6s}: {Q[i].round(3)}")

In [None]:
# Visualize Q, K, V matrices
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

for ax, matrix, name, color in [(axes[0], Q, 'Queries (Q)', 'Reds'),
                                  (axes[1], K, 'Keys (K)', 'Blues'),
                                  (axes[2], V, 'Values (V)', 'Greens')]:
    im = ax.imshow(matrix, cmap=color, aspect='auto')
    ax.set_xlabel('Dimension')
    ax.set_ylabel('Token')
    ax.set_yticks(range(len(tokens)))
    ax.set_yticklabels(tokens)
    ax.set_title(name, fontsize=13, fontweight='bold')
    plt.colorbar(im, ax=ax, shrink=0.8)
    
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            ax.text(j, i, f'{matrix[i,j]:.2f}', ha='center', va='center', fontsize=8)

plt.suptitle('Q, K, V Projections', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 3: Computing Attention Scores

Step 1: Compute the dot product between each query and all keys:

$$\text{scores} = Q K^T$$

This gives us a `(seq_len, seq_len)` matrix where `scores[i, j]` is how much token `i` should attend to token `j`.

In [None]:
# Step 1: Raw attention scores
scores = Q @ K.T  # (seq_len, seq_len)

print(f"Attention scores shape: {scores.shape}")
print(f"\nRaw scores (Q @ K^T):")
print(f"{'':8s}", end="")
for t in tokens:
    print(f"{t:>8s}", end="")
print()
for i, t in enumerate(tokens):
    print(f"{t:8s}", end="")
    for j in range(len(tokens)):
        print(f"{scores[i,j]:8.3f}", end="")
    print()

print(f"\nscores[i,j] = how much token i 'attends to' token j")

In [None]:
# Step 2: Scale by sqrt(d_k) to prevent softmax from becoming too peaky
scale = np.sqrt(d_k)
scaled_scores = scores / scale

print(f"Scaling factor: sqrt({d_k}) = {scale:.2f}")
print(f"\nWhy scale? Without scaling, large d_k means large dot products,")
print(f"which push softmax into saturated regions with near-zero gradients.")
print(f"\nBefore scaling: scores range [{scores.min():.2f}, {scores.max():.2f}]")
print(f"After scaling:  scores range [{scaled_scores.min():.2f}, {scaled_scores.max():.2f}]")

In [None]:
# Step 3: Apply softmax to get attention weights (probabilities)
def softmax(x):
    """Numerically stable softmax."""
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

attention_weights = softmax(scaled_scores)

print("Attention weights (after softmax):")
print(f"{'':8s}", end="")
for t in tokens:
    print(f"{t:>8s}", end="")
print()
for i, t in enumerate(tokens):
    print(f"{t:8s}", end="")
    for j in range(len(tokens)):
        print(f"{attention_weights[i,j]:8.3f}", end="")
    print(f"  (sum={attention_weights[i].sum():.3f})")

print("\nEach row sums to 1.0 -- it's a probability distribution!")
print("Each row tells us where that token 'looks' in the sequence.")

In [None]:
# Visualize attention weights as a heatmap
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Raw scores
ax = axes[0]
im = ax.imshow(scores, cmap='YlOrRd')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
ax.set_title('Raw Scores (Q @ K^T)', fontsize=13, fontweight='bold')
ax.set_xlabel('Key (attending to)')
ax.set_ylabel('Query (from)')
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, f'{scores[i,j]:.2f}', ha='center', va='center', fontsize=10)
plt.colorbar(im, ax=ax, shrink=0.8)

# Scaled scores
ax = axes[1]
im = ax.imshow(scaled_scores, cmap='YlOrRd')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
ax.set_title(f'Scaled Scores (/ sqrt({d_k}))', fontsize=13, fontweight='bold')
ax.set_xlabel('Key (attending to)')
ax.set_ylabel('Query (from)')
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, f'{scaled_scores[i,j]:.2f}', ha='center', va='center', fontsize=10)
plt.colorbar(im, ax=ax, shrink=0.8)

# Attention weights
ax = axes[2]
im = ax.imshow(attention_weights, cmap='YlOrRd', vmin=0, vmax=1)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
ax.set_title('Attention Weights (softmax)', fontsize=13, fontweight='bold')
ax.set_xlabel('Key (attending to)')
ax.set_ylabel('Query (from)')
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, f'{attention_weights[i,j]:.2f}', ha='center', va='center', fontsize=10)
plt.colorbar(im, ax=ax, shrink=0.8)

plt.suptitle('From Raw Scores to Attention Weights', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Step 4: Compute the output as weighted sum of values
output = attention_weights @ V  # (seq_len, d_k)

print("Output = Attention_Weights @ V")
print(f"Shape: {output.shape}")
print("\nEach output token is a weighted combination of all value vectors:")
for i, token in enumerate(tokens):
    contributions = [f"{attention_weights[i,j]:.2f}*V({tokens[j]})" for j in range(len(tokens))]
    print(f"  output({token}) = {' + '.join(contributions)}")
    print(f"            = {output[i].round(3)}")

## Part 4: Complete Scaled Dot-Product Attention

Let's wrap this into a clean function.

In [None]:
def scaled_dot_product_attention_numpy(Q, K, V, mask=None):
    """
    Scaled dot-product attention in NumPy.
    
    Args:
        Q: Queries (seq_len, d_k)
        K: Keys (seq_len, d_k)
        V: Values (seq_len, d_v)
        mask: Optional mask (seq_len, seq_len)
    
    Returns:
        output: (seq_len, d_v)
        attention_weights: (seq_len, seq_len)
    """
    d_k = Q.shape[-1]
    
    # Step 1: Compute scores
    scores = Q @ K.T / np.sqrt(d_k)
    
    # Step 2: Apply mask (if any)
    if mask is not None:
        scores = np.where(mask, scores, -1e9)
    
    # Step 3: Softmax
    attention_weights = softmax(scores)
    
    # Step 4: Weighted sum of values
    output = attention_weights @ V
    
    return output, attention_weights

# Test it
output_np, weights_np = scaled_dot_product_attention_numpy(Q, K, V)
print(f"Output shape: {output_np.shape}")
print(f"Weights shape: {weights_np.shape}")
print(f"Weights sum per row: {weights_np.sum(axis=-1).round(3)}")

## Part 5: Causal Masking for Autoregressive Models

In autoregressive models (GPT, Llama, etc.), each token can only attend to tokens **before it** (including itself). This prevents the model from "seeing the future" during both training and inference.

We achieve this by masking out (setting to $-\infty$) all positions where `j > i` in the attention score matrix.

In [None]:
# Create a causal mask
causal_mask = np.tril(np.ones((seq_len, seq_len), dtype=bool))

print("Causal mask (True = can attend, False = masked):")
print(f"{'':8s}", end="")
for t in tokens:
    print(f"{t:>8s}", end="")
print()
for i, t in enumerate(tokens):
    print(f"{t:8s}", end="")
    for j in range(len(tokens)):
        symbol = "  YES " if causal_mask[i,j] else "   -- "
        print(f"{symbol:>8s}", end="")
    print()

print("\n'The' can only see itself.")
print("'cat' can see 'The' and itself.")
print("'sat' can see 'The', 'cat', and itself.")
print("'down' can see everything.")

In [None]:
# Compute causal attention
output_causal, weights_causal = scaled_dot_product_attention_numpy(Q, K, V, mask=causal_mask)

# Compare causal vs non-causal attention weights
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Causal mask
ax = axes[0]
im = ax.imshow(causal_mask.astype(float), cmap='RdYlGn', vmin=0, vmax=1)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
ax.set_title('Causal Mask', fontsize=13, fontweight='bold')
ax.set_xlabel('Key position')
ax.set_ylabel('Query position')
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, 'OK' if causal_mask[i,j] else 'X', 
               ha='center', va='center', fontsize=11, fontweight='bold')

# Non-causal attention
ax = axes[1]
im = ax.imshow(weights_np, cmap='YlOrRd', vmin=0, vmax=1)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
ax.set_title('Bidirectional Attention', fontsize=13, fontweight='bold')
ax.set_xlabel('Key (attending to)')
ax.set_ylabel('Query (from)')
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, f'{weights_np[i,j]:.2f}', ha='center', va='center', fontsize=10)
plt.colorbar(im, ax=ax, shrink=0.8)

# Causal attention
ax = axes[2]
im = ax.imshow(weights_causal, cmap='YlOrRd', vmin=0, vmax=1)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
ax.set_title('Causal Attention (autoregressive)', fontsize=13, fontweight='bold')
ax.set_xlabel('Key (attending to)')
ax.set_ylabel('Query (from)')
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, f'{weights_causal[i,j]:.2f}', ha='center', va='center', fontsize=10)
plt.colorbar(im, ax=ax, shrink=0.8)

plt.suptitle('Bidirectional vs Causal (Autoregressive) Attention', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 6: Multi-Head Attention

A single attention head can only focus on one type of relationship at a time. **Multi-head attention** runs multiple attention operations in parallel, each with different learned projections.

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W_O$$

Where each head is: $\text{head}_i = \text{Attention}(XW_Q^i, XW_K^i, XW_V^i)$

In [None]:
class MultiHeadAttentionNumPy:
    """Multi-head attention implemented in NumPy."""
    
    def __init__(self, d_model, num_heads):
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Initialize projection matrices for all heads at once
        scale = np.sqrt(2.0 / (d_model + self.d_k))
        self.W_Q = np.random.randn(d_model, d_model) * scale
        self.W_K = np.random.randn(d_model, d_model) * scale
        self.W_V = np.random.randn(d_model, d_model) * scale
        self.W_O = np.random.randn(d_model, d_model) * scale
    
    def forward(self, X, causal=True):
        seq_len = X.shape[0]
        
        # Project to Q, K, V
        Q = X @ self.W_Q  # (seq_len, d_model)
        K = X @ self.W_K
        V = X @ self.W_V
        
        # Reshape to separate heads: (seq_len, d_model) -> (num_heads, seq_len, d_k)
        Q = Q.reshape(seq_len, self.num_heads, self.d_k).transpose(1, 0, 2)
        K = K.reshape(seq_len, self.num_heads, self.d_k).transpose(1, 0, 2)
        V = V.reshape(seq_len, self.num_heads, self.d_k).transpose(1, 0, 2)
        
        # Compute attention for each head
        all_weights = []
        all_outputs = []
        
        mask = np.tril(np.ones((seq_len, seq_len), dtype=bool)) if causal else None
        
        for h in range(self.num_heads):
            out, weights = scaled_dot_product_attention_numpy(Q[h], K[h], V[h], mask=mask)
            all_weights.append(weights)
            all_outputs.append(out)
        
        # Concatenate heads: (num_heads, seq_len, d_k) -> (seq_len, d_model)
        concat = np.concatenate(all_outputs, axis=-1)  # (seq_len, d_model)
        
        # Final projection
        output = concat @ self.W_O
        
        return output, all_weights

# Create multi-head attention with 4 heads
mha = MultiHeadAttentionNumPy(d_model=8, num_heads=4)
output_mha, head_weights = mha.forward(X, causal=True)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output_mha.shape}")
print(f"Number of attention heads: {len(head_weights)}")
print(f"Each head's weight shape: {head_weights[0].shape}")

In [None]:
# Visualize attention patterns across all heads
fig, axes = plt.subplots(1, 4, figsize=(18, 4.5))

for h, (ax, weights) in enumerate(zip(axes, head_weights)):
    im = ax.imshow(weights, cmap='YlOrRd', vmin=0, vmax=1)
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, fontsize=10)
    ax.set_yticklabels(tokens, fontsize=10)
    ax.set_title(f'Head {h+1}', fontsize=13, fontweight='bold')
    ax.set_xlabel('Key')
    if h == 0:
        ax.set_ylabel('Query')
    
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            ax.text(j, i, f'{weights[i,j]:.2f}', ha='center', va='center', fontsize=9)

plt.suptitle('Attention Patterns Across Heads (Causal)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Each head learns to focus on different relationships:")
print("- Some heads might focus on the immediately previous token")
print("- Some might focus on specific syntactic relationships")
print("- Some might attend broadly across the whole context")

## Part 7: PyTorch Implementation

Let's implement the same thing in PyTorch and verify it matches.

In [None]:
class ScaledDotProductAttentionPyTorch(nn.Module):
    """Scaled dot-product attention in PyTorch."""
    
    def forward(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        
        # (batch, heads, seq_len, d_k) @ (batch, heads, d_k, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

# Test
attn = ScaledDotProductAttentionPyTorch()

Q_torch = torch.tensor(Q, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Add batch & head dims
K_torch = torch.tensor(K, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
V_torch = torch.tensor(V, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Causal mask for PyTorch
mask_torch = torch.tril(torch.ones(seq_len, seq_len))

output_torch, weights_torch = attn(Q_torch, K_torch, V_torch, mask=mask_torch)

# Compare with NumPy implementation
print("Comparing NumPy vs PyTorch implementations:")
print(f"Weights match: {np.allclose(weights_causal, weights_torch.squeeze().numpy(), atol=1e-5)}")
print(f"Output match:  {np.allclose(output_causal, output_torch.squeeze().numpy(), atol=1e-5)}")

In [None]:
# Full Multi-Head Attention in PyTorch
class MultiHeadAttentionPyTorch(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x, causal=True):
        batch_size, seq_len, _ = x.shape
        
        # Linear projections
        Q = self.W_q(x)  # (batch, seq, d_model)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Reshape: (batch, seq, d_model) -> (batch, num_heads, seq, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        
        # Causal mask
        if causal:
            mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax and weighted sum
        weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(weights, V)
        
        # Reshape back: (batch, num_heads, seq, d_k) -> (batch, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Final projection
        output = self.W_o(attn_output)
        
        return output, weights

# Test with a realistic-sized model
d_model = 512
num_heads = 8
seq_len_test = 64

mha_torch = MultiHeadAttentionPyTorch(d_model, num_heads)
x_test = torch.randn(1, seq_len_test, d_model)

with torch.no_grad():
    output, weights = mha_torch(x_test)

print(f"Input: {x_test.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights: {weights.shape} (batch, heads, seq, seq)")
print(f"\nParameters:")
for name, param in mha_torch.named_parameters():
    print(f"  {name}: {param.shape} ({param.numel():,} params)")
total = sum(p.numel() for p in mha_torch.parameters())
print(f"  Total: {total:,} parameters")

## Part 8: Visualizing Attention in a Real Model

Let's look at actual attention patterns from GPT-2.

In [None]:
!pip install transformers -q

In [None]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModel.from_pretrained('gpt2', output_attentions=True)
model.eval()

text = "The cat sat on the mat and looked at the dog"
inputs = tokenizer(text, return_tensors='pt')
token_labels = [tokenizer.decode([t]) for t in inputs['input_ids'][0]]

with torch.no_grad():
    outputs = model(**inputs)

# outputs.attentions is a tuple of (n_layers) tensors, each (batch, heads, seq, seq)
attentions = outputs.attentions
print(f"Number of layers: {len(attentions)}")
print(f"Attention shape per layer: {attentions[0].shape}")
print(f"Tokens: {token_labels}")

In [None]:
# Visualize attention patterns from Layer 0 (first 4 heads)
layer_idx = 0
attn_layer = attentions[layer_idx][0]  # Remove batch dim

fig, axes = plt.subplots(2, 4, figsize=(20, 10))

for h in range(8):
    ax = axes[h // 4, h % 4]
    weights = attn_layer[h].numpy()
    
    im = ax.imshow(weights, cmap='YlOrRd', vmin=0)
    ax.set_xticks(range(len(token_labels)))
    ax.set_yticks(range(len(token_labels)))
    ax.set_xticklabels(token_labels, rotation=45, ha='right', fontsize=8)
    ax.set_yticklabels(token_labels, fontsize=8)
    ax.set_title(f'Head {h+1}', fontsize=11, fontweight='bold')

plt.suptitle(f'GPT-2 Layer {layer_idx+1} Attention Patterns\n"{text}"', 
            fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Show attention across layers for head 0
fig, axes = plt.subplots(3, 4, figsize=(20, 14))

for layer_idx in range(12):
    ax = axes[layer_idx // 4, layer_idx % 4]
    weights = attentions[layer_idx][0, 0].numpy()  # batch=0, head=0
    
    im = ax.imshow(weights, cmap='YlOrRd', vmin=0)
    ax.set_xticks(range(len(token_labels)))
    ax.set_yticks(range(len(token_labels)))
    ax.set_xticklabels(token_labels, rotation=45, ha='right', fontsize=7)
    ax.set_yticklabels(token_labels, fontsize=7)
    ax.set_title(f'Layer {layer_idx+1}, Head 1', fontsize=10, fontweight='bold')

plt.suptitle(f'GPT-2 Attention Across All 12 Layers (Head 1)\n"{text}"', 
            fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Observation: Early layers tend to have diffuse attention patterns.")
print("Later layers develop more specialized, focused patterns.")

## Part 9: Compute and Memory Costs of Attention

Understanding attention's compute and memory costs is crucial for inference engineering.

In [None]:
def attention_costs(seq_len, d_model, num_heads, batch_size=1, dtype_bytes=2):
    """Calculate compute and memory costs of attention."""
    d_k = d_model // num_heads
    
    # FLOPs for QKV projection: 3 matmuls of (B*S, d_model) @ (d_model, d_model)
    qkv_flops = 3 * 2 * batch_size * seq_len * d_model * d_model
    
    # FLOPs for attention scores: B*H matmuls of (S, d_k) @ (d_k, S)
    score_flops = 2 * batch_size * num_heads * seq_len * d_k * seq_len
    
    # FLOPs for attention output: B*H matmuls of (S, S) @ (S, d_k)
    attn_out_flops = 2 * batch_size * num_heads * seq_len * seq_len * d_k
    
    # FLOPs for output projection: (B*S, d_model) @ (d_model, d_model)
    out_proj_flops = 2 * batch_size * seq_len * d_model * d_model
    
    # Memory for attention scores matrix: B * H * S * S
    attn_score_memory = batch_size * num_heads * seq_len * seq_len * dtype_bytes
    
    # Memory for KV cache (just K and V): 2 * B * H * S * d_k
    kv_cache_memory = 2 * batch_size * num_heads * seq_len * d_k * dtype_bytes
    
    return {
        'QKV Projection': qkv_flops,
        'Attention Scores': score_flops,
        'Attention x Values': attn_out_flops,
        'Output Projection': out_proj_flops,
        'Total FLOPs': qkv_flops + score_flops + attn_out_flops + out_proj_flops,
        'Attention Score Memory': attn_score_memory,
        'KV Cache Memory': kv_cache_memory,
    }

# Analyze for Llama-7B scale
configs = [
    ('Short (256)', 256),
    ('Medium (2048)', 2048),
    ('Long (8192)', 8192),
    ('Very Long (32768)', 32768),
]

d_model = 4096
num_heads = 32

print(f"Attention costs for d_model={d_model}, num_heads={num_heads}, FP16")
print("=" * 80)

for name, seq_len in configs:
    costs = attention_costs(seq_len, d_model, num_heads)
    print(f"\n{name} (seq_len={seq_len}):")
    print(f"  Total FLOPs:             {costs['Total FLOPs']/1e9:10.1f} GFLOPs")
    print(f"  Attention Score Memory:  {costs['Attention Score Memory']/1e6:10.1f} MB")
    print(f"  KV Cache Memory:         {costs['KV Cache Memory']/1e6:10.1f} MB")

In [None]:
# Visualize how attention costs scale with sequence length
seq_lengths = np.array([128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768])

qkv_flops = []
attn_flops = []
kv_memory = []
score_memory = []

for sl in seq_lengths:
    costs = attention_costs(sl, 4096, 32)
    qkv_flops.append(costs['QKV Projection'] / 1e9)
    attn_flops.append((costs['Attention Scores'] + costs['Attention x Values']) / 1e9)
    kv_memory.append(costs['KV Cache Memory'] / 1e6)
    score_memory.append(costs['Attention Score Memory'] / 1e6)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# FLOPs scaling
ax1.plot(seq_lengths, qkv_flops, 'o-', label='QKV + Output Proj (linear in S)', color='#4ECDC4', linewidth=2)
ax1.plot(seq_lengths, attn_flops, 's-', label='Attention Scores (quadratic in S)', color='#FF6B6B', linewidth=2)
ax1.set_xlabel('Sequence Length', fontsize=12)
ax1.set_ylabel('GFLOPs', fontsize=12)
ax1.set_title('Attention FLOPs vs Sequence Length', fontsize=13, fontweight='bold')
ax1.legend(fontsize=10)
ax1.set_xscale('log', base=2)
ax1.set_yscale('log')
ax1.grid(True, alpha=0.3)

# Memory scaling
ax2.plot(seq_lengths, kv_memory, 'o-', label='KV Cache (linear in S)', color='#4ECDC4', linewidth=2)
ax2.plot(seq_lengths, score_memory, 's-', label='Attention Scores (quadratic in S)', color='#FF6B6B', linewidth=2)
ax2.set_xlabel('Sequence Length', fontsize=12)
ax2.set_ylabel('Memory (MB)', fontsize=12)
ax2.set_title('Attention Memory vs Sequence Length', fontsize=13, fontweight='bold')
ax2.legend(fontsize=10)
ax2.set_xscale('log', base=2)
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3)

plt.suptitle('The Quadratic Cost of Attention (per layer, Llama-7B scale)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("At short sequences, QKV projections dominate compute (linear matmuls).")
print("At long sequences, attention score computation dominates (quadratic).")
print("This is why long-context inference requires special techniques (FlashAttention, etc.).")

## Part 10: Arithmetic Intensity of Attention Operations

Let's analyze whether attention operations are compute-bound or memory-bound.

In [None]:
def attention_arithmetic_intensity(seq_len, d_model, num_heads, batch_size=1):
    """Calculate arithmetic intensity for different attention sub-operations."""
    d_k = d_model // num_heads
    
    results = {}
    
    # QKV Projection: (B*S, d_model) @ (d_model, d_model)
    M, K, N = batch_size * seq_len, d_model, d_model
    flops = 2 * M * K * N
    bytes_accessed = (M * K + K * N + M * N) * 2  # FP16
    results['QKV Projection'] = flops / bytes_accessed
    
    # Attention Scores (per head): (S, d_k) @ (d_k, S)
    M, K, N = seq_len, d_k, seq_len
    flops = 2 * M * K * N
    bytes_accessed = (M * K + K * N + M * N) * 2
    results['Attention Scores'] = flops / bytes_accessed
    
    # Attention x Values (per head): (S, S) @ (S, d_k)
    M, K, N = seq_len, seq_len, d_k
    flops = 2 * M * K * N
    bytes_accessed = (M * K + K * N + M * N) * 2
    results['Attn x Values'] = flops / bytes_accessed
    
    return results

# Compare across sequence lengths
print(f"Arithmetic Intensity (ops/byte) for d_model=4096, heads=32")
print(f"{'Seq Len':>10s} {'QKV Proj':>12s} {'Attn Scores':>14s} {'Attn x V':>12s}")
print("-" * 55)

gpu_ridge = 156  # A100
for sl in [1, 8, 64, 256, 1024, 4096]:
    ai = attention_arithmetic_intensity(sl, 4096, 32, batch_size=1)
    def fmt(v):
        bound = 'C' if v > gpu_ridge else 'M'
        return f"{v:.1f}({bound})"
    print(f"{sl:>10d} {fmt(ai['QKV Projection']):>12s} {fmt(ai['Attention Scores']):>14s} {fmt(ai['Attn x Values']):>12s}")

print(f"\nC=Compute-bound, M=Memory-bound (A100 ridge point: {gpu_ridge} ops/byte)")
print("\nKey insight: During decode (seq_len for current query = 1),")
print("ALL attention operations are heavily memory-bound!")

---

## Key Takeaways

1. **Attention is a soft lookup mechanism**: Queries ask for information, Keys advertise information, Values provide it. The dot product between Q and K determines how much each token attends to every other token.

2. **The scaling factor $1/\sqrt{d_k}$** prevents softmax from becoming too peaked, which would cause vanishing gradients.

3. **Causal masking** ensures autoregressive models can't see the future. Each token can only attend to tokens at the same or earlier positions.

4. **Multi-head attention** runs multiple attention operations in parallel with different projections. Each head can learn to focus on different types of relationships.

5. **Attention cost scales quadratically** with sequence length ($O(S^2)$ for score computation and memory). This is the fundamental bottleneck for long-context inference.

6. **During decode, attention is memory-bound** because the batch dimension is tiny (generating one token at a time). This is why KV cache optimization is critical.

7. **Real model attention patterns** vary across layers and heads. Early layers tend to have diffuse patterns; later layers develop more specialized focus.

---

**Next notebook:** We'll explore the KV cache -- the crucial optimization that prevents recomputing attention from scratch at each decode step.