# Multi-Head Attention

In the previous notebook, we saw that self-attention helps focus on the most relevant words when processing a sentence. But what if there are multiple ways to determine relevance?

Consider the previous example:

The cat sat on the mat.

If we ask, "Where did the cat sit?", the word "mat" is important. However, there are other relationships in the sentence:

- The word "cat" is related to "sat" (because the cat is the subject of the action).
- The word "on" is also relevant because it describes the spatial relationship between "sat" and "mat".

With only one attention head, the model might only capture one of these relationships at a time.

Solution:
Each attention head looks at the sentence from a different perspective. This way, the model can capture multiple relationships between words.

The input to the multi-head attention layer is a sequence of vectors. Each vector is the output of the previous layer (e.g., the word embeddings). The multi-head attention layer consists of multiple attention heads. Each head applies a different learned linear transformation to the input vectors. The outputs of the different heads are concatenated and linearly transformed to produce the final output of the multi-head attention layer.</p>
<img src="mha_img_original.png" width="500" style="background-color:white;">

in simple terms we apply multiple *scaled dot-product attentions* in parallel

<img src="mha_visualization-930x1030.png" width="500" style="background-color:white;">


## step 1: single head attention

Recall:

$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}}) V $$


In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def scaled_dot_product_attention(Q, K, V):
    d_k = Q.shape[-1]  # Dimension of key vectors
    scores = np.dot(Q, K.T) / np.sqrt(d_k)  # Compute scaled dot-product attention scores
    attention_weights = np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True)  # Apply softmax
    output = np.dot(attention_weights, V)  # Weighted sum of values
    return output, attention_weights

Each head in multi-head attention is responsible for computing scaled dot-product attention in a lower-dimensional subspace.
Mathematically, a single head is computed as:

$$ head_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)$$

where:
- $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$, $W_i^K \in \mathbb{R}^{d_{model} \times d_k}$, $W_i^V \in \mathbb{R}^{d_{model} \times d_v}$ are learned linear transformations

This means each head has its own learnable weight matrices $(W_i^Q, W_i^K, W_i^V)$, which project the input queries, keys, and values into lower-dimensional subspaces before applying attention.
