In [2]:
#necessary libraries to build a transformer model 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

In [6]:
#we are defining a class called MultiHeadAttention which is a subclass of the 
#attention nn.Module 

#subclass of nn module 
class MultiHeadAttention(nn.Module): 
    
    #d_model represents the dimension of the input and output, it is the 
    #hidden size of the transformer model 
    #num_heads represent the number of parallel attention heads that is used
    def __init__ (self, d_model, num_heads): 
        
        #calls the constructor of the parent class to initialize the object 
        super(MultiHeadAttention, self). __init__ () 
        
        #makes an assertion that the dimension of the input must be the same as the 
        #dimension of the output or it will return an error 
        assert d_model % num_heads == 0 #d_model must be divisible by num_heads 
        
#      this code initializes a MultiHeadAttention module by specifying the input and output dimension 
#      and the number of attention heads. It then performs a sanity check to ensure that 
#      the output is divisible by the num heads. 

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

#       the method scaled dot product attention is a class within MultiHeadAttention 
#       where the attention scores are calculated

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        
#        mask tensor is used to selectively mask certain positions in the attention score 
#        it controls which position should be attended to(value set to 1) and which positions 
#        should be ignored(value set to 0) if mask tensor is provided then and is not 'none' 
#        then the code masks out certain positions in the attention score by setting the value to a large 
#        negative number. This is done to prevent attending to certain positions 
#        such as padding tokens in the input sequence. 

        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
            
#           The line below applies the softmax function along the last 
#           dimension of the attention scores tensor to obtain the 
#           attention probabilities. Softmax normalizes the scores, 
#           ensuring that they sum up to 1. The attention probabilities 
#           indicate the weights assigned to different positions in the value tensor


            attn_probs = torch.softmax(attn_scores, dim=-1)
    
#           This line computes the weighted sum of the value tensor V using 
#           the attention probabilities.    
    
        output = torch.matmul(attn_probs, V) 
        return output
        
#         this subclass reshapes the input tensor 'x' into a tensor with 
#         dimensions (batch_size, seq_length, self.num_heads, self.d_k) 
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
    
#     this subclass performs the inverse of the split head class. It takes the tensor 'x'
#     containing the output of the attention heads and transposes it back to the 
#     original dimensionality 
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
#         defining the forward subclass of the model, where the linear transformation represented by 
#         self.W_q, self.W_k, self.W_v is applied to the query tensor 'Q','K', 'V' respectively. The 
#         linear transformation projects the query tensor to the query space 
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
#         the feed forward network concatenates the attention scores and return the masked output(tokenized->un-tokenize)
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
        

self.d_model(instance variable) = d_model(constructor argument) assigns the value. This instance variable represents the dimensionality of the input and output of the multi-head attention mechanism.

self.num_heads(instance variable) = num_heads(constructor argument)   This instance variable represents the number of parallel attention heads used in the multi-head attention mechanism.

self.d_k = d_model // num_heads calculates the dimensionality 'd_k' for each attention head. It is derived by dividing the d_model by the num_heads using the floor division operator (//). This calculation ensures that each attention head has an equal dimensionality, as the total d_model is distributed evenly among the num_heads.

The d_k value is important because, in the multi-head attention mechanism, the input tensor is linearly projected into different attention heads. The projection is performed by multiplying the input tensor with projection weight matrices specific to each attention head. The d_k value represents the size of the projected space for each attention head.

self.W_q = nn.Linear(d_model, d_model): This line creates an instance of the nn.Linear class and assigns it to the self.W_q instance variable. 
nn.Linear represents a linear transformation that maps an input tensor of size d_model to an output tensor of size d_model. 
self.W_q will be responsible for projecting the input tensor to the query space.

self.W_k = nn.Linear(d_model, d_model): self.W_k will be responsible for projecting the input tensor to the key space.

self.W_v = nn.Linear(d_model, d_model): self.W_v will be responsible for projecting the input tensor to the value space.

**Query, Key and Value, refer to the All you Need is Attention paper**

self.W_o = nn.Linear(d_model, d_model): self.W_o will be responsible for transforming the concatenated outputs of the attention heads back to the original d_model dimension.