In [1]:
# In this notebook, you learn:
#
# 1) What is Attention?
# 2) What is Multi Headed Attention?

In [2]:
# Resources to go through before continuing with this notebook:
#
# 1) https://www.youtube.com/watch?v=ySEx_Bqxvvo&t=827s
#       -- First discusses RNNs and then moved to Attention Mechanism.
#       -- MUST WATCH -- MUST WATCH -- MUST WATCH -- MUST WATCH -- MUST WATCH
# 2) https://jalammar.github.io/illustrated-transformer/
#       -- Explains the entire transformer model. We only need to focus on attention and 
#          Multi Headed Attention for this notebook.
# 3) https://peterbloem.nl/blog/transformers
#       -- Explains the transformer model with a bit more detail about the implementation.
#       -- Also, explains how the multi headed attention can be implemented using a single
#          set of matrices for Q, K, and V instead of creating h (Number of heads) matrices. 
#          This can be quite confusing and needs time to understand.
# 4) https://www.youtube.com/watch?v=wjZofJX0v4M
#       -- Gives good intution about all the basics needed to understand Attention. In
#          particular, the intution for word embededdings and how they work is very good.
#       -- However, this video is fast paced and it helps if you already have some idea
#          about the concepts from the other resources above.
# 5) https://drive.google.com/file/d/1465IgQ-jYDLFqfBZJ2NcyHwqfeJtkRmF/view?usp=drive_link
#       -- Shows visually how the input is transformed by the Multi-Headed Attention layer to
#          generate the output.
#       -- Please go through this resource before continuing with the notebook.

In [3]:
# NOTE: THIS IS A LONG NOTEBOOK. PLEASE TAKE YOUR TIME TO UNDERSTAND THE CONCEPTS.

In [4]:
from torch import nn, Tensor
from typing import Optional, Tuple

import copy
import math
import torch

In [5]:
# Generating input to experiment with the Multi-Headed Attention
def generate_batch_of_input_data(batch_size: int, seq_len: int, d_model: int) -> Tensor:
    return torch.randn(batch_size, seq_len, d_model)

In [6]:
# Input contains 2 sequences of length 3 (3 tokens) and each token embedding in the sequence is of size 8.
# d_model will be the size of the token embeddings in the actual model which is 512.
input = generate_batch_of_input_data(batch_size=2, seq_len=3, d_model=8)
print("shape: ", input.shape)
print("input: ", input)

shape:  torch.Size([2, 3, 8])
input:  tensor([[[-1.0202, -0.2347,  0.3195,  0.3942,  0.4593, -2.6094, -0.5602,
          -0.9243],
         [-0.6834,  0.0312, -1.2677,  0.4416, -0.3943, -0.6785, -1.0569,
          -0.1173],
         [ 1.4697,  0.1980,  0.3201,  0.2304, -1.1114, -0.3340,  1.0294,
          -0.4959]],

        [[-1.1367,  0.6126,  0.4391,  0.3038,  1.5510, -0.0983, -1.5240,
          -0.3695],
         [-0.6542,  0.6223, -0.5636, -1.5257, -1.9707, -0.6709,  0.3991,
           0.8756],
         [ 0.1549,  1.3908, -1.0694,  1.3774, -0.1948, -0.3535,  0.4294,
           0.4716]]])


In [7]:
# We are using 2 heads in the Multi-Headed Attention.
num_heads: int = 2
# The size of the word embeddings in the input.
d_model: int = 8
# This is the size of the query, key and value vectors for a token in a single head.
d_k: int = d_model // num_heads
print("d_k: ", d_k)
# This is the number of tokens per sentence.
seq_len = 3    # Referred to as 'n{t}' in the 'Input_Transformation_In_Multi_Headed_Attention.pdf'.
# This is the batch size i.e., the input batch contains 2 sentences.
batch_size = 2

d_k:  4


In [8]:
# Refer to 'understanding_nn_linear.ipynb' notebook (Add link to the notebook) to learn about how Linear Layer works.
# A linear layer to generate the query vectors for each token in the sequence.
query_creator = nn.Linear(in_features=d_model, out_features=d_model, bias=True)
print("query_creator: ", query_creator)
# A linear layer to generate the querie vectors for each token in the sequence.
key_creator = nn.Linear(in_features=d_model, out_features=d_model, bias=True)
print("key_creator: ", key_creator)
# A linear layer to generate the querie vectors for each token in the sequence.
value_creator = nn.Linear(in_features=d_model, out_features=d_model, bias=True)
print("value_creator: ", value_creator)

query_creator:  Linear(in_features=8, out_features=8, bias=True)
key_creator:  Linear(in_features=8, out_features=8, bias=True)
value_creator:  Linear(in_features=8, out_features=8, bias=True)


