In [4]:
import torch
from torch import nn
import numpy as np
import math
import copy


In [None]:
# Arch units
## Self Attention Unit
## Multi Head Attention
## Encode Decode Unit
## Norm + Residual Layer
## Feed Forward
## Input Positional Encoding 


In [2]:
# parameters from paper
word_emb_dim = 50
N = 6
d_model = 32 # 512
h = 4 # 8
d_k = d_v = d_model//h
d_ff = 128 # 2048
vocab_size = 100


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k)
        self.W_K = nn.Linear(d_model, d_k)
        self.W_V = nn.Linear(d_model, d_k)
        self.W_O = nn.Linear(h*d_v, d_model)
        
    def forward(self, Q, K, V, mask=None):
        head = None
        for _ in range(h):
            attn_head = attention(self.W_Q(Q), self.W_K(K), self.W_V(V), mask=None)
            if head != None:
                head = torch.cat((head, attn_head), dim=-1)
            else:
                head = attn_head
        # head size (b, seq, d_k*h)
        return self.W_O(head) # size (b, seq, d_model)

    def attention(self, Q, K, V, mask=None):
        assert Q.shape[-1] == K.shape[-1] or K.shape[-2] == V.shape[-2]
        qk = torch.matmul(Q, K.transpose(-2, -1)) # Q & K size (b, seq, d_k)
        scale_qk = qk/math.sqrt(d_k) # size (b, seq, seq)
        if mask:
            scale_qk = scale_qk.masked_fill(mask==0, 1e-9) # where mask is True replaced with 0
        softmax_qk = nn.functional.softmax(scale_qk, dim=-1) # size (b, seq, seq)
        return torch.matmul(softmax_qk, V) # size (b, seq, d_k)

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_mod=d_model):
        super(LayerNorm, self).__init__()
        self.d_mod = d_mod
        # https://stackoverflow.com/questions/39095252/fail-to-implement-layer-normalization-with-keras
        # https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter
        self.alpha = nn.Parameter(torch.ones(d_mod))
        self.beta = nn.Parameter(torch.zeros(d_mod))
    def forward(self, x, eps=1e-6):
        u = x.mean(-1, keepdim=True)
        sigma = x.std(-1, keepdim=True)
        return self.alpha * (x - u)/(1/(sigma + eps)) + self.beta
    

In [None]:
class EncoderCell(nn.Module):
    def __init__(self):
        super(EncoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_2 = LayerNorm()
    
    def forward(self, x):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x))  # Layer 1
        return self.norm_2(x_norm_1 + self.pff(x_norm_1)) # Layer 2
    

In [None]:
class DecoderCell(nn.Module):
    def __init__(self):
        super(DecoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.attn = MultiHeadAttention()
        self.norm_2 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_3 = LayerNorm()
        
    def forward(self, x, enc, src_mask=None, trg_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, trg_mask))
        x_norm_2 = self.norm_2(x_norm_1 + self.attn(x_norm_1, enc, enc, src_mask))
        return self.norm_3(x_norm_2 + self.pff(x_norm_2)) # (b, seq, d_model)


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self):
        pass
    

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self):
        pass

In [None]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.W_in = nn.Linear(word_emb_dim, d_model)
        # based on N values
        self.encoder_unit = EncoderCell()
        self.decoder_unit = DecoderCell()
        # https://stats.stackexchange.com/questions/392213/understand-the-output-layer-of-transformer
        self.W_out = nn.Linear(d_model, vocab_size)
        
    def forward(self, inp_x, inp_y):
        inp_x, inp_y = inp_x/math.sqrt(d_model), inp_y/math.sqrt(d_model)
        inp_x, inp_y = self.W_in(inp_x), self.W_in(inp_y) # (b, seq, word_embedding) -> (b, seq, d_model)
        enc_x = self.encoder_unit(inp_x)
        dec_x = self.decoder_unit(inp_y, enc_x)
        return self.W_out(dec_x)


In [None]:
class EmbeddingLayer(nn.Module):
    def __init__(self):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, word_emb_dim)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(d_model)

In [None]:
# https://www.reddit.com/r/MachineLearning/comments/bjgpt2/d_confused_about_using_masking_in_transformer/

In [None]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html
class Batch:
    def __init__(self, src, trg=None, pad=0): # size src, trg (b, seq)
        self.src = src 
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:,:-1] # size (b,0:seq-1)
            self.trg_y = trg[:,1:] # size (b,1:seq)
            self.trg_mask = self.std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() # size (1)
    
    @staticmethod
    def std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2) # size (b, 1, seq)
        return tgt_mask & subsequent_mask(tgt.shape[-1]) # size (b, 1, seq) * (1, seq, seq) -> (b, seq, seq) 
    
    @staticmethod
    def subsequent_mask(size):
        return torch.from_numpy(np.triu(np.ones((1,size,size)), k=1).astype('uint8')) == 0 # size (1, seq, seq)

In [None]:
attn.shape

In [None]:
attn.masked_fill(b.src_mask==0, 1e-9)

In [15]:
a = torch.randint(0,4,(3,4,5)) * 1.

In [16]:
a

tensor([[[2., 2., 3., 3., 3.],
         [0., 2., 0., 1., 3.],
         [1., 0., 0., 3., 3.],
         [0., 1., 2., 0., 3.]],

        [[2., 2., 1., 2., 2.],
         [1., 3., 2., 0., 3.],
         [2., 3., 3., 1., 2.],
         [3., 2., 1., 3., 0.]],

        [[3., 1., 3., 2., 2.],
         [0., 2., 2., 0., 3.],
         [3., 3., 1., 1., 1.],
         [2., 1., 0., 0., 1.]]])

In [12]:
h = 2

In [13]:
l1 = nn.Linear(5, 5 * h, bias=False)

In [17]:
a1 = l1(a)

In [24]:
a2 = a1.view(3,4,h,5)

In [32]:
a2.transpose(1,2).conti

torch.Size([3, 2, 4, 5])

In [43]:
a = torch.randint(0,4,(3,4,2,5,6))

In [49]:
b = torch.randint(0,4,(3,4,2,6,4))

In [51]:
a.matmul(b).shape

torch.Size([3, 4, 2, 5, 4])