<a href="https://colab.research.google.com/github/StanleyLiangYork/2024_journal_club_Transformer_AI/blob/main/Scaled_dot_product_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook implements the scaled dot product attention for transformer building blocks

The transformer architecture follows an encoder-decoder structure.

*   Encoder: maps an input sequence to a numeric continuous representations.
*   Decoder: receives the ouput of the encoder + the output of the decoder at the previous step, then generates the output sequence
*   Unlike RNN, transformer does not rely on recurrence and convolutions

Both the encoder and the decoder rely on the multi-head attention, or scaled dot-product attention to remember long sequence patterns.

The scaled dot-product attention receives these queries (Q), keys (K), and values (V) as inputs, then first it computes the dot-product of the queries with the keys. The result is subsequently scaled by the squared root of $d_{k}$, the dimensionality of the key values (# of columns), to get the attention score.
The attention scores are fed into a softmax function to get a set of attention weights. The attention weights scale the values (V) through a weighted multiplication operation. The whole process can be expressed as an attention function defined below:

$attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_{k}}}V) $

Note that each multi-head attention block uses a scaled dot-product attention operation to merge the attention heads.

The scaled dot-product attention at the decoder also applied a mask to the attention scores (prevent the decoder looking at the info in the front) before feeding them into the softmax function.


In [1]:
from numpy import random
from tensorflow import matmul, math, cast, float32
from tensorflow.keras.layers import Layer
from tensorflow.keras.backend import softmax

In [2]:
input_seq_length = 5  # Maximum length of the input sequence
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
batch_size = 64  # Batch size from the training process

In [3]:
# Implementing the Scaled-Dot Product Attention
# Inherit from the Tensorflow Layer class

class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        # Scoring the queries against the keys after transposing the latter, and scaling
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))

        # Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask

        # Computing the weights by a softmax operation
        weights = softmax(scores)

        # Computing the attention by a weighted sum of the value vectors
        return matmul(weights, values)

In [4]:
queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

In [5]:
# create the dot-product_attention object
attention = DotProductAttention()

In [7]:
attention_scores = attention(queries, keys, values, d_k)
print(attention_scores.shape)
print(attention_scores)

(64, 5, 64)
tf.Tensor(
[[[0.36580184 0.4946162  0.396619   ... 0.574539   0.3585462  0.6239938 ]
  [0.36968642 0.4919685  0.3965832  ... 0.5787249  0.3551683  0.6248512 ]
  [0.3598659  0.4799033  0.39663547 ... 0.5955414  0.3411876  0.60285866]
  [0.38125867 0.5037129  0.40927058 ... 0.58126813 0.34202945 0.6259441 ]
  [0.3876143  0.5160778  0.40697813 ... 0.559788   0.3605237  0.6431714 ]]

 [[0.3741589  0.5075558  0.5138777  ... 0.72573763 0.5537499  0.4150493 ]
  [0.39473665 0.5120391  0.5254037  ... 0.7333318  0.5548649  0.4134689 ]
  [0.37159356 0.49106467 0.5198252  ... 0.71670765 0.5640816  0.42884263]
  [0.39580414 0.5126434  0.5281597  ... 0.7302247  0.56063825 0.41995424]
  [0.37776154 0.50595164 0.51373225 ... 0.72351795 0.55618274 0.42122775]]

 [[0.540817   0.43442604 0.5778113  ... 0.40964317 0.49428675 0.34082535]
  [0.5585393  0.44039524 0.60080713 ... 0.40660763 0.47136644 0.33223078]
  [0.5562529  0.44410172 0.5933695  ... 0.41243148 0.48775625 0.33052337]
  [0.574377