In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class MaskedSelfAttention(nn.Module):

    def __init__(self, d_model = 2,
                 row_dim = 0,
                 col_dim = 1):


        super().__init__()

        self.W_q = nn.Linear(in_features= d_model, out_features = d_model, bias= False)
        self.W_k = nn.Linear(in_features= d_model, out_features = d_model, bias= False)
        self.W_v = nn.Linear(in_features= d_model, out_features = d_model, bias= False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, token_encodings, mask=None):

        # set the default value of masked to none so that we can use this for both self attention and masked self attention if we add a mask

        q = self.W_q(token_encodings) #does the matrix multiplication
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        similarity_score = torch.matmul(q,k.transpose(dim0= self.row_dim,dim1=self.col_dim))

        # then we scale the similarities by dividing by the square root of the dimension of the key matrix; so d_model. see above comment

        scaled_sim = similarity_score/ torch.tensor(k.size(self.col_dim)**0.5)

        # for the masked_fill method, imagine that the mask is trues and falses
        # the trues will correspond to the parts we want to mask out, which will replace the values with -1e9, so -inf
        # the falses will be replaced with 0

        if mask is not None:
            scaled_sim = scaled_sim.masked_fill(mask=mask,
                                                value= -1e9) 

        attention_percents = F.softmax(scaled_sim,dim=self.col_dim)

        #then we multiply by the values V to scale the values by their associated percentages which then gives us the self-attention scores

        attention_score = torch.matmul(attention_percents,v)

        return attention_score


In [4]:
# now let us test to ensure it works
# first create a matrix of token encodings

encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

# set the seed number for a random generator
torch.manual_seed(42)

maskedSelfAttention= MaskedSelfAttention(d_model=2,row_dim=0,col_dim=1)

#create the mask by creating the square matrix of ones first, so since we have 3 words, then 3 x3
#pass into torch.tril which makes it a triangular matrix and makes the upper triangle zeros

mask= torch.tril(torch.ones(3,3))

mask = mask == 0 # to convert the 1's in the mask into falses and the 0's in the mask into trues, hence becoming an actual mask
print(mask)


tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])


In [6]:
# calculate self attention for the token encoding

maskedSelfAttention(encodings_matrix,mask)


tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)