In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Attention(nn.Module):
    def __init__(self, dim_model=2, row_dim = 0, col_dim=1): 
        # dim_model: num of Word Embeddings per token
        # row_dim, col_dim: rows/columns indices
        super().__init__()
        self.q_weights = nn.Linear(in_features= dim_model, out_features= dim_model, bias = False)

        self.k_weights = nn.Linear(in_features= dim_model, out_features= dim_model, bias = False)

        self.v_weights = nn.Linear(in_features= dim_model, out_features= dim_model, bias = False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, q_encodings, k_encodings, v_encodings, mask=None):
        # mask allows for use of Masked Self-Attention
        q = self.q_weights(q_encodings)
        k = self.k_weights(k_encodings)
        v = self.v_weights(v_encodings)

        
        similarities = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim)) #multiplying Q K
        scaled_similarities = similarities/torch.tensor(k.size(self.col_dim)**0.5) # dividing by root d
        
        if mask is not None:
            scaled_similarities = scaled_similarities.masked_fill(mask=mask, value=-1e9)
        
        attention_percents = F.softmax(scaled_similarities, dim= self.col_dim) #applying softmax
        attention_scores = torch.matmul(attention_percents, v)
        
        return attention_scores           




In [3]:
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

#The encodings may be different that is why three variables were used

In [4]:
torch.manual_seed(42)

<torch._C.Generator at 0x28e95a743b0>

In [5]:
attention = Attention(dim_model=2, row_dim=0, col_dim=1)

In [16]:

mask = torch.tril(torch.ones(3,3)) #tril: upper triangle is 0s; 3 is dimension of the input matrix 
mask = mask ==0 #upper triangle is True 

mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [6]:
attention(encodings_for_q, 
          encodings_for_k,
           encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [7]:
class MultiHead(nn.Module):
    def __init__(self, dim_model=2, row_dim = 0, col_dim=1, heads_num=1):
        super().__init__()

        self.heads = nn.ModuleList([ Attention(dim_model, row_dim, col_dim) for i in range(heads_num)])

        self.col_dim = col_dim

    def forward(self, q_encodings, k_encodings, v_encodings):
        
        return torch.cat(
            [head(q_encodings, k_encodings, v_encodings) for head in self.heads], dim=self.col_dim)

In [8]:
torch.manual_seed(42)

<torch._C.Generator at 0x28e95a743b0>

In [9]:
multi_head_attention = MultiHead(dim_model=2, row_dim=0, col_dim=1, heads_num=1) #having heads_num = 1 makes it same as before 
multi_head_attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)

In [11]:
torch.manual_seed(42)

<torch._C.Generator at 0x28e95a743b0>

In [12]:
multi_head_attention2 = MultiHead(dim_model=2, row_dim=0, col_dim=1, heads_num=2) #having heads_num = 1 makes it same as before 
multi_head_attention2(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)