In [64]:
import torch
from torch import nn

# We simulate data like this:
BATCH_SIZE = 2
SEQ_LENGTH = 10
EMBEDDING_DIM = 64
MODEL_DIM = 64
# Fake input
dummy_input = torch.rand(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_DIM)
s1, s2 = dummy_input[0, :], dummy_input[1, :]
s1 = s1.unsqueeze(0)
s2 = s2.unsqueeze(0)

In [67]:
from torch import nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, model_dim):
        super().__init__()

        self.w_key = nn.Linear(model_dim, model_dim, bias=False)
        self.w_query = nn.Linear(model_dim, model_dim, bias=False)
        self.w_value = nn.Linear(model_dim, model_dim, bias=False)
        self.scale = model_dim ** 0.5

    def forward(self, batch):
        key = self.w_key(batch)
        query = self.w_query(batch)
        value = self.w_value(batch)

        return (
            F.softmax(
                query @ key.transpose(1, 2) / self.scale,
                dim=-1
            ) @ value
        )

In [68]:
class Encoder(nn.Module):
    def __init__(self, model_dim):
        super().__init__()
        self.attention = Attention(model_dim)
        self.layer_norm_attn = nn.LayerNorm(normalized_shape=model_dim)
        self.layer_norm_ffn = nn.LayerNorm(normalized_shape=model_dim)
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, model_dim*4),
            nn.ReLU(),
            nn.Linear(model_dim*4, model_dim),
        )
    
    def add_and_norm(self, input_embedding, attention_vector, norm):
        return norm(
            input_embedding + attention_vector
        )

    def forward(self, x):
        attention_vector = self.attention(x)
        x = self.add_and_norm(x, attention_vector, norm=self.layer_norm_attn)
        ffn_output = self.ffn(x)
        x = self.add_and_norm(x, ffn_output, norm=self.layer_norm_ffn)
        return x

In [69]:
encoder = Encoder(model_dim=MODEL_DIM)
encoder(dummy_input)

tensor([[[-0.5437, -0.2494, -2.4500,  ...,  0.8565,  1.9089,  0.4534],
         [-1.2249, -0.7665, -1.5751,  ...,  0.5661,  0.8805, -0.3863],
         [-0.8817, -1.6961, -1.9978,  ...,  0.0369,  0.4597,  0.6653],
         ...,
         [ 0.4076,  0.1817, -1.8120,  ...,  1.0638,  0.6535,  1.4683],
         [ 0.7207, -0.3525, -2.1541,  ...,  0.3371,  0.7966, -0.6754],
         [ 0.4961, -1.0684, -2.2194,  ..., -0.5459,  1.0376, -1.3120]],

        [[ 0.0836,  0.0376, -2.3489,  ...,  0.4005,  2.4390, -0.3639],
         [ 0.9453, -0.2103, -0.9645,  ..., -0.7632,  0.5418,  0.8988],
         [ 0.1297,  0.3100, -0.9045,  ...,  0.7685,  1.0156, -1.2542],
         ...,
         [ 0.0300, -1.4090, -1.9024,  ..., -0.9404,  1.8125, -0.2277],
         [-0.5926, -0.8747, -0.4620,  ...,  0.7621,  2.1470, -0.0145],
         [ 0.5113,  0.4476, -0.9220,  ...,  0.7993,  2.6659, -0.5982]]],
       grad_fn=<NativeLayerNormBackward0>)