# Build DeepSeek from Scratch - Multi-Head Attention

## Overview
This notebook covers the transition from self-attention to multi-head attention, explaining the fundamental motivation, mathematical implementation, and practical benefits of having multiple attention heads in transformer architectures.

## Learning Objectives
- Understand the limitations of single-head self-attention
- Learn why multi-head attention is necessary
- Master the step-by-step implementation of multi-head attention
- Visualize how different heads capture different perspectives
- Prepare for advanced concepts like Key-Value caching and Multi-Head Latent Attention

## Recap: Self-Attention Foundation

### The Four-Step Process
**Self-Attention Mechanism**:
1. **Transformation**: Input embeddings → Query, Key, Value matrices via trainable weights (WQ, WK, WV)
2. **Attention Scores**: Query @ Key.transpose()
3. **Attention Weights**: Scale by √d_k, apply softmax, apply causal masking
4. **Context Vectors**: Attention weights @ Value vectors

### The Achievement
- **Input embeddings**: Contain only token semantics + position
- **Context vectors**: Contain semantics + relationships with neighboring tokens
- **Result**: Richer representations that understand context and relationships

### Causal Attention Review
- **Constraint**: Tokens can only attend to past and current positions
- **Implementation**: Mask upper triangular elements to zero/negative infinity
- **Purpose**: Prevents "cheating" by looking into the future during training

## The Problem with Single-Head Attention

### Core Limitation: Single Perspective
**The fundamental issue**: Self-attention can only capture **one perspective** of a given input sequence.

### Ambiguous Sentence Example
**Sentence**: "The artist painted the portrait of a woman with a brush"

**Two Possible Interpretations**:
1. **Artist has brush**: The artist used a brush to paint a portrait of a woman
2. **Woman has brush**: The artist painted a portrait of a woman who is holding a brush

### Attention Matrix Differences

**Interpretation 1 - Artist has brush**:
```
High attention scores:
- woman → portrait (she's in the portrait)
- artist → brush (artist holds the brush)
- brush → artist (brush is with artist)
```

**Interpretation 2 - Woman has brush**:
```
High attention scores:
- woman → portrait (she's in the portrait)
- woman → brush (woman holds the brush)
- brush → woman (brush is with woman)
- brush → portrait (brush appears in portrait)
```

### The Single-Head Limitation
- **Problem**: Self-attention produces only ONE attention matrix
- **Result**: Can only capture ONE interpretation at a time
- **Consequence**: Loss of semantic richness and multiple perspectives

## Why Multi-Head Attention is Necessary

### Real-World Text Complexity
**Example**: "The government should regulate free speech"

**Multiple Perspectives**:
1. **Restrictive interpretation**: Government should impose restrictions on free speech
2. **Protective interpretation**: Government should protect and preserve free speech

### The Business Case
**Without multi-head attention**:
- Model captures only one perspective
- Summarization loses nuance
- Understanding is impoverished

**With multi-head attention**:
- Model captures multiple perspectives simultaneously
- Richer understanding of ambiguous text
- Better summarization and generation

### Core Insight
**If one self-attention head captures one perspective, then multiple heads can capture multiple perspectives simultaneously.**

### The Solution Strategy
1. **Multiple self-attention mechanisms** in parallel
2. **Each head** captures a different perspective
3. **Merge results** to create richer context vectors
4. **Same output dimensions** but multiple perspectives embedded

## Multi-Head Attention: Step-by-Step Implementation

### Key Parameters
- **Input sentence**: "The artist painted the portrait of a woman with a brush"
- **Number of tokens**: 11
- **Input dimension (d_in)**: 8
- **Output dimension (d_out)**: 4
- **Number of heads**: 2
- **Head dimension**: d_out / num_heads = 4 / 2 = 2

### Step 1: Start with Input Embeddings
```python
# Input embedding matrix
X = [11 × 8]  # 11 tokens, 8-dimensional embeddings
```

### Step 2: Single-Head Reference
**What single-head attention would do**:
```python
# Single head matrices
WQ = [8 × 4]  # Query weight matrix
WK = [8 × 4]  # Key weight matrix  
WV = [8 × 4]  # Value weight matrix

# Results
Q = X @ WQ = [11 × 4]
K = X @ WK = [11 × 4]
V = X @ WV = [11 × 4]
```

### Step 3: Split Weight Matrices for Multiple Heads
**Key insight**: Split the output dimension among heads

```python
# Head 1 matrices
WQ1 = [8 × 2]  # First 2 columns of original WQ
WK1 = [8 × 2]  # First 2 columns of original WK
WV1 = [8 × 2]  # First 2 columns of original WV

# Head 2 matrices
WQ2 = [8 × 2]  # Last 2 columns of original WQ
WK2 = [8 × 2]  # Last 2 columns of original WK
WV2 = [8 × 2]  # Last 2 columns of original WV
```

**Critical point**: We're not adding parameters, we're splitting existing ones!

