In [2]:
#from collections import defaultdict

from einops import rearrange
import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

#import random
import sys

from datetime import datetime as dt
#import time

# Helpers

## functions

In [4]:
def create_masks(question, reply_input, reply_target):
    
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)         # (batch_size, 1, 1, max_words)
     
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)  # (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    reply_target_mask = reply_target!=0              # (batch_size, max_words)
    
    return question_mask, reply_input_mask, reply_target_mask

# Models

## embedding

In [14]:
class Embedder(nn.Module):
    def __init__(self, vocab_size, d_emb, seq_length, emb_drop):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_emb)
        self.d_emb = d_emb
        self.sr_d_emb = np.sqrt(d_emb)
        self.dropout = nn.Dropout(emb_drop)
        self.seq_length = seq_length
        
        self.pe = nn.Parameter(torch.rand(seq_length, d_emb))
    
    def forward(self, x):
        # make embeddings relatively larger
        x = self.embed(x) * self.sr_d_emb
        x = x + self.pe[:x.size(1)]
        x = self.dropout(x)
        return x

## transformer blocks

### Multi Head Attention

In [6]:
class Multi_Head_Attention(nn.Module):
    def __init__(self, d_emb, hidden_dim, heads, decode = False):
        super().__init__()
        self.h_dim = hidden_dim
        self.heads = heads
        self.dim_per_head = self.h_dim // self.heads
        
        self.decode = decode
        if self.decode:
            self.q = nn.Linear(d_emb, self.h_dim, bias = False)
            self.kv = nn.Linear(d_emb, self.h_dim * 2, bias = False)
        else:
            self.qkv = nn.Linear(d_emb, self.h_dim * 3, bias = False)
        
        self.unifyheads = nn.Linear(self.h_dim, d_emb)
    
    def self_attention(self, q, k, v, mask):
        scores = torch.einsum('...ij,...kj->...ik', q, k) / np.sqrt(self.dim_per_head)
        scores = scores.masked_fill(mask == 0, -float('inf'))
        scores = F.softmax(scores, dim = -1)
        return torch.einsum('...ij,...jk->...ik', scores, v)
    
    def forward(self, x, mask, y = None):
        if self.decode:
            q = self.q(y)
            kv = self.kv(x)
            k = kv[..., :self.h_dim]
            v = kv[..., self.h_dim:]
        else:
            qkv = self.qkv(x)
            q = qkv[..., :self.h_dim]
            k = qkv[..., self.h_dim:self.h_dim * 2]
            v = qkv[..., self.h_dim * 2:]
            
        q = rearrange(q, '... i (h j) -> ... h i j', h = self.heads)
        k = rearrange(k, '... i (h j) -> ... h i j', h = self.heads)
        v = rearrange(v, '... i (h j) -> ... h i j', h = self.heads)
        
        # qkv = rearrange(qkv, '... i (h j) -> ... h i j', h = self.heads * 3)
        # q, k, v = torch.split(qkv, split_size_or_sections = self.heads, dim = -3)
                
        scores = self.self_attention(q, k, v, mask)
        scores = rearrange(scores, '... h i j -> ... i (h j)').contiguous()
                      
        return self.unifyheads(scores)

### Encoder

In [3]:
class Encoder_layer(nn.Module):
    def __init__(self, d_emb, hidden_dim, hidden_mult, heads, enc_drop):
        super().__init__()
        self.dropout = nn.Dropout(enc_drop)
        
        self.mha = Multi_Head_Attention(d_emb, hidden_dim, heads)
        self.norm_1 = nn.LayerNorm(d_emb)
        self.ff = nn.Sequential(
            nn.Linear(d_emb, hidden_mult * d_emb),
            #nn.ReLU(),
            nn.GELU(),
            nn.Linear(hidden_mult * d_emb, d_emb)
        )
        self.norm_2 = nn.LayerNorm(d_emb)
        
    def forward(self, x, q_mask):
        attended = self.mha(x, q_mask)
        x = attended + x
        x = self.dropout(x)
        x = self.norm_1(x)
        fed_for = self.ff(x)
        x = fed_for + x
        x = self.dropout(x)
        x = self.norm_2(x)
        return x

### Decoder

