# Attention Mechanisms: Mathematical Derivations

This notebook derives the forward and backward passes for attention mechanisms used in transformers,
providing complete mathematical proofs and implementation verification.

## Glossary of Terms

| Term | Definition |
|------|------------|
| **Attention** | A mechanism that computes weighted combinations of values based on query-key similarity. Allows the model to focus on relevant parts of the input. |
| **Query (Q)** | A vector representing "what am I looking for?" Derived from the current position's representation. Shape: $(T, d)$ or $(B, h, T, d)$ for multi-head. |
| **Key (K)** | A vector representing "what do I contain?" Used to compute compatibility with queries. Same shape as Q. |
| **Value (V)** | A vector representing "what information do I provide?" The actual content that gets aggregated based on attention weights. Same shape as Q. |
| **Attention scores** | Raw compatibility scores $S = QK^T$, measuring how well each query matches each key. Shape: $(T_q, T_{kv})$. |
| **Attention weights/probabilities** | Normalized scores after softmax: $P = \text{softmax}(S)$. Each row sums to 1, forming a probability distribution over keys. |
| **Scaled dot-product** | Dividing attention scores by $\sqrt{d}$ to prevent large values that saturate softmax. Critical for stable gradients. |
| **Multi-head attention (MHA)** | Running multiple attention operations in parallel with different learned projections, then concatenating results. Allows learning diverse attention patterns. |
| **Head** | One of $h$ parallel attention computations in MHA. Each head has dimension $d = D/h$. |
| **Causal mask** | A lower-triangular mask that prevents positions from attending to future positions. Essential for autoregressive generation (GPT-style models). |
| **Self-attention** | When Q, K, V all come from the same sequence. The sequence attends to itself. |
| **Cross-attention** | When Q comes from one sequence and K, V from another (e.g., in encoder-decoder models). |
| **Projection matrices** | Learned weight matrices $W^Q, W^K, W^V, W^O$ that transform inputs into queries, keys, values, and final outputs. |
| **Softmax saturation** | When softmax inputs are very large, outputs approach one-hot vectors with near-zero gradients. The $\sqrt{d}$ scaling prevents this. |
| **Jacobian** | Matrix of all partial derivatives. For softmax: $J_{ij} = \frac{\partial p_i}{\partial x_j}$. |

## Formulas and Theorems

### Core Attention Formulas

| Formula | Description |
|---------|-------------|
| $\text{softmax}(\mathbf{x})_i = \frac{e^{x_i}}{\sum_j e^{x_j}}$ | Softmax function |
| $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$ | Scaled dot-product attention |
| $S = \frac{QK^T}{\sqrt{d_k}}$ | Attention scores (scaled) |
| $P = \text{softmax}(S)$ | Attention probabilities |
| $O = PV$ | Attention output |

### Multi-Head Attention

| Formula | Description |
|---------|-------------|
| $Q_i = XW^Q_i$ | Query projection for head $i$ |
| $K_i = XW^K_i$ | Key projection for head $i$ |
| $V_i = XW^V_i$ | Value projection for head $i$ |
| $\text{head}_i = \text{Attention}(Q_i, K_i, V_i)$ | Single head output |
| $\text{MHA}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O$ | Multi-head attention |

### Backward Pass Formulas

| Formula | Description |
|---------|-------------|
| $\frac{\partial \text{softmax}_i}{\partial x_j} = \text{softmax}_i (\delta_{ij} - \text{softmax}_j)$ | Softmax Jacobian |
| $\frac{\partial L}{\partial S} = P \odot \left(\frac{\partial L}{\partial P} - \mathbf{1} \left(\frac{\partial L}{\partial P} \cdot P\right)^T\right)$ | Softmax backward (row-wise) |
| $\frac{\partial L}{\partial V} = P^T \frac{\partial L}{\partial O}$ | Gradient w.r.t. values |
| $\frac{\partial L}{\partial Q} = \frac{1}{\sqrt{d_k}} \frac{\partial L}{\partial S} K$ | Gradient w.r.t. queries |
| $\frac{\partial L}{\partial K} = \frac{1}{\sqrt{d_k}} \left(\frac{\partial L}{\partial S}\right)^T Q$ | Gradient w.r.t. keys |

