In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class scaledDotProduct(nn.Module):
    '''
        Attention(Q, K, V ) = softmax( QK^T/√d_k)V 
    
    '''
    def __init__(self, dim, drop=0.0):
        super(scaledDotProduct, self).__init__()
        #dim is (d_k) when sqrt'd it is meant to counter small gradients in large sets of queries and keys
        self.d_k = np.sqrt(dim)
        #Simple drop out 
        self.drop = nn.Dropout(drop)

    def forward(self, q, k, v, mask=None):
        #first two dimensions are batch and number of heads?
        n = torch.matmul(q, k.transpose(2,3)) / self.d_k

        if mask:
            n = n.masked_fill_(mask==0, -1e9)
        #Drop out referenced later in paper but not in original diagram
        att = self.drop(F.softmax(n, -1))

        out = torch.matmul(n, v)

        return out, att 
        
        


In [3]:
#Scaled dot product attention testing
#dim should be size of q and k
scaled_dot = scaledDotProduct(3)
q = torch.rand(1,1,2,3)
k = torch.rand(1,1,2,3)
v = torch.rand(1,1,2,4)


scaled_dot(q,k,v)

(tensor([[[[0.3464, 0.2510, 0.2729, 0.0916],
           [0.3734, 0.2844, 0.3056, 0.1029]]]]),
 tensor([[[[0.4680, 0.5320],
           [0.4710, 0.5290]]]]))

In [4]:
class multiHeadedAttention(nn.Module):
    def __init__(self, n_heads, dims, d_k, d_v, dropout=0.0):
        super(multiHeadedAttention, self).__init__()
        #d_k=d_v = dims/h

        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        #Pre-attention projection matrices
        self.w_q = nn.Linear(dims, n_heads * d_k, bias=False)
        self.w_k = nn.Linear(dims, n_heads * d_k, bias=False)
        self.w_v = nn.Linear(dims, n_heads * d_v, bias=False)

        self.att = scaledDotProduct(d_k)
        #Final linear layer after concat and attention
        self.fc = nn.Linear(n_heads*d_v, dims)

        self.drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dims,eps=1e-6)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, heads = self.d_k, self.d_v, self.n_heads
        batch_len, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        res = q

        #Pass through projection layers prior to attention layer batch x length of query x (nheads x value dimensionality)
        #View as batches x len of query x numbers of heads x dimensionality to sperate out heads dimension
        print(q.shape)
        q = self.w_q(q).view(batch_len, len_q, heads, d_k)
        k = self.w_k(k).view(batch_len, len_k, heads, d_k)
        v = self.w_v(v).view(batch_len, len_v, heads, d_v)


        #Transpose for attention
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask:
            mask = mask.unsqueeze(1)

        q, attn = self.att(q, k, v, mask=mask)
        #Move head dim back - batch x len query x heads x dimensionality
        #Combined all heads into one - batch x len query x (heads x dimensionality)
        q = q.transpose(1,2).contiguous().view(batch_len, len_q, -1)
        q = self.drop(self.fc(q))
        q += res

        q = self.norm(q)

        return q, attn

        


In [5]:
#heads, d_model, d_km d_v as per the paper
multiHead = multiHeadedAttention(8, 512, 64, 64)
#batches, dims, dimensionalityxn_heads
q = torch.rand(1,512,512)
k = torch.rand(1,512,512)
v = torch.rand(1,512,512)


multiHead(q,k,v)

torch.Size([1, 512, 512])


