In [1]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
class Encoder(nn.Module):
    def __init__(self, model_dim:int=512, num_heads:int=8):
        super().__init__()
        self.num_stacks = 6
        self.attention = nn.MultiheadAttention(embed_dim=model_dim, num_heads=num_heads)  ### IMPLEMENT FROM SCRATCH
        self.ffnn = nn.Sequential(
            nn.Linear(model_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, model_dim)
        )
        self.layernorm = nn.LayerNorm(normalized_shape=model_dim)
        self.queryM = torch.randn(model_dim, model_dim/num_heads)
        self.keyM = torch.randn(model_dim, model_dim/num_heads)
        self.valueM = torch.randn(model_dim, model_dim/num_heads)

    def forward(self, inputs):
        return self._stack(inputs, self.num_stacks)

    def _stack(self, input, num_stacks):
        if num_stacks < 1:
            return input
        
        query = self.queryM * input
        key = self.keyM * input
        value = self.valueM * input

        layer = self.layernorm(input + self.attention.forward(query=query, key=key, value=value))
        
        output = self.layernorm(layer + self.ffnn(layer))

        return self._stack(output, num_stacks-1)    