### Causal Masking

| Formula | Description |
|---------|-------------|
| $M_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}$ | Causal mask (lower triangular) |
| $\text{softmax}(S + M)$ | Masked attention scores |

## Prerequisites

This notebook assumes familiarity with:

### 1. Matrix Multiplication and Transposes

For matrices $A \in \mathbb{R}^{m \times n}$ and $B \in \mathbb{R}^{n \times p}$:
- $(AB)_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}$
- $(AB)^T = B^T A^T$
- $\frac{\partial (XW)}{\partial X} = W^T$ (when computing gradient flow)
- $\frac{\partial (XW)}{\partial W} = X^T$ (when computing parameter gradients)

### 2. The Chain Rule for Matrices

If $L$ is a scalar loss and we have $Y = f(X)$ and $L = g(Y)$, then:
$$\frac{\partial L}{\partial X_{ij}} = \sum_{k,l} \frac{\partial L}{\partial Y_{kl}} \frac{\partial Y_{kl}}{\partial X_{ij}}$$

In practice, we express this using matrix operations. For $Y = XW$:
$$\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W^T$$

### 3. Softmax Function

The softmax function converts a vector of real numbers into a probability distribution:
$$\text{softmax}(\mathbf{x})_i = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$$

Properties:
- Output is always positive: $\text{softmax}(\mathbf{x})_i > 0$
- Outputs sum to 1: $\sum_i \text{softmax}(\mathbf{x})_i = 1$
- Translation invariant: $\text{softmax}(\mathbf{x} + c) = \text{softmax}(\mathbf{x})$

### 4. The Kronecker Delta

$$\delta_{ij} = \begin{cases} 1 & \text{if } i = j \\ 0 & \text{if } i \neq j \end{cases}$$

This notation is used in the softmax Jacobian derivation.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
np.set_printoptions(precision=4, suppress=True)

---

## Part 1: Softmax Derivation

Before deriving attention, we need to fully understand softmax, as it's the core nonlinearity in attention.

### Forward Pass

Given input vector $\mathbf{x} = [x_1, x_2, ..., x_n]$:

$$p_i = \text{softmax}(\mathbf{x})_i = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$$

For numerical stability, we subtract the maximum:

$$p_i = \frac{e^{x_i - \max(\mathbf{x})}}{\sum_{j=1}^{n} e^{x_j - \max(\mathbf{x})}}$$

This doesn't change the output (softmax is translation invariant) but prevents overflow.

In [None]:
def softmax(x):
    """Numerically stable softmax along last axis."""
    z = x - x.max(axis=-1, keepdims=True)  # Subtract max for stability
    e = np.exp(z)
    return e / e.sum(axis=-1, keepdims=True)

# Example
x = np.array([2.0, 1.0, 0.1])
p = softmax(x)
print(f"Input:  {x}")
print(f"Output: {p}")
print(f"Sum:    {p.sum():.6f} (should be 1.0)")

### Backward Pass: Deriving the Softmax Jacobian

We need to compute $\frac{\partial p_i}{\partial x_j}$ for all $i, j$.

Let $S = \sum_k e^{x_k}$. Then $p_i = \frac{e^{x_i}}{S}$.

**Case 1: When $i = j$ (diagonal elements)**

Using the quotient rule $\frac{d}{dx}\frac{u}{v} = \frac{u'v - uv'}{v^2}$:

$$\frac{\partial p_i}{\partial x_i} = \frac{e^{x_i} \cdot S - e^{x_i} \cdot e^{x_i}}{S^2} = \frac{e^{x_i}}{S} - \frac{e^{x_i}}{S} \cdot \frac{e^{x_i}}{S} = p_i - p_i^2 = p_i(1 - p_i)$$

**Case 2: When $i \neq j$ (off-diagonal elements)**

$$\frac{\partial p_i}{\partial x_j} = \frac{0 \cdot S - e^{x_i} \cdot e^{x_j}}{S^2} = -\frac{e^{x_i}}{S} \cdot \frac{e^{x_j}}{S} = -p_i p_j$$

**Combined using Kronecker delta:**

$$\frac{\partial p_i}{\partial x_j} = p_i(\delta_{ij} - p_j)$$

