In [1]:
import torch
from torch import nn
import numpy as np
import math
import time
import copy


In [2]:
# Arch units
## Self Attention Unit
## Multi Head Attention
## Encode Decode Unit
## Norm + Residual Layer
## Feed Forward
## Input Positional Encoding 


In [49]:
# parameters from paper
N = 2 # 6
d_model = 512
h = 8
d_k = d_v = d_model//h
d_ff = 2048 #128
vocab_size = 11


In [50]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        
    def forward(self, Q, K, V, mask):
        b, q_seq, _ = Q.size()
        b, k_seq, _ = K.size()
        query = self.W_Q(Q).view(b, q_seq, h, d_k) # view (b, q_seq, h, d)
        key = self.W_K(K).view(b, k_seq, h, d_k) # view (b, k_seq, h, d)
        value  = self.W_V(V).view(b, k_seq, h, d_k) # view (b, k_seq, h, d)
        
        query = query.transpose(1, 2).contiguous() # view (b, h, q_seq, d)
        key = key.transpose(1, 2).contiguous() # view (b, h, k_seq, d)
        value = value.transpose(1, 2).contiguous() # view (b, h, k_seq, d)
        
        qk = query.matmul(key.transpose(-2,-1))
        scale_qk = qk/(math.sqrt(d_k)) # shape (b, h, q_seq, k_seq)
        
        if mask is not None: # mask size (b, 1, k_seq)
            mask = mask.unsqueeze(1) # mask size (b, 1, 1, k_seq)
            scale_qk = scale_qk.masked_fill(mask==0, 1e-9)
        
        softmax_qk = nn.functional.softmax(scale_qk, dim=-1) # (b, h, q_seq, k_seq)
        weighted_value = softmax_qk.matmul(value) # (b, h, q_seq, d)
        return self.W_O(weighted_value.transpose(2,1).contiguous().view(b, q_seq, h*d_k)) # (b, h, d_model)


In [51]:
class LayerNorm(nn.Module):
    def __init__(self, d_mod=d_model):
        super(LayerNorm, self).__init__()
        self.d_mod = d_mod
        # https://stackoverflow.com/questions/39095252/fail-to-implement-layer-normalization-with-keras
        # https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter
        self.alpha = nn.Parameter(torch.ones(d_mod))
        self.beta = nn.Parameter(torch.zeros(d_mod))
    def forward(self, x, eps=1e-6):
        u = x.mean(-1, keepdim=True)
        sigma = x.std(-1, keepdim=True)
        return self.alpha * (x - u)/(1/(sigma + eps)) + self.beta
    

In [169]:
class EncoderCell(nn.Module):
    def __init__(self):
        super(EncoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_2 = LayerNorm()
    
    def forward(self, x, src_mask=None):
        print("x ", x)
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, src_mask))  # Layer 1
        print("norm_1", x_norm_1)
#         return self.norm_2(x_norm_1 + self.pff(x_norm_1)) # Layer 2
        alpha = self.norm_2(x_norm_1 + self.pff(x_norm_1))
        print("alpha", alpha)
        return alpha

In [170]:
class DecoderCell(nn.Module):
    def __init__(self):
        super(DecoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.attn = MultiHeadAttention()
        self.norm_2 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_3 = LayerNorm()
        
    def forward(self, x, enc, src_mask=None, trg_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, trg_mask))
        x_norm_2 = self.norm_2(x_norm_1 + self.attn(x_norm_1, enc, enc, src_mask))
        return self.norm_3(x_norm_2 + self.pff(x_norm_2)) # (b, seq, d_model)


In [171]:
class EncoderStack(nn.Module):
    def __init__(self, N):
        super(EncoderStack, self).__init__()
        self.N = N
        
    def forward(self, x, src_mask):
        cell = EncoderCell()
        encoders = nn.ModuleList([copy.deepcopy(cell) for _ in range(self.N)])
        
        for enc in encoders:
            x = enc(x, src_mask)
        return x
    

In [172]:
class DecoderStack(nn.Module):
    def __init__(self, N):
        super(DecoderStack, self).__init__()
        self.N = N
    
    def forward(self, x, enc, src_mask, trg_mask):
        cell = DecoderCell()
        decoders = nn.ModuleList([copy.deepcopy(cell) for _ in range(self.N)])
        for decdr in decoders:
            x = decdr(x, enc, src_mask, trg_mask)
        return x
    

