In [None]:
import torch
import torch.nn as nn


In [None]:
MAX_LEN = _

class MaskedMultiheadAttention(nn.Module):
    def __init__(self, mask=False):
        super(MaskedMultiheadAttention, self).__init__()
        assert args.nhid_tran % args.nhead == 0
        self.key = nn.Linear(args.nhid_tran, args.nhid_tran)
        self.query = nn.Linear(args.nhid_tran, args.nhid_tran)
        self.value = nn.Linear(args.nhid_tran, args.nhid_tran)
        # regularization
        self.attn_drop = nn.Dropout(args.attn_pdrop)
        # output projection
        self.proj = nn.Linear(args.nhid_tran, args.nhid_tran)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        if mask:
            self.register_buffer("mask", torch.tril(torch.ones(MAX_LEN, MAX_LEN)))
        self.nhead = args.nhead
        self.d_k = args.nhid_tran // args.nhead

    def forward(self, q, k, v, mask=None):
        q = self.query(q).reshape(q.shape[0], q.shape[1], self.nhead, -1).contiguous().transpose(1,2).contiguous()
        k = self.key(k).reshape(k.shape[0], k.shape[1], self.nhead, -1).contiguous().transpose(1,2).contiguous()
        v = self.value(v).reshape(v.shape[0], v.shape[1], self.nhead, -1).contiguous().transpose(1,2).contiguous()
        
        similiarity = torch.matmul(q, k.transpose(-1,-2)) / self.d_k ** 0.5

        
        # if hasattr(self, 'mask'): 
        if mask is None:
          # similiarity.shape = (B, nhead, T_q, T); 
          # self.mask.shape = (T_q, T) --> (1, 1, T_q, T)
          mask = self.mask[:similiarity.shape[-2], :similiarity.shape[-1]].unsqueeze(dim=0).unsqueeze(dim=0)
          mask = mask.repeat(similiarity.shape[0], 1, 1, 1)
        else:
          # similiarity.shape = (B, nhead, T_q, T)
          # mask shape = (B,T)  --> (B, 1, 1, T)
          mask = mask.unsqueeze(dim=1).unsqueeze(dim=1)
          mask = mask.repeat(1, 1, similiarity.shape[2], 1)
        
        similiarity = similiarity.masked_fill(mask==0, -np.inf)
        scaled = self.attn_drop(torch.softmax(similiarity, dim=-1))
        attn_out = torch.matmul(scaled, v).transpose(1,2)
        attn_out = attn_out.contiguous().reshape(attn_out.shape[0], attn_out.shape[1], -1)
        output = self.proj(attn_out)

        return output


In [None]:
class TransformerEncLayer(nn.Module):
    def __init__(self):
        super(TransformerEncLayer, self).__init__()
        self.ln1 = nn.LayerNorm(args.nhid_tran)
        self.ln2 = nn.LayerNorm(args.nhid_tran)
        self.attn = MaskedMultiheadAttention()
        self.dropout1 = nn.Dropout(args.resid_pdrop)
        self.dropout2 = nn.Dropout(args.resid_pdrop)
        self.ff = nn.Sequential(
            nn.Linear(args.nhid_tran, args.nff),
            nn.ReLU(), 
            nn.Linear(args.nff, args.nhid_tran)
        )

    def forward(self, x, mask=None):
        id1 = self.ln1(x) 
        attn_out = self.attn(id1, id1, id1, mask)
        attn_out = self.dropout1(attn_out)
        attn_out = id1 + attn_out
        
        id2 = self.ln2(attn_out)
        h = self.ff(id2)
        h = self.dropout2(h)

        outputs = id2 + h

        return outputs 

