In [2]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.module() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax()

In [3]:
class Attention(nn.Module): 
                            
    def __init__(self, d_model=2,  
                 row_dim=0, 
                 col_dim=1):
        
        super().__init__()
        
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
            
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

In [4]:
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

In [5]:
torch.manual_seed(42)

<torch._C.Generator at 0x7e7a53849f90>

In [11]:
attention = Attention(d_model=2,row_dim=0,col_dim=1)
attention

Attention(
  (W_q): Linear(in_features=2, out_features=2, bias=False)
  (W_k): Linear(in_features=2, out_features=2, bias=False)
  (W_v): Linear(in_features=2, out_features=2, bias=False)
)

In [12]:
attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[0.6226, 0.1312],
        [0.5522, 0.2499],
        [0.5669, 0.2324]], grad_fn=<MmBackward0>)

In [13]:
class MultiHearAttention(nn.Module):
    def __init__(self,d_model=2,row_dim=0,col_dim=1,num_heads=1):

        super().__init__()

        self.heads = nn.ModuleList(
            [Attention(d_model, row_dim, col_dim) 
             for _ in range(num_heads)]
        )
        self.col_dim = col_dim

    def forward(self,encodings_for_q, encodings_for_k, encodings_for_v):
        return torch.cat([head(encodings_for_q, 
                               encodings_for_k,
                               encodings_for_v) 
                          for head in self.heads],dim=self.col_dim)

In [39]:
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHearAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=2)


In [40]:
multiHeadAttention

MultiHearAttention(
  (heads): ModuleList(
    (0-1): 2 x Attention(
      (W_q): Linear(in_features=2, out_features=2, bias=False)
      (W_k): Linear(in_features=2, out_features=2, bias=False)
      (W_v): Linear(in_features=2, out_features=2, bias=False)
    )
  )
)

In [41]:
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)

In [24]:
attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[0.6226, 0.1312],
        [0.5522, 0.2499],
        [0.5669, 0.2324]], grad_fn=<MmBackward0>)

In [42]:
heads = nn.ModuleList(
            [Attention(2, 0, 1) 
             for _ in range(2)]
        )
heads

ModuleList(
  (0-1): 2 x Attention(
    (W_q): Linear(in_features=2, out_features=2, bias=False)
    (W_k): Linear(in_features=2, out_features=2, bias=False)
    (W_v): Linear(in_features=2, out_features=2, bias=False)
  )
)

In [43]:
torch.cat([head(encodings_for_q, 
                               encodings_for_k,
                               encodings_for_v) 
                          for head in heads],dim=1)

tensor([[0.6226, 0.1312, 1.0106, 0.8625],
        [0.5522, 0.2499, 1.4153, 1.0420],
        [0.5669, 0.2324, 0.3679, 0.5894]], grad_fn=<CatBackward0>)

In [37]:
[head(encodings_for_q, 
                               encodings_for_k,
                               encodings_for_v) 
                          for head in heads]

[tensor([[-0.6674,  0.5665],
         [-0.5970,  1.5640],
         [-0.7832, -0.0405]], grad_fn=<MmBackward0>),
 tensor([[ 0.7700, -0.9269],
         [ 0.7713, -0.9210],
         [ 0.7669, -0.8751]], grad_fn=<MmBackward0>),
 tensor([[-0.2467, -0.9469],
         [-0.1830, -0.9897],
         [-0.4784, -0.8523]], grad_fn=<MmBackward0>)]