In [1]:
import torch
from torch import nn
import numpy as np
import math
import time
import copy


In [2]:
# Arch units
## Self Attention Unit
## Multi Head Attention
## Encode Decode Unit
## Norm + Residual Layer
## Feed Forward
## Input Positional Encoding 


In [3]:
# parameters from paper
N = 2 # 6
d_model = 512
h = 8
d_k = d_v = d_model//h
d_ff = 2048 #128
vocab_size = 11


In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, Q, K, V, mask, dropout=None):
        b, q_seq, _ = Q.size()
        b, k_seq, _ = K.size()
        query = self.W_Q(Q).view(b, q_seq, h, d_k) # view (b, q_seq, h, d)
        key = self.W_K(K).view(b, k_seq, h, d_k) # view (b, k_seq, h, d)
        value  = self.W_V(V).view(b, k_seq, h, d_k) # view (b, k_seq, h, d)
        
        query = query.transpose(1, 2).contiguous() # view (b, h, q_seq, d)
        key = key.transpose(1, 2).contiguous() # view (b, h, k_seq, d)
        value = value.transpose(1, 2).contiguous() # view (b, h, k_seq, d)
        
        qk = query.matmul(key.transpose(-2,-1))
        scale_qk = qk/(math.sqrt(d_k)) # shape (b, h, q_seq, k_seq)
        
        if mask is not None: # mask size (b, 1, k_seq)
            mask = mask.unsqueeze(1) # mask size (b, 1, 1, k_seq)
            scale_qk = scale_qk.masked_fill(mask==0, 1e-9)
        
        softmax_qk = nn.functional.softmax(scale_qk, dim=-1) # (b, h, q_seq, k_seq)
        if dropout is not None:
            softmax_qk = self.dropout(softmax_qk)
        weighted_value = softmax_qk.matmul(value) # (b, h, q_seq, d)
        return self.W_O(weighted_value.transpose(2,1).contiguous().view(b, q_seq, h*d_k)) # (b, h, d_model)


In [5]:
class LayerNorm(nn.Module):
    def __init__(self, d_mod=d_model):
        super(LayerNorm, self).__init__()
        self.d_mod = d_mod
        # https://stackoverflow.com/questions/39095252/fail-to-implement-layer-normalization-with-keras
        # https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter
        self.alpha = nn.Parameter(torch.ones(d_mod))
        self.beta = nn.Parameter(torch.zeros(d_mod))
    def forward(self, x, eps=1e-6):
        u = x.mean(-1, keepdim=True)
        sigma = x.std(-1, keepdim=True)
        return self.alpha * (x - u)/(sigma + eps) + self.beta
    

In [6]:
class EncoderCell(nn.Module):
    def __init__(self, dropout=0.1, Adropout=0.1):
        super(EncoderCell, self).__init__()
        self.attn = MultiHeadAttention(Adropout)
        self.norm_1 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_ff, d_model))
        self.norm_2 = LayerNorm()
    
    def forward(self, x, src_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, src_mask))  # Layer 1
        return self.norm_2(x_norm_1 + self.pff(x_norm_1)) # Layer 2


In [7]:
class DecoderCell(nn.Module):
    def __init__(self, dropout=0.1, Ddropout=0.1):
        super(DecoderCell, self).__init__()
        self.attn = MultiHeadAttention(Ddropout)
        self.norm_1 = LayerNorm()
        self.attn = MultiHeadAttention(Ddropout)
        self.norm_2 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_ff, d_model))
        self.norm_3 = LayerNorm()
        
    def forward(self, x, enc, src_mask=None, trg_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, trg_mask))
        x_norm_2 = self.norm_2(x_norm_1 + self.attn(x_norm_1, enc, enc, src_mask))
        return self.norm_3(x_norm_2 + self.pff(x_norm_2)) # (b, seq, d_model)


In [8]:
class EncoderStack(nn.Module):
    def __init__(self, N, Edropout=0.1, Adropout=0.1):
        super(EncoderStack, self).__init__()
        self.N = N
        self.Edropout = Edropout
        self.Adropout = Adropout
        self.encoders = nn.ModuleList([EncoderCell(self.Edropout, self.Adropout) \
                                       for _ in range(self.N)])
 
    def forward(self, x, src_mask):
        for enc in self.encoders:
            x = enc(x, src_mask)
        return x


