# II/ Seq2seq with attention

## A) Attention in general

> - Source:
>   - [Transformers from scratch - Peter Bloem](https://peterbloem.nl/blog/transformers)

- Before talking about Bahdanau attention, let's understand what do people mean by attention. 
- Attention can be better explained through the lens of movie recommendation.
- Let’s say you run a movie rental business and you have some movies, and some users, and you would like to recommend movies to your users that they are likely to enjoy.
- One way to go about this, is to:
    - create manual features for your movies, such as how much romance there is in the movie, and how much action
    - design corresponding features for your users: how much they enjoy romantic movies and how much they enjoy action-based movies. 
- If you did this, the dot product between the two feature vectors would give you a score for how well the attributes of the movie match what the user enjoys.
<p align="center"> <img src="./assets/dot_product.svg" height="500" width="1100" /></p> 

- If for example:
    - the user enjoys romance and the movie has a lot of romance, then the dot product for that feature will be positive.
    - the user hates romance and the movie has a lot of romance, then the dot product for that feature will be negative.
- This is the basic intuition behind attention. The dot product helps us to represent objects and their relations by expressing how related two vectors in the input are. The output vectors are just a weighted sums over the whole input, with the weights determined by these dot products.
- How is dot product expressed in neural networks ? Through the use of matrix multiplication which is just a vectorized dot product !
- However, there is a problem as matrix multiplication do not normalized the input ! As such, if we compute the similarity between `A` and `A.T`, we won't have a score of 1.0 in the diagonal as we would expect (because the similarity between oneself should be maximal).

In [4]:
import numpy as np

np.random.seed(42)
np.set_printoptions(precision=3)

A = np.array([
    [0.375, 0.951, 0.732, 0.599, 0.156, 0.156],
    [0.058, 0.866, 0.601, 0.708, 0.021, 0.97 ],
    [0.832, 0.212, 0.182, 0.183, 0.304, 0.525],
    [0.432, 0.291, 0.612, 0.139, 0.292, 0.366],
    [0.456, 0.785, 0.2,   0.514, 0.592, 0.046],
    [0.608, 0.171, 0.065, 0.949, 0.966, 0.808]
])

print(f"A = \n{A}")
print("--------------------")
# This means that when computing norm of A/N, it will be equal to 1
n = np.linalg.norm(A, ord=2, axis=0)
B = A / n

print(f"norm of A = {n}")
print(f"Normalized A = \n{B}")
print("--------------------")

# If we compute the norm on axis=0 (columns) => features are on each column 
# => Transpose B to do matmul on the first feature 
print(f"Normalized dot product: \n{B.T @ B}")
# They are all in the diagonal because they are normalized
print(f"Indices of maximum value = {np.argmax(B.T @ B, axis=1)}")

print("--------------------")
print(f"Unormaliazed Dot product: \n{A.T @ A}")
print(f"Indices of maximum value = {np.argmax(A.T @ A, axis=1)}")

np.set_printoptions()

A = 
[[0.375 0.951 0.732 0.599 0.156 0.156]
 [0.058 0.866 0.601 0.708 0.021 0.97 ]
 [0.832 0.212 0.182 0.183 0.304 0.525]
 [0.432 0.291 0.612 0.139 0.292 0.366]
 [0.456 0.785 0.2   0.514 0.592 0.046]
 [0.608 0.171 0.065 0.949 0.966 0.808]]
--------------------
norm of A = [1.265 1.559 1.161 1.441 1.219 1.425]
Normalized A = 
[[0.296 0.61  0.63  0.416 0.128 0.109]
 [0.046 0.556 0.517 0.491 0.017 0.681]
 [0.658 0.136 0.157 0.127 0.249 0.368]
 [0.341 0.187 0.527 0.096 0.24  0.257]
 [0.36  0.504 0.172 0.357 0.486 0.032]
 [0.481 0.11  0.056 0.658 0.792 0.567]]
--------------------
Normalized dot product: 
[[1.    0.594 0.583 0.707 0.84  0.678]
 [0.594 1.    0.885 0.814 0.498 0.622]
 [0.583 0.885 1.    0.685 0.383 0.652]
 [0.707 0.814 0.685 1.    0.811 0.836]
 [0.84  0.498 0.383 0.811 1.    0.644]
 [0.678 0.622 0.652 0.836 0.644 1.   ]]
Indices of maximum value = [0 1 2 3 4 5]
--------------------
Unormaliazed Dot product: 
[[1.6   1.171 0.856 1.289 1.296 1.222]
 [1.171 2.429 1.601 1.828 0.9

- As we can see, the matrix multiplication is not properly reflecting the notion of “similarity”. One reason could be that matrix multiplication can be easily parallelized, engineers may have favor speed instead of “similarity precision” ? (maybe normalizing gives extra overhead ?)

## B) Bahdanau attention

<p align="center"> <img src="./assets/lily-bahdanau.png" height="500" width="900" /></p> 

- Attention mechanism (Bahdanau):
    - **Goal**: born to help memorize long source sentences in neural machine translation (NMT).
    - **Structure**: At different steps, let a model "focus" on different parts of the input. At each decoder step, it decides which source parts are more important. In this setting, the encoder does not have to compress the whole source into a single vector - it gives representations for all source tokens (for example, all RNN states instead of the last one).
- The whole process looks like this:
    - **Decoder `Hidden layer Nth`**:
        - Init hidden state with last encoder output
        - Compute **attention score**: use all encoder hidden states and decoder `hidden layer 1` state
        - Compute **attention weights**: apply softmax to attention score
        - Compute **attention output**: weighted sum between attention weights and all encoder states
        - Pass **attention output** and **`decoder hidden state Nth`** to compute get **`decoder hidden state Nth+1`** (i.e `self.lstm(attention_output, hidden_nth)`)
    <p align="center"> <img src="./assets/bahdanau.png" height="500" width="900" /></p>
- So we can see that Bahdanau computes the score through a 1 single layer feed forward neural network 
- Bahdanau attention (also known as additive attention or concat attention) is defined as [follow](https://paperswithcode.com/method/additive-attention): $f_{att}\left(\textbf{h}_{i}, \textbf{s}_{j}\right) = w_{a}^{T}\tanh\left(\textbf{W}_{a}\left[\textbf{h}_{i};\textbf{s}_{j}\right]\right)$ (1)
- Sometimes we also see written as sum: $f_{att}\left(\textbf{h}_{i}, \textbf{s}_{j}\right) = w_{a}^{T}\tanh\left(\textbf{W}_{a}\textbf{h}_{i} + \textbf{U}_{a}\textbf{s}_{j}\right)$ (2)
- This is because the projection (matmul) of 2 concatenated vectors <=> the sum of the projections of respective vectors ! ([source](https://stats.stackexchange.com/a/524729))
    > - Note: the $\textbf{W}_{a}$ in eq (1) and (2) are differents, it should be better to rewrite (2) as $f_{att}\left(\textbf{h}_{i}, \textbf{s}_{j}\right) = w_{a}^{T}\tanh\left(\textbf{T}_{a}\textbf{h}_{i} + \textbf{B}_{a}\textbf{s}_{j}\right)$ with $\textbf{T}$ being the "Top part" and $\textbf{B}$, the "Bottom part" of the same $\textbf{W}$
    > <p align="center"> <img src="./assets/concat-add-bahdanau.png" height="500" width="900" /></p>
    > - That's why they have different names (additive or concat attention)! 


In [6]:
# TODO: code

## C) CNN + attention

In [7]:
# TODO: https://towardsdatascience.com/transformers-141e32e69591

In [1]:
# TODO: https://github.com/bentrevett/pytorch-seq2seq/blob/master/5%20-%20Convolutional%20Sequence%20to%20Sequence%20Learning.ipynb

## D) Towards an unified model of attention

- TODO:
    - [How to truly understand attention mechanism in transformers?]( https://www.reddit.com/r/MachineLearning/comments/qidpqx/d_how_to_truly_understand_attention_mechanism_in/)
    - https://johnthickstun.com/docs/transformers.pdf

> - Source:
>   - [A review on the attention mechanism of deep learning](https://www.sciencedirect.com/science/article/abs/pii/S092523122100477X)

- The way we compute the Bahdanau atttention can be be reframed in a more unified way. Researchers call it the "unified model of attention" which is divided into 2 steps:
    - Compute attention distribution on input information **(Green)**
    - Compute context vector using attention distribution **(Red)**
    <p align="center"> <img src="./assets/unified_model.png" height="500" width="1100" /></p>

#### Computing the attention distribution **(Green)**

- Here is the process to compute attention distribution:
    - **Key `(K)`**: encoding of input information. Can take different forms:
        - Certain area of an image
        - Word embedding of a document
        - Hidden states of RNN (cf Bahdanau attention) 
    - **Query `(Q)`**: Correspond to sentence representation of the immediate token history. In Bahdanau attention, this corresponds to previous decoder hidden state (`hidden_next = self.lstm(attention_output, hidden_prev`). Can take differents forms:
        - Matrix
        - 2 vectors
    - **Score function `(e)`**: defines how **Queries** and **Keyrs** are matched/combined together to better reflect the importance of **Queries** with respect to **Keys** in deciding the next output: `e = f(Q, K)`
        <p align="center"> <img src="./assets/all_attention.png" height="500" width="1000" /></p>
        
        > - 2 most commonly used attention is:
        >   - additive attention
        >   - multiplicative (dot-product) attention (because it's less expensive)
    - **Attention output**: normalizes scores to a probability distribution using softmax

#### Compute context vector using attention distribution **(Red)**

- Here is the process to compute context vector using attention distribution:
    - **Value (V)**: In many architectures, **Keys** and **Values** are the same representation of input data.
    - **Context vector (C):** $C = \phi(\text{attention\_output}, \text{values})$. Usually, $\phi$ is a weighted sum.

---
- We know that in Transformer attention, there is a notion of `Q K V` (also known as unified model). We can rewrite the bahdanau attention with the unified model notation: 
    - $f_{att}\left(\textbf{h}_{i}, \textbf{s}_{j}\right) = f_{att}\left(\textbf{K}, \textbf{Q}\right) = w_{a}^{T}\tanh\left(\textbf{W}_{a}\left[\textbf{K};\textbf{Q}\right]\right)$ (3) 
    - where `Q=s` is last decoder hidden state and `K=h`, the all encoder hidden state  ([source](https://d2l.ai/chapter_attention-mechanisms-and-transformers/bahdanau-attention.html#model)). 
    > - In Bahdanau, `V` = `K`

----
- Summary: <p align="center"> <img src="assets/part2-summary.png" height="400" width="700" /></p> 