In [1]:
import numpy as np

import tensorflow as tf
layers = tf.keras.layers

import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

# Math behind attention in Transformers

## Bahdanau attention

<br><br>

**Bahdanau attention** is also known as *concat* or *additive* attention. It has been introduced in the seminal paper [***Neural machine translation by jointly learning to align and translate***](https://arxiv.org/pdf/1409.0473.pdf) by Bahdanau et al., 2015.

<br><br><br>

<img src="https://cdn-images-1.medium.com/max/509/1*O-xKW4z-HWg1AC0vVFe3vg.png">

<br><br><br>

<img src="https://sigmoidal.io/wp-content/uploads/2020/09/attention-heatmap-2.png.webp" width=400>

<br><br><br>

To understand  **Bahdanau attention** let's start with understanding the decoder function:


$$ \Large p(y_j|y_0, ..., y_{j-1}, \mathbf{x}) = g(y_{j-1}, s_j, c_j)$$


<br><br><br><br><br><br><br><br><br><br><br><br>

**Let's unpack it!**

<br><br>

* $g()$ is some non-linear function (e.g. an **RNN**)

<br><br>

* $s_j$ is the decoder's **hidden state** at position $t$. It's a function of:


$$\Large s_j = g(s_{j-1}, y_{j - 1}, c_j)$$

<br><br>

...and $c_j$ is a **weighted sum** of **encoder hidden states**:


💎💎💎$$\Large c_j = \sum_i^{T_x}\alpha_{i, j}h_i$$:

<br><br>

...while $\alpha$s are computed using **softmax** over $e_{i,j}$s

$$\Large \alpha_{i, j} = \frac{\exp(e_{i,j})}{\sum_{i'}^{T_x} \exp(e_{i',j})}$$

<br><br>

...and $e_{i, j}$ is **alignment energy** defined as:

$$\Large e_{i, j} = NN(s_{j-1}, h_i)$$

<br><br>

where:

<br><br>

* $h_i$ is the encoder's **hidden state** for the $i^{th}$ **input token**.

<br><br>


"*The probability $\alpha_{i,j}$, or its associated energy $e_{i,j}$, reflects the importance of the annotation $h_i$ with respect to the previous hidden state $s_j-1$ in deciding the next state $s_i$ and generating $y_i$. Intuitively,
this implements a **mechanism of attention** in the decoder.*"

Bahdanau et al., 2015

<br><br><br><br><br><br>



## Transformers - scaled dot product self-attention

<br><br>

$$\Large Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$$

<br><br>

where:

* $Q$ is a **query** matrix $\in \mathbb{R}^{L_Q \times D}$

<br><br>

* $K$ is a **key** matrix $\in \mathbb{R}^{L_K \times D}$

<br><br>

* $D$ is the embedding dimensionality

<br><br>

* $L_X$ are sequence lengths (e.g. in the translation setting)

<br><br>

Let's stop here for a while and contemplate this sub-equation:

<br><br>

$$\Large W_A = softmax(QK^T)$$

where $W_A \in \mathbb{R}^{L_Q \times L_K}$

<br><br><br><br>

Now, let's add the $V$ matrix:

<br><br>

$$\Large Attention(Q, K, V) = softmax(QK^T)V$$

<br><br>

* $V$ is a **value** matrix $\in \mathbb{R}^{L_K \times D}$ 

...and is in fact pretty often the same matrix as $K$

<br><br> 

Let's try to make sense out of this:


<br><br>

<img src='content/att_2.jpg' width=700>

<br><br>

You can also think about attention as a **soft dictionary** object:


<br><br>

<img src="https://cdn-ak.f.st-hatena.com/images/fotolife/e/ey_nosukeru/20190622/20190622045649.png" width=600>

<br><br>
<br><br>
<br><br>
<br><br>

We're still missing one element. Let's get back to the full formula:


$$\Large Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$$

<br><br>


* $\sqrt{d_k}$ is the embedding dimensinality.

<br><br>
<br><br>

"*We suspect that for large values of $d_k$, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.*"

Vaswani et al., 2017

<br><br>
<br><br>


### Causal self-attention

In **causal self-attention** we mask the words before the current position. It's used in **generative models**.

$$\Large Attention(Q, K, V) = softmax(\frac{QK^T + M}{\sqrt{d_k}})V$$

<br><br>

where $M$ is a $L \times L$ mask matrix with $0$s on and below the diagonal and $- \infty$ above the diagonal.


<br><br>
<br><br>

### Going multihead

<br><br>

<img src='content/multi.jpg' width=400>

<br><br>

$$\Large Multihead(Q, K, V) = Concat(h_i, ..., h_m)W^O$$

<br><br>

where:


$$\Large h_i = Attention(QW^Q_i, KW^K_i, VW^V_i)$$