# 15 Implementing Scaled Dot-Product Attention in Keras

In [4]:
from keras.backend import softmax
from tensorflow import cast, float32, math, matmul
from tensorflow.keras.layers import Layer
from numpy import random

## 15.1 Recap of the Transformer Architecture

The encoder and the decoder share much of their architecture. At the heart of their numerous, stacked, multi-head attention blocks is the *scaled dot-product attention* mechanism.
In the multi-head attention block of the encoder, the query, key and value vectors (which form the Query, Key and Value matrices once concatenated) are simply the encoded and embedded (see ch. 14) input sequence. Similarly, on the decoder side, the first attention block gets the encoded/embedded _target_ sequence in the form of query, key and value vectors. However, the _second_ attention block receives the final output of the encoder block for its keys and values but uses the [normalized] output of its own first attention block as its queries. (The latter can be thought of as the decoder output from the "previous time step", but do keep in mind that there is no recurrence here and everything is fed to the model all at once). 
We will denote the dimensionality of queries and keys with $d_k$ and that of values with $d_v$.
First we calculate the matrix multiplication of $Q$ and $K^T$ (which is equivalent to calculating the dot products of query and key _vectors_). Then we scale the result by the square root of $d_k$ to get the _attention scores_. We feed the result to the $softmax$ function to get _attention weights_. And finally, we scale the the value vectors by matrix-multiplying the result with $V$.
$$attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

A "mask" can optionally be applied to the attention scores before they are fed to the $softmax$ function. Here are two conceivable applications for this:  
- A "look-ahead mask" (as in the first attention block of the decoder) can prevent the model from, you guessed it, "looking ahead" and attending to succeeding tokens in the target sequence. ("Succeeding" in the sense that it has not yet reached them and output(ted) a prediction for those positions in the target sequence).
- A "padding mask" can prevent the padding (often zero) tokens from being processed along with meaningful tokens both in the encoder and decoder stages.
Masking works by replacing the attention scores to be masked with $-\infty$ so that $softmax$ will result in zeros for those positions.

## 15.2 Implementing the Scaled Dot-Product Attention from Scratch

In [7]:
# Implementing the Scaled Dot Product Attention
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        # Score the queries against the keys after transposing the latter, and then scale
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))
        # Apply mask to the attention scores
        if mask is not None:
            scores += float("-inf") * mask
        # Compute the weights using a softmax operation
        weights = softmax(scores)
        # Compute attention by a weighted sum of the value vectors
        return matmul(weights, values)

## 15.3 Testing Out the Code

In [30]:
d_k = 64 # Dimensionality of the linearly projected queries and keys
d_v = 64 # Dimensionality of the linearly projected values
batch_size = 64 # Batch size from the training process

# Dummy data follows...
# In reality, these would be obtained from the tokenized and then embedded sequences.
input_seq_length = 5 # Maximum length of the input sequence
queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

attention = DotProductAttention()
print(attention(queries, keys, values, d_k))

tf.Tensor(
[[[0.37687194 0.42425776 0.40963838 ... 0.5839578  0.32307482 0.7153933 ]
  [0.39388978 0.3984782  0.42378533 ... 0.5752644  0.31323692 0.7163614 ]
  [0.37789828 0.41287902 0.41379955 ... 0.5935992  0.31470624 0.72527605]
  [0.37479573 0.4177801  0.42555138 ... 0.5980845  0.31870538 0.7384004 ]
  [0.38047948 0.42033958 0.43337804 ... 0.5881334  0.31998193 0.7339838 ]]

 [[0.66044503 0.47343627 0.3590861  ... 0.61529505 0.8354962  0.47700217]
  [0.65794665 0.49121913 0.34872597 ... 0.6100477  0.8346374  0.4796377 ]
  [0.6600553  0.48846138 0.324326   ... 0.59198326 0.838455   0.46564752]
  [0.66188097 0.47298568 0.35112163 ... 0.6091814  0.83281344 0.48945424]
  [0.6658252  0.46457112 0.32625848 ... 0.59076834 0.83756536 0.47890344]]

 [[0.6367261  0.7159094  0.6269977  ... 0.3774765  0.2877917  0.56857604]
  [0.6354828  0.7148559  0.62548065 ... 0.38535112 0.2851437  0.5614736 ]
  [0.64709723 0.7224838  0.60133654 ... 0.39898255 0.28389564 0.57082427]
  [0.63750595 0.7234973

**Note:** The output shape is `(batch size, sequence length, dim_values)`.