In [1]:
from einops import rearrange
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# Helpers

## functions

In [2]:
def create_masks(question, reply_input, reply_target):
    
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)         # (batch_size, 1, 1, max_words)
     
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)  # (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    reply_target_mask = reply_target!=0              # (batch_size, max_words)
    
    return question_mask, reply_input_mask, reply_target_mask

# Models

## embedding

In [3]:
class Embedder(nn.Module):
    def __init__(self, vocab_size, d_emb, seq_length, emb_drop):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_emb)
        self.sr_d_emb = np.sqrt(d_emb)
        self.dropout = nn.Dropout(emb_drop)
        self.pe = nn.Parameter(torch.rand(seq_length, d_emb))
    
    def forward(self, x):
        # make embeddings relatively larger
        x = self.embed(x) * self.sr_d_emb
        x = x + self.pe[: x.size(1)]
        x = self.dropout(x)
        return x

## transformer blocks

### Multi Head Attention

In [None]:
class Multi_Head_Attention(nn.Module):
    def __init__(self, d_emb, d_hid, heads, decode = False):
        super().__init__()
        self.d_hid = d_hid
        self.heads = heads
        self.dim_per_head = self.d_hid // self.heads
        
        self.decode = decode
        if self.decode:
            self.q = nn.Linear(d_emb, self.d_hid, bias = False)
            self.kv = nn.Linear(d_emb, self.d_hid * 2, bias = False)
        else:
            self.qkv = nn.Linear(d_emb, self.d_hid * 3, bias = False)
        
        self.unifyheads = nn.Linear(self.d_hid, d_emb)
    
    def self_attention(self, q, k, v, mask):
        scores = torch.einsum('...ij,...kj->...ik', q, k) / np.sqrt(self.dim_per_head)
        scores = scores.masked_fill(mask == 0, -float('inf'))
        scores = F.softmax(scores, dim = -1)
        return torch.einsum('...ij,...jk->...ik', scores, v)
    
    def forward(self, x, mask, y = None):
        if self.decode:
            q = self.q(y)
            kv = self.kv(x)
            k = kv[..., :self.d_hid]
            v = kv[..., self.d_hid:]
        else:
            qkv = self.qkv(x)
            q = qkv[..., :self.d_hid]
            k = qkv[..., self.d_hid : self.d_hid * 2]
            v = qkv[..., self.d_hid * 2 :]
            
        q = rearrange(q, '... i (h j) -> ... h i j', h = self.heads)
        k = rearrange(k, '... i (h j) -> ... h i j', h = self.heads)
        v = rearrange(v, '... i (h j) -> ... h i j', h = self.heads)
                
        scores = self.self_attention(q, k, v, mask)
        scores = rearrange(scores, '... h i j -> ... i (h j)').contiguous()
                      
        return self.unifyheads(scores)

### Gated Linear Unit

