In [None]:
import torch
from torch import nn
from utls import MultiHeadAttention, LayerNorm, PositionwiseFeedForward, TransformerEmbedding, Encoder

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, embedding_dim, ffn_hidden, n_head, dropout):
        super().__init__()
        self.attention1 = MultiHeadAttention(embedding_dim, n_head)
        self.norm1 = LayerNorm(embedding_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.cross_attention = MultiHeadAttention(embedding_dim, n_head)
        self.norm2 = LayerNorm(embedding_dim)
        self.dropout2 = nn.Dropout(dropout)
        self.ffn = PositionwiseFeedForward(embedding_dim, ffn_hidden, dropout)
        self.norm3 = LayerNorm(embedding_dim)
        self.dropout3 = nn.Dropout(dropout)
        
    def forward(self, dec, enc, trg_mask, src_mask): # target mask & source mask
        _x = dec
        x = self.attention1(dec, dec, dec, trg_mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        _x = x
        x = self.cross_attention(x, enc, enc, src_mask)
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        x = self.ffn(x)
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x
        

In [None]:
class Decoder(nn.Module):
    def __init__(self, voc_size, embedding_dim, max_len, n_layers, ffn_hidden, n_head, dropout=0.1, device='cpu'):
        super().__init__()
        self.embedding = TransformerEmbedding(voc_size, embedding_dim, max_len, dropout, device)
        self.layers = nn.ModuleList(
            [
                DecoderLayer(embedding_dim, ffn_hidden, n_head, dropout) for _ in range(n_layers)
            ]
        ).to(device)
        self.fc = nn.Linear(embedding_dim, voc_size).to(device)
        
    def forward(self, dec, enc, t_mask, s_mask):
        dec = self.embedding(dec)
        for layer in self.layers:
            dec = layer(dec, enc, t_mask, s_mask)
        dec = self.fc(dec)
        return dec

In [None]:
device = 'cuda:0'
decoder = Decoder(voc_size=5, embedding_dim=512, max_len=8, n_layers=3, ffn_hidden=256, n_head=8, dropout=0.1, device=device)
encoder = Encoder(voc_size=5, embedding_dim=512, max_len=8, n_layers=3, ffn_hidden=256, n_head=8, dropout=0.1, device=device)

In [None]:
x = torch.tensor([[1, 2, 3, 4, 2, 3, 1, 1],
                  [2, 3, 4, 1, 0, 0, 0, 0]])

enc = encoder(x.to(device), mask=None)
dec = torch.tensor([[1, 2, 3, 2, 0, 0, 0, 0], 
                    [1, 2, 3, 3, 0, 0, 0, 0]])           # Because of the cross-attention, here we need some padding to align the enc and dec
dec = decoder(dec.to(device), enc, None, None)