In [None]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
"""" Sequence length = maximum number of characters/words we can pass into the transformer at a time.
Dimension of embedding = size of vector representing each character / word."""

sequence_length = 4 ## length of the input, in out case (I live in MTL)
batch_size = 1 ## batch is for parallal processing.
input_dim = 512 ## this is size of vector that represents each of the sequence/word
d_model = 512 ## output size of attention unit of every single word
"""# here the logic is, first place the batch_Size, then length of our sequence,
# then check how many vectors are representing each words."""
x = torch.randn(batch_size, sequence_length, input_dim) # randomly generated data. As we did not use positional encoding yet.

In [None]:
x.size

<function Tensor.size>

In [None]:
"""Here, we have to create query, key and values thats why we multiply 3 times the d_model."""
qkv_layer = nn.Linear(input_dim, 3 * d_model)

In [None]:
""" now pass our input through the q,k,v generator """
qkv = qkv_layer(x)

In [None]:
qkv.shape ## here it means, 1 batch, 4 sequence length and 512x3 = 1536 vector for each word

torch.Size([1, 4, 1536])

In [None]:
num_heads = 8  # number of heads of multi-head attention
head_dim = d_model // num_heads # each head dimention would be d_model(512) // num_head (8)
"""  # reshape the qkv, specially the last dimention of the matrix,
here last dimention devided into num_heads and 3 x head_dim. here its 3 because, Q, k & V."""
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)

In [None]:
qkv.shape

torch.Size([1, 4, 8, 192])

In [None]:
""" we have to permute the 2nd and 3rd dimention switching between sequence length
and num_heads, which will be helpful for parallal computation """
qkv = qkv.permute(0, 2, 1, 3)
qkv.shape ## [ batch_size, num_heads, sequence_length, 3* head_dim]

torch.Size([1, 8, 4, 192])

In [None]:
q, k, v = qkv.chunk(3, dim = -1) # here we created the q, k & v.thatwhy we 3, and using the last dimention so -1
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [None]:
d_k = q.size() [-1] # find the size of the vector dimention,last value
"""Here we applied transpose on key to match the size of tensor with query.
usually, we transpose  k.T like this, but here we needed to wrtite k.transpose,
becuase this are not matrix, these are tensor. in case of matrix we can transpose
k.T like this. Here we mentioned the transpose dimentions, here we want to transpose
last two diementions, sequence_length and head_dimention size """
scaled = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

masking, which is needed for decoder.

In [None]:
""" # we will mask the upcoming words,
 so placed with -inf that after the softmax func. there is no information.
 As softmax uses exponance(exp) of each element, so exp of -inf will be zero"""
mask = torch.full(scaled.size(), float('-inf'))
mask = torch.triu(mask, diagonal = 1)
mask [0] [1]

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [None]:
(scaled + mask) [0] [0]

tensor([[-0.1238,    -inf,    -inf,    -inf],
        [ 0.1354,  0.0572,    -inf,    -inf],
        [ 0.5056, -0.1130, -0.3738,    -inf],
        [ 0.0885,  0.2727,  0.0360,  0.0967]], grad_fn=<SelectBackward0>)

In [None]:
# how the softmax works! Just an example
np.exp(.2811) / (np.exp(.2811) + np.exp(-0.1152)) #2nd row, prob of 1st ele, is exp of 1st ele / sum of exp of all ele.

0.597798371649087

In [None]:
scaled += mask

In [None]:
""" we want to apply softmax function in the last dimention, which will apply the tensor row by row"""
attention = F.softmax(scaled, dim = -1)

In [None]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5195, 0.4805, 0.0000, 0.0000],
        [0.5118, 0.2757, 0.2124, 0.0000],
        [0.2404, 0.2890, 0.2281, 0.2424]], grad_fn=<SelectBackward0>)

In [None]:
values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

What we have done so far lets make a function for that accordingly.

In [None]:
def scaled_dot_product(q, k , v, mask = None):
  d_k = q.size() [-1]
  scaled = torch.matmul(q,k.transpose(-1,-2)) / math.sqrt(d_k)
  if mask is not None:
    scaled +=mask
  attention = F.softmax(scaled, dim = -1)
  values = torch.matmul(attention, v)
  return values, attention

In [None]:
# lets execute the function
values, attention  = scaled_dot_product(q, k, v, mask = None)

In [None]:
attention.shape

torch.Size([1, 8, 4, 4])

In [None]:
attention [0] [0] ## if we make the mask = True then we will se the next values as zero.

tensor([[0.1958, 0.2397, 0.2388, 0.3257],
        [0.2661, 0.2461, 0.2624, 0.2254],
        [0.4303, 0.2318, 0.1786, 0.1593],
        [0.2404, 0.2890, 0.2281, 0.2424]], grad_fn=<SelectBackward0>)

In [None]:
values.size()

torch.Size([1, 8, 4, 64])

In [None]:
""" Now we will combine/concanate all the heads all together"""
values = values.reshape(batch_size, sequence_length, num_heads * head_dim) # we multiply the num_head & head_dim together, that we seprated before
values.size()

torch.Size([1, 4, 512])

In [None]:
""" as the heads can communicate between each other with the information that they gained,
a linear_layer is applied """
linear_layer = nn.Linear(d_model, d_model)

In [None]:
out = linear_layer(values)

In [None]:
out.shape ## now this output vector, now much more contex aware and have more informationt han innitial input vector.

torch.Size([1, 4, 512])

**Now lets convert everything we did so far, turn those into Class and function.**

In [None]:
import torch
import torch.nn as nn
import math

# this function we already created above
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out

Apply and use the class and function

In [None]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