In [None]:
class TransformerDecLayer(nn.Module):
    def __init__(self):
        super(TransformerDecLayer, self).__init__()
        self.ln1 = nn.LayerNorm(args.nhid_tran)
        self.ln2 = nn.LayerNorm(args.nhid_tran)
        self.ln3 = nn.LayerNorm(args.nhid_tran)
        self.dropout1 = nn.Dropout(args.resid_pdrop)
        self.dropout2 = nn.Dropout(args.resid_pdrop)
        self.dropout3 = nn.Dropout(args.resid_pdrop)
        self.attn1 = MaskedMultiheadAttention(mask=True) # self-attention 
        self.attn2 = MaskedMultiheadAttention() # tgt to src attention
        self.ff = nn.Sequential(
            nn.Linear(args.nhid_tran, args.nff),
            nn.ReLU(), 
            nn.Linear(args.nff, args.nhid_tran)
        )
        
    def forward(self, x, enc_o, enc_mask=None):
        id1 = self.ln1(x)
        attn_out1 = self.attn1(id1, id1, id1)
        attn_out1 = self.dropout1(attn_out1)
        attn_out1 = attn_out1 + id1

        id2 = self.ln2(attn_out1)
        attn_out2 = self.attn2(id2, enc_o, enc_o, enc_mask)
        attn_out2 = self.dropout2(attn_out2)
        attn_out2 = attn_out2 + id2 

        id3 = self.ln3(attn_out2)
        h = self.ff(id3)
        h = self.dropout3(h)
        outputs = h + id3 

        return outputs


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len=4096):
        super().__init__()
        dim = args.nhid_tran
        pos = np.arange(0, max_len)[:, None]
        i = np.arange(0, dim // 2)
        denom = 10000 ** (2 * i / dim)

        pe = np.zeros([max_len, dim])
        pe[:, 0::2] = np.sin(pos / denom)
        pe[:, 1::2] = np.cos(pos / denom)
        pe = torch.from_numpy(pe).float()

        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.shape[1]]

class TransformerEncoder(nn.Module):

    def __init__(self):
        super(TransformerEncoder, self).__init__()
        # input embedding stem
        self.tok_emb = nn.Embedding(src_ntoken, args.nhid_tran)
        self.pos_enc = PositionalEncoding()
        self.dropout = nn.Dropout(args.embd_pdrop)
        # transformer
        self.transform = nn.ModuleList([TransformerEncLayer() for _ in range(args.nlayers_transformer)])
        # decoder head
        self.ln_f = nn.LayerNorm(args.nhid_tran)
        

    def forward(self, x, mask):
        emb = self.tok_emb(x)
        pos_enc = self.pos_enc(emb)
        out = self.dropout(pos_enc)

        for iter, transformerLayer in enumerate(self.transform):
          out = transformerLayer(out, mask)
          
        outputs = self.ln_f(out)

        return outputs

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self):
        super(TransformerDecoder, self).__init__()
        self.tok_emb = nn.Embedding(trg_ntoken, args.nhid_tran)
        self.pos_enc = PositionalEncoding()
        self.dropout = nn.Dropout(args.embd_pdrop)
        self.transform = nn.ModuleList([TransformerDecLayer() for _ in range(args.nlayers_transformer)])
        self.ln_f = nn.LayerNorm(args.nhid_tran)
        self.lin_out = nn.Linear(args.nhid_tran, trg_ntoken)
        self.lin_out.weight = self.tok_emb.weight


    def forward(self, x, enc_o, enc_mask):
        emb = self.tok_emb(x)
        pos_enc = self.pos_enc(emb)
        out = self.dropout(pos_enc)

        for iter, transformerLayer in enumerate(self.transform):
          out = transformerLayer(out, enc_o, enc_mask)

        h = self.ln_f(out)
        logits = self.lin_out(h)

        logits /= args.nhid_tran ** 0.5 # Scaling logits. Do not modify this
        return logits

In [None]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = TransformerEncoder()
        self.decoder = TransformerDecoder()
        
    def forward(self, x, y, length_x, max_len=None, teacher_forcing=True):
        if max_len is None:
          max_len = y.shape[1]

        if length_x is not None:
          enc_mask = torch.ones(x.shape).to(device)

          for idx_i in range(enc_mask.shape[0]):
            for idx_j in range(enc_mask.shape[1]):
              if length_x[idx_i] > idx_j:
                continue
              enc_mask[idx_i, idx_j] = 0

        enc_o = self.encoder(x, enc_mask)

        if teacher_forcing or self.training:
          outputs = self.decoder(y[:, :-1], enc_o, enc_mask)

          return outputs

        else:
          dec_input = y[:, :1]
          dec_output = None

          for iter in range(max_len-1):
            dec_output = self.decoder(dec_input, enc_o, enc_mask)
            dec_input = torch.cat((dec_input, dec_output[:,-1:].argmax(-1)), dim=1)
            
          return dec_output
