# Day 17: Multi-Head Attention & Positional Encoding

**Building LLMs from Scratch** — Following Andrej Karpathy's makemore lectures.

---

## 1. Introduction

A single attention head can only learn **one type of relationship** between tokens — e.g., "what comes before me syntactically". Real language has many simultaneous relationships: grammatical agreement, semantic similarity, coreference, positional proximity.

**Multi-head attention** runs $h$ attention heads in **parallel**, each learning different relationship patterns, then concatenates their outputs:

$$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O$$

where $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$.

**Positional encoding** solves the fact that attention is **permutation-invariant** — without it, "dog bites man" and "man bites dog" would produce identical representations.

## 2. Multi-Head Attention

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math

torch.manual_seed(42)

class Head(nn.Module):
    """Single attention head."""
    def __init__(self, embed_dim, head_size, block_size, dropout=0.1):
        super().__init__()
        self.key   = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        self.head_size = head_size

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        scores = q @ k.transpose(-2, -1) * (self.head_size ** -0.5)
        scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = self.dropout(F.softmax(scores, dim=-1))
        return weights @ v, weights


class MultiHeadAttention(nn.Module):
    """h parallel attention heads, then project back to embed_dim."""
    def __init__(self, embed_dim, num_heads, block_size, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0
        head_size = embed_dim // num_heads
        self.heads = nn.ModuleList([
            Head(embed_dim, head_size, block_size, dropout)
            for _ in range(num_heads)
        ])
        self.proj = nn.Linear(embed_dim, embed_dim)   # output projection
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Run all heads in parallel
        head_outs, all_weights = zip(*[h(x) for h in self.heads])
        out = torch.cat(head_outs, dim=-1)  # (B, T, embed_dim)
        out = self.dropout(self.proj(out))
        return out, list(all_weights)


# Test
embed_dim, num_heads, block_size = 32, 4, 16
mha = MultiHeadAttention(embed_dim, num_heads, block_size)
x = torch.randn(2, 8, embed_dim)
out, weights = mha(x)
print(f"Input:   {x.shape}")
print(f"Output:  {out.shape}  (same as input — embed_dim preserved)")
print(f"Each head size: {embed_dim // num_heads}  ({num_heads} heads x {embed_dim // num_heads} = {embed_dim})")

## 3. What Different Heads Learn

In [None]:
# Visualize all 4 heads' attention patterns on the same input
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, (w, ax) in enumerate(zip(weights, axes)):
    im = ax.imshow(w[0].detach().numpy(), cmap='Blues', vmin=0, vmax=1)
    ax.set_title(f'Head {i+1}')
    ax.set_xlabel('Key position')
    if i == 0:
        ax.set_ylabel('Query position')
    plt.colorbar(im, ax=ax)

plt.suptitle('Multi-Head Attention — Each Head Learns Different Patterns', y=1.02)
plt.tight_layout()
plt.show()
print("Each head attends to different positions — this is the power of multi-head attention.")

## 4. Positional Encoding — Why We Need It

Self-attention treats the input as a **set**, not a sequence — "dog bites man" = "man bites dog" without positional info. We add a positional signal to each token embedding.

In [None]:
# Demonstrate permutation-invariance problem
embed_dim = 8
vocab_size = 5
emb = nn.Embedding(vocab_size, embed_dim)

# "dog bites man" = tokens [1, 2, 3]
# "man bites dog" = tokens [3, 2, 1]
seq1 = torch.tensor([[1, 2, 3]])  # dog bites man
seq2 = torch.tensor([[3, 2, 1]])  # man bites dog

e1 = emb(seq1)  # (1, 3, 8)
e2 = emb(seq2)  # (1, 3, 8)

# Sum of embeddings is same regardless of order
print(f"Sum of embeddings (dog bites man): {e1.sum(dim=1)[0].tolist()}")
print(f"Sum of embeddings (man bites dog): {e2.sum(dim=1)[0].tolist()}")
print(f"Same? {torch.allclose(e1.sum(1), e2.sum(1))}  <- problem!")

## 5. Learned vs Sinusoidal Positional Encoding

In [None]:
# Option A: Learned positional embedding (GPT-style)
class LearnedPosEncoding(nn.Module):
    def __init__(self, block_size, embed_dim):
        super().__init__()
        self.pos_emb = nn.Embedding(block_size, embed_dim)

    def forward(self, x):
        B, T, C = x.shape
        pos = torch.arange(T, device=x.device)  # [0, 1, ..., T-1]
        return x + self.pos_emb(pos)             # broadcast over batch


# Option B: Sinusoidal positional encoding (original Transformer)
def sinusoidal_pos_encoding(T, embed_dim):
    """Fixed sinusoidal encoding: PE[pos, 2i] = sin(pos/10000^(2i/d))"""
    pe = torch.zeros(T, embed_dim)
    pos = torch.arange(T).unsqueeze(1).float()           # (T, 1)
    div = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000) / embed_dim))
    pe[:, 0::2] = torch.sin(pos * div)
    pe[:, 1::2] = torch.cos(pos * div[:embed_dim//2])
    return pe


# Visualize sinusoidal encoding
T, d = 50, 64
pe = sinusoidal_pos_encoding(T, d)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

im = axes[0].imshow(pe.numpy().T, aspect='auto', cmap='RdBu', vmin=-1, vmax=1)
axes[0].set_xlabel('Position')
axes[0].set_ylabel('Dimension')
axes[0].set_title('Sinusoidal Positional Encoding\n(each row = one dimension, alternating sin/cos)')
plt.colorbar(im, ax=axes[0])

# Show a few dimensions
for dim in [0, 1, 4, 10, 20]:
    axes[1].plot(pe[:, dim].numpy(), label=f'dim {dim}')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Encoding value')
axes[1].set_title('Sinusoidal Encoding per Dimension\n(low dims = high freq, high dims = low freq)')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

## 6. Positional Encoding Makes Order Matter

In [None]:
# With positional encoding, dog bites man ≠ man bites dog
T = 3
pos_enc = sinusoidal_pos_encoding(T, embed_dim)  # (3, 8)

e1_pos = e1 + pos_enc  # dog(pos0) + bites(pos1) + man(pos2)
e2_pos = e2 + pos_enc  # man(pos0) + bites(pos1) + dog(pos2)

print(f"With positional encoding:")
print(f"Sum (dog bites man): {e1_pos.sum(dim=1)[0].tolist()}")
print(f"Sum (man bites dog): {e2_pos.sum(dim=1)[0].tolist()}")
print(f"Same? {torch.allclose(e1_pos.sum(1), e2_pos.sum(1))}  <- fixed!")

# Similarity between positions: positions close together should be similar
pe = sinusoidal_pos_encoding(20, 64)
sim = pe @ pe.T  # cosine-like similarity

plt.figure(figsize=(7, 6))
plt.imshow(sim.numpy(), cmap='viridis')
plt.colorbar()
plt.title('Positional Encoding Similarity\n(nearby positions are more similar)')
plt.xlabel('Position j')
plt.ylabel('Position i')
plt.tight_layout()
plt.show()

## 7. Efficient Multi-Head Attention (Batched)

The parallel head implementation above runs each head sequentially in Python. The standard efficient approach reshapes Q, K, V to run all heads **in one batched matrix multiply**.

In [None]:
class EfficientMultiHeadAttention(nn.Module):
    """Efficient MHA: all heads computed in a single batched matmul."""
    def __init__(self, embed_dim, num_heads, block_size, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_size = embed_dim // num_heads

        self.qkv  = nn.Linear(embed_dim, 3 * embed_dim, bias=False)  # Q, K, V in one shot
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        H, hs = self.num_heads, self.head_size

        # Project to Q, K, V all at once
        qkv = self.qkv(x)                              # (B, T, 3*C)
        q, k, v = qkv.split(self.embed_dim, dim=2)     # each (B, T, C)

        # Reshape to (B, H, T, head_size)
        q = q.view(B, T, H, hs).transpose(1, 2)
        k = k.view(B, T, H, hs).transpose(1, 2)
        v = v.view(B, T, H, hs).transpose(1, 2)

        # Attention scores: (B, H, T, T)
        scores = q @ k.transpose(-2, -1) * (hs ** -0.5)
        scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = self.dropout(F.softmax(scores, dim=-1))

        # Aggregate: (B, H, T, hs) -> (B, T, C)
        out = (weights @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(out)


# Verify same output as naive version (with same weights)
emha = EfficientMultiHeadAttention(embed_dim=32, num_heads=4, block_size=16)
x = torch.randn(2, 8, 32)
out = emha(x)
print(f"Efficient MHA output: {out.shape}")
print(f"Parameters: {sum(p.numel() for p in emha.parameters()):,}")

# Speed comparison
import time
x_large = torch.randn(32, 64, 128)
emha_large = EfficientMultiHeadAttention(128, 8, 64)
mha_large  = MultiHeadAttention(128, 8, 64)

t0 = time.time()
for _ in range(100): emha_large(x_large)
t1 = time.time()
for _ in range(100): mha_large(x_large)
t2 = time.time()

print(f"\nEfficient (batched): {(t1-t0)*1000:.1f} ms for 100 forward passes")
print(f"Naive (loop):        {(t2-t1)*1000:.1f} ms for 100 forward passes")
print(f"Speedup: {(t2-t1)/(t1-t0):.1f}x")

---

**Building LLMs from Scratch** — [Day 17: Multi-Head Attention](https://omkarray.com/llm-day17.html) | [← Prev](llm_day16_self_attention.ipynb) | [Next →](llm_day18_transformer_block.ipynb)