In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SelfAttention(nn.Module):

  def __init__(self,d_model=2,row_dim=0,col_dim=1):
    super().__init__()
    ## d_model = the number of embedding values per token.
    self.d_model = d_model
    ## row_dim, col_dim = the indices we should use to access rows or columns
    self.row_dim=row_dim
    self.col_dim=col_dim
    ## Initialize the Weights (W) that we'll use to create the query (q), key (k) and value (v) for each token
    self.Wq=nn.Linear(in_features=d_model,out_features=d_model,bias=False)
    self.Wk=nn.Linear(in_features=d_model,out_features=d_model,bias=False)
    self.Wv=nn.Linear(in_features=d_model,out_features=d_model,bias=False)

  def forward(self,token_encodings):
    ## Create the query, key and values using the encoding numbers
    q = self.Wq(token_encodings)
    k = self.Wk(token_encodings)
    v = self.Wv(token_encodings)
    ## Compute similarities scores: (q * k^T)
    sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
    ## Scale the similarities by dividing by sqrt(k.col_dim)
    scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
    ## Apply softmax to determine what percent of each tokens' value to use in the final attention values.
    attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
    ## Scale the values by their associated percentages and add them up.
    attention_scores = torch.matmul(attention_percents, v)
    return attention_scores


In [3]:
## create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])
torch.manual_seed(42)
selfAttention = SelfAttention(d_model=2,row_dim=0,col_dim=1)
selfAttention(encodings_matrix)


tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [5]:
## print out the weight matrix that creates the queries
selfAttention.Wq.weight.transpose(0, 1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [6]:
## print out the weight matrix that creates the keys
selfAttention.Wk.weight.transpose(0, 1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [7]:
## print out the weight matrix that creates the values
selfAttention.Wv.weight.transpose(0, 1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

In [9]:
## calculate the queries
selfAttention.Wq(encodings_matrix)

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [10]:
## calculate the keys
selfAttention.Wk(encodings_matrix)

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [11]:
## calculate the values
selfAttention.Wv(encodings_matrix)

tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)