### Step 4: Generate Multiple Q, K, V Matrices
```python
# Head 1 vectors
Q1 = X @ WQ1 = [11 × 2]
K1 = X @ WK1 = [11 × 2]
V1 = X @ WV1 = [11 × 2]

# Head 2 vectors
Q2 = X @ WQ2 = [11 × 2]
K2 = X @ WK2 = [11 × 2]
V2 = X @ WV2 = [11 × 2]
```

**Important observation**: 
- Number of rows (tokens) remains the same: 11
- Number of columns (head dimension) is reduced: 2 instead of 4
- We have 2 sets of Q, K, V matrices instead of 1

### Step 5: Compute Attention Scores for Each Head
```python
# Head 1 attention scores
Attention_scores_1 = Q1 @ K1.T = [11 × 2] @ [2 × 11] = [11 × 11]

# Head 2 attention scores  
Attention_scores_2 = Q2 @ K2.T = [11 × 2] @ [2 × 11] = [11 × 11]
```

**Key insight**: Even though head dimension is reduced (2 vs 4), attention scores matrix remains [11 × 11] because it represents token-to-token relationships.

### Step 6: Apply Scaling, Softmax, and Causal Masking
```python
# For each head, apply the same process:
# 1. Scale by √(head_dimension) = √2
# 2. Apply softmax
# 3. Apply causal masking (upper triangular → 0)
# 4. Apply dropout (optional)

Attention_weights_1 = process_attention(Attention_scores_1)  # [11 × 11]
Attention_weights_2 = process_attention(Attention_scores_2)  # [11 × 11]
```

### Step 7: Compute Context Vectors for Each Head
```python
# Head 1 context vectors
Context_1 = Attention_weights_1 @ V1 = [11 × 11] @ [11 × 2] = [11 × 2]

# Head 2 context vectors
Context_2 = Attention_weights_2 @ V2 = [11 × 11] @ [11 × 2] = [11 × 2]
```

### Step 8: Concatenate Results
```python
# Merge both heads
Final_Context = concat([Context_1, Context_2], dim=-1) = [11 × 4]
```

**Final result**: Same dimensions as single-head attention [11 × 4], but now contains multiple perspectives!

## Mathematical Summary

### Dimension Tracking
```python
# Input
X: [batch_size, seq_len, d_in] = [1, 11, 8]

# Weight matrices (split across heads)
WQ_h, WK_h, WV_h: [d_in, head_dim] = [8, 2] for each head

# Per-head computations
Q_h, K_h, V_h: [batch_size, seq_len, head_dim] = [1, 11, 2]
Attention_scores_h: [batch_size, seq_len, seq_len] = [1, 11, 11]
Attention_weights_h: [batch_size, seq_len, seq_len] = [1, 11, 11]
Context_h: [batch_size, seq_len, head_dim] = [1, 11, 2]

# Final concatenation
Final_Context: [batch_size, seq_len, d_out] = [1, 11, 4]
```

### Key Formula
```
head_dim = d_out / num_heads
```

### Multi-Head Attention in Matrix Form
```python
def multi_head_attention(X, num_heads):
    head_dim = d_out // num_heads
    
    # Split weight matrices
    WQ_heads = split(WQ, num_heads, dim=-1)
    WK_heads = split(WK, num_heads, dim=-1)  
    WV_heads = split(WV, num_heads, dim=-1)
    
    contexts = []
    for i in range(num_heads):
        Q_i = X @ WQ_heads[i]
        K_i = X @ WK_heads[i]
        V_i = X @ WV_heads[i]
        
        scores_i = Q_i @ K_i.T / sqrt(head_dim)
        weights_i = softmax(causal_mask(scores_i))
        context_i = weights_i @ V_i
        
        contexts.append(context_i)
    
    return concat(contexts, dim=-1)
```

## Advantages and Trade-offs

### Advantages of Multi-Head Attention
1. **Multiple Perspectives**: Each head can capture different semantic relationships
2. **Richer Representations**: Final context vectors contain diverse information
3. **Specialized Learning**: Heads can specialize in different aspects:
   - Head 1: Syntactic relationships (subject-verb-object)
   - Head 2: Semantic relationships (word meanings)
   - Head 3: Positional relationships (temporal sequences)
4. **Same Output Dimensions**: No increase in final output size
5. **Parallel Processing**: All heads computed simultaneously

### Trade-offs
1. **Reduced Per-Head Capacity**: Each head has fewer dimensions to work with
   - Single head: 4 dimensions per head
   - Multi-head (2 heads): 2 dimensions per head
2. **Divide and Conquer**: Trade individual head expressivity for multiple perspectives
3. **Computational Overhead**: More matrix operations (but parallelizable)

### The Trade-off Analysis
**Single Head**: 
- ✅ Full dimensional capacity (4D)
- ❌ Single perspective only

**Multi-Head**: 
- ✅ Multiple perspectives (2 heads)
- ❌ Reduced dimensional capacity per head (2D each)
- ✅ Net gain in representational power

**Empirical Evidence**: Multi-head attention consistently outperforms single-head attention across all major language tasks.

## Practical Demonstration: Visualizing Different Heads

