In [1]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
class Encoder(nn.Module):
    def __init__(self, model_dim:int=512, num_heads:int=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=model_dim, num_heads=num_heads)  ### IMPLEMENT FROM SCRATCH
        self.ffnn = nn.Sequential(
            nn.Linear(model_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, model_dim)
        )
        self.layernorm = nn.LayerNorm(normalized_shape=model_dim)

        self.queryM = nn.Parameter(torch.randn(model_dim, model_dim)) 
        self.keyM = nn.Parameter(torch.randn(model_dim, model_dim))
        self.valueM = nn.Parameter(torch.randn(model_dim, model_dim))

    def forward(self, inputs, num_stacks:int=6):
        return self._stack(inputs, num_stacks)

    def _stack(self, input, num_stacks:int=6):
        if num_stacks < 1:
            return input
        
        query = input @ self.queryM
        key = input @ self.keyM
        value = input @ self.valueM

        layer = self.layernorm(input + self.attention.forward(query=query, key=key, value=value))
        
        output = self.layernorm(layer + self.ffnn(layer))

        return self._stack(output, num_stacks-1)    

In [None]:
class Decoder(nn.module):
    def __init__(self, model_dim:int=512, num_heads:int=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=model_dim, num_heads=num_heads)
        self.ffnn = nn.Sequential(
            nn.Linear(model_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, model_dim)
        )
        self.layernorm = nn.LayerNorm(normalized_shape=model_dim)

        self.queryMSelf = nn.Parameter(torch.randn(model_dim, model_dim)) # for self attention layers
        self.keyMSelf = nn.Parameter(torch.randn(model_dim, model_dim))   # ^
        self.valueMSelf = nn.Parameter(torch.randn(model_dim, model_dim)) # ^

        self.queryMDec = nn.Parameter(torch.randn(model_dim, model_dim)) # for 'encoder-decoder' layer
        self.keyMEnc = nn.Parameter(torch.randn(model_dim, model_dim))   # ^
        self.valueEnc = nn.Parameter(torch.randn(model_dim, model_dim))  # ^

    def forward(self, encoder_output, num_stacks:int=6):
        return self._stack(encoder_output, num_stacks)

    def _stack(self, input, encoder_output, num_stacks:int=6):
        if num_stacks < 1:
            return input
        
        querySelf = input @ self.queryMSelf
        keySelf = input @ self.keyMSelf
        valueSelf = input @ self.valueMSelf

        queryDec = input @ self.queryMDec
        keySelf = input @ self.keyMEnc
        valueSelf = input @ self.valueEnc

        self.layernorm(input + self.attention.forward(query=querySelf, key=keySelf, value=valueSelf)) ### Add attn_mask tensor. understand how to initialize (randn, rand, etc.). shape is (target_seq_len, source_seq_len)






