# 16 Implementing Multi-Head Attention in Keras

In [1]:
import numpy as np
from tensorflow import cast, float32, math, matmul, reshape, shape, transpose
from tensorflow.keras.backend import softmax
from tensorflow.keras.layers import Dense, Layer

random_seed = 42

# 16.1 Recap of Multi-Head Attention

Okay, so actually we left out some pretty important details in Ch. 15 but we'll go over them now.  

First off, within each attention block there are actually multiple attention mechanisms ("heads") working in parallel (literally, the input is fed to them in parallel). This, in theory, allows the model to pay various "kinds" of attention. In the NLP context, you could think of this as allowing the model to extract various "aspects" or "qualities" (e.g. temporal, gender, cardinality, et cetera) of words in the sequence from the word embeddings during training.  

Secondly, there are multiple "linear projection matrices". There is one per attention head for each of Q, K and V. Essentially these are trainable weight matrices for queries, keys and values that generate different subspace representations of them. Each attention head then works on of these projected versions of Q, K and V. There is also one right at the end which produces a projection of the concatenated outputs of all the different heads. Once again, these weights are learned during training. (You can think of each as a Dense/FF layer.  

Did you catch that?! The outputs of the various scaled dot product attention heads is joined via a concatenation operation. That is the third important detail.  

By the way, in the AIAYN transformer they had 8 attention heads. And one more thing that we only mentioned in passing is that the "encoder block" and "decoder block" are actually _stacks_ of architecturally identical blocks. In the AIAYN paper they had 6 of them. I guess we'll get to that eventually, when we code up the entire transformer.  

**Note:** There is nothing magical about the aforementioned numbers (`6` and `8`).

## 16.2 Implementing Multi-Head Attention from Scratch

First, let us import our scaled dot-product attention layer from the previous chapter.

In [2]:
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, queries, keys, values, mask=None):
        d_k = shape(keys)[-1]
        # Score the queries against the keys after transposing the latter, and then scale
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(
            cast(d_k, float32)
        )
        # Apply mask to the attention scores
        if mask is not None:
            scores += float("-inf") * mask
        # Compute the weights using a softmax operation
        weights = softmax(scores)
        # Compute attention by a weighted sum of the value vectors
        return matmul(weights, values)

Now we proceed to define our Multi-Head Attention layer. Things are about to get very, VERY messy and confusing. Part 3 of the "Transformers Explained Visually" series on Towards Data Science ([here](https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853)) was extremely helpful in understanding what is about to come. (A word of warning: Although the images on that page help immensely in clarifying the logic, the visual representations of dimensions, i.e. the _visual_ height/width/depth of lines, are _not_ to be taken too literally. Focus on the dimension labels instead).  

Here's what's about to happen: The outputs of the linear layers that produce the Q, K and V matrices, that is to say the Q, K and V matrices themselves, are going to be "split" between the different attention heads. But this is _not_ a "physical" split. It is a "logical" one. That is to say, each attention head is going to process logically separate sections of the _same, single_ Q (or K or V) matrix. So, in effect, all attention heads share the same linear layer, but operate on their "own" logical section of each data matrix. This is just so that the computations of all attention heads can be performed in a single matrix operation rather than N separate operations (vectorization/parallelization ftw). This keeps the model simple (due to fewer linear layers being needed) while achieving the power of independent attention heads.  

Let's forget about the `batch_size` dimension for now (but keep it in the background of our minds!) and focus on one example input/target sequence for simplicity. The single example embedded sequence comes in to the linear layer with dimensions $(L_{seq} \times d_{model})$, gets matrix-multiplied by the $(d_{model} \times d_{model})$ `W_q`, `W_k` and `W_v` matrices to yield the Q, K and V matrices, still of dimensions $(L_{seq} \times d_{model})$. Then these get reshaped. How? Let's focus on the query matrix Q (the other two follow an identical logic). If we have $h$ heads, then let $s = d_{model} \div h$. We will first reshape our Q to have dimensions $(L_{seq} \times h \times s)$, then reshape it again to have dimensions $(h \times L_{seq} \times s)$.  
Again, it doesn't matter why they are "physically" split like this. What matters are the _logical_ splits, and those are visualized quite well in the article linked to earlier.  

I told you this was going to get messy! But fortunately we don't have to keep track of everything as `tf.reshape()` will take care of the grunt work for us.  

In short: 
- We need to reshape the linearly projected queries, keys, and values so that attention heads can work in parallel.
- Queries, keys and values come in with dimensions `(batch_size, seq_length, d_model)`.
- They are then linearly projected to dimensions `(batch_size, seq_length, d_?)` where `?` is `q`, `k` or `v`.
- Then they are rearranged to have dimensions `(batch_size, n_heads, seq_length, depth)` using the helper method `reshape_tensor()`. (Note: `depth` is the same thing as $s$ above and `n_heads` is just $h$).  
  But this is done in two steps:
  - First they're reshaped to dimensions `(batch_size, seq_length, n_heads, depth)`.
  - Then the second and third dimensions are transposed.  
  
With all of that said and done, `d_k` and `d_v` below will both equal `d_model / h`. To be honest, I'm not sure why they are allowed to have static, independent values in the code from the book I'm following. The book doesn't have the clearest explanations (hence all the other resources I find, study and link to). I guess they were trying to conserve generality, and to be fair, the AIAYN paper does the same, but both fail to explain how this can be useful. In any case, going forward I will modify their code and impose the AIAYN implementation where $d_k = d_v = d_{model} / h$.  

