In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
class scaledDotProduct(nn.Module):
    '''
        Attention(Q, K, V ) = softmax( QK^T/√d_k)V 
    
    '''
    def __init__(self, dim, drop=0.0):
        super(scaledDotProduct, self).__init__()
        #Temp is the sqrt(d_k) meant to counter small gradients in large sets of queries and keys
        self.d_k = np.sqrt(dim)
        #Simple drop out 
        self.drop = nn.Dropout(drop)

    def forward(self, q, k, v, mask=None):
        print(q.shape)
        print(k.shape)
        #first two dimensions are batch and number of heads?
        n = torch.matmul(q, k.transpose(2,3)) / self.d_k

        if mask:
            n = n.masked_fill_(mask==0, -1e9)
        #Drop out referenced later in paper but not in original diagram
        att = self.drop(F.softmax(n, -1))

        out = torch.matmul(n, v)

        return out, att 
        
        


In [17]:
scaled_dot = scaledDotProduct(7)
q = torch.rand(1,1,2,3)
k = torch.rand(1,1,2,3)
v = torch.rand(1,1,2,4)


scaled_dot(q,k,v)

torch.Size([1, 1, 2, 3])
torch.Size([1, 1, 2, 3])


(tensor([[[[0.1676, 0.1477, 0.4742, 0.4178],
           [0.2135, 0.1883, 0.6043, 0.5325]]]]),
 tensor([[[[0.5175, 0.4825],
           [0.5221, 0.4779]]]]))