<a href="https://www.kaggle.com/code/aisuko/coding-the-multi-head-attention?scriptVersionId=160449289" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

**Note: The images are from the Credit section**

In the [Coding the Self-Attention Mechanism](https://www.kaggle.com/code/aisuko/coding-the-self-attention-mechanism) we implement self-attention machanim. And from [Encoder in Transformers Architecture](https://www.kaggle.com/code/aisuko/encoder-in-transformers-architecture), we can see that transformers use a module called **multi-head-attention**. In this notebook, we will talk about How does that relate to the self attention again by implementing it in code.

In the scaled dot-product attention(self-attention), the input sequence was transformed using three matrices representingt the query, key and values. These three matrices can be considered as a single attention head in the conext of multi-head attention. The figure below summarizes this single attention head we covered in the previously notebook above.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/814/921/458/820/088/original/1fb77b4eb89e6718.png" width="80%" heigh="80%" alt="Scaled-dot-product attention"></div>

# Multi-Head Attention

As its name implies, multi-head attention involves multiple such heads, each consisting of query, key, and value matrices. This concept is similar to the use of multiple kernels in convolutional neural networks.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/814/925/242/665/171/original/821c33b401832d5d.png" width="80%" heigh="80%" alt="multi-head attention"></div>

Here the code from the previouly notebook below:

In [1]:
import torch
import torch.nn.functional as F

inputs="According to the news, it it hard to say Melbourne is safe now"
d_q, d_k, d_v=24,24,28

input_ids={s:i for i,s in enumerate(sorted(inputs.replace(',','').split()))}
input_tokens=torch.tensor([input_ids[s] for s in inputs.replace(',','').split()])

torch.manual_seed(123)
embed=torch.nn.Embedding(13,16)
embedded_sentence=embed(input_tokens).detach()
d=embedded_sentence.shape[1]

# defining the Weight Matrices
W_query=torch.nn.Parameter(torch.rand(d_q, d))
W_key=torch.nn.Parameter(torch.rand(d_k, d))
W_value=torch.nn.Parameter(torch.rand(d_v, d))

# only computing the attention-vector for the second input element
# In this example, the second input element acts as the query
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

# computing the key and value for all inputs
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

# computing the unnormalized attention wieghts w

# in this example, we compute the query and the 5th input element(the index position is 4) as follows
w_2_4=query_2.dot(keys[4])

# compute the unnormalized attention weight for all the input tokens
w_2=query_2.matmul(keys.T)

attention_weights_2=F.softmax(w_2/d_k**0.5, dim=0)

# The final context vector(an attention-weighted version of the original query input x_2)
context_vector_2=attention_weights_2.matmul(values)
context_vector_2

tensor([-2.1845, -3.4618, -2.5052, -3.4871, -2.2224, -3.4605, -3.9543, -4.4065,
        -4.7564, -4.1877, -2.8166, -4.1730, -3.3587, -3.0407, -4.5513, -4.7335,
        -1.3817, -2.6396, -2.3683, -2.7940, -3.4905, -4.4358, -5.2125, -4.3044,
        -3.0761, -3.4201, -4.7494, -3.3475], grad_fn=<SqueezeBackward4>)

And To illustrate this in code, suppose we have 3 attention heads, so we now extend the $d^{'}*d$ dimensional weight matrices so $3*d^{'}*d$:

In [2]:
h=3
multihead_W_query=torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key=torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value=torch.nn.Parameter(torch.rand(h, d_v, d))

Consequently, each query element is now $3*d_{q}$ dimensional, where $d_{q}=24$ (here, let's keep the focus on the 3rd element corresponding to index position 2):

In [3]:
multihead_query_2=multihead_W_query.matmul(x_2)
multihead_query_2

tensor([[-3.8033, -1.8514, -3.0982, -1.6475, -1.7888, -3.1605, -2.3619, -0.5279,
         -3.6521, -3.4834, -3.6471, -3.2028, -0.6245, -1.6851, -1.0399, -3.3090,
         -2.1283, -5.2142, -1.6018, -0.4544, -3.1030, -0.0287, -4.3965, -2.2998],
        [-4.1517, -4.6697, -1.6747, -1.1715, -3.5441, -0.4090, -1.6129, -4.4261,
         -2.1847, -2.9327, -2.6157, -3.1685, -1.9501, -2.9855, -3.1613, -1.2670,
         -0.5295, -1.1895, -0.4661, -2.3916, -0.9902,  0.3367, -0.4596, -2.9863],
        [-1.6977, -1.6078, -3.5137, -4.9699, -4.1886, -0.7016, -3.3832, -3.2597,
         -2.1036, -4.3422, -1.9974, -1.7627, -2.9813, -1.5485,  0.0060, -1.7442,
         -5.0369, -4.2576, -1.7272, -0.5214, -2.1458, -2.9699, -1.4175, -0.9593]],
       grad_fn=<UnsafeViewBackward0>)

In [4]:
multihead_query_2.shape

torch.Size([3, 24])

Let's do the computing for keys and values

In [5]:
multihead_key_2=multihead_W_key.matmul(x_2)
multihead_value_2=multihead_W_value.matmul(x_2)

Now, these ket and value elements are specific to the query element. But. similar to earlier, we will also need the values and keys for the other sequence elements in order to compute the attention scores for the query. We can do this by expanding the input sequence embeddings to size 3(the number of attention heads):

In [6]:
embedded_sentence.shape

torch.Size([13, 16])

In [7]:
stacked_inputs=embedded_sentence.T.repeat(3,1,1)

In [8]:
stacked_inputs.shape

torch.Size([3, 16, 13])

Now, we cam compute all the keys and values using `torch.bmm()` (batch matrix multiplication)

In [9]:
multihead_keys=torch.bmm(multihead_W_key, stacked_inputs)
multihead_values=torch.bmm(multihead_W_value, stacked_inputs)
print(multihead_keys.shape)
print(multihead_values.shape)

torch.Size([3, 24, 13])
torch.Size([3, 28, 13])


We now have tensors that represent the three attention heads in their first dimension. The third and second dimensions refer to the number of words and the embedding size, respectively. To make the values and keys more intuitive to interpret, we will swap the second and third dimensions, resulting in tensors with the same dimensional structure as the original input sequence, **embedded_sentence**:

In [10]:
multihead_keys=multihead_keys.permute(0,2,1)
multihead_values=multihead_values.permute(0,2,1)
print(multihead_keys.shape)
print(multihead_values.shape)

torch.Size([3, 13, 24])
torch.Size([3, 13, 28])


We follow the same steps as previously to compute the unscaled attention weights $w$ and attention weights $\alpha$, followed the scaled-softmax computation to obtain an $h*d_{v}$(here $3*d_{v}$) dimensional context vector $z$ for the input element $x^{(2)}$.

Let's summarize the code in a compact `MultiHeadAttentionWrapper` class:

In [11]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq=d_out_kq
        self.W_query=nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key=nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value=nn.Parameter(torch.rand(d_in, d_out_v))
        
    def forward(self, x):
        keys=x.matmul(self.W_key)
        queries=x.matmul(self.W_query)
        values=x.matmul(self.W_value)
        
        # unnormalized attention weights
        attn_scores=queries.matmul(keys.T)
        
        attn_weights=torch.softmax(
            attn_scores/self.d_out_kq**0.5, dim=-1
        )
        
        context_vex=attn_weights.matmul(values)
        return context_vex

class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads=nn.ModuleList(
            [SelfAttention(d_in, d_out_kq, d_out_v) for _ in range(num_heads)]
        )
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1) #dim=-1, the last dimension
    

torch.manual_seed(123)
# Let's suppose we have a single Self-Attention head with output dimension 1 to keep it simple
# d_in: Dimension of the input feature vector(embedded vector, here is 16, see above)
# d_out_kq: Dimension for both query and key outputs
# d_out_v: Dimension for value outputs
# num_heads: Number of attention heads
d_in,d_out_kq,d_out_v=16,2,1

sa=SelfAttention(d_in, d_out_kq, d_out_v)
sa(embedded_sentence)

tensor([[ 3.1373],
        [-2.6629],
        [ 3.2493],
        [-2.4250],
        [-2.3970],
        [-2.3970],
        [-2.6269],
        [-2.6629],
        [-2.4197],
        [-2.4037],
        [ 3.3277],
        [-2.4113],
        [-2.6131]], grad_fn=<MmBackward0>)

In [12]:
torch.manual_seed(123)

block_size=embedded_sentence.shape[1]
mha=MultiHeadAttentionWrapper(d_in, d_out_kq, d_out_v, num_heads=4)

context_vecs=mha(embedded_sentence)

print(context_vecs)
print(f"Shape:{context_vecs.shape}")

tensor([[ 3.1373,  2.1516, -1.6622, -1.1571],
        [-2.6629, -1.9705, -2.2620, -3.2353],
        [ 3.2493, -2.0611, -2.0625, -3.0395],
        [-2.4250, -1.9014, -2.1996, -3.3473],
        [-2.3970, -1.8704, -2.0173, -3.3243],
        [-2.3970, -1.8704, -2.0173, -3.3243],
        [-2.6269, -2.3509, -2.2068, -3.2777],
        [-2.6629, -1.9705, -2.2620, -3.2353],
        [-2.4197, -1.7686, -2.1098, -3.1776],
        [-2.4037, -1.2218, -0.3614, -3.1320],
        [ 3.3277,  2.1627, -0.5280, -2.0840],
        [-2.4113, -1.7658, -2.2009, -3.4129],
        [-2.6131, -1.7475, -1.7935, -2.9953]], grad_fn=<CatBackward0>)
Shape:torch.Size([13, 4])


Based on the output above, we can see that the single self-attention head created earlier now represents the first column in the output tensor above.

Notice that the multi-head attention result is a 13x4-dimensional tensor: We have 13 input tokens and 4 self-attention heads, where each self-attention head returns a 1-dimensional output(we set `d_out_v=1`). 

In practice, why do we even need multiple attention heads if we can regulate the output embedding size in the `SelfAttention` class itself?

The distinction between increasing the output dimension of a single self-attention head multiple attention heads **lies in how the model processes and learns from the data**. While both approaches increase the capacity of the model to represent different features or aspects of the data, they do so infundamentally different ways.

For instance, **each attention head in multi-head attention can potentially learn to focus on different parts of the input sequence**, **capturing various aspects** or **relationships within the data**. This **diversity in representation is key to the success of multi-head attention**.

Multi-head attention can also **be more efficient, especially in terms of parallel computation**. Each head can be processed independently, making it well-suited for modern hardware accelerators like GPUs or TPUs that excel at parallel processing.

In short, the use of multiple attention heads is not just about increasing the model's capacity but about enhacing its ability to learn a diverse set of features and relationships within the data. For example, the 7B Llama 2 modle uses 32 attention heads.

# Credit

* https://magazine.sebastianraschka.com?utm_source=navbar&utm_medium=web&r=fbe14
* https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html