This can be written in matrix form as the Jacobian:
$$J = \text{diag}(\mathbf{p}) - \mathbf{p}\mathbf{p}^T$$

In [None]:
def softmax_jacobian(p):
    """Compute the full Jacobian matrix of softmax."""
    # J[i,j] = p[i] * (delta[i,j] - p[j])
    # = diag(p) - p @ p.T
    return np.diag(p) - np.outer(p, p)

# Verify the Jacobian
x = np.array([2.0, 1.0, 0.1])
p = softmax(x)
J = softmax_jacobian(p)

print("Softmax output p:", p)
print("\nJacobian matrix:")
print(J)
print(f"\nDiagonal (p_i * (1 - p_i)): {p * (1 - p)}")
print(f"Row sums (should be ~0):    {J.sum(axis=1)}")

### Efficient Softmax Backward

Given upstream gradient $\frac{\partial L}{\partial \mathbf{p}}$, we need $\frac{\partial L}{\partial \mathbf{x}}$.

By chain rule:
$$\frac{\partial L}{\partial x_j} = \sum_i \frac{\partial L}{\partial p_i} \frac{\partial p_i}{\partial x_j} = \sum_i \frac{\partial L}{\partial p_i} p_i (\delta_{ij} - p_j)$$

Expanding:
$$\frac{\partial L}{\partial x_j} = \frac{\partial L}{\partial p_j} p_j - p_j \sum_i \frac{\partial L}{\partial p_i} p_i$$

Let $\mathbf{g} = \frac{\partial L}{\partial \mathbf{p}}$ (upstream gradient). Then:
$$\frac{\partial L}{\partial x_j} = g_j p_j - p_j \sum_i g_i p_i = p_j \left(g_j - \sum_i g_i p_i\right)$$

In vector form:
$$\frac{\partial L}{\partial \mathbf{x}} = \mathbf{p} \odot \left(\mathbf{g} - (\mathbf{g} \cdot \mathbf{p}) \mathbf{1}\right)$$

where $\odot$ is element-wise multiplication and $\mathbf{g} \cdot \mathbf{p}$ is the dot product.

In [None]:
def softmax_backward(dp, p):
    """Efficient softmax backward pass.
    
    Args:
        dp: Upstream gradient (same shape as p)
        p: Softmax output from forward pass
    
    Returns:
        dx: Gradient w.r.t. input x
    """
    # dot = sum(dp * p) - the dot product
    dot = (dp * p).sum(axis=-1, keepdims=True)
    # dx = p * (dp - dot)
    return p * (dp - dot)

# Verify against Jacobian method
x = np.array([2.0, 1.0, 0.1])
p = softmax(x)
dp = np.array([0.5, -0.3, 0.2])  # Arbitrary upstream gradient

# Method 1: Full Jacobian (inefficient but clear)
J = softmax_jacobian(p)
dx_jacobian = J.T @ dp

# Method 2: Efficient formula
dx_efficient = softmax_backward(dp, p)

print(f"Via Jacobian:  {dx_jacobian}")
print(f"Via efficient: {dx_efficient}")
print(f"Match: {np.allclose(dx_jacobian, dx_efficient)}")

---

## Part 2: Scaled Dot-Product Attention

The fundamental attention operation introduced in "Attention Is All You Need" (Vaswani et al., 2017).

### The Intuition

Attention computes a weighted combination of **values** ($V$), where the weights are determined by how well **queries** ($Q$) match **keys** ($K$).

- **Query ($Q$)**: "What am I looking for?"
- **Key ($K$)**: "What do I contain?"
- **Value ($V$)**: "What information do I provide?"

### Forward Pass

Given:
- $Q \in \mathbb{R}^{T_q \times d}$ (queries)
- $K \in \mathbb{R}^{T_{kv} \times d}$ (keys)
- $V \in \mathbb{R}^{T_{kv} \times d}$ (values)

**Step 1: Compute attention scores**
$$S = \frac{QK^T}{\sqrt{d}}$$

Shape: $S \in \mathbb{R}^{T_q \times T_{kv}}$

The scaling by $\sqrt{d}$ is crucial. Without it, dot products grow with dimension, pushing softmax into saturation where gradients vanish.

