## Multi Head Attention

In [167]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [168]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn((batch_size,sequence_length,input_dim))

In [169]:
x.size()

torch.Size([1, 4, 512])

In [170]:
x

tensor([[[-0.6145, -1.5876, -0.0061,  ..., -0.3973, -0.7069,  0.5202],
         [ 1.0035,  0.5081, -0.8010,  ..., -0.0953, -1.2909,  0.9203],
         [ 1.6107, -1.0099,  1.4375,  ..., -0.7651, -1.0811, -1.7743],
         [ 1.5170, -0.1520,  1.2073,  ...,  0.6237,  0.1760, -0.3910]]])

In [171]:
qkv_layer = nn.Linear(input_dim,3*d_model)

In [172]:
qkv = qkv_layer(x)

In [173]:
qkv.size()

torch.Size([1, 4, 1536])

In [174]:
num_attention_heads = 8
head_dim = d_model//num_attention_heads
qkv =  qkv.reshape(batch_size,sequence_length,num_attention_heads,3*head_dim)

In [175]:
qkv.shape

torch.Size([1, 4, 8, 192])

In [176]:
qkv = qkv.permute(0,2,1,3) # [batch_size,num_attention_heads,sequence_length,3*head_dim]
qkv.shape

torch.Size([1, 8, 4, 192])

In [177]:
q,k,v = qkv.chunk(3,dim=-1)
q.shape,k.shape,v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

### Self Attention for multiple heads

In [178]:
d_k = q.size()[-1]
scaled = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

In [179]:
y = torch.randn(2,3)
print(y)
torch.transpose(y,0,1)

tensor([[-0.0678, -0.2859, -1.1409],
        [ 0.8877,  0.8633, -0.9922]])


tensor([[-0.0678,  0.8877],
        [-0.2859,  0.8633],
        [-1.1409, -0.9922]])

In [180]:
torch.transpose(y,1,0)

tensor([[-0.0678,  0.8877],
        [-0.2859,  0.8633],
        [-1.1409, -0.9922]])

### Masking 


In [181]:
mask = torch.full(scaled.size(),float('-inf'))
mask = torch.triu(mask,diagonal=1)
mask.size()

torch.Size([1, 8, 4, 4])

In [182]:
(scaled + mask)[0][0]


tensor([[ 0.3878,    -inf,    -inf,    -inf],
        [ 0.3305, -0.2549,    -inf,    -inf],
        [-0.1362, -0.1780, -0.1523,    -inf],
        [ 0.2660,  0.5332,  0.3430,  0.2431]], grad_fn=<SelectBackward0>)

In [183]:
scaled += mask

In [184]:
np.exp(0.2971)/(np.exp(0.2971)+np.exp(-0.5782))


0.7058473197476167

In [185]:
attention = F.softmax(scaled,dim = -1)


In [186]:
attention.shape

torch.Size([1, 8, 4, 4])

In [187]:
v.shape

torch.Size([1, 8, 4, 64])

In [188]:
values = torch.matmul(attention,v)
values.shape

torch.Size([1, 8, 4, 64])

### Function

In [189]:
import math

def scaled_dot_product(q,k,v,mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q,k.transpose(-1,-2))/math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled,dim=-1)
    values = torch.matmul(attention,v)
    return attention,values

attention,values = scaled_dot_product(q,k,v,mask=mask)






In [190]:
values.size()

torch.Size([1, 8, 4, 64])

In [191]:
values = values.reshape(batch_size,sequence_length,num_attention_heads*head_dim)
values.shape

torch.Size([1, 4, 512])

In [192]:
linear_layer = nn.Linear(d_model,d_model)
out = linear_layer(values)

In [193]:
out.shape

torch.Size([1, 4, 512])

In [194]:
out


tensor([[[ 0.1020, -0.2873, -0.1127,  ..., -0.0105,  0.0454, -0.0719],
         [-0.2045, -0.3514, -0.2858,  ..., -0.1326, -0.3167,  0.1571],
         [ 0.3738,  0.0124,  0.2707,  ..., -0.0969,  0.0523,  0.1896],
         [ 0.1110, -0.0442, -0.0504,  ..., -0.2251, -0.3435,  0.2780]]],
       grad_fn=<ViewBackward0>)

### Class

In [206]:
import torch
import torch.nn as nn
import math

def scaled_dot_product(q,k,v,mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q,k.transpose(-1,-2))/math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled,dim=-1)
    values = torch.matmul(attention,v)
    return values,attention
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_attention_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_attention_heads = num_attention_heads
        self.head_dim = d_model//num_attention_heads
        self.qkv_layer = nn.Linear(input_dim,3*d_model)
        self.linear_layer = nn.Linear(d_model,d_model)

    def forward(self,x,mask=None):
        batch_size,sequence_length,input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size,sequence_length,self.num_attention_heads,3*self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv =qkv.permute(0,2,1,3)
        print(f"qkv.size(): {qkv.size()}")
        q,k,v = qkv.chunk(3,dim = -1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values,attention = scaled_dot_product(q,k,v,mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size,sequence_length,self.num_attention_heads*self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out
 
        




In [207]:
input_dim = 1024
d_model = 512
num_attention_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