In [9]:
class DecoderStack(nn.Module):
    def __init__(self, N, Ddropout=0.1, Adropout=0.1):
        super(DecoderStack, self).__init__()
        self.N = N
        self.Ddropout = Ddropout
        self.Adropout = Adropout
        self.decoders = nn.ModuleList([DecoderCell(self.Ddropout, self.Adropout) \
                                       for _ in range(self.N)])
        
    def forward(self, x, enc, src_mask, trg_mask):
        for decdr in self.decoders:
            x = decdr(x, enc, src_mask, trg_mask)
        return x


In [10]:
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(EmbeddingLayer, self).__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
        
    def forward(self, x):
        return self.embedding(x) * math.sqrt(d_model)


In [11]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dpout=0.1, max_seq=50):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dpout)
        
        pe_matx = torch.zeros(max_seq, d_model, requires_grad=False)
        position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(-1)
        w_t = torch.exp(torch.arange(0, d_model, 2).float() * -math.log(10000)/d_model)
        val = position * w_t
        pe_matx[:, 0::2] = torch.sin(val)
        pe_matx[:, 1::2] = torch.cos(val)
        pe_matx = pe_matx.unsqueeze(1)
        self.register_buffer("pe_matx", pe_matx)
        
    def forward(self, x):
        # x - (batch, seq, emb), pe_matrix - (max_seq, 1, d_model)
        x += self.pe_matx[:x.size(0), :]
        return(self.dropout(x))
    

In [12]:
class Transformer(nn.Module):
    def __init__(self, embedd = True, dropout=0.1):
        super(Transformer, self).__init__()
#         self.W_in = nn.Linear(word_emb_dim, d_model)
        self.encoderStack = EncoderStack(N, Adropout=dropout)
        self.decoderStack = DecoderStack(N, Adropout=dropout)
        # https://stats.stackexchange.com/questions/392213/understand-the-output-layer-of-transformer
        self.W_out = nn.Linear(d_model, vocab_size)
        self.embedd = embedd
        if self.embedd:
            embed_x = EmbeddingLayer(vocab_size, d_model)
            embed_y = EmbeddingLayer(vocab_size, d_model)
            pe_x = PositionalEncoding(d_model)
            pe_y = copy.deepcopy(pe_x)
            self.enc_x = nn.Sequential(embed_x, pe_x)
            self.enc_y = nn.Sequential(embed_y, pe_y)
        
    def forward(self, inp_x, inp_y, src_mask, trg_mask, sftmx=True):
        
        dec_x = self.decoder(inp_y, self.encoder(inp_x, src_mask), 
                             src_mask, trg_mask)
        
        if sftmx:
            return nn.functional.log_softmax(self.W_out(dec_x), dim=-1)
        return self.W_out(dec_x)
    
    def encoder(self, inp_x, src_mask):
        if self.embedd:
            inp_x = self.enc_x(inp_x)
        return self.encoderStack(inp_x, src_mask)
    
    def decoder(self, inp_y, enc_x, src_mask, trg_mask):
        if self.embedd:
            inp_y = self.enc_y(inp_y)
        return self.decoderStack(inp_y, enc_x, src_mask, trg_mask)


In [13]:
# https://www.reddit.com/r/MachineLearning/comments/bjgpt2/d_confused_about_using_masking_in_transformer/

In [14]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html
class Batch:
    def __init__(self, src, trg=None, pad=0): # size src, trg (b, seq)
        self.src = src 
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:,:-1] # size (b,0:seq-1)
            self.trg_y = trg[:,1:] # size (b,1:seq)
            self.trg_mask = self.std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() # size (1)
    
    @staticmethod
    def std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2) # size (b, 1, seq)
        mask = torch.from_numpy(np.triu(np.ones((1,tgt.shape[-1],tgt.shape[-1])), k=1).astype('uint8')) == 0
        return tgt_mask & mask # size (b, 1, seq) * (1, seq, seq) -> (b, seq, seq) 
    


In [15]:
# https://github.com/pytorch/pytorch/issues/7455    
kldivLoss = nn.KLDivLoss(size_average=False)
# CELoss = nn.CrossEntropyLoss()

