In [1]:
import torch
from torch import nn
import numpy as np
import math


In [2]:
# Arch units
## Self Attention Unit
## Multi Head Attention
## Encode Decode Unit
## Norm + Residual Layer
## Feed Forward
## Input Positional Encoding 


In [3]:
# parameters from paper
word_emb_dim = 50
N = 6
d_model = 32 # 512
h = 4 # 8
d_k = d_v = d_model//h
d_ff = 128 # 2048
vocab_size = 100


In [4]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        
    def forward(self, Q, K, V, mask=None):
        assert Q.shape[-1] == K.shape[-1] == V.shape[-1]
        qk = torch.bmm(Q, K.transpose(-2, -1)) # Q & K size (b, seq, d_k)
        scale_qk = qk/math.sqrt(d_k) # size (b, seq, seq)
        if mask:
            scale_qk = scale_qk.masked_fill(mask==0, 1e-9) # where mask is True replaced with 0
        softmax_qk = nn.functional.softmax(scale_qk, dim=-1) # size (b, seq, seq)
        return torch.bmm(softmax_qk, V) # size (b, seq, d_k)


In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k)
        self.W_K = nn.Linear(d_model, d_k)
        self.W_V = nn.Linear(d_model, d_k)
        self.attn = Attention()
        self.W_O = nn.Linear(h*d_v, d_model)
        
    def forward(self, Q, K, V, mask=None):
        head = None
        for _ in range(h):
            attn_head = self.attn(self.W_Q(Q), self.W_K(K), self.W_V(V), mask=None)
            if head != None:
                head = torch.cat((head, attn_head), dim=-1)
            else:
                head = attn_head
        # head size (b, seq, d_k*h)
        return self.W_O(head) # size (b, seq, d_model)


In [6]:
class LayerNorm(nn.Module):
    def __init__(self, d_mod=d_model):
        super(LayerNorm, self).__init__()
        self.d_mod = d_mod
        # https://stackoverflow.com/questions/39095252/fail-to-implement-layer-normalization-with-keras
        # https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter
        self.alpha = nn.Parameter(torch.ones(d_mod))
        self.beta = nn.Parameter(torch.zeros(d_mod))
    def forward(self, x, eps=1e-6):
        u = x.mean(-1, keepdim=True)
        sigma = x.std(-1, keepdim=True)
        return self.alpha * (x - u)/(1/(sigma + eps)) + self.beta
    

In [7]:
class EncoderCell(nn.Module):
    def __init__(self):
        super(EncoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_2 = LayerNorm()
    
    def forward(self, x):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x))  # Layer 1
        return self.norm_2(x_norm_1 + self.pff(x_norm_1)) # Layer 2
    

In [8]:
class DecoderCell(nn.Module):
    def __init__(self):
        super(DecoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.attn = MultiHeadAttention()
        self.norm_2 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_3 = LayerNorm()
        
    def forward(self, x, enc, src_mask=None, trg_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, trg_mask))
        x_norm_2 = self.norm_2(x_norm_1 + self.attn(x_norm_1, enc, enc, src_mask))
        return self.norm_3(x_norm_2 + self.pff(x_norm_2)) # (b, seq, d_model)


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self):
        pass
    

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self):
        pass

In [9]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.W_in = nn.Linear(word_emb_dim, d_model)
        # based on N values
        self.encoder_unit = EncoderCell()
        self.decoder_unit = DecoderCell()
        # https://stats.stackexchange.com/questions/392213/understand-the-output-layer-of-transformer
        self.W_out = nn.Linear(d_model, vocab_size)
        
    def forward(self, inp_x, inp_y):
        inp_x, inp_y = inp_x/math.sqrt(d_model), inp_y/math.sqrt(d_model)
        inp_x, inp_y = self.W_in(inp_x), self.W_in(inp_y) # (b, seq, word_embedding) -> (b, seq, d_model)
        enc_x = self.encoder_unit(inp_x)
        dec_x = self.decoder_unit(inp_y, enc_x)
        return self.W_out(dec_x)


In [242]:
class EmbeddingLayer(nn.Module):
    def __init__(self):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, word_emb_dim)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(d_model)

In [None]:
# https://www.reddit.com/r/MachineLearning/comments/bjgpt2/d_confused_about_using_masking_in_transformer/

In [250]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html
class Batch:
    def __init__(self, src, trg=None, pad=0): # size src, trg (b, seq)
        self.src = src 
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:,:-1] # size (b,0:seq-1)
            self.trg_y = trg[:,1:] # size (b,1:seq)
            self.trg_mask = self.std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() # size (1)
    
    @staticmethod
    def std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2) # size (b, 1, seq)
        return tgt_mask & subsequent_mask(tgt.shape[-1]) # size (b, 1, seq) * (1, seq, seq) -> (b, seq, seq) 
    
    @staticmethod
    def subsequent_mask(size):
        return torch.from_numpy(np.triu(np.ones((1,size,size)), k=1).astype('uint8')) == 0 # size (1, seq, seq)

In [263]:
attn.shape

torch.Size([3, 5, 5])

In [267]:
attn.masked_fill(b.src_mask==0, 1e-9)

tensor([[[2.2052e-01, 8.3386e-01, 2.0095e-01, 1.0000e-09, 3.4726e-02],
         [4.5698e-01, 9.6899e-01, 5.4661e-01, 1.0000e-09, 5.8170e-01],
         [7.3435e-01, 4.7575e-01, 6.9755e-01, 1.0000e-09, 6.0524e-01],
         [8.3406e-01, 4.3392e-01, 7.2352e-01, 1.0000e-09, 4.2996e-02],
         [3.6918e-01, 3.2248e-01, 6.3152e-01, 1.0000e-09, 2.7349e-01]],

        [[8.4413e-01, 7.4469e-01, 8.5517e-01, 2.1864e-01, 3.6536e-01],
         [7.1498e-02, 3.2991e-01, 6.4716e-01, 2.0532e-01, 9.3745e-01],
         [9.9492e-01, 6.3234e-01, 5.1108e-01, 6.3176e-01, 4.2131e-01],
         [4.9687e-01, 2.7079e-01, 7.6468e-01, 4.8283e-01, 2.2096e-01],
         [5.8373e-01, 5.5180e-01, 5.6868e-03, 6.8920e-01, 5.6148e-02]],

        [[3.7395e-02, 9.0728e-01, 1.7958e-01, 2.0664e-01, 2.6232e-01],
         [4.3243e-01, 5.1016e-01, 4.2045e-01, 2.6085e-01, 2.8659e-02],
         [3.8285e-01, 3.2779e-01, 7.3841e-01, 5.9082e-01, 2.3451e-01],
         [9.3807e-01, 6.3969e-01, 9.0029e-01, 9.3782e-01, 6.3724e-01],
  

RuntimeError: bool value of Tensor with more than one value is ambiguous