In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Attention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super(Attention, self).__init__()
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.row_dim = row_dim
        self.col_dim = col_dim
    
    def forward(self, encodings_for_q, encodings_for_k, ncodings_for_v, mask=None):
        Q = self.W_q(encodings_for_q)
        K = self.W_k(encodings_for_k)
        V = self.W_v(ncodings_for_v)
        
        sims = torch.matmul(Q, K.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(K.size(self.col_dim) ** 0.5)
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
        attention_percent = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percent, V)
        
        return attention_scores

In [4]:
ecodings_for_q = torch.tensor([[1.16, 0.23],
                            [0.57, 1.36],
                            [4.41, -2.16]])
ecodings_for_k = torch.tensor([[1.16, 0.23],
                            [0.57, 1.36],
                            [4.41, -2.16]])
ecodings_for_v = torch.tensor([[1.16, 0.23],
                            [0.57, 1.36],
                            [4.41, -2.16]])
torch.manual_seed(42)

attention = Attention(d_model=2, row_dim=0, col_dim=1)
attention(ecodings_for_k, ecodings_for_k, ecodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [5]:
class MutiHeadAttention(nn.Module):
    def __init__(self, d_model=2, n_heads=2, row_dim=0, col_dim=1, num_heads=1):
        super(MutiHeadAttention, self).__init__()
        self.heads = nn.ModuleList(
            [Attention(d_model, row_dim, col_dim) for _ in range(num_heads)]
            )
        self.col_dim = col_dim
    
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v):
        return torch.cat([head(encodings_for_q, encodings_for_k, encodings_for_v) for head in self.heads], dim=self.col_dim)

In [7]:
torch.manual_seed(42)
mutiHeadAttention = MutiHeadAttention(d_model=2, n_heads=2, row_dim=0, col_dim=1, num_heads=1)
mutiHeadAttention(ecodings_for_k, ecodings_for_k, ecodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)

In [8]:
torch.manual_seed(42)
mutiHeadAttention = MutiHeadAttention(d_model=2, n_heads=2, row_dim=0, col_dim=1, num_heads=2)
mutiHeadAttention(ecodings_for_k, ecodings_for_k, ecodings_for_v)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)