# Import modules

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F # access the softmax function for calculating the attention weights





# Self-Attention class explanation
- nn.Module is a base class for all nural network modules that you make with PyTorch
- SelfAttention inherite this base class
- d_model is a dimmesion of the model or in other words the Word Embedding values per token.
We use it to define the dimesion of the weight matrices for Queries, Keys and Values. In this case they are going to be with dimesion 2x2

- token_encodings are Word Embeddings + Positional Encoding for each input token




In [45]:
class SelfAttention(nn.Module):
    '''
    This class implements the basic self-attention mechanism.
    '''
    
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__() # call the __init__ method of the parent class
        
        # Query weights matrix which hold and calculate the query values
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        # Key weights matrix which hold and calculate the key values
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        # Value weights matrix which hold and calculate the value values
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        self.row_dim = row_dim
        self.col_dim = col_dim
    
    def forward(self, token_encodings):
        
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)
        
        # For test purposes
        # print('Q = ', q)
        # print('K = ', k)
        # print('V = ', v)
        
        # Calculate similarity scores q * k^T
        sims = torch.matmul(q, k.transpose(self.row_dim, self.col_dim))
        
        # Calculate scaled similarity scores
        scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)
        
        # Apply softmax to determine what percent of each tokens value to
        # use in the final attention values.
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        
        # Scale the values by their associated percentages and add them up.
        attension_scores = torch.matmul(attention_percents, v)
        
        return attension_scores

In [55]:
# create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

# set the seed for the random number generator
torch.manual_seed(42)

## create a basic self-attention ojbect
selfAttention = SelfAttention(d_model=2,
                               row_dim=0,
                               col_dim=1)

# calculate basic attention for the token encodings
selfAttention(encodings_matrix)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

# Verify calculations and print intermediate results

In [37]:
## print out the weight matrix that creates the queries, keys and values
print(selfAttention.W_q.weight.transpose(0, 1))
print(selfAttention.W_k.weight.transpose(0, 1))
print(selfAttention.W_v.weight.transpose(0, 1))


tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)
tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)
tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)


In [47]:
print(selfAttention.W_q(encodings_matrix))
print(selfAttention.W_k(encodings_matrix))
print(selfAttention.W_v(encodings_matrix))




tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)
tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)
tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)


In [48]:
q = selfAttention.W_q(encodings_matrix)
q

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [49]:
k = selfAttention.W_k(encodings_matrix)
k

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [41]:
sims = torch.matmul(q, k.transpose(dim0=0, dim1=1))
sims

tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)

In [50]:
scaled_sims = sims / (torch.tensor(2)**0.5)
scaled_sims

tensor([[-0.0700,  0.0458, -0.4612],
        [-0.2844,  0.2883, -2.1230],
        [ 0.3424, -0.4725,  2.8610]], grad_fn=<DivBackward0>)

In [51]:
attention_percents = F.softmax(scaled_sims, dim=1)
attention_percents

tensor([[0.3573, 0.4011, 0.2416],
        [0.3410, 0.6047, 0.0542],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)

In [53]:
res = torch.matmul(attention_percents, selfAttention.W_v(encodings_matrix))
res

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)