**Step 2: Apply softmax to get attention probabilities**
$$P = \text{softmax}(S)$$

Each row of $P$ sums to 1, representing how much attention each query pays to each key.

**Step 3: Compute weighted combination of values**
$$O = PV$$

Shape: $O \in \mathbb{R}^{T_q \times d}$

In [None]:
def attention_forward(Q, K, V, mask=None):
    """Scaled dot-product attention forward pass.
    
    Args:
        Q: Queries (T_q, d)
        K: Keys (T_kv, d)
        V: Values (T_kv, d)
        mask: Optional additive mask (T_q, T_kv)
    
    Returns:
        O: Output (T_q, d)
        cache: (Q, K, V, P, d) for backward pass
    """
    T_q, d = Q.shape
    scale = 1.0 / np.sqrt(d)
    
    # Step 1: Scaled attention scores
    S = scale * (Q @ K.T)  # (T_q, T_kv)
    
    # Optional: Apply mask
    if mask is not None:
        S = S + mask
    
    # Step 2: Softmax to get probabilities
    P = softmax(S)  # (T_q, T_kv)
    
    # Step 3: Weighted combination of values
    O = P @ V  # (T_q, d)
    
    return O, (Q, K, V, P, d)

# Example
np.random.seed(42)
T, d = 4, 3
Q = np.random.randn(T, d)
K = np.random.randn(T, d)
V = np.random.randn(T, d)

O, cache = attention_forward(Q, K, V)
print(f"Q shape: {Q.shape}")
print(f"Output shape: {O.shape}")
print(f"\nAttention probabilities P (rows sum to 1):")
print(cache[3])
print(f"Row sums: {cache[3].sum(axis=1)}")

### Visualizing Attention Weights

In [None]:
_, (Q, K, V, P, d) = attention_forward(Q, K, V)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Attention weights
im = axes[0].imshow(P, cmap='Blues', vmin=0, vmax=1)
axes[0].set_xlabel('Key position')
axes[0].set_ylabel('Query position')
axes[0].set_title('Attention Probabilities P')
plt.colorbar(im, ax=axes[0])

# Raw scores (before softmax)
S = (Q @ K.T) / np.sqrt(d)
im = axes[1].imshow(S, cmap='RdBu', vmin=-2, vmax=2)
axes[1].set_xlabel('Key position')
axes[1].set_ylabel('Query position')
axes[1].set_title('Attention Scores S (before softmax)')
plt.colorbar(im, ax=axes[1])

plt.tight_layout()
plt.show()

### Why Scale by $\sqrt{d}$?

Consider the dot product $q \cdot k = \sum_{i=1}^{d} q_i k_i$.

If $q_i$ and $k_i$ are independent with mean 0 and variance 1:
- $E[q_i k_i] = 0$ (product of independent zero-mean RVs)
- $\text{Var}(q_i k_i) = E[q_i^2 k_i^2] = E[q_i^2] E[k_i^2] = 1$

By independence of the $d$ terms:
$$\text{Var}(q \cdot k) = d \cdot \text{Var}(q_i k_i) = d$$

So $\text{Std}(q \cdot k) = \sqrt{d}$. As $d$ grows, dot products become larger.

Large values push softmax into saturation:
- $\text{softmax}([10, 0, 0]) \approx [1, 0, 0]$ (nearly one-hot)
- Gradients of softmax near saturation are near zero!

Scaling by $\sqrt{d}$ normalizes the variance back to $O(1)$.

In [None]:
# Demonstrate the variance scaling
dims = [8, 64, 512, 2048]
n_samples = 10000

print("Dimension | Unscaled Std | Scaled Std | Expected")
print("-" * 50)
for d in dims:
    q = np.random.randn(n_samples, d)
    k = np.random.randn(n_samples, d)
    
    dots = (q * k).sum(axis=1)  # Dot products
    scaled_dots = dots / np.sqrt(d)
    
    print(f"{d:^9} | {dots.std():^12.2f} | {scaled_dots.std():^10.2f} | {np.sqrt(d):^8.2f}")

---

## Part 3: Attention Backward Pass

