### Scaled Dot-Product Attention

In [8]:
from tensorflow import matmul, math, cast, float32 
from tensorflow.keras.layers import Layer
from keras.backend import softmax
from numpy import random

In [3]:
class DotProductAttention(Layer):
    def __int__(self,**kwargs):
        super().__init__(**kwargs)
    
    def call(self,queries,keys,values,d_k,mask=None):
        # Scoring the queries against the keys after transposing the latter, and scaling 
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))
        # Apply mask to the attention scores
        if mask is not None: scores += -1e9 * mask
                 # Computing the weights by a softmax operation
        weights = softmax(scores)
        # Computing the attention by a weighted sum of the value vectors
        return matmul(weights, values)
        

#### Test

In [16]:
d_k = 64 # Dimensionality of the linearly projected queries and keys 
d_v = 64 # Dimensionality of the linearly projected values 
batch_size = 64 # Batch size from the training process


In [17]:
input_seq_length = 5 # Maximum length of the input sequence
queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

In [19]:
attention = DotProductAttention()
print(attention(queries, keys, values, d_k))

tf.Tensor(
[[[0.17039935 0.5660157  0.49454716 ... 0.63606805 0.731504   0.46568075]
  [0.17068729 0.5583192  0.5042453  ... 0.63850045 0.74013174 0.4734642 ]
  [0.17546858 0.5731396  0.49296683 ... 0.6524726  0.72995955 0.47620583]
  [0.17138234 0.5680681  0.490462   ... 0.6407421  0.7280969  0.47297087]
  [0.17335843 0.5582293  0.49005768 ... 0.64072454 0.7280288  0.4742112 ]]

 [[0.39731753 0.40955257 0.4291552  ... 0.37298998 0.5416086  0.7824889 ]
  [0.3965895  0.4205944  0.42082188 ... 0.37335843 0.5425651  0.79109925]
  [0.35964045 0.40663934 0.42679647 ... 0.38441384 0.5283919  0.7844816 ]
  [0.3748286  0.41582245 0.41625062 ... 0.3807641  0.5327348  0.7933047 ]
  [0.35876435 0.39259434 0.39888084 ... 0.3928527  0.5334414  0.8034639 ]]

 [[0.16549972 0.28848526 0.4118323  ... 0.5087974  0.6443383  0.62525296]
  [0.17033522 0.26947588 0.41163826 ... 0.511412   0.6396189  0.6172437 ]
  [0.16947632 0.29201823 0.4199351  ... 0.49825054 0.64014316 0.615296  ]
  [0.1696158  0.2635906