In [None]:
import torch
import  torch.functional as F
from torch import nn
import math
import numpy as np

In [None]:
class Embeddings(nn.Module):
    """
    :param d_model: diamsion of word vector
    :param vocab: size of vocabary table 
    :return
    """
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        # Use Embedding module of pytorch
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        """
        Get a word vector projected by Embedding layer
        :param x: one-hot encode vector of word tokens
        """
        # multiplication of diamsion will be help to make train stable (scale the gradient)
        return self.lut(x) * math.sqrt(self.d_model)

In [None]:
class PositionalEncoding(nn.Module):
    """
    Add positional encoding on word embedding
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000) -> None:
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, d_model, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(10000.0) / d_model))
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        # regisier buffer will refuse the pe changed in backward
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:,:x.size(1)].clone().detach() # use add method combined word embedding and positional encoder
        return self.dropout(x)




In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, scale_factor, dropout=0.0) -> None:
        super(ScaledDotProductAttention, self).__init__()
        self.scale_factor = scale_factor
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.scale_factor, k.transpose(2,3))
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = torch.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        output = torch.matmul(attn, v)
        return output, attn

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0) -> None:
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)

        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attn = ScaledDotProductAttention(scale_factor=math.sqrt(d_k))
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        bz, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
        residual = q

        q = self.norm(q)
        k = self.norm(k)
        v = self.norm(v)

        # split head
        q=self.w_qs(q)
        q = q.view(bz, len_q, self.n_head, self.d_k).transpose(1,2)
        k=self.w_ks(k)
        k = k.view(bz, len_k, self.n_head, self.d_k).transpose(1,2)
        v=self.w_vs(v)
        v = v.view(bz, len_v, self.n_head, self.d_v).transpose(1,2)

        if mask is not None:
            mask = mask.unsqueeze(1)
        
        q, attn = self.attn(q, k, v, mask)

        # combine head
        q = q.transpose(1,2).contiguous()
        q = q.view(bz, len_q, -1)
        q = self.fc(q)
        q = self.dropout(q)
        q = self.norm(q+residual) # residual
        return q, attn



In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12) -> None:
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x:torch.Tensor):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)
        out = (x-mean) / torch.sqrt(var+self.eps)
        out = self.gamma * out + self.beta
        return out

In [None]:
class PoswiseFeedForward(nn.Module):
    def __init__(self, d_model) -> None:
        super(PoswiseFeedForward, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model,d_model,bias=False),
            nn.ReLU(),
            nn.Linear(d_model,d_model,bias=False)
        )
        self.norm = nn.LayerNorm(d_model, eps=1e-13)

    def forward(self, x):
        output = self.fc(x)
        output = self.norm(output + x)
        return output

In [None]:
def get_atten_pad_mask(seq_q, seq_k):
    bz, len_q = seq_q.size()
    bz, len_k = seq_k.size()
    pad_atten_mask = seq_k.data.eq(0)
    return pad_atten_mask.expand(bz, len_q, len_q)

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1) -> None:
        super(EncoderLayer, self).__init__()
        self.selfattn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout)
        self.feedward = PoswiseFeedForward(d_model)

    def forward(self, inputs, input_masks):
        output, attn = self.selfattn(inputs, inputs, inputs, input_masks)
        output = self.feedward(output)
        return output, attn

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab, d_model, n_head, n_layers, dropout) -> None:
        super(Encoder, self).__init__()
        self.emb = Embeddings(d_model, vocab)
        self.pos_emb = PositionalEncoding(d_model, dropout)
        self.layers = nn.ModuleList(EncoderLayer(n_head, d_model, d_model, d_model, dropout) for _ in range(n_layers))

    def forward(self, inputs, masks=None):
        embds = self.emb(inputs)
        output = self.pos_emb(embds)
        if masks is None:
            masks = get_atten_pad_mask(embds, embds)
        self_attns = []
        for layer in self.layers:
            output, attn = layer(output, masks)
            self_attns.append(attn)
        return output, self_attns



In [None]:
def get_train_decode_mask(seq):
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    mask = np.triu(np.ones(attn_shape), k=1)
    mask = torch.from_numpy(mask).byte()
    return mask

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1) -> None:
        super(DecoderLayer, self).__init__()
        self.selfattn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout)
        self.cross_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout)
        self.feedward = PoswiseFeedForward(d_model)

    def forward(self, decode_input, encode_output, decode_mask, encode_mask):
        decode_output, dec_attn = self.selfattn(decode_input, decode_input, decode_input, decode_mask)

        decode_output, dec_enc_attn = self.cross_attn(decode_output, encode_output, encode_output, encode_mask)

        decode_output = self.feedward(decode_output)
        return decode_output, dec_attn

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab, d_model, n_head, n_layers, dropout) -> None:
        super(Decoder, self).__init__()
        self.emb = Embeddings(d_model, vocab)
        self.pos_emb = PositionalEncoding(d_model, dropout)
        self.layers = nn.ModuleList(DecoderLayer(n_head, d_model, d_model, d_model, dropout) for _ in range(n_layers))

    def forward(self, decode_input, encode_input, encode_output):
        decode_input = self.emb(decode_input)
        output = self.pos_emb(decode_input)
        decode_mask = get_train_decode_mask(decode_input)
        decode_pad_mask = get_atten_pad_mask(decode_input)
        decode_attn_mask = torch.gt((decode_mask+decode_pad_mask), 0)
        cross_mask = get_atten_pad_mask(decode_input, encode_input)

        decode_self_attn, decode_cross_attn = [], []
        for layer in self.layers:
            output, self_attn, cross_attn = layer(output, encode_output, decode_attn_mask, cross_mask)
            decode_self_attn.append(self_attn)
            decode_cross_attn.append(cross_attn)

        return output, decode_self_attn, decode_cross_attn
        
        


In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab, d_model, n_head, en_n_layers, de_n_layers, dropout=0.1) -> None:
        super(Transformer, self).__init__()
        self.encoder = Encoder(vocab, d_model, n_head, en_n_layers, dropout)

        self.decoder = Decoder(vocab, d_model, n_head, de_n_layers, dropout)

        self.fc = nn.Linear(d_model, vocab, bias=False)

    def forward(self, encode_inputs, decode_inputs):
        
        encode_outputs, encode_attn = self.encoder(encode_inputs)

        decode_outputs, decode_self_attn, decode_cross_attn = self.decoder(decode_inputs, encode_inputs, encode_outputs)

        output = self.fc(decode_outputs)

        logits = torch.softmax(output, dim=-1)

        return output, logits, encode_attn, decode_self_attn, decode_cross_attn