Now we derive the gradients. Given upstream gradient $\frac{\partial L}{\partial O}$, we need:
- $\frac{\partial L}{\partial Q}$
- $\frac{\partial L}{\partial K}$
- $\frac{\partial L}{\partial V}$

### Step 1: Gradient w.r.t. V

From $O = PV$, treating each row independently:
$$O_i = \sum_j P_{ij} V_j \quad \Rightarrow \quad \frac{\partial O_i}{\partial V_k} = P_{ik}$$

By chain rule:
$$\frac{\partial L}{\partial V_k} = \sum_i \frac{\partial L}{\partial O_i} P_{ik} = \sum_i P_{ik} \frac{\partial L}{\partial O_i}$$

In matrix form:
$$\boxed{\frac{\partial L}{\partial V} = P^T \frac{\partial L}{\partial O}}$$

### Step 2: Gradient w.r.t. P

From $O = PV$:
$$\frac{\partial L}{\partial P_{ij}} = \sum_k \frac{\partial L}{\partial O_{ik}} \frac{\partial O_{ik}}{\partial P_{ij}}$$

Since $O_{ik} = \sum_l P_{il} V_{lk}$, we have $\frac{\partial O_{ik}}{\partial P_{ij}} = V_{jk}$.

Therefore:
$$\frac{\partial L}{\partial P_{ij}} = \sum_k \frac{\partial L}{\partial O_{ik}} V_{jk}$$

In matrix form:
$$\boxed{\frac{\partial L}{\partial P} = \frac{\partial L}{\partial O} V^T}$$

### Step 3: Gradient through Softmax

Using our efficient softmax backward formula, for each row:
$$\frac{\partial L}{\partial S_i} = P_i \odot \left(\frac{\partial L}{\partial P_i} - \left(\frac{\partial L}{\partial P_i} \cdot P_i\right) \mathbf{1}\right)$$

In matrix form (applied row-wise):
$$\boxed{\frac{\partial L}{\partial S} = P \odot \left(\frac{\partial L}{\partial P} - \text{rowsum}\left(\frac{\partial L}{\partial P} \odot P\right)\right)}$$

where $\text{rowsum}$ means summing along each row and broadcasting back.

### Step 4: Gradient w.r.t. Q and K

Recall $S = \frac{1}{\sqrt{d}} Q K^T$.

For $Q$: $S_{ij} = \frac{1}{\sqrt{d}} \sum_k Q_{ik} K_{jk}$

$$\frac{\partial L}{\partial Q_{ik}} = \sum_{j} \frac{\partial L}{\partial S_{ij}} \frac{1}{\sqrt{d}} K_{jk}$$

In matrix form:
$$\boxed{\frac{\partial L}{\partial Q} = \frac{1}{\sqrt{d}} \frac{\partial L}{\partial S} K}$$

For $K$: By similar derivation (or using the symmetry of transposed multiplication):
$$\boxed{\frac{\partial L}{\partial K} = \frac{1}{\sqrt{d}} \left(\frac{\partial L}{\partial S}\right)^T Q}$$

In [None]:
def attention_backward(dO, cache):
    """Backward pass for scaled dot-product attention.
    
    Args:
        dO: Upstream gradient (T_q, d)
        cache: (Q, K, V, P, d) from forward pass
    
    Returns:
        dQ, dK, dV: Gradients w.r.t. Q, K, V
    """
    Q, K, V, P, d = cache
    scale = 1.0 / np.sqrt(d)
    
    # Step 1: dV = P.T @ dO
    dV = P.T @ dO
    
    # Step 2: dP = dO @ V.T
    dP = dO @ V.T
    
    # Step 3: Softmax backward (row-wise)
    # dS = P * (dP - sum(dP * P, axis=-1, keepdims=True))
    rowdot = (dP * P).sum(axis=-1, keepdims=True)
    dS = P * (dP - rowdot)
    
    # Step 4: dQ and dK
    dQ = scale * (dS @ K)
    dK = scale * (dS.T @ Q)
    
    return dQ, dK, dV

# Verify with numerical gradient
np.random.seed(42)
T, d = 4, 3
Q = np.random.randn(T, d)
K = np.random.randn(T, d)
V = np.random.randn(T, d)

# Forward
O, cache = attention_forward(Q, K, V)

