In [1]:
import torch
from torch import nn
import numpy as np
import math


In [2]:
# Arch units
## Self Attention Unit
## Multi Head Attention
## Encode Decode Unit
## Norm + Residual Layer
## Feed Forward
## Input Positional Encoding 


In [3]:
# parameters from paper
word_emb_dim = 50
N = 6
d_model = 32 # 512
h = 4 # 8
d_k = d_v = d_model//h
d_ff = 128 # 2048
vocab_size = 100


In [4]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        
    def forward(self, Q, K, V, mask=None):
        assert Q.shape[-1] == K.shape[-1] == V.shape[-1]
        qk = torch.bmm(Q, K.transpose(-2, -1)) # Q & K size (b, seq, d_k)
        scale_qk = qk/math.sqrt(d_k) # size (b, seq, seq)
        if mask:
            scale_qk.masked_fill(mask, 0) # where mask is True replaced with 0
        softmax_qk = nn.functional.softmax(scale_qk, dim=-1) # size (b, seq, seq)
        return torch.bmm(softmax_qk, V) # size (b, seq, d_k)
    

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k)
        self.W_K = nn.Linear(d_model, d_k)
        self.W_V = nn.Linear(d_model, d_k)
        self.attn = Attention()
        self.W_O = nn.Linear(h*d_v, d_model)
        
    def forward(self, Q, K, V, mask=None):
        head = None
        for _ in range(h):
            attn_head = self.attn(self.W_Q(Q), self.W_K(K), self.W_V(V), mask=None)
            if head != None:
                head = torch.cat((head, attn_head), dim=-1)
            else:
                head = attn_head
        # head size (b, seq, d_k*h)
        return self.W_O(head) # size (b, seq, d_model)
    

In [6]:
class LayerNorm(nn.Module):
    def __init__(self, d_mod=d_model):
        super(LayerNorm, self).__init__()
        self.d_mod = d_mod
        # https://stackoverflow.com/questions/39095252/fail-to-implement-layer-normalization-with-keras
        # https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter
        self.alpha = nn.Parameter(torch.ones(d_mod))
        self.beta = nn.Parameter(torch.zeros(d_mod))
    def forward(self, x, eps=1e-6):
        u = x.mean(-1, keepdim=True)
        sigma = x.std(-1, keepdim=True)
        return self.alpha * (x - u)/(1/(sigma + eps)) + self.beta
    

In [7]:
class EncoderCell(nn.Module):
    def __init__(self):
        super(EncoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_2 = LayerNorm()
    
    def forward(self, x):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x))  # Layer 1
        return self.norm_2(x_norm_1 + self.pff(x_norm_1)) # Layer 2
    

In [8]:
class DecoderCell(nn.Module):
    def __init__(self):
        super(DecoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.attn = MultiHeadAttention()
        self.norm_2 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_3 = LayerNorm()
        
    def forward(self, x, enc, src_mask=None, trg_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, trg_mask))
        x_norm_2 = self.norm_2(x_norm_1 + self.attn(x_norm_1, enc, enc, src_mask))
        return self.norm_3(x_norm_2 + self.pff(x_norm_2)) # (b, seq, d_model)
    

In [9]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.W_in = nn.Linear(word_emb_dim, d_model)
        # based on N values
        self.encoder_unit = EncoderCell()
        self.decoder_unit = DecoderCell()
        # https://stats.stackexchange.com/questions/392213/understand-the-output-layer-of-transformer
        self.W_out = nn.Linear(d_model, vocab_size)
        
    def forward(self, inp_x, inp_y):
        inp_x, inp_y = inp_x/math.sqrt(d_model), inp_y/math.sqrt(d_model)
        inp_x, inp_y = self.W_in(inp_x), self.W_in(inp_y) # (b, seq, word_embedding) -> (b, seq, d_model)
        enc_x = self.encoder_unit(inp_x)
        dec_x = self.decoder_unit(inp_y, enc_x)
        return self.W_out(dec_x)
    

In [10]:
a = torch.rand(5,4,word_emb_dim)
b = torch.rand(5,3,word_emb_dim)
t = Transformer()

In [11]:
c = t(a,b)

In [12]:
c.shape

torch.Size([5, 3, 100])

In [None]:
# https://www.reddit.com/r/MachineLearning/comments/bjgpt2/d_confused_about_using_masking_in_transformer/

In [25]:
c = torch.rand(4,3)
print(c)
c.unsqueeze(1)