In [173]:
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(EmbeddingLayer, self).__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
        
    def forward(self, x):
        return self.embedding(x) * math.sqrt(d_model)


In [174]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dpout=0.1, max_seq=50):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dpout)
        
        pe_matx = torch.zeros(max_seq, d_model)
        position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(-1)
        w_t = torch.exp(torch.arange(0, d_model, 2).float() * -math.log(10000)/d_model)
        val = position * w_t
        pe_matx[:, 0::2] = torch.sin(val)
        pe_matx[:, 1::2] = torch.cos(val)
        pe_matx = pe_matx.unsqueeze(1)
        self.register_buffer("pe_matx", pe_matx)
        
    def forward(self, x):
        # x - (batch, seq, emb), pe_matrix - (max_seq, 1, d_model)
        x += self.pe_matx[:x.size(0), :]
        return(self.dropout(x))
    

In [175]:
class Transformer(nn.Module):
    def __init__(self, embedd = True, log_softmx=True):
        super(Transformer, self).__init__()
#         self.W_in = nn.Linear(word_emb_dim, d_model)
        self.encoder = EncoderStack(N)
        self.decoder = DecoderStack(N)
        # https://stats.stackexchange.com/questions/392213/understand-the-output-layer-of-transformer
        self.W_out = nn.Linear(d_model, vocab_size)
        self.sftmx = log_softmx
        self.embedd = embedd
        if self.embedd:
            embed_x = EmbeddingLayer(vocab_size, d_model)
            embed_y = EmbeddingLayer(vocab_size, d_model)
            pe_x = PositionalEncoding(d_model)
            pe_y = copy.deepcopy(pe_x)
            self.enc_x = nn.Sequential(embed_x, pe_x)
            self.enc_y = nn.Sequential(embed_y, pe_y)
        
    def forward(self, inp_x, inp_y, src_mask, trg_mask):
        if self.embedd:
            inp_x, inp_y = self.enc_x(inp_x), self.enc_y(inp_y)
        enc_x = self.encoder(inp_x, src_mask)
        dec_x = self.decoder(inp_y, enc_x, src_mask, trg_mask)
        if self.sftmx:
            return nn.functional.log_softmax(self.W_out(dec_x), dim=-1)
        return self.W_out(dec_x)
        

In [176]:
# https://www.reddit.com/r/MachineLearning/comments/bjgpt2/d_confused_about_using_masking_in_transformer/

In [177]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html
class Batch:
    def __init__(self, src, trg=None, pad=0): # size src, trg (b, seq)
        self.src = src 
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:,:-1] # size (b,0:seq-1)
            self.trg_y = trg[:,1:] # size (b,1:seq)
            self.trg_mask = self.std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() # size (1)
    
    @staticmethod
    def std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2) # size (b, 1, seq)
        mask = torch.from_numpy(np.triu(np.ones((1,tgt.shape[-1],tgt.shape[-1])), k=1).astype('uint8')) == 0
        return tgt_mask & mask # size (b, 1, seq) * (1, seq, seq) -> (b, seq, seq) 
    
#     @staticmethod
#     def subsequent_mask(size):
#         return torch.from_numpy(np.triu(np.ones((1,size,size)), k=1).astype('uint8')) == 0 # size (1, seq, seq)


In [178]:
# https://github.com/pytorch/pytorch/issues/7455    
def labelSmoothingLoss(x, y, epsilon, padding_value=0, cls=1, d=-1):
    # concat x, y batch as index_fill_ don't support vector dim > 1
#     x = x.view(-1, x.size(-1))    
    x=x.contiguous().view(-1, x.size(-1))
    y=y.contiguous().view(-1)
    
    x_ = x.data.clone()
    x_.fill_(epsilon / (x_.size(-1) - cls))
    x_.scatter_(d, y.data.unsqueeze(-1), (1 - epsilon))
    x_[:, padding_value] = 0
    mask = torch.nonzero(y.data == padding_value)
    if mask.dim() > 0:
        x_.index_fill_(0, mask.squeeze(), 0.0)
    return torch.mean(torch.sum(-x_*x), dim=d) # x_ is true distribution and x is prediction


In [179]:
model = Transformer()

