In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import math

In [None]:
class Transformer(nn.Module):
    def __init__(
                self,
                src_pad_idx,
                trg_pad_idx,
                encoder_voc_size,
                decoder_voc_size,
                max_len,
                d_model,
                n_head,
                ffn_hidden,
                n_layer,
                dropout,
                device,
                 ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            encoder_voc_size,
            max_len,
            d_model,
            n_head,
            ffn_hidden,
            n_layer,
            dropout,
            device,
        )
        self.decoder = Decoder(
            decoder_voc_size,
            max_len,
            d_model,
            n_head,
            ffn_hidden,
            n_layer,
            dropout,
            device,
        )
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_pad_mask(self, Q, K, pad_idx_q, pad_idx_k):
        len_q = Q.size(1)
        len_k = K.size(1)
        Q = Q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3) # (B, 1, Lq, 1)
        Q = Q.repeat(1, 1, 1, len_k) 
        K = K.ne(pad_idx_k).unsqueeze(1).unsqueeze(2) # (B, 1, 1, Lk)
        K = K.repeat(1, 1, len_q, 1)
        mask = Q & K
        return mask
    
    def make_causal_mask(self, Q, K):
        mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(device)
        return mask
    
    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * self.make_casual_mask(trg, trg)
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(trg, encoder_output, trg_mask, src_mask)
        encoder = self.encoder(src, src_mask)
        out = self.decoder(trg, encoder, trg_mask, src_mask)
        return out