# Copmact and optimized Casual Single Head Self-Attention

In [64]:
import torch
from torch import nn

### Buffers in NN
- PyTorch process everything in Tensors
- NN Modules used to put in device with single statement but on torch.nn modules not any other tensor
- mask tensor is not the part of nn module in atention mechanism.
- either we have to manually put in device(GPU) or use Buffer to store

In [65]:
# base code of selfAttentionImproved used from Single Head Self-Attention

class compacCasultSHSL(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout=0.5, qkv_biased=False):
        super(compacCasultSHSL, self).__init__()
        self.d_out = d_out
        self.liner_query = nn.Linear(d_in, d_out, bias=qkv_biased)   # default value of requires_grad is True
        self.liner_key = nn.Linear(d_in, d_out, bias=qkv_biased)
        self.liner_value = nn.Linear(d_in, d_out, bias=qkv_biased)
        self.droput = nn.Dropout(0.5)       # 50% probability of dropout, GPT model 0.1 or 0.2 is used
        
        # # for CPU it works fine, but when we put model into GPU "mask" doesn't follow the same
        # # and keeps running on CPU only, which leads to error
        # self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

        # model itself follows the device
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.liner_query(x)
        keys = self.liner_key(x)
        values = self.liner_value(x)
        
        # Compute attention score
        att_score = queries @ keys.transpose(1,2)
        
        # MASKING future attention score, replace with inf that will be changed to zero by softmat
        # fill by -inf where it finds Ture
        att_score.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        
        # att_score = att_score.masked_fill(self.mask.bool(), -torch.inf)
        print('att_score', att_score)
        
        # attention weight
        norm_factor = keys.shape[-1] ** 0.5              # normalization factor    
        att_weights = torch.softmax(att_score/norm_factor, dim=-1)
        print(f"Masked att_weights marix: {att_weights}")
        
        # Appling droput to masked att_weights
        droput_att_weights = self.droput(att_weights)
        print(f"Dropout att_weights marix: {droput_att_weights}")
        
        # context matrix
        context = droput_att_weights @ values
        
        return context
        
        
        

In [66]:
# input
torch.manual_seed(123)
inputs = torch.randn(4,3)
batch = torch.stack((inputs, inputs), dim=0)
batch

tensor([[[-0.1115,  0.1204, -0.3696],
         [-0.2404, -1.1969,  0.2093],
         [-0.9724, -0.7550,  0.3239],
         [-0.1085,  0.2103, -0.3908]],

        [[-0.1115,  0.1204, -0.3696],
         [-0.2404, -1.1969,  0.2093],
         [-0.9724, -0.7550,  0.3239],
         [-0.1085,  0.2103, -0.3908]]])

In [73]:
context_length = batch.shape[1]
print('context_length', context_length)
d_in = inputs.shape[1]
d_out = 3
ccsa = compacCasultSHSL(d_in=d_in, d_out=d_out, context_length=context_length)
print("W_query.device:", ccsa.liner_query.weight.device)
print("mask.device:", ccsa.mask.device)
ccsa

context_length 4
W_query.device: cpu
mask.device: cpu


compacCasultSHSL(
  (liner_query): Linear(in_features=3, out_features=3, bias=False)
  (liner_key): Linear(in_features=3, out_features=3, bias=False)
  (liner_value): Linear(in_features=3, out_features=3, bias=False)
  (droput): Dropout(p=0.5, inplace=False)
)

In [68]:
with torch.no_grad():
    context = ccsa.forward(batch)
print(f"Context Matrix after Self-Attention: {context}")

att_score tensor([[[ 0.0356,    -inf,    -inf,    -inf],
         [-0.1919, -0.0616,    -inf,    -inf],
         [-0.2273, -0.2666, -0.4173,    -inf],
         [ 0.0489, -0.0534,  0.0238,  0.0548]],

        [[ 0.0356,    -inf,    -inf,    -inf],
         [-0.1919, -0.0616,    -inf,    -inf],
         [-0.2273, -0.2666, -0.4173,    -inf],
         [ 0.0489, -0.0534,  0.0238,  0.0548]]])
Masked att_weights marix: tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4812, 0.5188, 0.0000, 0.0000],
         [0.3480, 0.3402, 0.3118, 0.0000],
         [0.2543, 0.2398, 0.2507, 0.2552]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4812, 0.5188, 0.0000, 0.0000],
         [0.3480, 0.3402, 0.3118, 0.0000],
         [0.2543, 0.2398, 0.2507, 0.2552]]])
Dropout att_weights marix: tensor([[[2.0000, 0.0000, 0.0000, 0.0000],
         [0.9624, 0.0000, 0.0000, 0.0000],
         [0.6960, 0.6804, 0.0000, 0.0000],
         [0.5087, 0.4795, 0.5014, 0.5104]],

        [[0.0000, 0.0000, 0.0000, 0