# Backward with arbitrary upstream gradient
dO = np.random.randn(*O.shape)
dQ, dK, dV = attention_backward(dO, cache)

print("Analytic gradients computed!")
print(f"dQ shape: {dQ.shape}")
print(f"dK shape: {dK.shape}")
print(f"dV shape: {dV.shape}")

### Numerical Gradient Verification

We verify our derivation using finite differences:
$$\frac{\partial L}{\partial x} \approx \frac{f(x + \epsilon) - f(x - \epsilon)}{2\epsilon}$$

In [None]:
def numerical_gradient(f, x, eps=1e-5):
    """Compute numerical gradient using central differences."""
    grad = np.zeros_like(x)
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:
        idx = it.multi_index
        old_val = x[idx]
        
        x[idx] = old_val + eps
        fplus = f(x)
        x[idx] = old_val - eps
        fminus = f(x)
        
        grad[idx] = (fplus - fminus) / (2 * eps)
        x[idx] = old_val
        it.iternext()
    return grad

# Define loss function: L = sum(O * dO) (dot product with arbitrary vector)
def loss_wrt_Q(Q_test):
    O_test, _ = attention_forward(Q_test, K, V)
    return (O_test * dO).sum()

def loss_wrt_K(K_test):
    O_test, _ = attention_forward(Q, K_test, V)
    return (O_test * dO).sum()

def loss_wrt_V(V_test):
    O_test, _ = attention_forward(Q, K, V_test)
    return (O_test * dO).sum()

# Compute numerical gradients
dQ_num = numerical_gradient(loss_wrt_Q, Q.copy())
dK_num = numerical_gradient(loss_wrt_K, K.copy())
dV_num = numerical_gradient(loss_wrt_V, V.copy())

# Compare
print("Gradient verification:")
print(f"  dQ max error: {np.abs(dQ - dQ_num).max():.2e}")
print(f"  dK max error: {np.abs(dK - dK_num).max():.2e}")
print(f"  dV max error: {np.abs(dV - dV_num).max():.2e}")
print(f"\nAll gradients match: {np.allclose(dQ, dQ_num) and np.allclose(dK, dK_num) and np.allclose(dV, dV_num)}")

---

## Part 4: Causal Masking

In autoregressive models (like GPT), each position can only attend to previous positions. This is enforced with a **causal mask**.

### The Mask

For a sequence of length $T$, the causal mask is:
$$M_{ij} = \begin{cases} 0 & \text{if } j \leq i \text{ (can attend)} \\ -\infty & \text{if } j > i \text{ (blocked)} \end{cases}$$

When we add this to scores before softmax:
$$P_{ij} = \text{softmax}(S_{ij} + M_{ij})$$

The $-\infty$ values become 0 after softmax (since $e^{-\infty} = 0$).

### Visual Representation

In [None]:
def causal_mask(seq_len, fill=-1e9):
    """Build causal (lower triangular) attention mask."""
    i = np.arange(seq_len)
    # mask[i,j] = fill if j > i else 0
    return np.where(i[:, None] >= i[None, :], 0.0, fill)

T = 6
mask = causal_mask(T)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Raw mask
axes[0].imshow(mask, cmap='RdBu', vmin=-10, vmax=0)
axes[0].set_title('Causal Mask\n(0 = attend, -âˆž = block)')
axes[0].set_xlabel('Key position j')
axes[0].set_ylabel('Query position i')

# Random scores
np.random.seed(0)
S = np.random.randn(T, T)
axes[1].imshow(S, cmap='RdBu', vmin=-2, vmax=2)
axes[1].set_title('Raw Attention Scores S')

# After softmax with mask
P = softmax(S + mask)
axes[2].imshow(P, cmap='Blues', vmin=0, vmax=1)
axes[2].set_title('Attention Probs (masked)\nLower triangular only')

plt.tight_layout()
plt.show()

print("\nAttention probabilities (notice zeros above diagonal):")
print(np.round(P, 3))

---

## Part 5: Multi-Head Attention

Instead of performing a single attention function, Multi-Head Attention (MHA) runs $h$ attention "heads" in parallel, each with different learned projections.

### Motivation

