# Overview

**Note: The images are from the Credit section**

In the [Implementing Self-Attention Mechanism](https://www.kaggle.com/code/aisuko/implementing-self-attention-mechanism/notebook) we implement self-attention machanim. And from [Encoder in Transformers Architecture](https://www.kaggle.com/code/aisuko/encoder-in-transformers-architecture), we can see that transformers use a module called **multi-head-attention**. In this notebook, we will talk about How does that relate to the self attention again by implementing it in code.

In the scaled dot-product attention(self-attention), the input sequence was transformed using three matrices representingt the query, key and values. These three matrices can be considered as a single attention head in the conext of multi-head attention. The figure below summarizes this single attention head we covered in the previously notebook above.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/814/921/458/820/088/original/1fb77b4eb89e6718.png" width="80%" heigh="80%" alt="Scaled-dot-product attention"></div>

# Multi-Head Attention

As its name implies, multi-head attention involves multiple such heads, each consisting of query, key, and value matrices. This concept is similar to the use of multiple kernels in convolutional neural networks.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/814/925/242/665/171/original/821c33b401832d5d.png" width="80%" heigh="80%" alt="multi-head attention"></div>

Here the code from the previouly notebook below:

In [1]:
import torch
import torch.nn.functional as F

inputs="According to the news, it it hard to say Melbourne is safe now"
d_q, d_k, d_v=24,24,28

input_ids={s:i for i,s in enumerate(sorted(inputs.replace(',','').split()))}
input_tokens=torch.tensor([input_ids[s] for s in inputs.replace(',','').split()])

torch.manual_seed(123)
embed=torch.nn.Embedding(13,16)
embedded_sentence=embed(input_tokens).detach()
d=embedded_sentence.shape[1]

# defining the Weight Matrices
W_query=torch.nn.Parameter(torch.rand(d_q, d))
W_key=torch.nn.Parameter(torch.rand(d_k, d))
W_value=torch.nn.Parameter(torch.rand(d_v, d))

# only computing the attention-vector for the second input element
# In this example, the second input element acts as the query
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

# computing the key and value for all inputs
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

# computing the unnormalized attention wieghts w

# in this example, we compute the query and the 5th input element(the index position is 4) as follows
w_2_4=query_2.dot(keys[4])

# compute the unnormalized attention weight for all the input tokens
w_2=query_2.matmul(keys.T)

attention_weights_2=F.softmax(w_2/d_k**0.5, dim=0)

# The final context vector(an attention-weighted version of the original query input x_2)
context_vector_2=attention_weights_2.matmul(values)
context_vector_2

tensor([-2.1845, -3.4618, -2.5052, -3.4871, -2.2224, -3.4605, -3.9543, -4.4065,
        -4.7564, -4.1877, -2.8166, -4.1730, -3.3587, -3.0407, -4.5513, -4.7335,
        -1.3817, -2.6396, -2.3683, -2.7940, -3.4905, -4.4358, -5.2125, -4.3044,
        -3.0761, -3.4201, -4.7494, -3.3475], grad_fn=<SqueezeBackward4>)

And To illustrate this in code, suppose we have 3 attention heads, so we now extend the $d^{'}*d$ dimensional weight matrices so $3*d^{'}*d$:

In [2]:
h=3
multihead_W_query=torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key=torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value=torch.nn.Parameter(torch.rand(h, d_v, d))

Consequently, each query element is now $3*d_{q}$ dimensional, where $d_{q}=24$ (here, let's keep the focus on the 3rd element corresponding to index position 2):

In [3]:
multihead_query_2=multihead_W_query.matmul(x_2)
multihead_query_2

tensor([[-3.8033, -1.8514, -3.0982, -1.6475, -1.7888, -3.1605, -2.3619, -0.5279,
         -3.6521, -3.4834, -3.6471, -3.2028, -0.6245, -1.6851, -1.0399, -3.3090,
         -2.1283, -5.2142, -1.6018, -0.4544, -3.1030, -0.0287, -4.3965, -2.2998],
        [-4.1517, -4.6697, -1.6747, -1.1715, -3.5441, -0.4090, -1.6129, -4.4261,
         -2.1847, -2.9327, -2.6157, -3.1685, -1.9501, -2.9855, -3.1613, -1.2670,
         -0.5295, -1.1895, -0.4661, -2.3916, -0.9902,  0.3367, -0.4596, -2.9863],
        [-1.6977, -1.6078, -3.5137, -4.9699, -4.1886, -0.7016, -3.3832, -3.2597,
         -2.1036, -4.3422, -1.9974, -1.7627, -2.9813, -1.5485,  0.0060, -1.7442,
         -5.0369, -4.2576, -1.7272, -0.5214, -2.1458, -2.9699, -1.4175, -0.9593]],
       grad_fn=<UnsafeViewBackward0>)

In [4]:
multihead_query_2.shape

torch.Size([3, 24])

Let's do the computing for keys and values

In [5]:
multihead_key_2=multihead_W_key.matmul(x_2)
multihead_value_2=multihead_W_value.matmul(x_2)

Now, these ket and value elements are specific to the query element. But. similar to earlier, we will also need the values and keys for the other sequence elements in order to compute the attention scores for the query. We can do this by expanding the input sequence embeddings to size 3(the number of attention heads):

In [6]:
embedded_sentence.shape

torch.Size([13, 16])

In [7]:
stacked_inputs=embedded_sentence.T.repeat(3,1,1)

In [8]:
stacked_inputs.shape

torch.Size([3, 16, 13])

Now, we cam compute all the keys and values using `torch.bmm()` (batch matrix multiplication)

In [9]:
multihead_keys=torch.bmm(multihead_W_key, stacked_inputs)
multihead_values=torch.bmm(multihead_W_value, stacked_inputs)
print(multihead_keys.shape)
print(multihead_values.shape)

torch.Size([3, 24, 13])
torch.Size([3, 28, 13])


We now have tensors that represent the three attention heads in their first dimension. The third and second dimensions refer to the number of words and the embedding size, respectively. To make the values and keys more intuitive to interpret, we will swap the second and third dimensions, resulting in tensors with the same dimensional structure as the original input sequence, **embedded_sentence**:

In [10]:
multihead_keys=multihead_keys.permute(0,2,1)
multihead_values=multihead_values.permute(0,2,1)
print(multihead_keys.shape)
print(multihead_values.shape)

torch.Size([3, 13, 24])
torch.Size([3, 13, 28])


We follow the same steps as previously to compute the unscaled attention weights $w$ and attention weights $\alpha$, followed the scaled-softmax computation to obtain an $h*d_{v}$(here $3*d_{v}$) dimensional context vector $z$ for the input element $x^{(2)}$.