# Day 16: Self-Attention — The Core of Transformers

**Building LLMs from Scratch** — Following Andrej Karpathy's makemore lectures.

---

## 1. Introduction

Every architecture so far (MLP, WaveNet) processes the context window in a **fixed, position-blind** way. A token at position 5 has no idea which other tokens are most relevant to it.

**Self-attention** solves this: each token **queries** all other tokens to find relevant ones, then **aggregates** their information weighted by relevance.

The mechanism:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

- $Q$ (queries): what this token is looking for
- $K$ (keys): what each token has to offer  
- $V$ (values): the actual content to aggregate
- $\sqrt{d_k}$: scaling to prevent softmax saturation

For language modeling, we use **causal (masked) self-attention**: token $i$ can only attend to tokens $\leq i$ (no peeking at the future).

## 2. The Intuition: Weighted Aggregation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(42)

# Start simple: uniform averaging ("bag of words" baseline)
B, T, C = 2, 6, 8  # batch=2, seq_len=6, channels=8
x = torch.randn(B, T, C)

# Version 1: simple loop — each position averages all PAST tokens
xbow = torch.zeros_like(x)  # "bag of words"
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]        # (t+1, C)
        xbow[b, t] = xprev.mean(0) # average over past

print("xbow[0, 3] (avg of tokens 0-3):", xbow[0, 3])
print("Manual check:", x[0, :4].mean(0))
print("Match:", torch.allclose(xbow[0, 3], x[0, :4].mean(0)))

## 3. Efficient Averaging via Lower-Triangular Mask

In [None]:
# Version 2: matrix multiply trick — no loop needed
wei = torch.tril(torch.ones(T, T))  # lower triangular
wei = wei / wei.sum(dim=1, keepdim=True)  # normalize rows

xbow2 = wei @ x  # (T, T) @ (B, T, C) broadcasts to (B, T, C)

print("Weight matrix (T=6 x T=6):")
print(wei.round(decimals=2))
print(f"\nxbow2 matches xbow: {torch.allclose(xbow, xbow2)}")

fig, ax = plt.subplots(figsize=(5, 4))
ax.imshow(wei.numpy(), cmap='Blues')
ax.set_title('Causal Attention Mask (uniform weights)')
ax.set_xlabel('Key position')
ax.set_ylabel('Query position')
plt.colorbar(ax.images[0])
plt.tight_layout()
plt.show()

## 4. Softmax Masking — Data-Dependent Weights

Uniform averaging treats all past tokens equally. Self-attention makes the weights **data-dependent** via Q·K dot products.

In [None]:
# Version 3: softmax masking (step toward real attention)
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))  # future tokens -> -inf
wei = F.softmax(wei, dim=-1)                     # softmax -> uniform over past

print("Masked softmax weights (uniform, since all zeros before masking):")
print(wei.round(decimals=3))

## 5. Single-Head Self-Attention from Scratch

In [None]:
class SingleHeadAttention(nn.Module):
    """One head of causal self-attention."""
    def __init__(self, embed_dim, head_size):
        super().__init__()
        self.key   = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        # Register causal mask as buffer (not a parameter)
        self.register_buffer('tril', torch.tril(torch.ones(256, 256)))
        self.head_size = head_size

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)    # (B, T, head_size)
        q = self.query(x)  # (B, T, head_size)
        v = self.value(x)  # (B, T, head_size)

        # Attention scores
        scale = self.head_size ** -0.5
        scores = q @ k.transpose(-2, -1) * scale  # (B, T, T)

        # Causal mask: zero out future
        scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(scores, dim=-1)        # (B, T, T)

        # Weighted aggregation
        out = weights @ v  # (B, T, head_size)
        return out, weights


# Test with a small example
embed_dim, head_size = 16, 8
attn = SingleHeadAttention(embed_dim, head_size)
x = torch.randn(2, 6, embed_dim)
out, weights = attn(x)

print(f"Input:   {x.shape}")
print(f"Output:  {out.shape}")
print(f"Weights: {weights.shape}")
print(f"\nAttention weights for batch=0 (causal = lower triangular):")
print(weights[0].detach().round(decimals=3))

## 6. Visualising Attention Patterns

