In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [26]:
class SingleAttention(nn.Module):
    def __init__(self, emb_dim, model_dim, dropout=0.1):
        # super() delegates the function call to the parent class, nn.Module
        super(SingleAttention, self).__init__()
        
        self.qkv_weights = nn.ModuleList()
        
        q_weight = nn.Linear(emb_dim, model_dim)
        k_weight = nn.Linear(emb_dim, model_dim)
        v_weight = nn.Linear(emb_dim, model_dim)
        
        self.qkv_weights.append(nn.ModuleList([q_weight, k_weight, v_weight]))
        
        self.out_proj = nn.Linear(model_dim, emb_dim)
        
        #dropout layer
        self.drop = nn.Dropout(dropout)
        
    def attention_op(self, q, k, v):
        # Dot-product bewteen
        q_k = torch.matmul(q, torch.transpose(k, -2, -1))
        argument = q_k / math.sqrt(model_dim)
        
        #softmax part
        S = F.softmax(argument, dim=-1)
        
        result = torch.matmul(S, v)
        
        return result
    
    def forward(self, query, key, value):
        output = [self.attention_op(Q(query), K(key), V(value)) for Q, K, V in self.qkv_weights]
        output = torch.cat(output, dim=-1) #last dimension
        
        # result = nn.Linear(output) # this already being done by self.out_proj
        
        output = self.out_proj(output)
        
        return self.drop(output)

In [27]:
torch.manual_seed(0)
emb_dim = 512
model_dim = 512
seq_len = 10

test = torch.rand((1, seq_len, emb_dim))
test

tensor([[[0.4963, 0.7682, 0.0885,  ..., 0.6673, 0.3561, 0.8091],
         [0.3613, 0.3136, 0.6259,  ..., 0.1876, 0.2099, 0.7210],
         [0.4650, 0.0278, 0.2117,  ..., 0.5025, 0.4458, 0.2083],
         ...,
         [0.1100, 0.0771, 0.6113,  ..., 0.7174, 0.6193, 0.0636],
         [0.8637, 0.4471, 0.2902,  ..., 0.9446, 0.1363, 0.9336],
         [0.9479, 0.9039, 0.5435,  ..., 0.6666, 0.7545, 0.5523]]])

In [28]:
single = SingleAttention(emb_dim, model_dim)

expected = nn.MultiheadAttention(emb_dim, num_heads=1)

In [29]:
out_mine = single(test, test, test)

out_torch = expected(test, test, test)

In [37]:
print("mine: ", out_mine.shape)
print("torch: ", out_torch[0].shape)

mine:  torch.Size([1, 10, 512])
torch:  torch.Size([1, 10, 512])
