In [9]:
from tensorflow.keras.layers import Dense, Layer
from tensorflow.keras.backend import softmax
from tensorflow import matmul,cast,float32,math,reshape,shape,transpose

In [4]:
class DotProductAttention(Layer):
    def __init__(self,**kwargs):
        super(DotProductAttention,self).__init__(**kwargs)
    
    def call(self,queries,keys,values,d_k,mask=None):
        scores=matmul(queries,keys,transpose_b=True)/math.sqrt(cast(d_k,float32))
    
    #applying mask so as to not base the occurence of a word on the basis of the words ahead
        if mask is not None:
            scores+= -1e9*mask
        weights=softmax(scores)
        
        return matmul(weights,values)

In [11]:
class MultiHeadAttention(Layer):
    def __init__(self,h,d_k,d_v,d_model,**kwargs):
        super(MultiHeadAttention,self).__init__(**kwargs)
        self.attention=DotProductAttention() #scaled dot product attention
        self.heads=h #number of attention heads to use
        self.d_k=d_k #dimentionality of linearly projected queries and keys
        self.d_v=d_v #dimantionality of linearly projected values
        self.d_model=d_model #dimentionality of the model
        self.W_q=Dense(d_k) #learned projection matrix for the queries
        self.W_k=Dense(d_k) #learned projection matrix for the keys
        self.W_v=Dense(d_v) #learned projection matrix for the values
        self.W_o=Dense(d_model) #leanred projection matrix for the multi head o/p
        
    def reshape_tensor(self,x,heads,flag):
        if flag:
            #tensor shape after reshaping and transposing: (batch_size,heads,seq_length,-1)
            x=reshape(x,shape=(shape(x)[0],shape(x)[1],heads,-1))
            x=transpose(x,perm=(0,2,1,3))
        else:
            #reverting the reshaping and transposing opertaions:(batch_size,seq,length,d_k)
            x=transpose(x,perm=(0,2,1,3))
            x=reshape(x,shape=(shape(x)[0],shape(x)[1],self.d_k))
        return x
    
    def call(self,queries,keys,values,mask=None):
        #rearrange the queries to be able to compute all heads in parallel
        q_reshaped=self.reshape_tensor(self.W_q(queries),self.heads,True)
        #resulting tensor shape: (batch_size,heads,input_seq_len,-1)
        
        #rearrange the keys to be able to compute all heads in parallel
        k_reshaped=self.reshape_tensor(self.W_k(keys),self.heads,True)
        
        #rearrange the values to be able to compute all heads in parallel
        v_reshaped=self.reshape_tensor(self.W_v(values),self.heads,True)
        
        #compute the multi head attention o/p using the reshaped q,k,v
        o_reshaped=self.attention(q_reshaped,k_reshaped,v_reshaped,self.d_k,mask)
        #resulting tensor shape: (batch_size,input_seq_len,-1)
        
        #rearrange back the o/p into concatenated form
        output=self.reshape_tensor(o_reshaped,self.heads,False)
        #resulting tensor shape: (batch_size,heads,input_seq_len,d_k)
        
        #apply one final layer linear projection to the o/p to generate the mutlihead attention
        #resulting tensor shaoe:(batch_size,input_seq_len,d_model)
        return self.W_o(output)

### testing using dummy values

In [12]:
from numpy import random

input_seq_len=5 #max len of the i/p sequence
h=8 #number of attention heads
d_k=64 #dimentionility of the linearly projected queris and keys
d_v=64 #dinetionality of the linearly projected values
d_model=512 #dimentionality of the model sub-layers' o/p
batch_size=64 #batch size from the training process

queries=random.random((batch_size,input_seq_len,d_k))
keys=random.random((batch_size,input_seq_len,d_k))
values=random.random((batch_size,input_seq_len,d_v))

multihead_attention=MultiHeadAttention(h,d_k,d_v,d_model)
print(multihead_attention(queries,keys,values))

tf.Tensor(
[[[-0.0434286   0.02936906  0.05576456 ... -0.56693375 -0.19152977
   -0.21011189]
  [-0.04031079  0.03237865  0.05654873 ... -0.5679413  -0.18873364
   -0.20874096]
  [-0.03539615  0.02926255  0.05997868 ... -0.56486785 -0.1920926
   -0.20644829]
  [-0.04033906  0.03066664  0.0610002  ... -0.5660929  -0.19030315
   -0.21040261]
  [-0.03978018  0.02913226  0.05819973 ... -0.56733584 -0.19054492
   -0.2096827 ]]

 [[-0.0372023  -0.058926    0.12408636 ... -0.51994145 -0.11749138
   -0.14175162]
  [-0.0357124  -0.06040482  0.12341858 ... -0.5191991  -0.11573078
   -0.1413696 ]
  [-0.03682811 -0.05871838  0.12145889 ... -0.5193449  -0.11685009
   -0.14421883]
  [-0.03612378 -0.05992437  0.12749702 ... -0.51628196 -0.11853358
   -0.14406869]
  [-0.03476172 -0.05852611  0.12721835 ... -0.5173383  -0.1182031
   -0.14478958]]

 [[ 0.09031312  0.01680743  0.1951084  ... -0.5136962  -0.18819906
   -0.13954961]
  [ 0.09156879  0.01353643  0.19411293 ... -0.5133157  -0.18917224
   -0.1