# MultiHeaded Attention
breaking down each word vector into N different parts and multiplyting each of them with different Linear Layers with different weights to get a better contextual understanding of the sequence of words

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
sequence_length = 4 #size of the sentence
batch_size = 1
input_dim = 512 #vector dimension of every word
d_model =512 #vector dimension of output
x = torch.randn((batch_size,sequence_length,input_dim))
# X is thesequence with positional and imput embeddings

In [3]:
x.shape

torch.Size([1, 4, 512])

In [4]:
qkv_layer = nn.Linear(input_dim,3*d_model) 
"""this creates the linear layer (a simple matrix with learnable values)
created and concatenated for q k and v"""

'this creates the linear layer (a simple matrix with learnable values)\ncreated and concatenated for q k and v'

In [5]:
qkv = qkv_layer(x)

In [6]:
qkv.shape

torch.Size([1, 4, 1536])

In [7]:
number_of_heads = 8 
dim_per_head = d_model//8
qkv = qkv.reshape(batch_size, sequence_length, number_of_heads, 3*dim_per_head)
qkv = qkv.permute(0,2,1,3) #reshaping the arrangement of the tensor to have[1,8,4,3*head_dim]

In [8]:
qkv.shape

torch.Size([1, 8, 4, 192])

In [9]:
Q,K,V = qkv.chunk(3,dim = -1)

In [10]:
Q.shape

torch.Size([1, 8, 4, 64])

In [11]:
d_k = Q.shape[-1]
scaled_Q_matmul_k = torch.matmul(Q,K.transpose(-2,-1)) / math.sqrt(d_k)

In [12]:
scaled_Q_matmul_k.shape

torch.Size([1, 8, 4, 4])

In [13]:
scaled_Q_matmul_k

tensor([[[[-0.2556,  0.1601, -0.0357,  0.1113],
          [-0.0708, -0.2969,  0.2796,  0.2689],
          [-0.6026, -0.4058,  0.6889,  0.0602],
          [ 0.0431,  0.4054,  0.6170, -0.1980]],

         [[ 0.4075, -0.2631, -0.2518, -0.3762],
          [ 0.4232, -0.0257,  0.0805, -0.3280],
          [ 0.3020,  0.0651, -0.0179, -0.0562],
          [-0.1143,  0.1181,  0.1633,  0.0483]],

         [[ 0.0107,  0.0161, -0.2362,  0.0865],
          [ 0.7051,  0.5968,  0.4943, -0.3331],
          [ 0.3352,  0.1637,  0.1089,  0.0669],
          [ 0.6016,  0.1963,  0.0076,  0.1804]],

         [[-0.2375,  0.3913, -0.5017, -0.0871],
          [ 0.2911,  0.2103, -0.2584,  0.1324],
          [-0.6113,  0.5588, -0.1236,  0.1377],
          [-0.0838, -0.6880, -0.3592,  0.1098]],

         [[ 0.1261,  0.1488, -0.1506,  0.2607],
          [ 0.1584, -0.5350,  0.4735,  0.2934],
          [ 0.1575,  0.7036, -0.3544, -0.5250],
          [-0.1817, -0.0067, -0.0210, -0.0542]],

         [[-0.5608, -0.1606,  

## Masking the scaled_q_k_output


In [65]:
mask = torch.full(scaled_Q_matmul_k.size(),-torch.inf)

In [66]:
mask = torch.triu(mask,diagonal = 1)

In [67]:
masked_scaled_q_matmul_k = mask + scaled_Q_matmul_k


In [68]:
masked_scaled_q_matmul_k

tensor([[[[ 0.1378,    -inf,    -inf,    -inf],
          [-0.0831,  0.4652,    -inf,    -inf],
          [-0.4318,  0.2800, -0.2073,    -inf],
          [ 0.2879, -0.1388,  0.1274, -0.1766]],

         [[-0.3567,    -inf,    -inf,    -inf],
          [-0.1704,  0.3626,    -inf,    -inf],
          [ 0.2982,  0.2223,  0.2475,    -inf],
          [-0.0281, -0.2246, -0.0802,  0.6789]],

         [[ 0.0020,    -inf,    -inf,    -inf],
          [ 0.1887, -0.5517,    -inf,    -inf],
          [-0.2751,  0.0403,  0.0251,    -inf],
          [ 0.0970, -0.0690,  0.2044,  0.2624]],

         [[ 0.0536,    -inf,    -inf,    -inf],
          [ 0.0969, -0.1592,    -inf,    -inf],
          [ 0.1331,  0.7730, -0.4895,    -inf],
          [ 0.0140,  0.4209,  0.4723, -0.2291]],

         [[ 0.0296,    -inf,    -inf,    -inf],
          [ 0.0609, -0.5630,    -inf,    -inf],
          [ 0.1409, -0.5189,  0.1368,    -inf],
          [-0.3432,  0.1459, -0.4373,  0.2193]],

         [[ 0.0577,    -inf,  

In [69]:
-torch.inf

-inf

In [70]:
masked_scaled_q_matmul_k.shape

torch.Size([1, 8, 4, 4])

In [73]:
attention = F.softmax(masked_scaled_q_matmul_k,dim = -1)

In [74]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.3663, 0.6337, 0.0000, 0.0000],
        [0.2331, 0.4751, 0.2918, 0.0000],
        [0.3192, 0.2083, 0.2719, 0.2006]], grad_fn=<SelectBackward0>)

In [75]:
values = torch.matmul(attention,V)

In [76]:
values.shape

torch.Size([1, 8, 4, 64])

In [80]:
values = values.reshape(batch_size,sequence_length,number_of_heads*d_k)

In [81]:
values.shape

torch.Size([1, 4, 512])

In [84]:
feed_forward_layer = nn.Linear(d_model,d_model)
output = feed_forward_layer(values)

In [86]:
output.shape

torch.Size([1, 4, 512])