(tensor([[[-0.5688,  1.4184,  0.3202,  ..., -1.0021, -0.2049, -0.3221],
          [-0.7246,  1.0414,  1.2804,  ..., -1.4171,  0.1137, -0.1924],
          [-0.7711,  0.8162,  1.2456,  ..., -1.2875,  0.1525, -0.1983],
          ...,
          [-1.8663,  1.2953,  1.9278,  ..., -0.9809, -0.2540, -0.2188],
          [-1.0484,  1.1081,  1.3690,  ..., -1.3603,  0.0067, -0.0589],
          [-0.9783,  1.1937,  1.2425,  ..., -1.3367, -0.2213, -0.0561]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[[0.0019, 0.0022, 0.0020,  ..., 0.0019, 0.0019, 0.0021],
           [0.0018, 0.0021, 0.0019,  ..., 0.0019, 0.0020, 0.0020],
           [0.0018, 0.0021, 0.0020,  ..., 0.0018, 0.0019, 0.0020],
           ...,
           [0.0018, 0.0022, 0.0020,  ..., 0.0020, 0.0019, 0.0021],
           [0.0018, 0.0021, 0.0020,  ..., 0.0018, 0.0020, 0.0022],
           [0.0018, 0.0022, 0.0019,  ..., 0.0019, 0.0019, 0.0021]],
 
          [[0.0020, 0.0021, 0.0020,  ..., 0.0021, 0.0021, 0.0021],
           [0.0020

In [6]:
class positionFeedFoward(nn.Module):
    def __init__(self, inp, hid, drop=0.0):
        super(positionFeedFoward, self).__init__()
        self.w1 = nn.Linear(inp,hid)
        self.w2 = nn.Linear(hid,inp)
        self.norm = nn.LayerNorm(inp, eps=1e-6)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        res = x

        x = self.w2(F.relu(self.w1(x)))
        x = self.drop(x)
        x += res
        x = self.norm(x)

        return x

In [7]:
class EncoderLayer(nn.Module):
    '''Combinds MultiHeadedAttention and FeedForward, two layers'''
    def __init__(self, dims, hid, nheads, d_k, d_v, drop=0.0):
        super(EncoderLayer, self).__init__()
        self.attn = multiHeadedAttention(nheads, dims,d_k, d_v, dropout=drop)
        self.ffn = positionFeedFoward(dims, hid, drop=drop)

    def forward(self, inp, mask=None):
        out, attn = self.attn(
            inp, inp, inp, mask
        )
        out = self.ffn(out)

        return out, attn
    
class DecoderLayer(nn.Module):
    '''Combinds MultiHeadedAttention and FeeForward, three layers'''
    def __init__(self, dims, hid, nheads, d_k, d_v, drop=0.0):
        super(EncoderLayer, self).__init__()
        self.slf_attn = multiHeadedAttention(nheads, dims,d_k, d_v, dropout=drop)
        self.enc_attn = multiHeadedAttention(nheads, dims,d_k, d_v, dropout=drop)
        self.ffn = positionFeedFoward(dims, hid, drop=drop)

    def forward(self, inp, enc_out, slf_mask, enc_mask=None):
        dec_out, dec_attn = self.slf_attn(
            inp, inp, inp, slf_mask
        )

        dec_out, enc_attn = self.enc_attn(
            dec_out, enc_out, enc_out, enc_mask
        )
        dec_out = self.ffn(dec_out)

        return dec_out, dec_attn, enc_attn

In [8]:
#heads, d_model, d_km d_v as per the paper
enc = EncoderLayer(512, 20, 8, 64, 64)
#batches, dims, dimensionalityxn_heads
q = torch.rand(1,512,512)
k = torch.rand(1,512,512)
v = torch.rand(1,512,512)


enc(v)

torch.Size([1, 512, 512])


(tensor([[[-0.9075,  0.3117,  0.8335,  ..., -0.0801, -0.3563, -0.8519],
          [ 0.5633, -0.3112, -0.0830,  ..., -0.6547, -0.8622, -0.9671],
          [ 0.4893,  1.1514, -1.5643,  ..., -0.2486, -1.9703,  0.0916],
          ...,
          [ 0.3537,  0.3940, -0.5105,  ..., -0.5003, -1.7616, -1.0855],
          [-0.2221,  0.1675, -0.5074,  ...,  0.5400, -1.4164,  0.1038],
          [-0.1265,  0.8833, -0.6364,  ..., -0.2704, -1.9819, -0.3983]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[[0.0019, 0.0018, 0.0020,  ..., 0.0018, 0.0020, 0.0020],
           [0.0020, 0.0017, 0.0020,  ..., 0.0018, 0.0020, 0.0020],
           [0.0020, 0.0018, 0.0020,  ..., 0.0019, 0.0020, 0.0020],
           ...,
           [0.0020, 0.0018, 0.0019,  ..., 0.0018, 0.0021, 0.0021],
           [0.0020, 0.0018, 0.0019,  ..., 0.0019, 0.0019, 0.0020],
           [0.0020, 0.0018, 0.0020,  ..., 0.0018, 0.0019, 0.0020]],
 
          [[0.0019, 0.0021, 0.0019,  ..., 0.0019, 0.0019, 0.0019],
           [0.0018

In [9]:
#Pytoch version adapted from here https://pub.aimind.so/creating-sinusoidal-positional-embedding-from-scratch-in-pytorch-98c49e153d6
class PosEncoding(nn.Module):
    def __init__(self, hid, n_pos=200):
        super(PosEncoding, self).__init__()

        self.register_buffer('table', self._get_sinusoid_encoding_table(n_pos, hid))

    def _get_sinusoid_encoding_table(self, n_pos, hid):

        if hid %2 != 0:
            raise ValueError("Sinusoidal positional embedding cannot apply to odd token embedding dim={}".format(hid))
        
        positions = torch.arange(0,n_pos).unsqueeze_(1)
        embeds = torch.zeros(n_pos, hid)

        denom = torch.pow(10000, 2 * torch.arange(0, hid//2)/2)
        embeds[:, 0::2] = torch.sin(positions/denom)
        embeds[:, 1::2] = torch.cos(positions/denom)

        return embeds
    
    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)].clone().detach()




In [10]:
class Encoder(nn.Module):
    '''Encoder model'''
    def __init__(
            self, n_vocab, d_word, n_layers, n_head, d_k, d_v, dims, hid, pad, dropout=0.0, n_pos=200, scale_emb=False
    ):
        super(Encoder, self).__init__()

        self.word_emb = nn.Embedding(n_vocab, d_word, padding_idx=pad)
        self.pos_enc = PosEncoding(d_word, n_pos=n_pos)
        self.drop = nn.Dropout(p=dropout)
        self.stack = nn.ModuleList([
            EncoderLayer(dims, hid, n_head, d_k, d_v, drop=dropout)
            for _ in range(n_layers)
        ])
        self.layer_norm = nn.LayerNorm(dims, eps=1e-6)
        self.scale_emb = scale_emb
        self.dims = dims

    def forward(self, seq, mask, ret_attns=False):
        enc_slf_attn_list = []

        enc_out = self.word_emb(seq)
        if self.scale_emb:
            enc_out *= self.dims ** 0.5
        enc_out = self.pos_enc(enc_out)
        enc_out = self.drop(enc_out)
        enc_out = self.layer_norm(enc_out)

        for enc_layer in self.stack:
            enc_out, enc_slf_attn = enc_layer(enc_out, mask=mask)
            enc_slf_attn_list += [enc_slf_attn] if ret_attns else []

        if ret_attns:
            return enc_out, enc_slf_attn_list
        return enc_out

class Decoder(nn.Module):
    '''Decoder model'''
    def __init__(
            self, n_vocab, d_word, n_layers, n_head, d_k, d_v, dims, hid, pad, dropout=0.0 , n_pos=200, scale_emb=False
    ):
        super(Decoder, self).__init__()

        self.word_emb = nn.Embedding(n_vocab, d_word, padding_idx=pad)
        self.pos_enc = PosEncoding(d_word, n_pos=n_pos)
        self.drop = nn.Dropout(p=dropout)
        self.stack = nn.ModuleList([
            DecoderLayer(dims, hid, n_head, d_k, d_v, drop=dropout)
            for _ in range(n_layers)
        ])
        self.layer_norm = nn.LayerNorm(dims, eps=1e-6)
        self.scale_emb = scale_emb
        self.dims = dims

    def forward(self, seq, mask, enc_out, src_mask, ret_attns=False):
        dec_slf_attn_list, dec_enc_attn_list = [],[]

        dec_out = self.word_emb(seq)
        if self.scale_emb:
            dec_out *= self.dims ** 0.5
        dec_out = self.pos_enc(dec_out)
        dec_out = self.drop(dec_out)
        dec_out = self.layer_norm(dec_out)

        for dec_layer in self.stack:
            dec_out, dec_self_attn, dec_enc_attn = dec_layer(
                dec_out, enc_out, slf_mask=mask, enc_mask=src_mask
            )
            dec_slf_attn_list += [dec_self_attn] if ret_attns else []
            dec_enc_attn_list += [dec_enc_attn] if ret_attns else []

        if ret_attns:
            return dec_out, dec_slf_attn_list, dec_enc_attn_list
        return dec_out

