In [1]:
import torch
from torch import nn
import numpy as np
import math
import copy


In [2]:
# Arch units
## Self Attention Unit
## Multi Head Attention
## Encode Decode Unit
## Norm + Residual Layer
## Feed Forward
## Input Positional Encoding 


In [3]:
# parameters from paper
word_emb_dim = 50
N = 6
d_model = 32 # 512
h = 4 # 8
d_k = d_v = d_model//h
d_ff = 128 # 2048
vocab_size = 100


In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, h*d_v)
        self.W_K = nn.Linear(d_model, h*d_v)
        self.W_V = nn.Linear(d_model, h*d_v)
        self.W_O = nn.Linear(h*d_v, d_model)
        
    def forward(self, Q, K, V, mask):
        b, seq, vec = Q.size()
        query = self.W_Q(Q).view(b, seq, h, d_k) # view (b, seq, h, d_k)
        key = self.W_K(K).view(b, seq, h, d_k) # view (b, seq, h, d_k)
        value  = self.W_V(V).view(b, seq, h, d_v) # view (b, seq, h, d_k)
        
        query = query.transpose(1, 2).contiguous() # view (b, h, seq, d_k)
        key = key.transpose(1, 2).contiguous() # view (b, h, seq, d_k)
        value = value.transpose(1, 2).contiguous() # view (b, h, seq, d_k)
        
        qk = query.matmul(key.transpose(-2,-1))
        scale_qk = qk/(math.sqrt(d_k)) # shape (b, h, seq, seq)
        
        if mask:
            scale_qk = scale_qk.masked_fill(mask==0, 1e-9)
        
        softmax_qk = nn.functional.softmax(scale_qk, dim=-1) # (b, h, seq, seq)
        weighted_value = softmax_qk.matmul(value) # (b, h, seq, d_v)
        return self.W_O(weighted_value.transpose(2,1).contiguous().view(b, seq, h*d_k)) # (b, h, d_model)


In [5]:
class LayerNorm(nn.Module):
    def __init__(self, d_mod=d_model):
        super(LayerNorm, self).__init__()
        self.d_mod = d_mod
        # https://stackoverflow.com/questions/39095252/fail-to-implement-layer-normalization-with-keras
        # https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter
        self.alpha = nn.Parameter(torch.ones(d_mod))
        self.beta = nn.Parameter(torch.zeros(d_mod))
    def forward(self, x, eps=1e-6):
        u = x.mean(-1, keepdim=True)
        sigma = x.std(-1, keepdim=True)
        return self.alpha * (x - u)/(1/(sigma + eps)) + self.beta
    

In [6]:
class EncoderCell(nn.Module):
    def __init__(self):
        super(EncoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_2 = LayerNorm()
    
    def forward(self, x, src_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, src_mask))  # Layer 1
        return self.norm_2(x_norm_1 + self.pff(x_norm_1)) # Layer 2
    

In [7]:
class DecoderCell(nn.Module):
    def __init__(self):
        super(DecoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.attn = MultiHeadAttention()
        self.norm_2 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_3 = LayerNorm()
        
    def forward(self, x, enc, src_mask=None, trg_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, trg_mask))
        x_norm_2 = self.norm_2(x_norm_1 + self.attn(x_norm_1, enc, enc, src_mask))
        return self.norm_3(x_norm_2 + self.pff(x_norm_2)) # (b, seq, d_model)


In [8]:
class EncoderStack(nn.Module):
    def __init__(self, N):
        super(EncoderStack, self).__init__()
        self.N = N
        
    def forward(self, x, src_mask):
        cell = EncoderCell()
        encoders = nn.ModuleList([copy.deepcopy(cell) for _ in range(self.N)])
        
        for enc in encoders:
            x = enc(x, src_mask)
        return x

In [9]:
class DecoderStack(nn.Module):
    def __init__(self, N):
        super(DecoderStack, self).__init__()
        self.N = N
    
    def forward(self, x, enc, src_mask, trg_mask):
        cell = DecoderCell()
        decoders = nn.ModuleList([copy.deepcopy(cell) for _ in range(self.N)])
        for decdr in decoders:
            x = decdr(x, enc, src_mask, trg_mask)
        return x

In [10]:
class Transformer(nn.Module):
    def __init__(self, log_softmx=True):
        super(Transformer, self).__init__()
        self.W_in = nn.Linear(word_emb_dim, d_model)
        self.encoder = EncoderStack(N)
        self.decoder = DecoderStack(N)
        # https://stats.stackexchange.com/questions/392213/understand-the-output-layer-of-transformer
        self.W_out = nn.Linear(d_model, vocab_size)
        self.sftmx = log_softmx
        
    def forward(self, inp_x, inp_y, src_mask, trg_mask):
        inp_x, inp_y = inp_x/math.sqrt(d_model), inp_y/math.sqrt(d_model)
        inp_x, inp_y = self.W_in(inp_x), self.W_in(inp_y) # (b, seq, word_embedding) -> (b, seq, d_model)
        enc_x = self.encoder(inp_x, src_mask)
        dec_x = self.decoder(inp_y, enc_x, src_mask, trg_mask)
        if self.sftmx:
            return nn.functional.log_softmax(self.W_out(dec_x), dim=-1)
        return self.W_out(dec_x)
        

In [11]:
class EmbeddingLayer(nn.Module):
    def __init__(self):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, word_emb_dim)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(d_model)


In [12]:
# https://www.reddit.com/r/MachineLearning/comments/bjgpt2/d_confused_about_using_masking_in_transformer/

In [13]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html
class Batch:
    def __init__(self, src, trg=None, pad=0): # size src, trg (b, seq)
        self.src = src 
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:,:-1] # size (b,0:seq-1)
            self.trg_y = trg[:,1:] # size (b,1:seq)
            self.trg_mask = self.std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() # size (1)
    
    @staticmethod
    def std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2) # size (b, 1, seq)
        return tgt_mask & subsequent_mask(tgt.shape[-1]) # size (b, 1, seq) * (1, seq, seq) -> (b, seq, seq) 
    
    @staticmethod
    def subsequent_mask(size):
        return torch.from_numpy(np.triu(np.ones((1,size,size)), k=1).astype('uint8')) == 0 # size (1, seq, seq)


In [24]:
# https://github.com/pytorch/pytorch/issues/7455    
def labelSmoothingLoss(x, y, epsilon, padding_value=0, cls=2, d=-1):
    # concat x, y batch as index_fill_ don't support vector dim > 1
    x_ = x.data.clone()
    x_.fill_(epsilon / (x_.size(-1) - cls))
    x_.scatter_(d, y.data.unsqueeze(-1), (1 - epsilon))
    x_[:, padding_value] = 0
    mask = torch.nonzero(y.data == padding_value)
    return x_, mask
    if mask.dim() > 0:
        x_.index_fill_(0, mask.squeeze(), 0.0)
    return torch.mean(torch.sum(-x_*x), dim=d)
#     return x_


In [49]:
model = Transformer()

In [52]:
# init all parameters as we used deepcopy to save computation tym
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [50]:
optimizer = torch.optim.Adam(model.parameters())