# Casual Single Head Self-Attention (Masked Attention)

In [None]:
# Dropuout in DL is a technique where randomly selected hidden units are ignored duraing training
# it helps to prevent to overfitting the model

In [3]:
import torch
from torch import nn

In [7]:
droput = nn.Dropout(0.1)
torch.ones(5,5), droput(torch.ones(5,5))

(tensor([[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]),
 tensor([[1.1111, 0.0000, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111, 0.0000],
         [1.1111, 1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 0.0000, 1.1111]]))

In [14]:
# base code of selfAttentionImproved used from Single Head Self-Attention

class selfAttentionImproved(nn.Module):
    def __init__(self, d_in, d_out, qkv_biased=False):
        super(selfAttentionImproved, self).__init__()
        self.liner_query = nn.Linear(d_in, d_out, bias=qkv_biased)   # default value of requires_grad is True
        self.liner_key = nn.Linear(d_in, d_out, bias=qkv_biased)
        self.liner_value = nn.Linear(d_in, d_out, bias=qkv_biased)
        self.droput = nn.Dropout(0.5)       # 50% probability of dropout, GPT model 0.1 or 0.2 is used
    
    def forward(self, x):
        x_q = self.liner_query(x)
        x_k = self.liner_key(x)
        x_v = self.liner_value(x)
        
        # Compute attention score
        att_score = x_q @ x_k.T
        
        # MASKING future attention score, replace with inf that will be changed to zero by softmat
        context_len = att_score.shape[0]
        mask = torch.triu(torch.ones(context_len, context_len), diagonal=1)
        masked_att_score = att_score.masked_fill(mask.bool(), -torch.inf)
        
        # attention weight
        norm_factor = x_v.shape[-1] ** 0.5              # normalization factor    
        att_weights = torch.softmax(masked_att_score/norm_factor, dim=-1)
        print(f"Masked att_weights marix: {att_weights}")
        
        # Appling droput to masked att_weights
        droput_att_weights = self.droput(att_weights)
        print(f"Dropout att_weights marix: {droput_att_weights}")
        
        # context matrix
        context = droput_att_weights @ x_v
        
        return context
        
        
        

In [15]:
sai = selfAttentionImproved(3,3)
context = sai.forward(torch.randn(6,3))
print(f"Context Matrix after Self-Attention: {context}")

Masked att_weights marix: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3230, 0.6770, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1463, 0.3685, 0.4853, 0.0000, 0.0000, 0.0000],
        [0.1727, 0.2275, 0.2213, 0.3785, 0.0000, 0.0000],
        [0.1681, 0.1327, 0.1180, 0.3515, 0.2296, 0.0000],
        [0.0911, 0.1999, 0.2203, 0.1507, 0.1380, 0.2001]],
       grad_fn=<SoftmaxBackward0>)
Dropout att_weights marix: tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.3540, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7369, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4549, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1822, 0.3998, 0.4405, 0.0000, 0.0000, 0.4002]],
       grad_fn=<MulBackward0>)
Context Matrix after Self-Attention: tensor([[ 0.0000,  0.0000,  0.0000],
        [-1.7454, -1.6705, -0.5202],
        [-0.9499, -0.9091, -0.2831],
        [-0.5864, -0.5612, -0.1748]