In [1]:
import torch 
from torch import nn 
import torch.nn.functional as F


class Attention(nn.Module):

    def __init__(self ,
                 d_model = 2,
                 row_dim = 0,
                 col_dim = 1):

        super().__init__()

        # weight matrices
        
        self.W_q = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False) 
        self.W_k = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False)
        
        self.W_v = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False)

        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model


    
    def forward(self , 
                encodings_q ,
                encodings_k ,
                encodings_v,
               mask = None):
        
        q = self.W_q(encodings_q)
        k = self.W_k(encodings_k)
        v = self.W_v(encodings_v)
        
        sims = torch.matmul(q , k.transpose(dim0 = self.row_dim,
                                           dim1 = self.col_dim))
        
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        # scaled_sims = sims / torch.tensor(self.d_model**0.5)

        if mask is not None:
            # mask = torch.tril(torch.ones(k.size(self.row_dim),k.size(self.row_dim)))
            # mask = mask == 0
            scaled_sims = scaled_sims.masked_fill(mask = mask,
                                                 value = -1e-9)       
       
        attention_percents = F.softmax(scaled_sims , dim = self.col_dim)
        attention_scores = torch.matmul(attention_percents , v)

        return attention_scores

    
class MultiHeadAttention(nn.Module):
    
    def __init__(self ,
                 d_model = 2,
                 row_dim = 0,
                 col_dim = 1,
                num_heads = 1):
        super().__init__()


        self.heads = nn.ModuleList(
            [Attention(d_model = d_model,row_dim = row_dim,col_dim = col_dim) 
             for _ in range(num_heads)]
        )

        self.col_dim = col_dim


    def forward(self , 
                encodings_q ,
                encodings_k ,
                encodings_v):

        return torch.cat(
            [head(encodings_q ,
                encodings_k ,
                encodings_v)
             for head in self.heads] , dim = self.col_dim
        )

In [2]:
encodings_q = torch.tensor([[1.3921, 1.2440],
        [1.2494, 1.1960],
        [3.4989, 2.2427]])

encodings_k = torch.tensor([[1.3921, 1.2440],
        [1.2494, 1.1960],
        [3.4989, 2.2427]])


encodings_v = torch.tensor([[1.3921, 1.2440],
        [1.2494, 1.1960],
        [3.4989, 2.2427]])

torch.manual_seed(42)

<torch._C.Generator at 0x21dff2b7d30>

In [3]:
mha = MultiHeadAttention(d_model = 2,
                 row_dim = 0,
                 col_dim = 1,
                num_heads = 3)

In [4]:
mha(encodings_q,encodings_q,encodings_q)

tensor([[ 0.4082,  1.3614, -0.9187,  0.1218,  1.3103,  0.2117],
        [ 0.4108,  1.3661, -0.9077,  0.1238,  1.3070,  0.2119],
        [ 0.3628,  1.2806, -1.0814,  0.0919,  1.3522,  0.2101]],
       grad_fn=<CatBackward0>)

In [5]:
class PositionalEmbeddings(nn.Module):

    def __init__(self ,
                 d_model  , 
                 max_seq_length = 512):
        super().__init__()

        pe = torch.zeros(max_seq_length , d_model)

        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp()

        

In [8]:
import math 
d_model = 2

div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))


In [9]:
div_term

tensor([1.])