In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, ffn_hidden, drop_prob=0.1):
        super(DecoderLayer, self).__init__()
        self.attention1 = MultiHeadAttention(d_model, n_head)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(drop_prob)

        self.cross_attention = MultiHeadAttention(d_model, n_head)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(drop_prob)

        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(drop_prob)

    def forward(self, dec, enc, src_mask, tgt_mask):
        _x = dec
        x = self.attention1(dec, dec, dec, tgt_mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        if enc is not None:
            _x = x
            x = self.cross_attention(x, enc, enc, src_mask)
            x = self.dropout2(x)
            x = self.norm2(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x

class Decoder(nn.Module):
    def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
        super(Decoder, self).__init__()
        self.embedding = TransformerEmbedding(dec_voc_size, d_model, device, drop_prob, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_head, ffn_hidden, drop_prob) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, dec_voc_size)

    def forward(self, dec, enc, tgt_mask, src_mask):
        dec = self.embedding(dec)
        for layer in self.layers:
            dec = layer(dec, enc, src_mask, tgt_mask)
        dec = self.fc(dec)
        return dec