1. **Different representation subspaces**: Each head can learn different aspects (e.g., syntax vs. semantics)
2. **More expressive**: Multiple attention patterns simultaneously
3. **Computational efficiency**: Same total dimension split across heads

### Architecture

Given input $X \in \mathbb{R}^{B \times T \times D}$ where $D = h \times d$:

**Step 1: Linear projections**
$$Q = XW^Q, \quad K = XW^K, \quad V = XW^V$$
where $W^Q, W^K, W^V \in \mathbb{R}^{D \times D}$

**Step 2: Split into heads**
Reshape $(B, T, D) \to (B, h, T, d)$

**Step 3: Parallel attention**
Apply scaled dot-product attention to each head independently

**Step 4: Concatenate and project**
$$\text{MHA}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O$$

In [None]:
class MultiHeadAttention:
    """Multi-Head Attention with full forward/backward implementation."""
    
    def __init__(self, d_model, n_heads, seed=42):
        assert d_model % n_heads == 0
        self.D = d_model
        self.h = n_heads
        self.d = d_model // n_heads
        
        # Initialize projections
        rng = np.random.default_rng(seed)
        scale = np.sqrt(2.0 / d_model)
        self.Wq = rng.normal(0, scale, (d_model, d_model)).astype(np.float32)
        self.Wk = rng.normal(0, scale, (d_model, d_model)).astype(np.float32)
        self.Wv = rng.normal(0, scale, (d_model, d_model)).astype(np.float32)
        self.Wo = rng.normal(0, scale, (d_model, d_model)).astype(np.float32)
        
        self._cache = None
    
    def split_heads(self, X):
        """(B, T, D) -> (B, h, T, d)"""
        B, T, _ = X.shape
        return X.reshape(B, T, self.h, self.d).transpose(0, 2, 1, 3)
    
    def combine_heads(self, X):
        """(B, h, T, d) -> (B, T, D)"""
        B, h, T, d = X.shape
        return X.transpose(0, 2, 1, 3).reshape(B, T, h * d)
    
    def forward(self, X, mask=None):
        """Forward pass.
        
        Args:
            X: Input (B, T, D)
            mask: Optional mask (1, 1, T, T) or broadcastable
        
        Returns:
            Y: Output (B, T, D)
        """
        B, T, D = X.shape
        
        # Linear projections
        Q_lin = X @ self.Wq  # (B, T, D)
        K_lin = X @ self.Wk
        V_lin = X @ self.Wv
        
        # Split heads
        Q = self.split_heads(Q_lin)  # (B, h, T, d)
        K = self.split_heads(K_lin)
        V = self.split_heads(V_lin)
        
        # Attention scores
        scale = 1.0 / np.sqrt(self.d)
        S = scale * np.einsum('bhtd,bhsd->bhts', Q, K)  # (B, h, T, T)
        
        if mask is not None:
            S = S + mask
        
        P = softmax(S)  # (B, h, T, T)
        
        # Weighted values
        O = np.einsum('bhts,bhsd->bhtd', P, V)  # (B, h, T, d)
        
        # Combine heads and output projection
        H = self.combine_heads(O)  # (B, T, D)
        Y = H @ self.Wo
        
        self._cache = (X, Q, K, V, P, H)
        return Y
    
    def backward(self, dY):
        """Backward pass - returns dX and computes parameter gradients."""
        X, Q, K, V, P, H = self._cache
        B, T, D = X.shape
        scale = 1.0 / np.sqrt(self.d)
        
        # Output projection backward
        dWo = H.reshape(-1, D).T @ dY.reshape(-1, D)
        dH = dY @ self.Wo.T  # (B, T, D)
        
        # Split dH back to heads
        dO = self.split_heads(dH)  # (B, h, T, d)
        
        # Attention backward
        dV = np.einsum('bhts,bhtd->bhsd', P, dO)  # (B, h, T, d)
        dP = np.einsum('bhtd,bhsd->bhts', dO, V)  # (B, h, T, T)
        
        # Softmax backward (row-wise on last two dims)
        rowdot = (dP * P).sum(axis=-1, keepdims=True)
        dS = P * (dP - rowdot)
        
        # Q, K backward
        dQ = scale * np.einsum('bhts,bhsd->bhtd', dS, K)
        dK = scale * np.einsum('bhts,bhtd->bhsd', dS, Q)
        
        # Combine heads for projection backward
        dQ_lin = self.combine_heads(dQ)
        dK_lin = self.combine_heads(dK)
        dV_lin = self.combine_heads(dV)
        
        # Projection gradients
        Xf = X.reshape(-1, D)
        dWq = Xf.T @ dQ_lin.reshape(-1, D)
        dWk = Xf.T @ dK_lin.reshape(-1, D)
        dWv = Xf.T @ dV_lin.reshape(-1, D)
        
        # Input gradient
        dX = (dQ_lin @ self.Wq.T + dK_lin @ self.Wk.T + dV_lin @ self.Wv.T)
        
        # Store gradients
        self.dWq, self.dWk, self.dWv, self.dWo = dWq, dWk, dWv, dWo
        
        return dX

