<a href="https://colab.research.google.com/github/JSJeong-me/AI-Innovation-2024/blob/main/Transformer/5-1-Scaled-Dot-Product-Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from tensorflow import matmul, math, cast, float32
from tensorflow.keras.layers import Layer
from tensorflow.keras.activations import softmax # Import softmax from tensorflow.keras.activations

# Implementing the Scaled-Dot Product Attention
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        # Scoring the queries against the keys after transposing the latter, and scaling
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))

        # Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask

        # Computing the weights by a softmax operation
        weights = softmax(scores) # Use softmax from tensorflow.keras.activations

        # Computing the attention by a weighted sum of the value vectors
        return matmul(weights, values)

In [2]:
from numpy import random

input_seq_length = 5  # Maximum length of the input sequence
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
batch_size = 64  # Batch size from the training process

queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

attention = DotProductAttention()
# Pass d_k as a keyword argument
print(attention(queries, keys, values, d_k=d_k))

tf.Tensor(
[[[0.406152   0.6481798  0.61734873 ... 0.3330139  0.62651014 0.47233817]
  [0.41655076 0.6454326  0.61695933 ... 0.347259   0.62359726 0.4805547 ]
  [0.41623163 0.6274299  0.6196872  ... 0.32457232 0.6173672  0.48080087]
  [0.40989777 0.6294844  0.61528504 ... 0.32892424 0.6237523  0.47626266]
  [0.42396507 0.6435493  0.6124284  ... 0.3489979  0.62224984 0.4792506 ]]

 [[0.6930096  0.6724482  0.6079855  ... 0.4888395  0.33188948 0.57895654]
  [0.69275737 0.6841609  0.6082085  ... 0.47907    0.33110777 0.58557   ]
  [0.68699646 0.6717293  0.6222161  ... 0.49429387 0.3361022  0.57459044]
  [0.6810884  0.67848057 0.62003136 ... 0.4848075  0.35164532 0.5719061 ]
  [0.6803761  0.6697084  0.62488127 ... 0.49267438 0.35136554 0.56842065]]

 [[0.7187425  0.59889805 0.4584425  ... 0.5093596  0.24952105 0.52193904]
  [0.7213501  0.5821472  0.4867586  ... 0.5209122  0.2686088  0.5374378 ]
  [0.7069361  0.6041473  0.46155518 ... 0.5231636  0.2664616  0.567302  ]
  [0.72728294 0.603428 