In [180]:
# init all parameters as we used deepcopy to save computation tym
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
        

In [181]:
optimizer = torch.optim.Adam(model.parameters())


In [182]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html#synthetic-data
def data_generation(V, batch, nbatches):
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch, 10))
        data[:, 0] = 1
        src = data.clone().detach()
        trg = data.clone().detach()
        yield Batch(src, trg, 0)


In [183]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html#synthetic-data
def run_epoch(data_itr, model, optimizer):
    start = time.time()
    total_token = 0
    total_loss = 0
    tokens = 0
    
    for i, batch in enumerate(data_itr):
        optimizer.zero_grad()
        
        outp = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = labelSmoothingLoss(outp, batch.trg_y, batch.ntokens)
        
        break
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss+=loss
        total_token+=batch.ntokens
        tokens+=batch.ntokens
        
        if i%30 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
#     return total_loss/total_token


In [184]:
run_epoch(data_generation(vocab_size, 30, 20), model, optimizer)


x  tensor([[[-0.0000,  1.2420,  2.6834,  ...,  2.0979, -2.0528, -1.1805],
         [ 2.1321,  0.8550,  1.8131,  ...,  1.8008,  1.2272, -0.5109],
         [ 2.6156,  3.6029, -1.1444,  ..., -0.8462,  0.3334,  3.3015],
         ...,
         [ 1.7315, -0.0000,  0.0000,  ..., -0.8977, -1.1905,  2.2879],
         [-1.0748, -0.3836,  2.5188,  ...,  2.3092,  0.4682, -1.4502],
         [ 0.0000, -0.1461,  1.0798,  ..., -0.8977, -1.1905,  2.2879]],

        [[ 0.1791,  0.7312,  3.5966,  ...,  2.0979, -2.0527, -0.0000],
         [-0.1398, -0.8943,  0.0000,  ...,  2.3092,  0.0000, -0.0000],
         [-0.1398, -0.8943,  3.4320,  ...,  2.3092,  0.4683, -1.4502],
         ...,
         [ 3.0671,  0.3442,  2.7263,  ...,  1.8008,  1.2273, -0.5109],
         [ 3.4505, -1.7813,  1.1816,  ..., -1.0663,  1.4006,  2.9691],
         [ 3.0671,  0.3442,  2.7263,  ...,  0.0000,  1.2273, -0.5109]],

        [[ 0.2544, -0.0000,  3.7239,  ...,  2.0979, -2.0525, -1.1805],
         [ 3.5259, -0.0000,  1.3089,  ...,

In [164]:
mah = MultiHeadAttention()

In [166]:
q = torch.rand(5, 10, d_model)
k = torch.rand(5, 10, d_model)
v = torch.rand(5, 10, d_model)

In [168]:
mah(q,k,v,None)

tensor([[[-0.3094,  0.0192,  0.1885,  ...,  0.2809,  0.1203, -0.1954],
         [-0.3081,  0.0196,  0.1878,  ...,  0.2825,  0.1199, -0.1942],
         [-0.3095,  0.0203,  0.1880,  ...,  0.2813,  0.1213, -0.1940],
         ...,
         [-0.3083,  0.0193,  0.1872,  ...,  0.2823,  0.1200, -0.1944],
         [-0.3101,  0.0210,  0.1881,  ...,  0.2816,  0.1207, -0.1940],
         [-0.3092,  0.0209,  0.1863,  ...,  0.2812,  0.1212, -0.1941]],

        [[-0.4010,  0.0192,  0.1289,  ...,  0.2678,  0.1162, -0.1459],
         [-0.4009,  0.0195,  0.1298,  ...,  0.2670,  0.1171, -0.1450],
         [-0.4010,  0.0190,  0.1286,  ...,  0.2680,  0.1162, -0.1451],
         ...,
         [-0.3988,  0.0197,  0.1286,  ...,  0.2680,  0.1168, -0.1469],
         [-0.3995,  0.0194,  0.1286,  ...,  0.2676,  0.1172, -0.1450],
         [-0.4004,  0.0189,  0.1288,  ...,  0.2686,  0.1178, -0.1457]],

        [[-0.3997,  0.0504,  0.1788,  ...,  0.1710,  0.1605, -0.1597],
         [-0.3989,  0.0492,  0.1800,  ...,  0