In [8]:
from tensorflow import matmul,cast,float32,math
from tensorflow.keras.layers import Layer
from tensorflow.keras.backend import softmax

In [9]:
class DotProductAttention(Layer):
    def __init__(self,**kwargs):
        super(DotProductAttention,self).__init__(**kwargs)
    
    def call(self,queries,keys,values,d_k,mask=None):
        scores=matmul(queries,keys,transpose_b=True)/math.sqrt(cast(d_k,float32))
        
        if mask is not None:
            scores+= -1e9*mask
        weights=softmax(scores)
        
        return matmul(weights,values)

### testing using dummy values

In [10]:

from numpy import random
 
input_seq_length = 5  # Maximum length of the input sequence
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
batch_size = 64  # Batch size from the training process
 
queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))
 
attention = DotProductAttention()
print(attention(queries, keys, values, d_k))

tf.Tensor(
[[[0.43439865 0.57446873 0.5006186  ... 0.32303673 0.46497482 0.501368  ]
  [0.44587904 0.58248234 0.46138436 ... 0.3071723  0.46704707 0.48796767]
  [0.44831926 0.5751674  0.46460605 ... 0.31604692 0.4724028  0.48221013]
  [0.43756688 0.57769686 0.47251344 ... 0.31773275 0.46963233 0.4862868 ]
  [0.46419376 0.5821735  0.4840212  ... 0.30262208 0.45929372 0.5048124 ]]

 [[0.36041713 0.40153107 0.55443263 ... 0.60989    0.5205752  0.5911077 ]
  [0.35975575 0.3907556  0.54939795 ... 0.6028986  0.5025436  0.61038756]
  [0.36530334 0.40771726 0.56081146 ... 0.6109898  0.5102244  0.5901164 ]
  [0.3749388  0.39237887 0.55807114 ... 0.6059265  0.49486136 0.60034543]
  [0.36788118 0.38355082 0.547907   ... 0.60731846 0.52602214 0.5921198 ]]

 [[0.53434294 0.7053262  0.50499725 ... 0.40913734 0.5711663  0.5944548 ]
  [0.52527404 0.70513326 0.52652967 ... 0.39810777 0.5990136  0.62535447]
  [0.5188073  0.70701504 0.5384836  ... 0.40575576 0.6082389  0.6280854 ]
  [0.54295874 0.7033777