# <u>**Multi-Head Attention**</u>

Multi-head attention (Vaswani et al., 2017) builds on the attention mechanism introduced for neural machine translation (Bahdanau et al., 2015). Instead of a single attention map, it learns multiple attention "heads" in parallel, allowing the model to capture different types of relationships across the same sequence.

**Unmasked vs. masked:**
- **Unmasked multi-head attention** allows each token to attend to all tokens (encoder-style).
- **Masked multi-head attention** uses a causal mask so tokens cannot look ahead (decoder-style).

In practice, each head specializes in different alignment patterns (e.g., syntax, locality, long-range dependencies), and their outputs are combined to form a richer representation.




### Mathematical Formulation

Multi-head attention is defined as:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$

$$\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)$$

**Masked version (causal):**

$$\text{Attention}(Q, K, V, M) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)V$$

Here, $M$ is an additive mask with $-\infty$ on positions that should be blocked.

### Visual Intuition

![Scaled Dot-Product Attention](./assets/self_attention_diagram.png)

![Multi-Head Attention](./assets/multihead_attention_diagram.png)

Each head performs scaled dot-product attention on its own learned projections, and the results are concatenated and linearly mixed to produce the final output.

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_model: int, h: int, dropout: float):
        """
        Our implementation of Multi-Headed Scaled Dot-Product Attention. 

        Args:
            d_model (int): Model embedding dimension.
            h (int): Number of attention heads.
            dropout: float
            
        Context: 
            d_model = size of embedding vector
            h = number of heads
            d_k = dimension of each head (corr: width)
        
        Just to keep consistency with the maths in the paper
        """
        super().__init__()
        self.d_model = d_model              # Also equivalent to h * d_k
        self.h = h
        assert d_model % h == 0
        
        self.d_k = d_model // h             # Also equivalent to d_v in paper
        self.scale = math.sqrt(self.d_k)
        
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        
        
    def split_heads(self, x, batch_size: int):
        """
        Split the model dimension into multiple heads.

        Args:
            x (torch.Tensor): Input tensor [Q', K' V'] of shape (batch, seq_len, d_model).
            batch_size (int): Batch size for reshaping.

        Returns:
            torch.Tensor: Tensor of shape (batch, n_heads, seq_len, d_k).
        """
        seq_len = x.size(1)
        x = x.reshape(batch_size, seq_len, self.h, self.d_k)
        return x.permute(0, 2, 1, 3)          
    
    def compute_attn(self, query, key, value, mask=None):
        """
        Compute scaled dot-product attention for each group
        of split heads (e.g., q1, k1, v1, ...).

        Args:
            query (torch.Tensor): Query tensor of shape (batch, h, seq_len, d_k).
            key (torch.Tensor): Key tensor of shape (batch, h, seq_len, d_k).
            value (torch.Tensor): Value tensor of shape (batch, h, seq_len, d_k).
            mask (torch.Tensor, optional): Attention mask broadcastable to
                (batch, h, seq_len, seq_len). Positions with 0 are masked.

        Returns:
            torch.Tensor: Attention output of shape (batch, h, seq_len, d_k).
        """
        scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        return torch.matmul(attn_weights, value)
    
    def concat_heads(self, x, batch_size: int):
        """
        Concatenate attention heads back into the model dimension (tensor H).

        Args:
            x (torch.Tensor): Tensor of shape (batch, n_heads, seq_len, head_dim).
            batch_size (int): Batch size for reshaping.

        Returns:
            torch.Tensor: Tensor of shape (batch, seq_len, d_model).
        """
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.view(batch_size, -1, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q (torch.Tensor): Query tensor of shape (batch, seq_len, d_model).
            K (torch.Tensor): Key tensor of shape (batch, seq_len, d_model).
            V (torch.Tensor): Value tensor of shape (batch, seq_len, d_model).
            mask (torch.Tensor, optional): Attention mask broadcastable to
                (batch, n_heads, seq_len, seq_len). Positions with 0 are masked.

        Returns:
            torch.Tensor: Output tensor of shape (batch, seq_len, d_model).
        """
        batch_size = Q.size(0)
        
        query = self.split_heads(self.W_q(Q), batch_size)
        key = self.split_heads(self.W_k(K), batch_size)
        value = self.split_heads(self.W_v(V), batch_size)
        
        attention = self.compute_attn(query, key, value, mask)
        
        output = self.concat_heads(attention, batch_size)
        
        return self.W_o(output)

## References

1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, ≈Å., & Polosukhin, I. (2017). 
   **Attention Is All You Need.** *Advances in Neural Information Processing Systems, 30.* 
   [arXiv:1706.03762](https://arxiv.org/abs/1706.03762)