In [9]:
# A single linear layer (a single matrix multiplication) is going to generate the queries for all the heads.
#
# input: [batch_size, seq_len, d_model]
# queries: [batch_size, seq_len, d_model]
# 
# Lets run through a toy example to explain the query creation.
#
# For the toy example:
# d_model   = 4
# num_heads = 2
# d_k       = d_model // num_heads = 2
# seq_len  = 3
#
# Lets ignore the batch dimension for now. 
# 
# The input will look something like this:
# input = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] 
# [1, 2, 3, 4] --> input vector for first token in the sequence.
# [5, 6, 7, 8] --> input vector for second token in the sequence.
# [9, 10, 11, 12] --> input vector for third token in the sequence.
# 
# Assume query_creator is an identity function for this example i.e., Q = I (Identity Matrix). So, the input is same as output.
# The output of query_creator will look something like this:
# output = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
# 
# This output contains the queries for all the heads for all the tokens in the sequence.
# [1, 2, 3, 4] --> [1, 2] | [3, 4] --> Logical division of the query vector for the first token.
# [1, 2] --> query vector for the first token in the first head.
# [3, 4] --> query vector for the first token in the second head.
#
# [5, 6, 7, 8] --> [5, 6] | [7, 8] --> Logical division of the query vector for the second token.
# [5, 6] --> query vector for the second token in the first head.
# [7, 8] --> query vector for the second token in the second head.
#
# [9, 10, 11, 12] --> [9, 10] | [11, 12] --> Logical division of the query vector for the third token.
# [9, 10] --> query vector for the third token in the first head.
# [11, 12] --> query vector for the third token in the second head.
#
# We basically are reducing the size of the query vector for each head so that the overall size is
# still the same as d_model.
#
# The above example is for a single sequence. The actual input will have a batch dimension as well.
# The exact same process is repeated for every sequence (or 2D matrix) in the batch.
queries = query_creator(input)
print("shape: ", queries.shape)
print("queries: ", queries)

shape:  torch.Size([2, 3, 8])
queries:  tensor([[[-0.2101, -0.3526,  0.2313, -0.2076, -0.6798, -0.0843, -0.4807,
           0.4135],
         [ 0.3132, -0.7180, -0.4983, -0.1015, -0.0967,  0.1530,  0.2631,
          -0.0825],
         [ 0.0373,  0.0739,  0.2784,  0.5304,  0.2237, -0.7869, -0.9946,
           0.8966]],

        [[ 0.8011, -0.2928,  0.1292, -0.6189, -0.1394,  0.5197,  0.2627,
           0.4689],
         [ 0.4967, -0.3740,  0.0132, -0.6353, -0.8985,  0.1501, -0.5515,
          -0.3980],
         [ 0.2963, -0.3409, -0.2049, -0.0806,  0.6988,  0.1840, -0.6014,
           0.3676]]], grad_fn=<ViewBackward0>)


In [10]:
# Key and Value creation is similar to Query creation. We just use a different linear layer for each i.e.,
# key_creator for keys creation and value_creator for values creation respectively.
keys = key_creator(input)
print("shape: ", keys.shape)
print("keys: ", keys)
print("-" * 150)

values = value_creator(input)
print("shape: ", values.shape)
print("values: ", values)

shape:  torch.Size([2, 3, 8])
keys:  tensor([[[-0.1720, -0.0974, -0.4074, -0.3137, -0.7831, -1.3113,  0.1043,
          -0.3691],
         [-0.3618, -0.0844, -0.5938, -0.0148, -0.2940, -0.6218,  0.5029,
          -0.4131],
         [-0.1265,  0.2863,  0.0526,  0.3498, -0.4518, -0.1813, -0.0158,
          -0.1033]],

        [[-0.0239, -1.1511, -0.4069, -1.3810,  0.2647, -0.6113, -0.0314,
           0.5024],
         [-1.4270,  0.5460,  0.1839,  0.6636, -0.1671, -0.5133,  0.8250,
          -0.3704],
         [-0.1410, -0.6418, -0.3549, -0.5290, -0.7715, -0.4057,  0.6064,
           0.0792]]], grad_fn=<ViewBackward0>)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 3, 8])
