In [1]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.module() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax()

In [2]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):

      super().__init__()

      self.W_q = nn.Linear(in_features=d_model,out_features=d_model, bias=False)
      self.W_k = nn.Linear(in_features=d_model,out_features=d_model, bias=False)
      self.W_v = nn.Linear(in_features=d_model,out_features=d_model, bias=False)

      self.row_dim = row_dim
      self.col_dim = col_dim

    def forward(self, token_encodings, mask=None):

      q = self.W_q(token_encodings)
      k = self.W_k(token_encodings)
      v = self.W_v(token_encodings)

      sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

      scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

      if mask is not None:
        scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

      attn_percents = F.softmax(scaled_sims, dim=self.col_dim)

      attn_scores = torch.matmul(attn_percents, v)

      return attn_scores


In [3]:
## create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

In [4]:
torch.manual_seed(42)

<torch._C.Generator at 0x7e1e7932d910>

In [5]:
## create a masked self-attention object
maskedSelfAttention = MaskedSelfAttention(d_model=2,
                               row_dim=0,
                               col_dim=1)

In [8]:
mask = torch.tril(torch.ones(3, 3))


#Tril converts 1s to 0s in the upper triangle

# 111
# 111
# 111

# To

# 100
# 110
# 111

In [9]:
mask = mask == 0 # convert 0 to True and 1 to False
mask # print out the mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [11]:
maskedSelfAttention(encodings_matrix, mask) #Get masked self attn values

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)