# Scaled Dot Product Attention

> Exploration of the scaled dot product attention.

The main idea of the scaled dot product attention is to take a set of queries ($Q$) and compare them with a set of keys ($K$) by performing the *Dot product* between them. The concept is like taking a query, $q_i$, and comparing it with all the keys, $k_j$, to find the most similar key. This could be written as follows:
$$
A_i = \sum_j q_i k_j
$$

where $A_i$ would be 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange, repeat

Define a sample $q_i$:

In [None]:
q_i = np.linspace(1,6, num=6)
q_i

And a set of keys:

In [None]:
k_1 = np.array([1,0,-1,2,7,4])
k_2 = np.array([1,2,3,4,7,6])

k_1, k_2

Now we can calculate how close is $q_i$ to $k_1$ and $k_2$. We would expect it to be much closer to $k_2$ than to $k_1$:

> We're going to use the loop definition of the dot product to be more consistent with the math and then we'll swap to vector and matrices multiplications.

In [None]:
def dot_prod(v1, v2): return np.sum(v1*v2)

In [None]:
dot_prod(q_i, k_1), dot_prod(q_i, k_2)

By putting together the whole set of keys into a matrix we can build the matrix $K$, where each row corresponds to a particular key:

In [None]:
K = np.concatenate([k_1[None,:], k_2[None,:]], axis=0)
K

This allows us to calculate the similarity of $q_i$ to all the $k_j$ at the same time by employing matrix multiplications:

In [None]:
# We're adding an empty dim so that numpy treats q_i as a row vector
A_i = q_i[None,:] @ K.T
A_i, A_i.shape

The same operation can be expressed using Einstein summation:

In [None]:
np.einsum("i,ji", q_i, K)

Adding to this, if we stack a set of queries into a matrix $Q$ (as we did with the keys), we can calculate the simmilarity of a set of queries with a set of keys in a single operation to obtain the attention matrix, $A$:

In [None]:
q_1 = np.linspace(1,6, num=6)
q_2 = np.random.randint(-1, 7, size=6)
Q = np.concatenate([q_1[None,:], q_2[None,:]], axis=0)
Q

In [None]:
A = Q @ K.T
A

In this matrix, the element $A_{ij}$ represents how similar is $q_i$ to $k_j$.

After obtaining this matrix, we have to weight the values, $v_i$, like $O = \sum_j v_i A_{ij}$. Before doing so, to be able to interpretate the rows of the matrix $A$ as weights, we'll take the softmax row-wise. When our vectors have high dimensionality, the softmax can push the gradients to be very low. To avoid this effect, we'll divide $A$ by $\sqrt{d_k}$ to **scale**.

In [None]:
A = A / np.sqrt(6)
A

In [None]:
A = A / A.max(axis=1)[:,None]
A

Define some value vectors:

In [None]:
v_1 = np.random.randint(-1, 7, size=6)
v_2 = np.random.randint(-1, 7, size=6)
V = np.concatenate([v_1[None,:], v_2[None,:]], axis=0)
V

And finally calculate the output:

In [None]:
output = A @ V
output.shape

In [None]:
output

This output matrix can be understood as a new $V$ matrix whose rows are obtained as a weighted sum of the previous $v_i$, where the weight depends on how similar the query $q_i$ is to all the other keys, $k_j$.

## Implementing it as a *Keras* `Layer`

> Now that we kind of understand how scaled dot product attention works, we can build a *Keras* layer to introduce it in our models.

In [None]:
#| hide
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [None]:
#| export
import tensorflow as tf
from tensorflow.keras import layers

In [None]:
#| export
class ScaledDotProductSA(layers.Layer):
    """Scaled dot product self-attention layer."""

    def __init__(self,
                 dk, # Dim of the queries and keys.
                 dv, # Dim of the values.
                 return_attn=False, # Wether to return te attention matrix or not.
                 **kwargs,
                 ):
        super(ScaledDotProductSA, self).__init__(**kwargs)
        self.dk = dk
        self.dv = dv
        self.return_attn = return_attn
    
    def build(self,
              input_shape,
              ):
        self.Q = layers.Dense(self.dk)
        self.K = layers.Dense(self.dk)
        self.V = layers.Dense(self.dv)
    
    def call(self,
             inputs,
             **kwargs,
             ):
        ## 1. Project the input sequence into Q, K and V.
        Q, K, V = self.Q(inputs), self.K(inputs), self.V(inputs)

        ## 2. Dot product between queries and keys.
        A = Q @ K.T

        ## 3. Scale A and apply softmax row-wise.
        A = tf.nn.softmax(A/tf.math.sqrt(self.dk), axis=-1)

        ## 4. Use A to weight the values V.
        output = A @ V

        if self.return_attn: return A, output
        else: return output