In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, dim_model=2, row_dim = 0, col_dim=1): 
        #dim_model: num of Word Embeddings per token
        # row_dim, col_dim: rows/columns indices
        super().__init__()
        self.q_weights = nn.Linear(in_features= dim_model, out_features= dim_model, bias = False)

        self.k_weights = nn.Linear(in_features= dim_model, out_features= dim_model, bias = False)

        self.v_weights = nn.Linear(in_features= dim_model, out_features= dim_model, bias = False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, token_encodings, mask=None):
        q = self.q_weights(token_encodings)
        k = self.k_weights(token_encodings)
        v = self.v_weights(token_encodings)

        similarities = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim)) #multiplying Q K
        scaled_similarities = similarities/torch.tensor(k.size(self.col_dim)**0.5) # dividing by root d
        attention_percents = F.softmax(scaled_similarities, dim= self.col_dim) #applying softmax
        attention_scores = torch.matmul(attention_percents, v)
        
        return attention_scores           




In [None]:
encodings_mat = torch.tensor([[1.16, 0.23],
                              [0.57, 1.36],
                              [4.41, -2.16]])

encodings_mat

tensor([[ 1.1600,  0.2300],
        [ 0.5700,  1.3600],
        [ 4.4100, -2.1600]])

In [None]:
torch.manual_seed(42)

<torch._C.Generator at 0x1590d5b5390>

In [None]:
masked_self_attention = MaskedSelfAttention(dim_model=2, row_dim=0, col_dim=1)

masked_self_attention(encodings_mat)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [None]:
print(f"""Q Weights: {masked_self_attention.q_weights.weight.transpose(0,1)}
K Weights: {masked_self_attention.k_weights.weight.transpose(0,1)}
V Weights: {masked_self_attention.v_weights.weight.transpose(0,1)}""")

Q Weights: tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)
K Weights: tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)
V Weights: tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)