tensor([[0.8239, 0.6005, 0.5948],
        [0.5414, 0.9904, 0.6901],
        [0.7687, 0.7560, 0.6091],
        [0.8969, 0.8435, 0.7998]])


tensor([[[0.8239, 0.6005, 0.5948]],

        [[0.5414, 0.9904, 0.6901]],

        [[0.7687, 0.7560, 0.6091]],

        [[0.8969, 0.8435, 0.7998]]])

In [23]:
c.view(-1,3)

tensor([[0.2774, 0.3314, 0.7153],
        [0.4680, 0.3764, 0.4510],
        [0.3530, 0.0466, 0.2651],
        [0.4527, 0.1481, 0.5217]])

In [94]:
s = torch.randint(0,4,(3,4,5))
t = torch.randint(0,4,(3,4,5))
p = 0


In [95]:
a = (s != p).unsqueeze(-2)


In [97]:
t

tensor([[[1, 2, 1, 3, 1],
         [0, 3, 2, 1, 2],
         [1, 0, 1, 0, 2],
         [1, 0, 1, 3, 1]],

        [[2, 3, 1, 3, 1],
         [2, 0, 1, 2, 1],
         [1, 2, 3, 1, 0],
         [3, 1, 0, 2, 0]],

        [[0, 1, 2, 3, 3],
         [1, 3, 1, 3, 0],
         [2, 3, 1, 1, 0],
         [2, 0, 1, 3, 3]]])

In [98]:
t_1 = t[:, :-1]
t_y = t[:, 1:]

In [99]:
t_1

tensor([[[1, 2, 1, 3, 1],
         [0, 3, 2, 1, 2],
         [1, 0, 1, 0, 2]],

        [[2, 3, 1, 3, 1],
         [2, 0, 1, 2, 1],
         [1, 2, 3, 1, 0]],

        [[0, 1, 2, 3, 3],
         [1, 3, 1, 3, 0],
         [2, 3, 1, 1, 0]]])

In [100]:
t_y

tensor([[[0, 3, 2, 1, 2],
         [1, 0, 1, 0, 2],
         [1, 0, 1, 3, 1]],

        [[2, 0, 1, 2, 1],
         [1, 2, 3, 1, 0],
         [3, 1, 0, 2, 0]],

        [[1, 3, 1, 3, 0],
         [2, 3, 1, 1, 0],
         [2, 0, 1, 3, 3]]])

In [101]:
t_mask = make_std_mask(t_1, p)

  after removing the cwd from sys.path.


In [128]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html
class Batch:
    def __init__(self, src, trg=None, pad=0):
        self.src = src
        if trg is not None:
            self.trg = trg[:,:-1]
            self.trg_y = trg[:,1:]
#             self.trg_mask = 
    
    @staticmethod
    def std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt, pad)
         
    def subsequent_mask(self.size):
        return torch.from_numpy(np.triu(np.ones((1,size,size)), k=1).astype('uint8')) == 0

tensor([[[ True, False, False],
         [ True,  True, False],
         [ True,  True,  True]]])

In [140]:
t_q = (t_1 != p).unsqueeze(-2)


In [138]:
def subsequent_mask(size):
    return torch.from_numpy(np.triu(np.ones((1,size,size)), k=1).astype('uint8')) == 0

In [153]:
t_q.shape

torch.Size([3, 3, 1, 5])

In [159]:
subsequent_mask(t_q.shape[-1]).shape

torch.Size([1, 5, 5])

tensor([[[[ True,  True,  True,  True,  True]],

         [[False,  True,  True,  True,  True]],

         [[ True, False,  True, False,  True]]],


        [[[ True,  True,  True,  True,  True]],

         [[ True, False,  True,  True,  True]],

         [[ True,  True,  True,  True, False]]],


        [[[False,  True,  True,  True,  True]],

         [[ True,  True,  True,  True, False]],

         [[ True,  True,  True,  True, False]]]])

In [158]:
t

tensor([[[1, 2, 1, 3, 1],
         [0, 3, 2, 1, 2],
         [1, 0, 1, 0, 2],
         [1, 0, 1, 3, 1]],

        [[2, 3, 1, 3, 1],
         [2, 0, 1, 2, 1],
         [1, 2, 3, 1, 0],
         [3, 1, 0, 2, 0]],

        [[0, 1, 2, 3, 3],
         [1, 3, 1, 3, 0],
         [2, 3, 1, 1, 0],
         [2, 0, 1, 3, 3]]])