**Notes**:
1. Notice how the number of parameters for multi-head attention is the same as the number of parameters in the equivalent single-head attention. The parameters are merely divided between heads.  
2. We could have split the _input_ matrices (i.e. the queries, keys and values themselves) between the attention heads _before_ linearly projecting them to get Q, K and V. We would have achieved the same result. But the way we've done it is more streamlined.  
3. The `reshape_tensor()` method below also has a 'flag' argument that allows us to undo (revert) the operation. This is useful for "stitching" (concatenating) the outputs of all the attention heads back together.

In [3]:
class MultiHeadAttention(Layer):
    def __init__(self, n_heads, d_model, **kwargs):
        super().__init__(**kwargs)

        assert d_model % h == 0
        d_k = d_model // h
        # We assume d_v always equals d_k
        d_v = d_k

        self.attention = DotProductAttention()  # Scaled dot product attention
        self.n_heads = n_heads  # Number of attention heads
        self.d_k = d_k  # Dim of the linearly projected queries and keys
        self.d_v = d_v  # Dim of the linearly projected values
        self.W_q = Dense(d_k)  # Learned projection matrix for the queries, ...
        self.W_k = Dense(d_k)  # ... for the keys
        self.W_v = Dense(d_v)  # ... for the values
        self.W_o = Dense(d_model)  # ... for the multi-head output

    def reshape_tensor(self, x, n_heads, do_split_flag):
        if do_split_flag:
            # Tensor shape after reshaping and transposing:
            # (batch_size, n_heads, seq_length, -1)
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], n_heads, -1))
            x = transpose(x, perm=(0, 2, 1, 3))
        else:
            # Reverting the reshaping and transposing operations:
            # (batch_size, seq_length, d_model)
            x = transpose(x, perm=(0, 2, 1, 3))
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], -1))
        return x

    def call(self, queries, keys, values, mask=None):
        # Rearrange the queries to be able to compute all heads in parallel
        q_reshaped = self.reshape_tensor(self.W_q(queries), self.n_heads, True)
        # Resulting tensor shape: (batch_size, n_heads, input_seq_length, -1)

        # Rearrange the keys to be able to compute all heads in parallel
        k_reshaped = self.reshape_tensor(self.W_k(keys), self.n_heads, True)
        # Resulting tensor shape: (batch_size, n_heads, input_seq_length, -1)

        # Rearrange the values to be able to compute all heads in parallel
        v_reshaped = self.reshape_tensor(self.W_v(values), self.n_heads, True)
        # Resulting tensor shape: (batch_size, n_heads, input_seq_length, -1)

        # Compute the multi-head attention output using the reshaped queries, keys,
        # and values
        o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, mask)
        # Resulting tensor shape: (batch_size, n_heads, input_seq_length, -1)

        # Rearrange back the output into concatenated form
        output = self.reshape_tensor(o_reshaped, self.n_heads, False)
        # Resulting tensor shape: (batch_size, input_seq_length, d_v)

        # Apply one final linear projection to the output to generate the multi-head
        # attention. Resulting tensor shape: (batch_size, input_seq_length, d_model)
        return self.W_o(output)

## 16.3 Testing Out the Code

In [25]:
input_seq_length = 5  # Maximum length of the input sequence
h = 8  # Number of self-attention heads
d_model = 512  # Dimensionality of the model (the input embeddings, as well as all its sub-layers' outputs)
batch_size = (
    64  # Batch size from the training process; a training hyperparameter
)

rng = np.random.default_rng(random_seed)

queries = rng.random((batch_size, input_seq_length, d_model))
keys = rng.random((batch_size, input_seq_length, d_model))
values = rng.random((batch_size, input_seq_length, d_model))

multihead_attention = MultiHeadAttention(h, d_model)
output = multihead_attention(queries, keys, values)

print(output)

tf.Tensor(
[[[ 8.14943686e-02 -2.07826361e-01 -1.29091278e-01 ... -5.07743597e-01
    2.35847443e-01 -2.45192349e-01]
  [ 1.06212638e-01 -2.19399706e-01 -1.43369183e-01 ... -5.07689774e-01
    2.30445623e-01 -2.42180824e-01]
  [ 8.59445482e-02 -2.10692257e-01 -1.42490596e-01 ... -5.08839250e-01
    2.28606790e-01 -2.38618150e-01]
  [ 1.00529008e-01 -2.03786239e-01 -1.13968298e-01 ... -5.02856731e-01
    2.35798791e-01 -2.50318110e-01]
  [ 8.30031931e-02 -2.15402111e-01 -1.52123138e-01 ... -5.22031546e-01
    2.31339246e-01 -2.45917723e-01]]

 [[ 6.23010769e-02 -8.63183886e-02 -1.23630837e-01 ... -5.23672342e-01
    1.13774933e-01 -5.15483208e-02]
  [ 5.61044067e-02 -8.52782354e-02 -1.30698130e-01 ... -5.21288097e-01
    1.23439729e-01 -5.24104871e-02]
  [ 4.14301269e-02 -7.81833306e-02 -1.33477762e-01 ... -5.34742653e-01
    1.25398427e-01 -5.44013903e-02]
  [ 3.21061462e-02 -8.77115801e-02 -1.62450403e-01 ... -5.03533721e-01
    1.33375034e-01 -7.47317523e-02]
  [ 5.66560365e-02 -8.22