In [1]:
import torch
import torch.nn as nn
import math

torch.manual_seed(42)

<torch._C.Generator at 0x7f8f8019f530>

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(SelfAttention, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.query = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.key = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.value = nn.Linear(hidden_dim, hidden_dim, bias=False)
        
    def forward(self, X):
        Q = self.query(X)
        K = self.key(X)
        V = self.value(X)
        
        attention_value = Q @ K.transpose(-1, -2)
        attention_weight = torch.softmax(attention_value / math.sqrt(self.hidden_dim), dim=-1)
        output = attention_weight @ V

        return output        

In [3]:
X = torch.tensor([
        [
            [1, 2, 3, 4],
            [5, 6, 7, 8]
        ],
        [
            [1, 2, 3, 4],
            [5, 6, 7, 8]
        ]
    ]).float()
X.shape

torch.Size([2, 2, 4])

In [4]:
net = SelfAttention(4)
net(X)

tensor([[[1.5909, 0.5973, 2.5884, 1.2866],
         [1.3072, 0.4228, 2.2864, 1.0799]],

        [[1.5909, 0.5973, 2.5884, 1.2866],
         [1.3072, 0.4228, 2.2864, 1.0799]]], grad_fn=<UnsafeViewBackward0>)