# Test the implementation
np.random.seed(42)
B, T, D = 2, 5, 12
n_heads = 3

mha = MultiHeadAttention(D, n_heads)
X = np.random.randn(B, T, D).astype(np.float32)
mask = causal_mask(T)[None, None, :, :]  # (1, 1, T, T)

Y = mha.forward(X, mask)
print(f"Input shape:  {X.shape}")
print(f"Output shape: {Y.shape}")
print(f"\nNumber of heads: {n_heads}")
print(f"Dimension per head: {D // n_heads}")

### Verify Multi-Head Attention Gradients

In [None]:
# Numerical gradient check for MHA
dY = np.random.randn(B, T, D).astype(np.float32)
dX = mha.backward(dY)

def loss_mha(X_test):
    Y_test = mha.forward(X_test, mask)
    return (Y_test * dY).sum()

# Numerical gradient for X
eps = 1e-4
dX_num = np.zeros_like(X)
for b in range(B):
    for t in range(T):
        for d in range(D):
            X_plus = X.copy()
            X_plus[b, t, d] += eps
            X_minus = X.copy()
            X_minus[b, t, d] -= eps
            dX_num[b, t, d] = (loss_mha(X_plus) - loss_mha(X_minus)) / (2 * eps)

print("Multi-Head Attention gradient check:")
print(f"  dX max error: {np.abs(dX - dX_num).max():.2e}")
print(f"  dX matches: {np.allclose(dX, dX_num, rtol=1e-3, atol=1e-5)}")

---

## Part 6: Summary and Key Insights

### What We Derived

1. **Softmax backward**: $\frac{\partial L}{\partial x} = p \odot (g - (g \cdot p))$

2. **Scaled dot-product attention**:
   - Forward: $O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V$
   - Backward: Chain rule through softmax, matrix multiplies

3. **Multi-head attention**: Same operations, just with head dimension management

### Key Insights

1. **Scaling matters**: Without $\sqrt{d}$ scaling, attention scores explode with dimension

2. **Softmax is row-wise**: Each query has its own probability distribution over keys

3. **Gradients flow through both Q and K**: Changes to queries AND keys affect attention

4. **Multi-head gives expressivity**: Different heads can learn different attention patterns

### Computational Complexity

For sequence length $T$ and dimension $D$:
- Attention: $O(T^2 D)$ - quadratic in sequence length!
- This is why long sequences are expensive (motivation for Flash Attention, etc.)

In [None]:
# Visualize attention heads learning different patterns
np.random.seed(123)
B, T, D = 1, 8, 16
n_heads = 4

mha = MultiHeadAttention(D, n_heads, seed=42)
X = np.random.randn(B, T, D).astype(np.float32)

_ = mha.forward(X, mask=causal_mask(T)[None, None, :, :])
_, _, _, _, P, _ = mha._cache  # Get attention patterns

fig, axes = plt.subplots(1, 4, figsize=(14, 3))
for h in range(4):
    im = axes[h].imshow(P[0, h], cmap='Blues', vmin=0, vmax=1)
    axes[h].set_title(f'Head {h+1}')
    axes[h].set_xlabel('Key position')
    if h == 0:
        axes[h].set_ylabel('Query position')
plt.colorbar(im, ax=axes[-1])
plt.suptitle('Attention patterns per head (with causal mask)')
plt.tight_layout()
plt.show()