# Multi-Head Attention

## Paper Reference

**Vaswani, A., et al. (2017).** *Attention Is All You Need.* NeurIPS. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762)

---

## Mathematical Derivation

### Core Formula

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

Where each head is:

$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

### Projection Matrices

- $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$
- $W_i^K \in \mathbb{R}^{d_{model} \times d_k}$
- $W_i^V \in \mathbb{R}^{d_{model} \times d_v}$
- $W^O \in \mathbb{R}^{hd_v \times d_{model}}$

Typically: $d_k = d_v = d_{model} / h$

### Why Multiple Heads?

1. **Subspace Learning**: Each head can learn different attention patterns
2. **Position Focus**: Different heads attend to different positions
3. **Ensemble Effect**: Aggregates information from multiple representations

### Complexity Analysis

| Operation | Time Complexity | Space Complexity |
|-----------|-----------------|------------------|
| Projections | $O(n \cdot d^2)$ | $O(d^2)$ |
| Attention per head | $O(n^2 \cdot d_k)$ | $O(n^2)$ |
| All heads | $O(n^2 \cdot d)$ | $O(n^2 \cdot h)$ |
| Output projection | $O(n \cdot d^2)$ | $O(d^2)$ |
| **Total** | $O(n^2 \cdot d)$ | $O(n^2 \cdot h)$ |

In [None]:
import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention mechanism."""
    
    def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.0) -> None:
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.scale = 1.0 / math.sqrt(self.d_k)
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, seq_len_q = query.size(0), query.size(1)
        seq_len_k = key.size(1)
        
        # Project and reshape: (batch, heads, seq, d_k)
        q = self.w_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(value).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        
        # Attention scores: (batch, heads, seq_q, seq_k)
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask.bool(), float("-inf"))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Attended values
        attended = torch.matmul(attn_weights, v)
        attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        
        return self.w_o(attended), attn_weights

In [None]:
# Demonstration
torch.manual_seed(42)

batch_size, seq_len, d_model, num_heads = 1, 8, 64, 4
attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)

x = torch.randn(batch_size, seq_len, d_model)
output, weights = attention(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

In [None]:
# Visualize attention patterns for each head
fig, axes = plt.subplots(1, num_heads, figsize=(16, 4))

for head_idx in range(num_heads):
    sns.heatmap(
        weights[0, head_idx].detach().numpy(),
        ax=axes[head_idx],
        cmap='viridis',
        square=True,
        cbar=head_idx == num_heads - 1
    )
    axes[head_idx].set_title(f'Head {head_idx}')
    axes[head_idx].set_xlabel('Key')
    if head_idx == 0:
        axes[head_idx].set_ylabel('Query')

plt.suptitle('Multi-Head Attention Patterns')
plt.tight_layout()
plt.show()

## When to Use Multi-Head Attention

| Use Case | Recommendation |
|----------|---------------|
| Standard transformer | 8-16 heads |
| Small models | 4-8 heads |
| Large models | 16-64 heads |
| Memory constrained | Consider MQA/GQA |