# Multi Head Attention or Sequential Single Head Attentions

In [4]:
import torch
import torch.nn as nn

In [None]:
# single head self attention
class SHSL(nn.Module):
    def __init__(self, d_in, d_out, context_length, causal=True, dropout=0.5, qkv_biased=False):
        super(SHSL, self).__init__()
        self.causal = causal
        self.d_out = d_out
        self.liner_query = nn.Linear(d_in, d_out, bias=qkv_biased)   # default value of requires_grad is True
        self.liner_key = nn.Linear(d_in, d_out, bias=qkv_biased)
        self.liner_value = nn.Linear(d_in, d_out, bias=qkv_biased)
        self.droput = nn.Dropout(0.5)       # 50% probability of dropout, GPT model 0.1 or 0.2 is used
        
        if self.causal:
            # # for CPU it works fine, but when we put model into GPU "mask" doesn't follow the same
            # # and keeps running on CPU only, which leads to error
            # self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

            # model itself follows the device
            self.register_buffer(
                "mask",
                torch.triu(torch.ones(context_length, context_length),
                        diagonal=1)
            )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.liner_query(x)
        keys = self.liner_key(x)
        values = self.liner_value(x)
        
        # Compute attention score
        att_score = queries @ keys.transpose(1,2)
        
        if self.causal:
            # MASKING future attention score, replace with inf that will be changed to zero by softmat
            # fill by -inf where it finds Ture
            att_score.masked_fill_(
                self.mask.bool()[:num_tokens, :num_tokens],
                -torch.inf
            )
            
            # att_score = att_score.masked_fill(self.mask.bool(), -torch.inf)
            # print('att_score', att_score)
        
        # attention weight
        norm_factor = keys.shape[-1] ** 0.5              # normalization factor    
        att_weights = torch.softmax(att_score/norm_factor, dim=-1)
        # print(f"Masked att_weights marix: {att_weights}")
        
        # Appling droput to masked att_weights
        droput_att_weights = self.droput(att_weights)
        # print(f"Dropout att_weights marix: {droput_att_weights}")
        
        # context matrix
        context = droput_att_weights @ values
        
        return context
        
        
        

In [None]:
# sequential multi head attention
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads=2, causal=True, dropout=0.5, qkv_biased=False):
        super().__init__()
        self.d_out = d_out
        self.heads = nn.ModuleList(
                                    [ SHSL(d_in, d_out, context_length, causal=causal, dropout=dropout, qkv_biased=qkv_biased) 
                                    for _ in range(num_heads) ]
                                   )
    def forward(self, x):
        return torch.cat( [ head.forward(x) for head in self.heads ] )
        

In [32]:
# input
torch.manual_seed(124)
inputs = torch.randn(4,3)
batch = torch.stack((inputs, inputs), dim=0)
batch

tensor([[[ 0.2922,  1.5814,  0.9303],
         [ 0.6592,  0.3796, -0.3670],
         [ 2.3163, -0.1895, -0.4247],
         [-0.6814,  1.6722, -0.6039]],

        [[ 0.2922,  1.5814,  0.9303],
         [ 0.6592,  0.3796, -0.3670],
         [ 2.3163, -0.1895, -0.4247],
         [-0.6814,  1.6722, -0.6039]]])

In [33]:
context_length = batch.shape[1]
print('context_length', context_length)
d_in = inputs.shape[1]
d_out = 3
mhsa = MultiHeadSelfAttention(d_in=d_in, d_out=d_out, context_length=context_length)
mhsa

context_length 4


MultiHeadSelfAttention(
  (heads): ModuleList(
    (0-1): 2 x SHSL(
      (liner_query): Linear(in_features=3, out_features=3, bias=False)
      (liner_key): Linear(in_features=3, out_features=3, bias=False)
      (liner_value): Linear(in_features=3, out_features=3, bias=False)
      (droput): Dropout(p=0.5, inplace=False)
    )
  )
)

In [34]:
with torch.no_grad():
    context = mhsa.forward(batch)
print(f"Context Matrix after multi head Self-Attention: {context}")

Context Matrix after multi head Self-Attention: tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.1171, -0.2019, -0.1344],
         [ 0.2901, -0.8044, -0.5432]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.1390, -0.2397, -0.1596],
         [ 0.0995, -0.9639, -0.0860],
         [ 0.1711, -0.3428, -0.1796]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.1648, -0.0429, -0.0987],
         [ 0.0000,  0.0000,  0.0000],
         [-0.0923, -0.5233, -0.6319]],

        [[ 0.3458, -0.5400, -0.6614],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.5621,  0.1401,  0.0528],
         [-0.2838, -0.3262, -0.3685]]])
