In [None]:
import torch
import torch.nn as nn
torch.manual_seed(123)

<torch._C.Generator at 0x79b3ce5086f0>

In [None]:
d_in = 4
batch = 2
no_of_tokens = 5
# hyperparameter
d_out = 3
inputToken = torch.rand(no_of_tokens,d_in)
batch_data = torch.stack((inputToken, inputToken))
print(batch_data)
print("Shape:", batch_data.shape)

tensor([[[0.2961, 0.5166, 0.2517, 0.6886],
         [0.0740, 0.8665, 0.1366, 0.1025],
         [0.1841, 0.7264, 0.3153, 0.6871],
         [0.0756, 0.1966, 0.3164, 0.4017],
         [0.1186, 0.8274, 0.3821, 0.6605]],

        [[0.2961, 0.5166, 0.2517, 0.6886],
         [0.0740, 0.8665, 0.1366, 0.1025],
         [0.1841, 0.7264, 0.3153, 0.6871],
         [0.0756, 0.1966, 0.3164, 0.4017],
         [0.1186, 0.8274, 0.3821, 0.6605]]])
Shape: torch.Size([2, 5, 4])


In [None]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, context_len, no_of_tokens, d_in, d_out, dropout):
        super(MaskedSelfAttention, self).__init__()
        self.QueryM = nn.Linear(d_in, d_out)
        self.KeyM   = nn.Linear(d_in, d_out)
        self.ValueM = nn.Linear(d_in, d_out)
        self.mask = torch.triu(torch.ones(context_len, context_len), diagonal=1)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, inputToken):
        B, T, C = inputToken.size()  # Batch, Time, Embedding dim
        queries = self.QueryM(inputToken)
        keys = self.KeyM(inputToken)
        values = self.ValueM(inputToken)

        attention_scores = queries @ keys.transpose(1, 2)  # shape: [B, T, T]

        # Make sure mask is on the same device and dtype
        mask = self.mask[:T, :T].to(attention_scores.device).bool()
        attention_scores = attention_scores.masked_fill(mask.unsqueeze(0), float('-inf'))
        attention_weights = torch.softmax(attention_scores / (C ** 0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)
        context_vector = attention_weights @ values

        return context_vector, queries, keys, values, attention_weights, attention_scores

In [None]:
AttentionObj = MaskedSelfAttention(6,5,4,3,0)
AttentionObj.mask.bool()

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])

In [None]:
context_vector, queries, keys, values, attention_weights, maskedAttention = AttentionObj(batch_data)
print(context_vector)

tensor([[[-0.2011, -0.3797, -0.0952],
         [-0.2657, -0.5626, -0.2966],
         [-0.2420, -0.5398, -0.2594],
         [-0.2626, -0.4917, -0.2129],
         [-0.2497, -0.5053, -0.2158]],

        [[-0.2011, -0.3797, -0.0952],
         [-0.2657, -0.5626, -0.2966],
         [-0.2420, -0.5398, -0.2594],
         [-0.2626, -0.4917, -0.2129],
         [-0.2497, -0.5053, -0.2158]]], grad_fn=<UnsafeViewBackward0>)
