In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class scaledDotProduct(nn.Module):
    '''
        Attention(Q, K, V ) = softmax( QK^T/âˆšd_k)V 
    
    '''
    def __init__(self, dim, drop=0.0):
        super(scaledDotProduct, self).__init__()
        #dim is (d_k) when sqrt'd it is meant to counter small gradients in large sets of queries and keys
        self.d_k = np.sqrt(dim)
        #Simple drop out 
        self.drop = nn.Dropout(drop)

    def forward(self, q, k, v, mask=None):
        #first two dimensions are batch and number of heads?
        n = torch.matmul(q, k.transpose(2,3)) / self.d_k

        if mask:
            n = n.masked_fill_(mask==0, -1e9)
        #Drop out referenced later in paper but not in original diagram
        att = self.drop(F.softmax(n, -1))

        out = torch.matmul(n, v)

        return out, att 
        
        


In [20]:
#Scaled dot product attention testing
#dim should be size of q and k
scaled_dot = scaledDotProduct(3)
q = torch.rand(1,1,2,3)
k = torch.rand(1,1,2,3)
v = torch.rand(1,1,2,4)


scaled_dot(q,k,v)

(tensor([[[[0.5498, 0.5309, 0.4403, 0.5063],
           [0.3494, 0.3224, 0.2702, 0.3095]]]]),
 tensor([[[[0.4935, 0.5065],
           [0.5173, 0.4827]]]]))

In [None]:
class multiHeadedAttention(nn.Module):
    def __init__(self, n_heads, dims, d_k, d_v, dropout=0.0):
        super(multiHeadedAttention, self).__init__()

        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        #Pre-attention projection matrices
        self.w_q = nn.Linear(dims, n_heads * d_k, bias=False)
        self.w_k = nn.Linear(dims, n_heads * d_k, bias=False)
        self.w_v = nn.Linear(dims, n_heads * d_v, bias=False)

        self.att = scaledDotProduct(d_k)
        #Final linear layer after concat and attention
        self.fc = nn.Linear(n_heads*d_v, dims)
