In [1]:
import torch

In [3]:
import torch.nn as nn
import torch.nn.functional as F

In [35]:
class SelfAttention(nn.Module):
  def __init__(self, d_model=2, row_dim=0, col_dim=1):
    super().__init__()

    self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias= False)
    self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias= False)
    self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias= False)
    self.row_dim = row_dim
    self.col_dim = col_dim

  def forward(self, token_encodings):

    q = self.W_q(token_encodings)
    k = self.W_k(token_encodings)
    v = self.W_v(token_encodings)

    sims = torch.matmul(q,k.transpose(dim0=self.row_dim, dim1=self.col_dim))
    scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
    attention_percent = F.softmax(scaled_sims, dim=self.col_dim)
    attention = torch.matmul(attention_percent, v)

    return attention

In [36]:
encoding_matrix = torch.tensor([[1.16, 0.23],[0.57,1.36],[4.41,-2.16]])

In [37]:
torch.manual_seed(42)

<torch._C.Generator at 0x7fa54422cd30>

In [38]:
selfAttention = SelfAttention(d_model=2, row_dim=0, col_dim=1  )

In [39]:
selfAttention(encoding_matrix)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [19]:
selfAttention.W_q.weight.transpose(0, 1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [20]:
selfAttention.W_k.weight.transpose(0, 1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [21]:
selfAttention.W_v.weight.transpose(0, 1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

In [23]:
selfAttention.W_q(encoding_matrix)

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [25]:
selfAttention.W_k(encoding_matrix)

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [26]:
selfAttention.W_v(encoding_matrix)

tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)

In [27]:
q = selfAttention.W_q(encoding_matrix)
q

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [29]:
k = selfAttention.W_k(encoding_matrix)
k

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [30]:
sims = torch.matmul(q, k.transpose(dim0=0, dim1=1))
sims

tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)

In [31]:
scaled_sims = sims / (torch.tensor(2)**0.5)
scaled_sims

tensor([[-0.0700,  0.0458, -0.4612],
        [-0.2844,  0.2883, -2.1230],
        [ 0.3424, -0.4725,  2.8610]], grad_fn=<DivBackward0>)

In [32]:
attention_percents = F.softmax(scaled_sims, dim=1)
attention_percents

tensor([[0.3573, 0.4011, 0.2416],
        [0.3410, 0.6047, 0.0542],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)

In [34]:
torch.matmul(attention_percents, selfAttention.W_v(encoding_matrix))

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [40]:
## Masked Attention

In [42]:
class MaskedSelfAttention(nn.Module):
  def __init__(self, d_model=2, row_dim=0, col_dim=1):
    super().__init__()

    self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias= False)
    self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias= False)
    self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias= False)
    self.row_dim = row_dim
    self.col_dim = col_dim

  def forward(self, token_encodings, mask=None):

    q = self.W_q(token_encodings)
    k = self.W_k(token_encodings)
    v = self.W_v(token_encodings)

    sims = torch.matmul(q,k.transpose(dim0=self.row_dim, dim1=self.col_dim))
    scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
    if mask is not None:
      scaled_sims = scaled_sims.masked_fill(mask=mask, value=float("-inf"))
    attention_percent = F.softmax(scaled_sims, dim=self.col_dim)
    attention = torch.matmul(attention_percent, v)

    return attention

In [43]:
encoding_matrix = torch.tensor([[1.16, 0.23],[0.57,1.36],[4.41,-2.16]])
torch.manual_seed(42)

maskedSelfAttention = MaskedSelfAttention(d_model=2,
                               row_dim=0,
                               col_dim=1)

mask = torch.tril(torch.ones(3, 3))
mask = mask == 0
mask


tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [45]:
maskedSelfAttention(encoding_matrix, mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

In [46]:
maskedSelfAttention.W_q.weight.transpose(0, 1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [47]:
maskedSelfAttention.W_k.weight.transpose(0, 1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [48]:
maskedSelfAttention.W_v.weight.transpose(0, 1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

In [53]:
q = maskedSelfAttention.W_q(encoding_matrix)
q

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [54]:
k = maskedSelfAttention.W_k(encoding_matrix)
k

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [55]:
v = maskedSelfAttention.W_v(encoding_matrix)
v

tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)

In [56]:
sims = torch.matmul(q, k.transpose(dim0=0, dim1=1))
sims

tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)

In [57]:
scaled_sims = sims / (torch.tensor(2)**0.5)

In [58]:
masked_scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
masked_scaled_sims

tensor([[-6.9975e-02, -1.0000e+09, -1.0000e+09],
        [-2.8442e-01,  2.8833e-01, -1.0000e+09],
        [ 3.4241e-01, -4.7253e-01,  2.8610e+00]],
       grad_fn=<MaskedFillBackward0>)

In [59]:
attention_percents = F.softmax(masked_scaled_sims, dim=1)
attention_percents

tensor([[1.0000, 0.0000, 0.0000],
        [0.3606, 0.6394, 0.0000],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)

In [61]:
torch.matmul(attention_percents, maskedSelfAttention.W_v(encoding_matrix))

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

In [63]:
class Attention(nn.Module):

    def __init__(self, d_model=2,
                 row_dim=0,
                 col_dim=1):

        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

In [64]:

encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
attention = Attention(d_model=2,
                      row_dim=0,
                      col_dim=1)

## calculate encoder-decoder attention
attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [67]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model=2,
                 row_dim=0,
                 col_dim=1,
               num_heads=1) -> None:
     super().__init__()
     self.heads = nn.ModuleList(
         [Attention(d_model, row_dim, col_dim) for _ in range(num_heads)]
     )
     self.col_dim = col_dim

  def forward(self, encodings_for_q, encodings_for_k, encodings_for_v):
      return torch.cat([head(encodings_for_q, encodings_for_k, encodings_for_v) for head in self.heads], dim=self.col_dim)

In [68]:
## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=1)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)

In [69]:
## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=2)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)