# Part 4: Multi-Head Attention

## Why Multiple Heads?

A single attention head can only focus on one type of relationship at a time. Multi-head attention runs **multiple attention operations in parallel**, each learning different patterns:

- One head might focus on **syntactic relationships** (subject-verb)
- Another on **semantic similarity** (synonyms)
- Another on **positional proximity** (nearby words)

---


In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

def softmax(x, axis=-1):
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


## The Multi-Head Attention Architecture

```
Input X (seq_len, d_model)
    |
    +---> Head 1: attention(X @ W_Q1, X @ W_K1, X @ W_V1)
    |
    +---> Head 2: attention(X @ W_Q2, X @ W_K2, X @ W_V2)
    |
    +---> Head 3: attention(X @ W_Q3, X @ W_K3, X @ W_V3)
    |
    ...
    |
    +---> Head h: attention(X @ W_Qh, X @ W_Kh, X @ W_Vh)
    |
    v
Concatenate all heads: [Head1 ; Head2 ; ... ; Head_h]
    |
    v
Final projection: Concat @ W_O
    |
    v
Output (seq_len, d_model)
```

### Key Insight: Smaller Dimensions Per Head

If we have 8 heads and d_model=512:
- Each head works with d_k = d_v = 512/8 = 64 dimensions
- Total computation stays the same as single-head with 512 dimensions
- But we get **8 different attention patterns**!


In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """Single-head attention from previous notebook."""
    d_k = K.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)
    
    if mask is not None:
        scores = np.where(mask == 0, -1e9, scores)
    
    weights = softmax(scores, axis=-1)
    output = weights @ V
    return output, weights


class MultiHeadAttention:
    """
    Multi-Head Attention implemented from scratch.
    """
    
    def __init__(self, d_model, num_heads):
        """
        d_model: Total dimension of the model
        num_heads: Number of attention heads
        """
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head
        
        # Initialize weight matrices for each head
        # We could use separate matrices, but it's more efficient to use one big matrix
        # and then split it into heads
        self.W_Q = np.random.randn(d_model, d_model) * 0.1
        self.W_K = np.random.randn(d_model, d_model) * 0.1
        self.W_V = np.random.randn(d_model, d_model) * 0.1
        self.W_O = np.random.randn(d_model, d_model) * 0.1
    
    def split_heads(self, x):
        """
        Split the last dimension into (num_heads, d_k).
        
        Input: (seq_len, d_model)
        Output: (num_heads, seq_len, d_k)
        """
        seq_len = x.shape[0]
        # Reshape: (seq_len, d_model) -> (seq_len, num_heads, d_k)
        x = x.reshape(seq_len, self.num_heads, self.d_k)
        # Transpose: (seq_len, num_heads, d_k) -> (num_heads, seq_len, d_k)
        return x.transpose(1, 0, 2)
    
    def combine_heads(self, x):
        """
        Reverse of split_heads.
        
        Input: (num_heads, seq_len, d_k)
        Output: (seq_len, d_model)
        """
        # Transpose: (num_heads, seq_len, d_k) -> (seq_len, num_heads, d_k)
        x = x.transpose(1, 0, 2)
        seq_len = x.shape[0]
        # Reshape: (seq_len, num_heads, d_k) -> (seq_len, d_model)
        return x.reshape(seq_len, self.d_model)
    
    def forward(self, X, mask=None):
        """
        Compute multi-head attention.
        
        X: Input (seq_len, d_model)
        mask: Optional attention mask
        
        Returns: output (seq_len, d_model), attention_weights (num_heads, seq_len, seq_len)
        """
        # Step 1: Linear projections
        Q = X @ self.W_Q  # (seq_len, d_model)
        K = X @ self.W_K
        V = X @ self.W_V
        
        # Step 2: Split into multiple heads
        Q = self.split_heads(Q)  # (num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Step 3: Apply attention to each head
        head_outputs = []
        attention_weights = []
        
        for i in range(self.num_heads):
            output, weights = scaled_dot_product_attention(
                Q[i], K[i], V[i], mask
            )
            head_outputs.append(output)
            attention_weights.append(weights)
        
        # Stack outputs: list of (seq_len, d_k) -> (num_heads, seq_len, d_k)
        head_outputs = np.stack(head_outputs, axis=0)
        attention_weights = np.stack(attention_weights, axis=0)
        
        # Step 4: Combine heads
        concat_output = self.combine_heads(head_outputs)  # (seq_len, d_model)
        
        # Step 5: Final linear projection
        output = concat_output @ self.W_O  # (seq_len, d_model)
        
        return output, attention_weights


# Test our implementation
d_model = 32
num_heads = 4
seq_len = 6