### Setup
- **Model**: Pre-trained BERT (bidirectional attention)
- **Sentence**: "The artist painted the portrait of a woman with a brush"
- **Analysis**: Layer 3, Heads 3 and 8 (out of 12 total heads)
- **Focus Token**: "woman"

### Head 3 Analysis
**Query Token**: "woman"
**Highest Attention**: "brush"

**Interpretation**: This head seems to capture the perspective where the woman is holding the brush.

### Head 8 Analysis  
**Query Token**: "woman"
**Highest Attention**: "portrait"

**Interpretation**: This head captures the perspective where the woman is the subject of the portrait (not necessarily holding the brush).

### Visualization Code Structure
```python
# Load pre-trained model
model = load_pretrained_bert()

# Extract attention weights
attention_weights = model.get_attention_weights(
    sentence="The artist painted the portrait of a woman with a brush",
    layer=3,
    head=[3, 8]
)

# Visualize attention patterns
visualize_attention(attention_weights, focus_token="woman")
```

### Key Insights from Demo
1. **Head 3**: woman → brush (high attention)
   - Suggests: "woman with a brush" interpretation
2. **Head 8**: woman → portrait (high attention)  
   - Suggests: "portrait of a woman" interpretation
3. **Different heads capture different semantic relationships**
4. **Pre-trained models naturally learn diverse perspectives**

## Implementation Considerations

### Parameter Count Analysis
**Single Head**:
```python
WQ: [d_in × d_out] = [8 × 4] = 32 parameters
WK: [d_in × d_out] = [8 × 4] = 32 parameters  
WV: [d_in × d_out] = [8 × 4] = 32 parameters
Total: 96 parameters
```

**Multi-Head (2 heads)**:
```python
WQ1: [8 × 2] = 16 parameters
WK1: [8 × 2] = 16 parameters
WV1: [8 × 2] = 16 parameters
WQ2: [8 × 2] = 16 parameters
WK2: [8 × 2] = 16 parameters
WV2: [8 × 2] = 16 parameters
Total: 96 parameters
```

**Key insight**: Same number of parameters, just organized differently!

### Memory Requirements
**Attention Matrices**: Each head requires [seq_len × seq_len] attention matrix
- Single head: 1 × [11 × 11] = 121 elements
- Multi-head: 2 × [11 × 11] = 242 elements
- **Memory scales linearly with number of heads**

### Computational Complexity
**Per head**: O(seq_len² × head_dim)
**Total**: O(seq_len² × d_out) - same as single head!

The computation per head is smaller but we have more heads, so total computation remains the same.

## Connection to Modern LLMs

### Why All Modern LLMs Use Multi-Head Attention
1. **GPT models**: 12-96 attention heads per layer
2. **BERT models**: 12-16 attention heads per layer  
3. **T5 models**: 12-32 attention heads per layer
4. **DeepSeek**: Advanced multi-head latent attention

### Head Specialization in Practice
**Research findings show heads often specialize in**:
- **Syntactic heads**: Subject-verb relationships, dependency parsing
- **Semantic heads**: Word meaning relationships, entity recognition
- **Positional heads**: Sequential patterns, temporal relationships
- **Attention heads**: Long-range dependencies, discourse structure

### The Path Forward
**Next steps in our journey**:
1. ✅ Self-attention
2. ✅ Causal attention  
3. ✅ Multi-head attention
4. 🔄 **Next**: Key-Value caching (efficiency optimization)
5. 🔄 **Final**: Multi-head latent attention (DeepSeek's innovation)

### Why This Foundation Matters
- **KV caching**: Optimizes the key-value computations we just learned
- **Multi-head latent attention**: Modifies the multi-head structure for better efficiency
- **Understanding prerequisites**: Can't understand advanced concepts without mastering these basics

## Key Takeaways

### Core Concepts
1. **Problem**: Single-head attention captures only one perspective
2. **Solution**: Multiple heads capture multiple perspectives simultaneously  
3. **Implementation**: Split weight matrices across heads, compute in parallel, concatenate results
4. **Trade-off**: Reduced per-head capacity for increased perspective diversity

### Mathematical Insights
- **Same parameter count**: Multi-head doesn't add parameters, just reorganizes them
- **Same output dimensions**: Final context vectors have same size as single-head
- **Same computational complexity**: O(seq_len² × d_out) regardless of head count
- **Linear memory scaling**: Memory increases with number of heads

### Practical Benefits
- **Richer representations**: Multiple perspectives embedded in same output
- **Specialized learning**: Heads can focus on different aspects of language
- **Better performance**: Empirically superior across all language tasks
- **Interpretability**: Can visualize and understand what each head learns

### The Foundation for Advanced Concepts
Multi-head attention is the essential building block for:
- **Transformer architecture**: Core component of all modern LLMs
- **Efficiency optimizations**: KV caching, gradient checkpointing
- **Architectural innovations**: Multi-head latent attention, mixture of experts
- **Understanding language**: How models capture complex linguistic relationships

This mechanism is fundamental to how language models understand and generate human language, making it one of the most important concepts in modern AI.