In [None]:
# Visualize the attention weight matrix
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Single example attention weights
w = weights[0].detach().numpy()
im = axes[0].imshow(w, cmap='Blues', vmin=0, vmax=1)
axes[0].set_title('Self-Attention Weights\n(causal: upper triangle = 0)')
axes[0].set_xlabel('Key (which token to attend to)')
axes[0].set_ylabel('Query (which token is asking)')
plt.colorbar(im, ax=axes[0])

# Show how the weights change with temperature (scaling factor)
x_single = torch.randn(1, 6, embed_dim)
k = attn.key(x_single)
q = attn.query(x_single)
raw_scores = (q @ k.transpose(-2, -1))[0].detach()

tril = torch.tril(torch.ones(6, 6))
for scale_factor, label in [(0.1, 'No scaling (peaked)'), (head_size**-0.5, f'Scaled (1/√{head_size}), default)')]:
    pass  # just using the default weights plot

axes[1].bar(range(6), weights[0, 4].detach().numpy(), color='steelblue')
axes[1].set_title('Token 4 attends to tokens 0-4\n(what it finds most relevant)')
axes[1].set_xlabel('Token position')
axes[1].set_ylabel('Attention weight')
axes[1].set_xticks(range(6))

plt.tight_layout()
plt.show()

## 7. Why Scale by √d_k?

Without scaling, dot products grow with `head_size`. Large values push softmax into saturation — one token gets weight ≈ 1, all others ≈ 0. Gradients vanish.

In [None]:
# Demonstrate saturation effect
q = torch.randn(1, 6, 64)
k = torch.randn(1, 6, 64)

scores_unscaled = (q @ k.transpose(-2, -1))[0, 0].detach()
scores_scaled   = scores_unscaled / (64 ** 0.5)

w_unscaled = F.softmax(scores_unscaled, dim=-1)
w_scaled   = F.softmax(scores_scaled, dim=-1)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].bar(range(6), w_unscaled.numpy(), color='tomato')
axes[0].set_title(f'Unscaled: max = {w_unscaled.max().item():.3f}\n(nearly one-hot, vanishing gradients)')
axes[0].set_ylim(0, 1)

axes[1].bar(range(6), w_scaled.numpy(), color='steelblue')
axes[1].set_title(f'Scaled by 1/√64: max = {w_scaled.max().item():.3f}\n(spread out, healthy gradients)')
axes[1].set_ylim(0, 1)

plt.suptitle('Effect of √d_k Scaling on Attention Weights')
plt.tight_layout()
plt.show()

print(f"Unscaled entropy: {-(w_unscaled * w_unscaled.log()).sum():.4f}")
print(f"Scaled entropy:   {-(w_scaled * w_scaled.log()).sum():.4f}")
print("(Higher entropy = more spread out = better gradients)")

## 8. Building a Simple Attention-Based LM

One attention head + a feedforward layer = a baby transformer block.

In [None]:
class AttentionLM(nn.Module):
    """Minimal language model with a single self-attention head."""
    def __init__(self, vocab_size, embed_dim=32, head_size=32, block_size=8):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(block_size, embed_dim)   # learned positions
        self.attn    = SingleHeadAttention(embed_dim, head_size)
        self.proj    = nn.Linear(head_size, embed_dim)
        self.ff      = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        self.ln1  = nn.LayerNorm(embed_dim)
        self.ln2  = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)
        self.block_size = block_size

    def forward(self, idx):
        B, T = idx.shape
        tok = self.tok_emb(idx)                          # (B, T, C)
        pos = self.pos_emb(torch.arange(T))              # (T, C)
        x = tok + pos                                    # broadcast
        # Attention with residual
        attn_out, _ = self.attn(self.ln1(x))
        x = x + self.proj(attn_out)
        # FFN with residual
        x = x + self.ff(self.ln2(x))
        return self.head(x)                              # (B, T, vocab_size)


# Quick test
vocab_size = 27  # letters + special
lm = AttentionLM(vocab_size=vocab_size)
idx = torch.randint(0, vocab_size, (2, 8))
logits = lm(idx)
print(f"Input shape:  {idx.shape}")
print(f"Output shape: {logits.shape}  (B, T, vocab_size)")
print(f"Parameters:   {sum(p.numel() for p in lm.parameters()):,}")

---

**Building LLMs from Scratch** — [Day 16: Self-Attention](https://omkarray.com/llm-day16.html) | [← Prev](llm_day15_wavenet.ipynb) | [Next →](llm_day17_multihead_attention.ipynb)