<a href="https://www.kaggle.com/code/aisuko/multiple-head-self-attention-mechanism?scriptVersionId=164119992" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

The Multiple Head Self Attention mechanism, which is a key component in Transformer models used in Natural Language Processing (NLP). In a nutshell, the Multi-Head Self Attention mechanism allows the model to focus on different positions of the input sequence, capturing various aspects of the information. **Multi-Head** means that the model has multiple sets of attention **heads**, allowing it to focus on different parts of the input for each head, thereby capturing a richer range of information.

# Implement it with PyTorch

In [1]:
from torch import nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size=embed_size
        self.num_heads=num_heads
        self.head_dim=embed_size//num_heads
        
        assert(self.head_dim * num_heads==embed_size), "Embedding size needs to be divisible by num_heads"
        
        self.values=nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys=nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries=nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out=nn.Linear(num_heads*self.head_dim, embed_size)
        
        self.layer_norm=nn.LayerNorm(embed_size) # Layer normalization
    
    def forward(self, values, keys, query, mask):
        N=query.shape[0]
        value_len,key_len,query_len=values.shape[1], keys.shape[1],query.shape[1]
        
        # Spliting the embedding into self.num_heads different pieces
        values=values.reshape(N, value_len, self.num_heads,self.head_dim)
        keys=keys.reshape(N, key_len,self.num_heads, self.head_dim)
        query=query.reshape(N, query_len, self.num_heads, self.head_dim)
        
        values=self.values(values)
        keys=self.keys(keys)
        queries=self.queries(query)
        
        # Get the dot product between queries and keys, and apply mask
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy =energy.masked_fill(mask==0, float("-1e20"))
        attention = torch.softmax(energy/(self.embed_size ** (1/2)), dim=3)
        
        out=torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.num_heads*self.head_dim)
        
        out=self.fc_out(out)
        
        # Applying layer normalization
        out=self.layer_norm(out)
        return out

# Testing the Multi-Head Self Attention

In [2]:
import torch

def test_multi_head_self_attention():
    batch_size=64
    sequence_length=100
    embed_size=512
    num_heads=8
    
    # Create a MultiheadSelfAttention instance
    attention = MultiHeadSelfAttention(embed_size, num_heads)
    
    # Create some random data to use as input
    values=torch.randn(batch_size, sequence_length, embed_size)
    keys=torch.randn(batch_size, sequence_length, embed_size)
    query=torch.randn(batch_size, sequence_length, embed_size)
    
    # Create a random mask
    mask=torch.randint(0,2,(batch_size,1,1,sequence_length)).to(torch.bool)
    
    # Pass the data through the attention mechanism
    out=attention(values, keys, query, mask)
    
    # Check that the output has the right shape
    assert out.shape==(batch_size, sequence_length, embed_size),"Output shape is incorrect"
    
test_multi_head_self_attention()

# A Linear Transformation

A linear transformatin typically refers to a transformation of the input data using a set of weights and potentially biases. This is often represented manthematically as `y=Wx+b`, where `W` is the weight matrix, `x` is the input data, `b` is the bias, and `y` is the output data.

In [3]:
# Define a linear layer with 5 input features and 3 output features
linear_layer = nn.Linear(in_features=5, out_features=3)

# Now the linear layer can be used with input of size [batch_size, num_features]
input_tensor = torch.randn(10, 5)
output = linear_layer(input_tensor)

print(input_tensor)
print(output)

tensor([[ 0.2439, -0.1287, -0.0315, -1.3184,  0.7279],
        [-0.8809,  0.8527, -1.2368,  1.0713,  1.9605],
        [-1.9387, -1.7888, -0.0498,  0.7842, -0.8771],
        [ 0.0944, -2.0787,  1.4073,  0.5511, -0.1695],
        [-1.3802,  0.5898, -0.8544,  0.7628, -1.7508],
        [-0.0255, -0.9585, -0.1014, -0.0206,  0.1431],
        [-0.0028, -1.9332, -0.9490,  1.3220,  0.7912],
        [-1.2679, -1.9436, -0.7273, -1.1826,  0.7140],
        [-0.0336,  0.6147, -0.6111, -0.8465,  0.2832],
        [ 0.3057, -0.8951,  1.0789,  0.1866, -1.6518]])
tensor([[-0.4170, -0.4357,  0.6642],
        [ 0.0235, -0.9622, -0.1988],
        [ 0.6236,  0.8460, -0.0572],
        [-0.0265,  0.6817,  0.1711],
        [ 1.4543, -0.1226,  0.3021],
        [ 0.0364, -0.0507,  0.3718],
        [ 0.0096, -0.0601, -0.0108],
        [-0.5574,  0.3546,  0.3470],
        [ 0.0921, -0.6850,  0.6240],
        [ 0.7344,  0.3885,  0.5738]], grad_fn=<AddmmBackward0>)


The weights and bias of the linear layer are automatically initialized and are kept as parameters of the layer, we can access them with below

In [4]:
print(linear_layer.weight)
print(linear_layer.bias)

Parameter containing:
tensor([[-0.0537,  0.1790, -0.0896,  0.2578, -0.4216],
        [-0.2518, -0.3187,  0.2433, -0.0111, -0.1438],
        [ 0.1896,  0.0263, -0.0685, -0.2402, -0.1494]], requires_grad=True)
Parameter containing:
tensor([ 0.2631, -0.3176,  0.4113], requires_grad=True)


# `torch.einsum` function

It is a very powerful function that allows you to perform operation on tensor in a very flexible way. The `String(nqhd,nkhd->nhqk)` argument to `einsum` specifies the operation in a compact way.

* "nqhd,nkhd": This part before the -> describes the dimensions of the input tensors.
* "->nhqk": This part after the -> describes the dimensions of the output tensor.

The operation performed by `einsum` in this case is a sum-product over the shared dimension $h$ (the last dimension of the first tensor and the third dimension of the second tensor). This is equivalent to calculating the dot product between the `queries` and `keys` tensors along the $h$ dimension, resulting in a new tensor with dimensions represented by `nhqk`.


# The Softmax Function

It is used to `convert a vector of real numbers into a probability distribution`. That is, after applying softmax, each element of the output vector will be in the range(0,1), and the sum of the elements will be 1. The softmax funciton is defined as follows:

$$Softmax(x_i) = \frac{exp(x_i)}{\Sigma_{j}(exp(x_j))}$$

Where:

* $x_i$ is the i-th element of the input vector
* $exp$ is the exponential function
* The denominator $\Sigma_{j}(exp(x_j))$ is the sum of the exponential of each element $x_j$ in the input vector

In [5]:
import torch.nn.functional as F

# a tensor
x=torch.tensor([1.0,2.0,3.0])

# apply softmax
y=F.softmax(x, dim=0)

print(y)

tensor([0.0900, 0.2447, 0.6652])