In [4]:
class Decoder_layer(nn.Module):
    def __init__(self, d_emb, hidden_dim, hidden_mult, heads, dec_drop):
        super().__init__()
        self.dropout = nn.Dropout(dec_drop)
        
        self.mha_1 = Multi_Head_Attention(d_emb, hidden_dim, heads)
        self.norm_1 = nn.LayerNorm(d_emb)
        self.mha_2 = Multi_Head_Attention(d_emb, hidden_dim, heads, decode = True)
        self.norm_2 = nn.LayerNorm(d_emb)
        self.ff = nn.Sequential(
            nn.Linear(d_emb, hidden_mult * d_emb),
            #nn.ReLU(),
            nn.GELU(),
            nn.Linear(hidden_mult * d_emb, d_emb)
        )
        self.norm_3 = nn.LayerNorm(d_emb)
        
    def forward(self, x, q_mask, y, r_mask):
        attended = self.mha_1(y, r_mask)
        y = attended + y
        y = self.dropout(y)
        y = self.norm_1(y)
        attended = self.mha_2(x, q_mask, y)
        y = attended + y
        y = self.dropout(y)
        y = self.norm_2(y)
        fed_for = self.ff(y)
        y = fed_for + y
        y = self.dropout(y)
        y = self.norm_3(y)
        return y

## Transformer

In [7]:
class Transformer(nn.Module):
    def __init__(self, model_hp):
        super().__init__()
        self.epochs = 0
        self.losses = []
        vocab_size_in, vocab_size_out, d_emb, hidden_dim, hidden_mult, seq_length, heads_order, dropouts \
                                                    = model_hp
        
        emb_drop, enc_drop, dec_drop = dropouts
        
        self.embedder_enc = Embedder(vocab_size_in, d_emb, seq_length, emb_drop)
        self.embedder_dec = Embedder(vocab_size_out, d_emb, seq_length, emb_drop)
        self.encoder = nn.ModuleList([Encoder_layer(d_emb, hidden_dim, hidden_mult, heads, 
                                                    enc_drop) for heads in heads_order])
        self.decoder = nn.ModuleList([Decoder_layer(d_emb, hidden_dim, hidden_mult, heads, 
                                                    dec_drop) for heads in heads_order])
        self.out = nn.Sequential(
            nn.Linear(d_emb, d_emb * 2),
            #nn.ReLU(),
            nn.GELU(),
            nn.Linear(d_emb * 2, vocab_size_out),
            )
        
#         self.out = nn.DataParallel(self.out)
        
    def encode(self, x, q_mask):
        x = self.embedder_enc(x)
        for enc in self.encoder:
            x = enc(x, q_mask)
        return x
    
    def decode(self, x, q_mask, y, r_mask):
        y = self.embedder_dec(y)
        for dec in self.decoder:
            y = dec(x, q_mask, y, r_mask)
        return y
    
    def forward(self, x, q_mask, y, r_mask):
        encoded = self.encode(x, q_mask)
        decoded = self.decode(encoded, q_mask, y, r_mask)
        
        out = self.out(decoded)
        
        out = F.log_softmax(out, dim = -1)
        return out

In [10]:
class AdamWarmup:
    def __init__(self, d_emb, warmup_steps, optimizer):
        self.d_emb = d_emb
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        return self.d_emb ** (-0.5) * min(self.current_step ** (-0.5), 
                                          self.current_step * self.warmup_steps ** (-1.5))
        
    def step(self):
        # Increment the number of steps each time we call the step function
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        # update the learning rate
        self.lr = lr
        self.optimizer.step()

In [11]:
class LossWithLS(nn.Module):
    def __init__(self, vocab_size_out, smooth):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(reduction = 'none')
        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = vocab_size_out
        
    def forward(self, prediction, target, mask):
        """
        prediction of shape: (batch_size, max_words, vocab_size)
        target and mask of shape: (batch_size, max_words)
        """
        prediction = rearrange(prediction, 'i j k -> (i j) k')   # (batch_size * max_words, vocab_size)
        target = rearrange(target, 'i j -> (i j)').contiguous()   # (batch_size * max_words)
        mask = mask.view(-1).float()    # (batch_size * max_words)
        labels = torch.full(prediction.shape, self.smooth / (self.size - 1)).to(device)
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)    # (batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        return loss