In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class scaledDotProduct(nn.Module):
    '''
        Attention(Q, K, V ) = softmax( QK^T/√d_k)V 
    
    '''
    def __init__(self, dim, drop=0.0):
        super(scaledDotProduct, self).__init__()
        #dim is (d_k) when sqrt'd it is meant to counter small gradients in large sets of queries and keys
        self.d_k = np.sqrt(dim)
        #Simple drop out 
        self.drop = nn.Dropout(drop)

    def forward(self, q, k, v, mask=None):
        #first two dimensions are batch and number of heads?
        n = torch.matmul(q, k.transpose(2,3)) / self.d_k

        if mask:
            n = n.masked_fill_(mask==0, -1e9)
        #Drop out referenced later in paper but not in original diagram
        att = self.drop(F.softmax(n, -1))

        out = torch.matmul(n, v)

        return out, att 
        
        


In [3]:
#Scaled dot product attention testing
#dim should be size of q and k
scaled_dot = scaledDotProduct(3)
q = torch.rand(1,1,2,3)
k = torch.rand(1,1,2,3)
v = torch.rand(1,1,2,4)


scaled_dot(q,k,v)

(tensor([[[[0.0842, 0.1299, 0.1133, 0.3195],
           [0.2432, 0.3760, 0.3280, 0.9247]]]]),
 tensor([[[[0.5121, 0.4879],
           [0.5356, 0.4644]]]]))

In [14]:
class multiHeadedAttention(nn.Module):
    def __init__(self, n_heads, dims, d_k, d_v, dropout=0.0):
        super(multiHeadedAttention, self).__init__()
        #d_k=d_v = dims/h

        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        #Pre-attention projection matrices
        self.w_q = nn.Linear(dims, n_heads * d_k, bias=False)
        self.w_k = nn.Linear(dims, n_heads * d_k, bias=False)
        self.w_v = nn.Linear(dims, n_heads * d_v, bias=False)

        self.att = scaledDotProduct(d_k)
        #Final linear layer after concat and attention
        self.fc = nn.Linear(n_heads*d_v, dims)

        self.drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dims,eps=1e-6)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, heads = self.d_k, self.d_v, self.n_heads
        batch_len, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        res = q

        #Pass through projection layers prior to attention layer batch x length of query x (nheads x value dimensionality)
        #View as batches x len of query x numbers of heads x dimensionality to sperate out heads dimension
        print(q.shape)
        q = self.w_q(q).view(batch_len, len_q, heads, d_k)
        k = self.w_k(k).view(batch_len, len_k, heads, d_k)
        v = self.w_v(v).view(batch_len, len_v, heads, d_v)


        #Transpose for attention
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask:
            mask = mask.unsqueeze(1)

        q, attn = self.att(q, k, v, mask=mask)
        #Move head dim back - batch x len query x heads x dimensionality
        #Combined all heads into one - batch x len query x (heads x dimensionality)
        q = q.transpose(1,2).contiguous().view(batch_len, len_q, -1)
        q = self.drop(self.fc(q))
        q += res

        q = self.norm(q)

        return q, attn

        


In [15]:
#heads, d_model, d_km d_v as per the paper
multiHead = multiHeadedAttention(8, 512, 64, 64)
#batches, dims, dimensionalityxn_heads
q = torch.rand(1,512,512)
k = torch.rand(1,512,512)
v = torch.rand(1,512,512)


multiHead(q,k,v)

torch.Size([1, 512, 512])


(tensor([[[ 0.6771, -1.7764, -0.8633,  ...,  1.5172, -0.7129,  1.2825],
          [-0.1932, -0.1158,  1.4648,  ..., -0.8946,  0.8421,  0.3897],
          [-1.0017, -2.1038,  0.0425,  ..., -0.7560, -0.7943,  1.2451],
          ...,
          [ 0.9047, -1.7641,  1.1018,  ..., -1.2272,  0.8404,  2.0151],
          [ 0.6745, -0.8764,  0.3654,  ...,  0.6772,  1.0115,  2.1145],
          [-0.4383, -1.6716, -0.0755,  ...,  0.8128, -0.7611,  0.4531]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[[0.0019, 0.0019, 0.0020,  ..., 0.0017, 0.0019, 0.0020],
           [0.0019, 0.0019, 0.0019,  ..., 0.0018, 0.0020, 0.0020],
           [0.0019, 0.0019, 0.0019,  ..., 0.0018, 0.0018, 0.0020],
           ...,
           [0.0019, 0.0019, 0.0019,  ..., 0.0017, 0.0019, 0.0020],
           [0.0019, 0.0019, 0.0019,  ..., 0.0017, 0.0019, 0.0019],
           [0.0020, 0.0019, 0.0019,  ..., 0.0018, 0.0019, 0.0019]],
 
          [[0.0020, 0.0021, 0.0019,  ..., 0.0020, 0.0020, 0.0022],
           [0.0020

In [None]:
class positionFeedFoward(nn.Module):
    def __init__(self, inp, hid, drop=0.0):
        super(positionFeedFoward, self).__init__()
        self.w1 = nn.Linear(inp,hid)
        self.w2 = nn.Linear(hid,inp)
        self.norm = nn.LayerNorm(inp, eps=1e-6)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        res = x

        x = self.w2(F.relu(self.w1(x)))
        x = self.drop(x)
        x += res
        x = self.norm(x)

        return x