def labelSmoothingLoss(x, y, epsilon=0.0, padding_value=0, cls=2, d=-1):
    # concat x, y batch as index_fill_ don't support vector dim > 1
#     x = x.view(-1, x.size(-1))    
    x=x.contiguous().view(-1, x.size(-1))
    y=y.contiguous().view(-1)
    
    x_ = x.data.clone()
    x_.fill_(epsilon / (x_.size(-1) - cls))
    x_.scatter_(d, y.data.unsqueeze(-1), (1 - epsilon))
    x_[:, padding_value] = 0
    mask = torch.nonzero(y.data == padding_value)
    if mask.dim() > 0:
        x_.index_fill_(0, mask.squeeze(), 0.0)
#     print(x, x_)
#     return CELoss(x, x_, )
#     return torch.mean(torch.sum(-x_*x), dim=d) # x_ is true distribution and x is prediction
    return kldivLoss(x, copy.deepcopy(x_))




In [16]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html#synthetic-data
def data_generation(V, batch, nbatches):
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch, 10))
        data[:, 0] = 1
        src = data.clone().detach()
        trg = data.clone().detach()
        yield Batch(src, trg, 0)


In [17]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html#synthetic-data
def run_epoch(data_itr, model, opt):
    start = time.time()
    total_token = 0
    total_loss = 0
    tokens = 0
    
    for i, batch in enumerate(data_itr):
        opt.optimizer.zero_grad()
        
        outp = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = labelSmoothingLoss(outp, batch.trg_y)
#         loss = loss/batch.ntokens
        loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
#         loss.data = loss.data * batch.ntokens

        opt.step()
        
        total_loss+=loss
        total_token+=batch.ntokens
        tokens+=batch.ntokens
        
        if i%50 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss/total_token


In [18]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))


In [19]:
model = Transformer()


In [20]:
# init all parameters as we used deepcopy to save computation tym
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)


In [21]:
# optimizer = torch.optim.Adam(model.parameters())
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

# https://nlp.seas.harvard.edu/2018/04/03/attention.html
model_opt = NoamOpt(d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))


In [22]:
for epoch in range(10):
    model.train()
    run_epoch(data_generation(vocab_size, 30, 20), model, model_opt)
    model.eval()
    print(run_epoch(data_generation(vocab_size, 30, 5), model, model_opt).data)
    
#     scheduler.step()
    

Epoch Step: 1 Loss: 3.009474 Tokens per Sec: 716.724121
Epoch Step: 1 Loss: 1.846428 Tokens per Sec: 808.009216
tensor(1.8439)
Epoch Step: 1 Loss: 1.788058 Tokens per Sec: 755.007202
Epoch Step: 1 Loss: 1.551715 Tokens per Sec: 807.668640
tensor(1.5134)
Epoch Step: 1 Loss: 1.517454 Tokens per Sec: 729.454773
Epoch Step: 1 Loss: 1.463484 Tokens per Sec: 810.447144
tensor(1.4462)
Epoch Step: 1 Loss: 1.507231 Tokens per Sec: 758.847900
Epoch Step: 1 Loss: 0.941215 Tokens per Sec: 797.798157
tensor(0.9294)
Epoch Step: 1 Loss: 0.945399 Tokens per Sec: 694.597107
Epoch Step: 1 Loss: 0.927456 Tokens per Sec: 805.195007
tensor(0.8918)
Epoch Step: 1 Loss: 0.788258 Tokens per Sec: 752.071228
Epoch Step: 1 Loss: 0.821191 Tokens per Sec: 764.545532
tensor(0.8560)
Epoch Step: 1 Loss: 0.765289 Tokens per Sec: 614.856995
Epoch Step: 1 Loss: 0.754950 Tokens per Sec: 803.661316
tensor(0.6905)
Epoch Step: 1 Loss: 0.657492 Tokens per Sec: 759.692383
Epoch Step: 1 Loss: 0.782052 Tokens per Sec: 802.910645

In [None]:
model

In [None]:
# run_epoch(data_generation(vocab_size, 30, 20), model, optimizer)


In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

model.eval()
src = torch.tensor(torch.LongTensor([[1,2,3,4,5,6,7,8,9,10]]) )
src_mask = torch.tensor(torch.ones(1, 1, 10) )
print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))