In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class Transformer(nn.Module):
    def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_heads, ffn_hidden, n_layers, drop_prob, device, max_len):
        super(Transformer, self).__init__()
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        self.encoder = Encoder(enc_voc_size, max_len, d_model, ffn_hidden, n_heads, n_layers, device, drop_prob)
        self.decoder = Decoder(dec_voc_size, max_len, d_model, ffn_hidden, n_heads, n_layers, drop_prob, device)
        self.projection = nn.Linear(d_model, dec_voc_size)

    def make_pad_mask(self, q, k, pad_idx):
        len_q, len_k = q.size(1), k.size(1)
        k = k.ne(pad_idx).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1, 1, len_q, 1)
        q = q.ne(pad_idx).unsqueeze(1).unsqueeze(3)
        q = q.repeat(1, 1, 1, len_k)
        mask = k & q
        return mask

    def make_causal_mask(self, q):
        batch_size, len_q = q.size()
        mask = torch.tril(torch.ones(len_q, len_q)).expand(batch_size, 1, len_q, len_q)
        return mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx)
        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx) & self.make_causal_mask(trg)
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(trg, enc_output, trg_mask, src_mask)
        output = self.projection(dec_output)
        return output