values:  tensor([[[ 0.5662,  1.2474,  0.6110,  0.3846,  0.7495,  0.0327, -0.5575,
          -0.4441],
         [ 0.3909,  0.2873,  0.0760,  0.7672,  0.7780,  0.0809, -0.5373,
          -0.4977],


In [11]:
# Refer to 'understanding_tenson_manipulations_part_6.ipynb' (Add link to the notebook) to 
# understand more about the 'view' and 'reshape' operations.
#
# Now, lets separate the queries, keys and values for each head. We already have a 3D tensor.
# Now, we will obtain a 4D tensor, where each 3D tensor holds the queries for all the heads for a  
# single sequence (or sentence).
#
# queries             : [batch_size, seq_len, d_model]
# transformed_queries : [batch_size, seq_len, num_heads, d_k]
#
# It is important to know how the elements are rearranged by this operation. So, please spend some time
# to understand the 'view' operation in detail. This will keep coming up in the Multi-Headed Attention.
queries = queries.view(batch_size, seq_len, num_heads, d_k)
print("shape: ", queries.shape)
print("queries: ", queries)

shape:  torch.Size([2, 3, 2, 4])
queries:  tensor([[[[-0.2101, -0.3526,  0.2313, -0.2076],
          [-0.6798, -0.0843, -0.4807,  0.4135]],

         [[ 0.3132, -0.7180, -0.4983, -0.1015],
          [-0.0967,  0.1530,  0.2631, -0.0825]],

         [[ 0.0373,  0.0739,  0.2784,  0.5304],
          [ 0.2237, -0.7869, -0.9946,  0.8966]]],


        [[[ 0.8011, -0.2928,  0.1292, -0.6189],
          [-0.1394,  0.5197,  0.2627,  0.4689]],

         [[ 0.4967, -0.3740,  0.0132, -0.6353],
          [-0.8985,  0.1501, -0.5515, -0.3980]],

         [[ 0.2963, -0.3409, -0.2049, -0.0806],
          [ 0.6988,  0.1840, -0.6014,  0.3676]]]], grad_fn=<ViewBackward0>)


In [12]:
# Keys and Values are also transformed in the same way as queries.
keys = keys.view(batch_size, seq_len, num_heads, d_k)
print("shape: ", keys.shape)
print("keys: ", keys)
print("-" * 150)

values = values.view(batch_size, seq_len, num_heads, d_k)
print("shape: ", values.shape)
print("values: ", values)

shape:  torch.Size([2, 3, 2, 4])
keys:  tensor([[[[-0.1720, -0.0974, -0.4074, -0.3137],
          [-0.7831, -1.3113,  0.1043, -0.3691]],

         [[-0.3618, -0.0844, -0.5938, -0.0148],
          [-0.2940, -0.6218,  0.5029, -0.4131]],

         [[-0.1265,  0.2863,  0.0526,  0.3498],
          [-0.4518, -0.1813, -0.0158, -0.1033]]],


        [[[-0.0239, -1.1511, -0.4069, -1.3810],
          [ 0.2647, -0.6113, -0.0314,  0.5024]],

         [[-1.4270,  0.5460,  0.1839,  0.6636],
          [-0.1671, -0.5133,  0.8250, -0.3704]],

         [[-0.1410, -0.6418, -0.3549, -0.5290],
          [-0.7715, -0.4057,  0.6064,  0.0792]]]], grad_fn=<ViewBackward0>)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 3, 2, 4])
values:  tensor([[[[ 0.5662,  1.2474,  0.6110,  0.3846],
          [ 0.7495,  0.0327, -0.5575, -0.4441]],

         [[ 0.3909,  0.2873,  0.0760,  0.7672],
     

In [13]:
# Refer to 'understanding_tensor_manipulations_part_6.ipynb' (link to the notebook) to understand 
# more about the transpose operation.
#
# queries           : [batch_size, seq_len, num_heads, d_k]
# transposed_queries: [batch_size, num_heads, seq_len, d_k]
#
# Matrix multiplication is performed (by Pytorch) on the last two dimensions of the tensors. We calculate the 
# attention scores for each token by every other token in the sentence. So, for each token in a sentence, we 
# need to compute the dot product of the query vector with the key vectors of all other tokens (including the 
# current token). So, we need to rearrange the queries and keys so that the tokens iterate on dimension 2 and 
# token vectors (or embeddings) iterate on dimension 3. This is why we transpose all the tensors here.
queries = queries.transpose(dim0=1, dim1=2)
print("shape: ", queries.shape)
print("queries: ", queries)

shape:  torch.Size([2, 2, 3, 4])
queries:  tensor([[[[-0.2101, -0.3526,  0.2313, -0.2076],
          [ 0.3132, -0.7180, -0.4983, -0.1015],
          [ 0.0373,  0.0739,  0.2784,  0.5304]],

         [[-0.6798, -0.0843, -0.4807,  0.4135],
          [-0.0967,  0.1530,  0.2631, -0.0825],
          [ 0.2237, -0.7869, -0.9946,  0.8966]]],


        [[[ 0.8011, -0.2928,  0.1292, -0.6189],
          [ 0.4967, -0.3740,  0.0132, -0.6353],
          [ 0.2963, -0.3409, -0.2049, -0.0806]],

         [[-0.1394,  0.5197,  0.2627,  0.4689],
          [-0.8985,  0.1501, -0.5515, -0.3980],
          [ 0.6988,  0.1840, -0.6014,  0.3676]]]],
       grad_fn=<TransposeBackward0>)


In [14]:
# Keys and Values are also transformed in the same way as queries.
keys = keys.transpose(dim0=1, dim1=2)
print("shape: ", keys.shape)
print("keys: ", keys)
print("-" * 150)

values = values.transpose(dim0=1, dim1=2)
print("shape: ", values.shape)
print("values: ", values)

shape:  torch.Size([2, 2, 3, 4])
keys:  tensor([[[[-0.1720, -0.0974, -0.4074, -0.3137],
          [-0.3618, -0.0844, -0.5938, -0.0148],
          [-0.1265,  0.2863,  0.0526,  0.3498]],

         [[-0.7831, -1.3113,  0.1043, -0.3691],
          [-0.2940, -0.6218,  0.5029, -0.4131],
          [-0.4518, -0.1813, -0.0158, -0.1033]]],


        [[[-0.0239, -1.1511, -0.4069, -1.3810],
          [-1.4270,  0.5460,  0.1839,  0.6636],
          [-0.1410, -0.6418, -0.3549, -0.5290]],

         [[ 0.2647, -0.6113, -0.0314,  0.5024],
          [-0.1671, -0.5133,  0.8250, -0.3704],
          [-0.7715, -0.4057,  0.6064,  0.0792]]]],
       grad_fn=<TransposeBackward0>)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 2, 3, 4])
values:  tensor([[[[ 0.5662,  1.2474,  0.6110,  0.3846],
          [ 0.3909,  0.2873,  0.0760,  0.7672],
          [-0.3288, -0.8719,  1.3444, -0.0021]]

In [15]:
# We now have queries, keys and values for each head in the Multi-Headed Attention. The next step is to 
# calculate the attention scores for each token within each head.

In [16]:
# queries: [batch_size, num_heads, seq_len, d_k]
# keys   : [batch_size, num_heads, seq_len, d_k]
# 
# To calculate the attention scores, we need to compute [queries * keys^{Transpose}] for each sentence. 
# * here represents matrix multiplication. So, we transpose keys and then perform the matrix multiplication.
#
# transposed_keys: [batch_size, num_heads, d_k, seq_len]
keys = keys.transpose(dim0=2, dim1=3)
print("shape: ", keys.shape)
print("keys: ", keys)

shape:  torch.Size([2, 2, 4, 3])
keys:  tensor([[[[-0.1720, -0.3618, -0.1265],
          [-0.0974, -0.0844,  0.2863],
          [-0.4074, -0.5938,  0.0526],
          [-0.3137, -0.0148,  0.3498]],

         [[-0.7831, -0.2940, -0.4518],
          [-1.3113, -0.6218, -0.1813],
          [ 0.1043,  0.5029, -0.0158],
          [-0.3691, -0.4131, -0.1033]]],


        [[[-0.0239, -1.4270, -0.1410],
          [-1.1511,  0.5460, -0.6418],
          [-0.4069,  0.1839, -0.3549],
          [-1.3810,  0.6636, -0.5290]],

         [[ 0.2647, -0.1671, -0.7715],
          [-0.6113, -0.5133, -0.4057],
          [-0.0314,  0.8250,  0.6064],
          [ 0.5024, -0.3704,  0.0792]]]], grad_fn=<TransposeBackward0>)


In [17]:
# queries: [batch_size, num_heads, seq_len, d_k]
# keys   : [batch_size, num_heads, d_k, seq_len]
#
# We scale the scores by the square root of the dimension of the vectors (d_k) to ensure that the scores don't 
# grow too large. We later apply softmax on these scores to normalize them. So, during the gradient calculation, 
# if the values after applying softmax are too large, the gradients will be too small. To avoid this, we scale 
# the scores to bring them down.
#
# attention_scores: [batch_size, num_heads, seq_len, seq_len]
#
# Lets ignore the batch dimension for now.
# Each 2D matrix in the attention_scores tensor will contain the attention scores for each token
# in the sequence with every other token in the sequence.
# Example: 
# [0.0207,  -0.0143,  -0.0674] --> Attention scores for the first token in the sequence.
# 
# 0.0207  --> Attention score for the first token with the first token.
# -0.0143 --> Attention score for the first token with the second token.
# -0.0674 --> Attention score for the first token with the third token.
# 
# ofcourse, the numbers will change if you run this cell again.
attention_scores = torch.matmul(queries, keys) / math.sqrt(d_k)
print("shape: ", attention_scores.shape)
print("attention_scores: \n", attention_scores)

shape:  torch.Size([2, 2, 3, 3])
attention_scores: 
 tensor([[[[ 0.0207, -0.0143, -0.0674],
          [ 0.1255,  0.1223, -0.1535],
          [-0.1467, -0.0965,  0.1083]],

         [[ 0.2201, -0.0801,  0.1437],
          [-0.0335,  0.0499,  0.0102],
          [ 0.2111, -0.2235, -0.0176]]],


        [[[ 0.5600, -0.8450,  0.1783],
          [ 0.6453, -0.6661,  0.2507],
          [ 0.2900, -0.3500,  0.1462]],

         [[-0.0637, -0.1002,  0.0466],
          [-0.2561, -0.1172,  0.1331],
          [ 0.1381, -0.4218, -0.4746]]]], grad_fn=<DivBackward0>)


In [18]:
# You can skip this cell (and the next one) and come back to this after completing the entire 
# notebook by ignoring masks. We mask the attention_scores if a mask is provided.
# We have 1 mask for 1 sequence in the batch. We will create a random mask for now for experimentation.
# Please note that this random mask might not even follow all the rules that the actual mask should 
# follow. Refer to 'step_5_data_batching_and_masking.ipynb' to understand more about creating masks 
# for the transformer inputs.
#
# attention_scores: [batch_size, num_heads, seq_len, seq_len]
# So, we also need the mask to have 4 dimensions. Every head uses the same mask since it is 
# essentially the same sentence being used in all the heads.
# Remember that 'True' in 'mask_random' means the value should not be masked and 'False' means the 
# value should be masked.
mask_random = torch.randint(low=-50, high=50, size=(batch_size, 1, seq_len, seq_len)) < 0
print("shape: ", mask_random.shape)
print("mask_random: \n", mask_random)

shape:  torch.Size([2, 1, 3, 3])
mask_random: 
 tensor([[[[False,  True,  True],
          [ True, False, False],
          [ True,  True, False]]],


        [[[ True, False,  True],
          [False, False, False],
          [ True, False, False]]]])


In [19]:
# The same mask is applied to all the heads of a single sentence. Notice that the
# corresponding values (with values as False) in the attention_scores tensor are 
# set to '-1e9'. Please do not set this to '-inf' as it will cause the softmax to
# return 'nan' values in some cases --> You can try setting it to '-inf' and see 
# what happens.
attention_scores = attention_scores.masked_fill(mask_random == False, float('-1e9'))
print("shape: ", attention_scores.shape)
print("attention_scores: \n", attention_scores)

shape:  torch.Size([2, 2, 3, 3])
attention_scores: 
 tensor([[[[-1.0000e+09, -1.4259e-02, -6.7406e-02],
          [ 1.2546e-01, -1.0000e+09, -1.0000e+09],
          [-1.4671e-01, -9.6459e-02, -1.0000e+09]],

         [[-1.0000e+09, -8.0123e-02,  1.4368e-01],
          [-3.3508e-02, -1.0000e+09, -1.0000e+09],
          [ 2.1106e-01, -2.2352e-01, -1.0000e+09]]],


        [[[ 5.6001e-01, -1.0000e+09,  1.7827e-01],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09],
          [ 2.8996e-01, -1.0000e+09, -1.0000e+09]],

         [[-6.3654e-02, -1.0000e+09,  4.6578e-02],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09],
          [ 1.3805e-01, -1.0000e+09, -1.0000e+09]]]],
       grad_fn=<MaskedFillBackward0>)


In [20]:
# The attention scores for the tokens which have been masked are set to zero by the softmax function.
# When we say set to zero, we mean the value is so small that it is practically zero. This means 
# these tokens won't contribute to the value calculation. 
attention_scores = attention_scores.softmax(dim=-1)
print("shape: ", attention_scores.shape)
print("attention_scores: ", attention_scores)

shape:  torch.Size([2, 2, 3, 3])
attention_scores:  tensor([[[[0.0000, 0.5133, 0.4867],
          [1.0000, 0.0000, 0.0000],
          [0.4874, 0.5126, 0.0000]],

         [[0.0000, 0.4443, 0.5557],
          [1.0000, 0.0000, 0.0000],
          [0.6070, 0.3930, 0.0000]]],


        [[[0.5943, 0.0000, 0.4057],
          [0.3333, 0.3333, 0.3333],
          [1.0000, 0.0000, 0.0000]],

         [[0.4725, 0.0000, 0.5275],
          [0.3333, 0.3333, 0.3333],
          [1.0000, 0.0000, 0.0000]]]], grad_fn=<SoftmaxBackward0>)


In [22]:
# We now have the attention scores for each token in each head. The next step is to calculate the 
# weighted sum of the values using these scores. These are called the attention heads.
#
# attention_scores: [batch_size, num_heads, seq_len, seq_len]
# values          : [batch_size, num_heads, seq_len, d_k]
#
# Lets consider 2 Matrices Mat1 and Mat2 and Output = Mat1 * Mat2.
# Mat1: [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
# Mat2: [[10, 11, 12], [13, 14, 15], [16, 17, 18]]
# 
# Output[0] = Mat1[0] * Mat2 = [1, 2, 3] * [[10, 11, 12], [13, 14, 15], [16, 17, 18]]
#          = 1 * [10, 11, 12] + 2 * [13, 14, 15] + 3 * [16, 17, 18]
#
# So, the first row of the output matrix will be the dot product of the first row of Mat1 with
# all the rows of Mat2.
#
# Output[1] = 4 * [10, 11, 12] + 5 * [13, 14, 15] + 6 * [16, 17, 18]
# Output[2] = 7 * [10, 11, 12] + 8 * [13, 14, 15] + 9 * [16, 17, 18]
# 
# For our case, 1, 2, 3 are the attention scores and [10, 11, 12], [13, 14, 15], [16, 17, 18]
# are the corresponding value vectors.
#
# Each row in the last dimension of the values tensor contain the values for the tokens in the 
# sequence. So, the weighted sum of the values for a token in the sequence is calculated by 
# multiplying the attention scores for that token with the values for all the tokens in the 
# sequence.
#
# attention_heads: [batch_size, num_heads, seq_len, d_k]
attention_heads = torch.matmul(attention_scores, values)
print("shape: ", attention_heads.shape)
print("attention_heads: ", attention_heads)

shape:  torch.Size([2, 2, 3, 4])
attention_heads:  tensor([[[[ 0.0406, -0.2769,  0.6934,  0.3928],
          [ 0.5662,  1.2474,  0.6110,  0.3846],
          [ 0.4764,  0.7553,  0.3367,  0.5807]],

         [[ 0.6527,  0.2197, -0.0136, -0.3940],
          [ 0.7495,  0.0327, -0.5575, -0.4441],
          [ 0.7607,  0.0516, -0.5495, -0.4652]]],


        [[[ 0.4334,  0.2426,  0.3511,  0.6911],
          [ 0.0757,  0.1837,  0.3198,  0.4638],
          [ 0.7716,  0.8189, -0.1598,  0.7462]],

         [[-0.1464,  0.5686, -0.6139, -0.1575],
          [ 0.1065,  0.4336, -0.2820, -0.3727],
          [-0.2320,  0.1364, -0.8177, -0.1223]]]],
       grad_fn=<UnsafeViewBackward0>)


In [23]:
# We now have attention heads for each token in each head. The next step is to concatenate the
# attention heads for each token from all the heads. Before concatenating, we need to rearrange
# the dimensions of the attention heads tensor appropriately. This is a preparation step for to 
# apply the concatenation operation in the next step.
#
# attention_heads           : [batch_size, num_heads, seq_len, d_k]
# transposed_attention_heads: [batch_size, seq_len, num_heads, d_k]
attention_heads = attention_heads.transpose(dim0=1, dim1=2)
print("shape: ", attention_heads.shape)
print("attention_heads: ", attention_heads)

shape:  torch.Size([2, 3, 2, 4])
attention_heads:  tensor([[[[ 0.0406, -0.2769,  0.6934,  0.3928],
          [ 0.6527,  0.2197, -0.0136, -0.3940]],

         [[ 0.5662,  1.2474,  0.6110,  0.3846],
          [ 0.7495,  0.0327, -0.5575, -0.4441]],

         [[ 0.4764,  0.7553,  0.3367,  0.5807],
          [ 0.7607,  0.0516, -0.5495, -0.4652]]],


        [[[ 0.4334,  0.2426,  0.3511,  0.6911],
          [-0.1464,  0.5686, -0.6139, -0.1575]],

         [[ 0.0757,  0.1837,  0.3198,  0.4638],
          [ 0.1065,  0.4336, -0.2820, -0.3727]],

         [[ 0.7716,  0.8189, -0.1598,  0.7462],
          [-0.2320,  0.1364, -0.8177, -0.1223]]]],
       grad_fn=<TransposeBackward0>)


In [24]:
# We now have attention heads in the required shape. The next step is to concatenate the attention
# heads for each token from all the heads to get a single attention head vector per token.
# 
# attentions_heads: [batch_size, seq_len, num_heads, d_k]
# concatenated_attention_heads: [batch_size, seq_len, d_model]
#
# Since the tensor is not contiguous, we use reshape instead of view. view might fail if the 
# tensor is not contiguous.
print("is_contiguous: ", attention_heads.is_contiguous())
attention_heads = attention_heads.reshape(batch_size, seq_len, d_model)
print("shape: ", attention_heads.shape)
print("attention_heads: ", attention_heads)
print("is_contiguous: ", attention_heads.is_contiguous())

is_contiguous:  False
shape:  torch.Size([2, 3, 8])
attention_heads:  tensor([[[ 0.0406, -0.2769,  0.6934,  0.3928,  0.6527,  0.2197, -0.0136,
          -0.3940],
         [ 0.5662,  1.2474,  0.6110,  0.3846,  0.7495,  0.0327, -0.5575,
          -0.4441],
         [ 0.4764,  0.7553,  0.3367,  0.5807,  0.7607,  0.0516, -0.5495,
          -0.4652]],

        [[ 0.4334,  0.2426,  0.3511,  0.6911, -0.1464,  0.5686, -0.6139,
          -0.1575],
         [ 0.0757,  0.1837,  0.3198,  0.4638,  0.1065,  0.4336, -0.2820,
          -0.3727],
         [ 0.7716,  0.8189, -0.1598,  0.7462, -0.2320,  0.1364, -0.8177,
          -0.1223]]], grad_fn=<UnsafeViewBackward0>)
is_contiguous:  True


In [25]:
# We finally pass our concatenated attention heads through a linear layer to get the output of the
# multi-headed attention layer.
output_creator = nn.Linear(in_features=d_model, out_features=d_model)
print("output_creator: ", output_creator)

output_creator:  Linear(in_features=8, out_features=8, bias=True)


In [26]:
# attention_heads: [batch_size, seq_len, d_model]
# multi_head_attention_output: [batch_size, seq_len, d_model]
multi_head_attention_output = output_creator(attention_heads)
print("shape: ", multi_head_attention_output.shape)
print("multi_head_attention_output: ", multi_head_attention_output)

shape:  torch.Size([2, 3, 8])
multi_head_attention_output:  tensor([[[ 0.5358,  0.0076,  0.2317,  0.2802, -0.4445,  0.3625, -0.0881,
           0.4434],
         [ 0.6170, -0.0066, -0.5055,  0.6572,  0.0556,  0.0129, -0.3898,
           0.8623],
         [ 0.5143,  0.0054, -0.3420,  0.5535, -0.0222, -0.0262, -0.3684,
           0.8703]],

        [[ 0.0982, -0.3812, -0.0171,  0.1120, -0.1386,  0.3311, -0.1053,
           0.6196],
         [ 0.3642, -0.3084,  0.0067,  0.1829, -0.1968,  0.3673, -0.0892,
           0.5677],
         [-0.0052, -0.3488, -0.3012,  0.1479,  0.1463, -0.0018, -0.2324,
           0.7958]]], grad_fn=<ViewBackward0>)


## Multi-Headed Attention In Transformers

In [27]:
# Creates a copy (deepcopy) of the module and returns ModuleList containing the copies.
def clone_module(module: nn.Module, num_clones: int) -> nn.ModuleList:
    return nn.ModuleList([copy.deepcopy(module) for _ in range(num_clones)])

In [28]:
def construct_attention_heads(queries: Tensor, keys: Tensor, values: Tensor, mask: Optional[Tensor]=None, dropout_layer: Optional[nn.Module]=None) -> Tuple[Tensor, Tensor]:
    """Calculates the attention scores for each token in the sequence with every other token in the sequence.
       Applues the mask if provided and then normalizes the scores using softmax. It then calculates the 
       attention heads for each token in the sequence.

    Args:
        queries (Tensor): [batch_size, num_heads, seq_len, d_k]
        keys (Tensor): [batch_size, num_heads, seq_len, d_k]
        values (Tensor): [batch_size, num_heads, seq_len, d_k]
        mask (Optional[Tensor], optional): [batch_size, 1, seq_len, seq_len]. Defaults to None.
        dropout_layer (Optional[nn.Module], optional): probability with which the values are dropped on dropout layer. Defaults to None.

    Returns:
        Tuple[Tensor, Tensor]: Returns the attention heads and the attention scores.
                               attention_heads: [batch_size, num_heads, seq_len, d_k]
                               attention_scores: [batch_size, num_heads, seq_len, seq_len]
    """
    # Size of the vectors for each token for each head in the sequence.
    d_k = queries.shape[-1]
    # Calculate the attention scores for each token in the sequence with every other token in the sequence.
    attention_scores = torch.matmul(queries, keys.transpose(dim0=2, dim1=3)) / math.sqrt(d_k)
    # Mask the attention scores if a mask is provided. Mask is used in two different ways:
    # 1) To prevent the model from attending to the padding tokens --> This applies for both src and tgt sentences.
    # 2) To prevent the model from attending to the future tokens in the sequence --> This applies only for tgt sentences.
    if mask is not None:
        # Please do not set the masked values to float('-inf') as it sometimes (not in everycase) causes softmax to return nan.
        attention_scores = attention_scores.masked_fill(mask == False, float('-1e9'))
    # Normalize the attention scores using softmax.
    attention_scores = attention_scores.softmax(dim=-1)
    # Apply dropout regularization to prevent overfitting problems.
    if dropout_layer is not None:
        dropout_layer(attention_scores)
    # Calculate the attention heads for each token in the sequence. The head for each token is calculated by
    # taking the weighted average (averaged by attention scores) of the values for all the tokens in the 
    # sequence for the token of interest.
    attention_heads = torch.matmul(attention_scores, values)
    return attention_heads, attention_scores

In [29]:
# We are just going to combine everything above and put it into a class. This class will be used
# to create the Multi-Headed Attention layer in the Transformer model.
# Refer to 'using_modules.ipynb' (Add link to the notebook) to understand more about Pytorch modules.
class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads: int, d_model: int, dropout_prob: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads."
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        # We use dropout to prevent overfitting.
        self.dropout_layer = nn.Dropout(p=dropout_prob)
        # Creating the linear layers that generate queries, keys and values for each token in the sequence.
        # Also, creating an additional linear layer to generate the output of the Multi-Headed Attention from concatenated attention heads.
        self.linear_layers = clone_module(module=nn.Linear(in_features=d_model, out_features=d_model), num_clones=4)


    def forward(self, query_input: Tensor, key_input: Tensor, value_input: Tensor, mask: Optional[Tensor]=None) -> Tensor:
        """Forward pass of the Multi-Headed Attention layer. 

        Args:
            query (Tensor): Input to be used for query creation.
                            query_input: [batch_size, seq_len, d_model]
            key (Tensor): Input to be used for key creation.
                          key_input  : [batch_size, seq_len, d_model]
            value (Tensor): Input to be used for value creation.
                            value_input: [batch_size, seq_len, d_model]
            mask (Tensor): Mask to be applied to the attention scores. Default is None. Same mask will 
                           be applied to all the heads in the Multi-Headed Attention layer.
                           mask: [batch_size, 1, seq_len, seq_len]

        Returns:
            Mutli-Headed Attention Output: Output of the Multi-Headed Attention layer. Generates one output vector 
                                           for each token in the sequence. Does this for each sequence in the batch.
                                           output: [batch_size, seq_len, d_model]
        """
        # Generates the queries, keys and values for each token in the sequence.
        # shape of queries, keys, values: [batch_size, seq_len, d_model]
        queries, keys, values = [linear_layer(input) for linear_layer, input in zip(self.linear_layers, (query_input, key_input, value_input))]
        batch_size = query_input.shape[0]
        seq_len = query_input.shape[1]
        # Separating the queries, keys and values for each head into a separate vector. The vectors for each token in all the heads
        # are concatenated when they are created using the linear_layers above.
        # Shape for queries, keys, values after view: [batch_size, seq_len, num_heads, d_k]
        # Shape for queries, key, values after transpose: [batch_size, num_heads, seq_len, d_k]
        queries, keys, values = [data.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(dim0=1, dim1=2) for data in (queries, keys, values)]
        # Calculate the attention heads for each token in the sequence.
        # attention_heads: [batch_size, num_heads, seq_len, d_k]
        attention_heads, attention_scores = construct_attention_heads(queries=queries, keys=keys, values=values, mask=mask, dropout_layer=self.dropout_layer)
        # Concatenate the attention heads for each token from all the heads.
        # attention_heads: [batch_size, seq_len, d_model]
        attention_heads = attention_heads.transpose(dim0=1, dim1=2).reshape(batch_size, seq_len, self.d_model)
        # Generate the output of the Multi-Headed Attention layer.
        return self.linear_layers[-1](attention_heads)

In [30]:
multi_headed_attention_layer = MultiHeadedAttention(num_heads=2, d_model=8, dropout_prob=0.1)
print(multi_headed_attention_layer)

MultiHeadedAttention(
  (dropout_layer): Dropout(p=0.1, inplace=False)
  (linear_layers): ModuleList(
    (0-3): 4 x Linear(in_features=8, out_features=8, bias=True)
  )
)


In [31]:
print("input shape: ", input.shape)
print("input: ", input)
transformer_multi_headed_attention_output = multi_headed_attention_layer(query_input=input, key_input=input, value_input=input)
print("output shape: ", transformer_multi_headed_attention_output.shape)
print("output: ", transformer_multi_headed_attention_output)

input shape:  torch.Size([2, 3, 8])
input:  tensor([[[-0.3415,  1.7753,  0.0380,  0.1188,  0.4282,  0.2899, -0.4513,
           0.3557],
         [-0.6632, -1.2565, -0.1667, -1.9468, -0.0124, -0.7232,  0.2141,
          -0.5374],
         [-0.2168,  0.7195, -2.5153,  1.2357, -0.7998, -0.4453,  0.7830,
           1.4223]],

        [[-1.2573,  1.0951,  0.9087, -0.7431, -0.2753,  1.4350,  1.1018,
          -0.7295],
         [ 0.3248, -0.9626, -0.2803, -1.0175,  0.2620, -0.6605, -1.0310,
           0.5023],
         [ 1.0583, -1.7163, -0.7393,  1.1486,  1.1051, -1.3013,  0.2931,
          -1.4220]]])
output shape:  torch.Size([2, 3, 8])
output:  tensor([[[-0.1944,  0.1598,  0.3522,  0.1647, -0.2321,  0.1088, -0.2963,
           0.4579],
         [ 0.1834,  0.0292,  0.2885,  0.3071, -0.1512,  0.2481, -0.1918,
           0.3766],
         [-0.3755, -0.0380,  0.4772,  0.3606, -0.0559,  0.1052, -0.1906,
           0.7190]],

        [[ 0.3711, -0.1907,  0.3214,  0.0076, -0.2554, -0.1442,  0.