# Casual Single Head Self-Attention (Masked Attention)

In [2]:
# Maskging in attention weight is required to prevent context from future words
# we will use SHSL code and amsk attention weights

In [7]:
import torch
from torch import nn

In [5]:
torch.triu(torch.ones(3,3), diagonal=1)

tensor([[0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 0.]])

In [16]:
# base code of selfAttentionImproved used from Single Head Self-Attention

class selfAttentionImproved(nn.Module):
    def __init__(self, d_in, d_out, qkv_biased=False):
        super(selfAttentionImproved, self).__init__()
        self.liner_query = nn.Linear(d_in, d_out, bias=qkv_biased)   # default value of requires_grad is True
        self.liner_key = nn.Linear(d_in, d_out, bias=qkv_biased)
        self.liner_value = nn.Linear(d_in, d_out, bias=qkv_biased)
    
    def forward(self, x):
        x_q = self.liner_query(x)
        x_k = self.liner_key(x)
        x_v = self.liner_value(x)
        
        # Compute attention score
        att_score = x_q @ x_k.T
        
        # MASKING future attention score, replace with inf that will be changed to zero by softmat
        context_len = att_score.shape[0]
        mask = torch.triu(torch.ones(context_len, context_len), diagonal=1)
        masked_att_score = att_score.masked_fill(mask.bool(), -torch.inf)
        
        # attention weight
        norm_factor = x_v.shape[-1] ** 0.5              # normalization factor    
        att_weights = torch.softmax(masked_att_score/norm_factor, dim=-1)
        print(f"Masked att_weights marix: {att_weights}")
        
        # context matrix
        context = att_weights @ x_v
        
        return context
        
        
        

In [18]:
sai = selfAttentionImproved(3,3)
context = sai.forward(torch.randn(6,3))
print(f"Context Matrix after Self-Attention: {context}")

Masked att_weights marix: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4493, 0.5507, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4710, 0.2108, 0.3181, 0.0000, 0.0000, 0.0000],
        [0.1872, 0.3602, 0.1760, 0.2767, 0.0000, 0.0000],
        [0.1236, 0.2554, 0.1453, 0.1739, 0.3018, 0.0000],
        [0.3141, 0.0865, 0.2031, 0.1809, 0.0669, 0.1485]],
       grad_fn=<SoftmaxBackward0>)
Context Matrix after Self-Attention: tensor([[-0.3362,  0.1060, -0.6693],
        [-0.0841,  0.1244, -0.4324],
        [-0.4453, -0.0922, -0.2784],
        [-0.1894,  0.0836, -0.4939],
        [-0.0856,  0.1486, -0.6032],
        [-0.4773, -0.1027, -0.3383]], grad_fn=<MmBackward0>)
