# Multi  Head Attention

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
sequence_length=4
batch_size=1
input_dim=512 # vector dimension of every word that goes to the attention unit
d_model=512 # is the output of the attention unit for every single word

x=torch.randn((batch_size,sequence_length,input_dim))

In [4]:
x.size()

torch.Size([1, 4, 512])

In [5]:
qkv_layer=nn.Linear(input_dim,3*d_model)

In [6]:
qkv=qkv_layer(x)
qkv.shape

torch.Size([1, 4, 1536])

In [13]:
num_heads=8
head_dim=d_model//num_heads
qkv=qkv.reshape(batch_size,sequence_length,num_heads,3*head_dim)
qkv.shape

torch.Size([1, 4, 8, 192])

In [14]:
qkv=qkv.permute(0,2,1,3)
qkv.shape

torch.Size([1, 8, 4, 192])

In [16]:
q,k,v= qkv.chunk(3,dim=-1)
q.shape,k.shape,v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

## Self Attention

In [19]:
k.T.shape # we only want to transpose the last two dimensions and not all the dimensions 

torch.Size([64, 4, 8, 1])

In [20]:
k.transpose(-2,-1)==k.transpose(-1,-2)

tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         ...,

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, 

In [21]:
k.transpose(-2,-1).shape

torch.Size([1, 8, 64, 4])

In [17]:
d_k=k.shape[-1]
scaled=torch.matmul(q,k.transpose(-2,-1))/np.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

In [27]:
mask=torch.full(scaled.shape,float('-inf'))
mask=torch.triu(mask,diagonal=1)
mask[0][1]

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [28]:
(scaled+mask)[0][1]

tensor([[ 1.1730,    -inf,    -inf,    -inf],
        [-0.0710,  0.2076,    -inf,    -inf],
        [-0.2560,  0.5427,  0.6447,    -inf],
        [-0.2272, -0.3109,  0.4632,  0.0740]], grad_fn=<SelectBackward0>)

In [29]:
attention=torch.softmax(scaled+mask,dim=-1)
attention[0][1]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4308, 0.5692, 0.0000, 0.0000],
        [0.1759, 0.3910, 0.4330, 0.0000],
        [0.1899, 0.1747, 0.3788, 0.2567]], grad_fn=<SelectBackward0>)

In [30]:
newV=torch.matmul(attention,v)
newV.shape

torch.Size([1, 8, 4, 64])

In [31]:
def self_attention(q,k,v,mask=None):
    scaled=torch.matmul(q,k.transpose(-2,-1))/np.sqrt(k.shape[-1])
    if mask is not None:
        scaled+=mask
    attention=torch.softmax(scaled,dim=-1)
    newV=torch.matmul(attention,v)
    return newV,attention

In [32]:
values,attention=self_attention(q,k,v,mask)

In [33]:
attention[0][1]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4308, 0.5692, 0.0000, 0.0000],
        [0.1759, 0.3910, 0.4330, 0.0000],
        [0.1899, 0.1747, 0.3788, 0.2567]], grad_fn=<SelectBackward0>)

In [34]:
values.shape,attention.shape

(torch.Size([1, 8, 4, 64]), torch.Size([1, 8, 4, 4]))

In [42]:
values=values.reshape(batch_size,sequence_length,num_heads*head_dim)
values.shape   

In [41]:
linear_layer = nn.Linear(d_model, d_model)
out = linear_layer(values)

# Class

In [37]:
import torch 
import torch.nn as nn
import numpy

def self_attention(q,k,v,mask=None):
    scaled=torch.matmul(q,k.transpose(-2,-1))/np.sqrt(k.shape[-1])
    if mask is not None:
        scaled+=mask
    attention=torch.softmax(scaled,dim=-1)
    newV=torch.matmul(attention,v)
    return newV,attention

class MultiHeadAttention(nn.Module):
    def __init__(self,input_dim,d_model,num_heads):
        super().__init__()
        self.input_dim=input_dim
        self.d_model=d_model
        self.num_heads=num_heads
        self.head_dim=d_model//num_heads
        self.qkv_layer=nn.Linear(input_dim,3*d_model)
        self.linear_layer=nn.Linear(d_model,d_model)
    
    def forward(self,x,mask=None):
        batch_size,sequence_length,input_dim=x.shape
        print(f"x.size(): {x.size()}")
        qkv=self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv=qkv.reshape(batch_size,sequence_length,self.num_heads,3*self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv=qkv.permute(0,2,1,3)
        print(f"qkv.size(): {qkv.size()}")
        q,k,v=qkv.chunk(3,dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values,attention=self_attention(q,k,v,mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values=values.reshape(batch_size,sequence_length,self.num_heads*self.head_dim)
        print(f"values.size(): {values.size()}")
        out=self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out


In [38]:
input_dim=1024
d_model=512
num_heads=8

batch_size=30
sequence_length=5
x=torch.randn((batch_size,sequence_length,input_dim))

model=MultiHeadAttention(input_dim=input_dim,d_model=d_model,num_heads=num_heads)
out=model(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