mha = MultiHeadAttention(d_model, num_heads)
X = np.random.randn(seq_len, d_model)

output, attention_weights = mha.forward(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"\nWe have {num_heads} attention patterns, each of size ({seq_len}, {seq_len})")


In [None]:
# Visualize attention patterns from each head
sentence = ["The", "cat", "sat", "on", "mat", "."]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for i in range(4):
    ax = axes[i]
    im = ax.imshow(attention_weights[i], cmap='Blues')
    ax.set_title(f'Head {i+1}')
    ax.set_xticks(range(len(sentence)))
    ax.set_yticks(range(len(sentence)))
    ax.set_xticklabels(sentence, rotation=45)
    ax.set_yticklabels(sentence)
    if i == 0:
        ax.set_ylabel('From')
    ax.set_xlabel('To')

plt.suptitle('Different Heads Learn Different Patterns', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("Notice: Each head has learned a different attention pattern!")
print("In a trained model, these patterns would be meaningful.")


## What Different Heads Learn

In trained models, researchers have found that different heads specialize:

| Head Type | What It Attends To |
|-----------|-------------------|
| Positional | Adjacent tokens (n-1, n+1) |
| Syntactic | Subject-verb pairs |
| Semantic | Related concepts |
| Delimiter | Punctuation, sentence boundaries |
| Rare token | Unusual or important words |


In [None]:
# Simulate different specialized heads
fig, axes = plt.subplots(2, 3, figsize=(14, 9))

sentence = ["The", "quick", "brown", "fox", "jumps", "."]
n = len(sentence)

# Head 1: Positional (attend to previous token)
positional = np.zeros((n, n))
for i in range(n):
    if i > 0:
        positional[i, i-1] = 0.7
    positional[i, i] = 0.3
ax = axes[0, 0]
im = ax.imshow(positional, cmap='Blues', vmin=0, vmax=1)
ax.set_title('Positional Head\n(Previous token)', fontsize=12)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(sentence, rotation=45)
ax.set_yticklabels(sentence)

# Head 2: Syntactic (subject-verb relationship)
syntactic = np.eye(n) * 0.1
syntactic[4, 0] = 0.6  # "jumps" attends to "The" (article of subject)
syntactic[4, 3] = 0.8  # "jumps" attends to "fox" (subject)
syntactic = syntactic / syntactic.sum(axis=1, keepdims=True)
ax = axes[0, 1]
ax.imshow(syntactic, cmap='Blues', vmin=0, vmax=1)
ax.set_title('Syntactic Head\n(Subject-verb)', fontsize=12)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(sentence, rotation=45)
ax.set_yticklabels(sentence)

# Head 3: Semantic (similar concepts)
semantic = np.eye(n) * 0.2
semantic[2, 1] = 0.5  # "brown" attends to "quick" (both adjectives)
semantic[1, 2] = 0.5  # vice versa
ax = axes[0, 2]
ax.imshow(semantic, cmap='Blues', vmin=0, vmax=0.7)
ax.set_title('Semantic Head\n(Similar types)', fontsize=12)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(sentence, rotation=45)
ax.set_yticklabels(sentence)

# Head 4: Beginning-of-sequence
bos = np.zeros((n, n))
bos[:, 0] = 0.7
bos += np.eye(n) * 0.3
bos = bos / bos.sum(axis=1, keepdims=True)
ax = axes[1, 0]
ax.imshow(bos, cmap='Blues', vmin=0, vmax=1)
ax.set_title('BOS Head\n(First token)', fontsize=12)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(sentence, rotation=45)
ax.set_yticklabels(sentence)

# Head 5: Delimiter
delimiter = np.eye(n) * 0.2
delimiter[:, -1] = 0.6  # Everyone attends to period
ax = axes[1, 1]
ax.imshow(delimiter, cmap='Blues', vmin=0, vmax=0.8)
ax.set_title('Delimiter Head\n(Punctuation)', fontsize=12)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(sentence, rotation=45)
ax.set_yticklabels(sentence)

# Head 6: Uniform (background)
uniform = np.ones((n, n)) / n
ax = axes[1, 2]
ax.imshow(uniform, cmap='Blues', vmin=0, vmax=0.3)
ax.set_title('Uniform Head\n(General context)', fontsize=12)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(sentence, rotation=45)
ax.set_yticklabels(sentence)

plt.suptitle('Specialized Attention Patterns (Simulated)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()


## The Math: Why It Works

### Single Head
```
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
```

### Multi-Head
```
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W_O

where head_i = Attention(Q W_Qi, K W_Ki, V W_Vi)
```

### Computational Cost

For input of size (seq_len, d_model):

| Operation | Single Head | Multi-Head (h heads) |
|-----------|-------------|---------------------|
| Q, K, V projections | O(d_model^2) | O(d_model^2) |
| Attention per head | O(seq_len^2 * d_model) | O(seq_len^2 * d_k) * h |
| Output projection | O(d_model^2) | O(d_model^2) |

Since d_k = d_model / h, the total cost is the **same**!


In [None]:
# Demonstrate computational equivalence
d_model = 512
seq_len = 100

# Single head
single_head_attention_ops = seq_len * seq_len * d_model
single_head_total = 3 * d_model * d_model + single_head_attention_ops + d_model * d_model

# Multi-head with h=8
h = 8
d_k = d_model // h
multi_head_attention_ops = h * (seq_len * seq_len * d_k)
multi_head_total = 3 * d_model * d_model + multi_head_attention_ops + d_model * d_model

print(f"d_model = {d_model}, seq_len = {seq_len}, num_heads = {h}")
print(f"\nSingle head attention operations: {single_head_attention_ops:,}")
print(f"Multi-head attention operations:  {multi_head_attention_ops:,}")
print(f"\nThey're equal! {single_head_attention_ops == multi_head_attention_ops}")
print(f"\nBut multi-head gives us {h} different attention patterns.")


## Efficient Implementation with Batched Operations

In practice, we don't loop over heads - we use batched matrix operations:


In [None]:
class EfficientMultiHeadAttention:
    """
    Efficient Multi-Head Attention using batched operations.
    No explicit loops over heads!
    """
    
    def __init__(self, d_model, num_heads):
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Combined weight matrices
        self.W_Q = np.random.randn(d_model, d_model) * 0.1
        self.W_K = np.random.randn(d_model, d_model) * 0.1
        self.W_V = np.random.randn(d_model, d_model) * 0.1
        self.W_O = np.random.randn(d_model, d_model) * 0.1
    
    def forward(self, X, mask=None):
        seq_len = X.shape[0]
        
        # Project to Q, K, V
        Q = X @ self.W_Q  # (seq_len, d_model)
        K = X @ self.W_K
        V = X @ self.W_V
        
        # Reshape for multi-head: (seq_len, d_model) -> (num_heads, seq_len, d_k)
        Q = Q.reshape(seq_len, self.num_heads, self.d_k).transpose(1, 0, 2)
        K = K.reshape(seq_len, self.num_heads, self.d_k).transpose(1, 0, 2)
        V = V.reshape(seq_len, self.num_heads, self.d_k).transpose(1, 0, 2)
        
        # Batched attention: (num_heads, seq_len, d_k) @ (num_heads, d_k, seq_len)
        scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(self.d_k)
        
        if mask is not None:
            scores = np.where(mask == 0, -1e9, scores)
        
        attention_weights = softmax(scores, axis=-1)
        
        # (num_heads, seq_len, seq_len) @ (num_heads, seq_len, d_k)
        context = np.matmul(attention_weights, V)
        
        # Reshape back: (num_heads, seq_len, d_k) -> (seq_len, d_model)
        context = context.transpose(1, 0, 2).reshape(seq_len, self.d_model)
        
        # Final projection
        output = context @ self.W_O
        
        return output, attention_weights

# Compare both implementations
efficient_mha = EfficientMultiHeadAttention(d_model=32, num_heads=4)
X_test = np.random.randn(6, 32)

output_efficient, weights_efficient = efficient_mha.forward(X_test)

print("Efficient implementation:")
print(f"  Output shape: {output_efficient.shape}")
print(f"  Weights shape: {weights_efficient.shape}")
print("\nNo loops over heads - all done with batched matrix operations!")


## Summary: Multi-Head Attention

### Architecture
```
Input X
    |
    v
[Q = X @ W_Q]  [K = X @ W_K]  [V = X @ W_V]
    |               |               |
    v               v               v
Split into h heads (each with d_k = d_model/h dimensions)
    |
    v
Run h parallel attention operations
    |
    v
Concatenate all heads
    |
    v
Output projection (W_O)
    |
    v
Output (same shape as input)
```

### Key Benefits

1. **Multiple perspectives**: Each head can learn different relationships
2. **No extra cost**: Same computation as single-head with full dimension
3. **Richer representations**: Combines information from all heads
4. **Specialization**: Different heads naturally specialize

### Typical Configurations

| Model | d_model | num_heads | d_k |
|-------|---------|-----------|-----|
| Transformer Base | 512 | 8 | 64 |
| Transformer Large | 1024 | 16 | 64 |
| GPT-2 Small | 768 | 12 | 64 |
| GPT-3 | 12288 | 96 | 128 |

---

**Next: 05_transformer_block.ipynb** - Assembling the complete Transformer block!