In [12]:
class GLU(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.out_size = out_size
        self.linear = nn.Linear(in_size, out_size * 2)
    def forward(self, x):
        x = self.linear(x)
        #x = x[..., : self.out_size] * x[..., self.out_size :].sigmoid()
        x = torch.einsum('...i, ...i->...i', [x[..., : self.out_size], x[..., self.out_size :].sigmoid()])

### Encoder

In [25]:
class Encoder_layer(nn.Module):
    def __init__(self, d_emb, d_hid, hidden_mult, heads, enc_drop):
        super().__init__()
        self.dropout = nn.Dropout(enc_drop)
        
        self.mha = Multi_Head_Attention(d_emb, d_hid, heads)
        self.norm_1 = nn.LayerNorm(d_emb)
        self.ff = nn.Sequential(
            nn.Linear(d_emb, hidden_mult * d_emb),
            #nn.PReLU(),
            #nn.GELU(),
            nn.LeakyReLU(),
            nn.Linear(hidden_mult * d_emb, d_emb)
        )
        #self.ff = GLU(d_emb, d_emb)

        self.norm_2 = nn.LayerNorm(d_emb)
        
    def forward(self, x, q_mask):
        attended = self.mha(x, q_mask)
        x = attended + x
        x = self.dropout(x)
        x = self.norm_1(x)
        fed_for = self.ff(x)
        x = fed_for + x
        x = self.dropout(x)
        x = self.norm_2(x)
        return x

### Decoder

In [26]:
class Decoder_layer(nn.Module):
    def __init__(self, d_emb, d_hid, hidden_mult, heads, dec_drop):
        super().__init__()
        self.dropout = nn.Dropout(dec_drop)
        
        self.mha_1 = Multi_Head_Attention(d_emb, d_hid, heads)
        self.norm_1 = nn.LayerNorm(d_emb)
        self.mha_2 = Multi_Head_Attention(d_emb, d_hid, heads, decode = True)
        self.norm_2 = nn.LayerNorm(d_emb)
        self.ff = nn.Sequential(
            nn.Linear(d_emb, hidden_mult * d_emb),
            #nn.PReLU(),
            #nn.GELU(),
            nn.LeakyReLU(),
            nn.Linear(hidden_mult * d_emb, d_emb)
        )
        #self.ff = GLU(d_emb, d_emb)
    
        self.norm_3 = nn.LayerNorm(d_emb)
        
    def forward(self, x, q_mask, y, r_mask):
        attended = self.mha_1(y, r_mask)
        y = attended + y
        y = self.dropout(y)
        y = self.norm_1(y)
        attended = self.mha_2(x, q_mask, y)
        y = attended + y
        y = self.dropout(y)
        y = self.norm_2(y)
        fed_for = self.ff(y)
        y = fed_for + y
        y = self.dropout(y)
        y = self.norm_3(y)
        return y

## Transformer

In [27]:
class Transformer(nn.Module):
    def __init__(self, model_hp):
        super().__init__()
        self.epochs = 0
        self.losses = []
        vocab_size_in, vocab_size_out, d_emb, hidden_mult, seq_length, order, dropouts, parallel = model_hp
        emb_drop, enc_drop, dec_drop = dropouts
        self.flag_parallel = parallel
        self.deep = len(order)
        
        if vocab_size_in == vocab_size_out:
            self.flag_voc_same = True
        else:
            self.flag_voc_same = False
        
        if self.flag_voc_same:
            self.embedder = Embedder(vocab_size_in, d_emb, seq_length, emb_drop)
        else:
            self.embedder_enc = Embedder(vocab_size_in, d_emb, seq_length, emb_drop)
            self.embedder_dec = Embedder(vocab_size_out, d_emb, seq_length, emb_drop)
            
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        for d_hid, heads in order:
            self.encoder.append(Encoder_layer(d_emb, d_hid, hidden_mult, heads, enc_drop))
            self.decoder.append(Decoder_layer(d_emb, d_hid, hidden_mult, heads, dec_drop))
        
        self.out = nn.Sequential(
            nn.Linear(d_emb, d_emb * 2),
            #nn.PReLU(),
            #nn.GELU(),
            nn.LeakyReLU(),
            #GLU(d_emb, d_emb),
            nn.Linear(d_emb * 2, vocab_size_out)
            )
        
#         self.out = nn.DataParallel(self.out)
    
    def encode(self, x, q_mask):
        if self.flag_voc_same:
            x = self.embedder(x)
        else:
            x = self.embedder_enc(x)
        for enc in self.encoder:
            x = enc(x, q_mask)
        return x
    
    def decode(self, x, q_mask, y, r_mask):
        if self.flag_voc_same:
            y = self.embedder(y)
        else:
            y = self.embedder_dec(y)
        for dec in self.decoder:
            y = dec(x, q_mask, y, r_mask)
        return y
    
    def encode_decode(self, x, q_mask, y, r_mask):
        if self.flag_voc_same:
            x = self.embedder(x)
            y = self.embedder(y)
        else:
            x = self.embedder_enc(x)
            y = self.embedder_dec(y)
        for i in range(self.deep):
            x = self.encoder[i](x, q_mask)
            y = self.decoder[i](x, q_mask, y, r_mask)
        return y
    
    def forward(self, x, q_mask, y, r_mask):
        if self.flag_parallel:
            decoded = self.encode_decode(x, q_mask, y, r_mask)
        else:
            encoded = self.encode(x, q_mask)
            decoded = self.decode(encoded, q_mask, y, r_mask)
        out = self.out(decoded)
        out = F.log_softmax(out, dim = -1)
        return out

## Optimizer`

In [8]:
class AdamWarmup:
    def __init__(self, d_emb, warmup_steps, optimizer):
        self.d_emb = d_emb
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        return self.d_emb ** (-0.5) * min(self.current_step ** (-0.5), 
                                          self.current_step * self.warmup_steps ** (-1.5))
        
    def step(self):
        # Increment the number of steps each time we call the step function
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        # update the learning rate
        self.lr = lr
        self.optimizer.step()

## Loss function

In [9]:
class LossWithLS(nn.Module):
    def __init__(self, vocab_size_out, smooth):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(reduction = 'none')
        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = vocab_size_out
        
    def forward(self, prediction, target, mask):
        """
        prediction of shape: (batch_size, max_words, vocab_size)
        target and mask of shape: (batch_size, max_words)
        """
        prediction = rearrange(prediction, 'i j k -> (i j) k')   # (batch_size * max_words, vocab_size)
        target = rearrange(target, 'i j -> (i j)').contiguous()   # (batch_size * max_words)
        mask = mask.view(-1).float()    # (batch_size * max_words)
        labels = torch.full(prediction.shape, self.smooth / (self.size - 1)).to(device)
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)    # (batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        return loss