# II/ Seq2seq with attention

## A) Attention in general

> - Source:
>   - [Transformers from scratch - Peter Bloem](https://peterbloem.nl/blog/transformers)

- Before talking about Bahdanau attention, let's understand what do people mean by attention. 
- Attention can be better explained through the lens of movie recommendation.
- Let’s say you run a movie rental business and you have some movies, and some users, and you would like to recommend movies to your users that they are likely to enjoy.
- One way to go about this, is to:
    - create manual features for your movies: how much romance there is in the movie, and how much action
    - create manual features for your users: how much they enjoy romantic movies and how much they enjoy action-based movies. 
- If you did this, the dot product between the two feature vectors would give you a score for how well the attributes of the movie match what the user enjoys.
<p align="center"> <img src="./assets/dot_product.svg" height="500" width="1100" /></p> 

- If for example:
    - the user enjoys romance and the movie has a lot of romance, then the dot product for that feature will be positive.
    - the user hates romance and the movie has a lot of romance, then the dot product for that feature will be negative.
- This is the basic intuition behind attention. The dot product helps us to represent relations between objects by expressing how related two vectors are.
- How is dot product expressed in neural networks ? Through the use of matrix multiplication which is just a vectorized dot product !
- However, there is a problem as matrix multiplication do not normalized the input ! As such, if we compute the similarity between `A` and `A.T`, we won't have a score of 1.0 in the diagonal as we would expect (because the similarity between oneself should be maximal).

In [4]:
import numpy as np

np.random.seed(42)
np.set_printoptions(precision=3)

A = np.array([
    [0.375, 0.951, 0.732, 0.599, 0.156, 0.156],
    [0.058, 0.866, 0.601, 0.708, 0.021, 0.97 ],
    [0.832, 0.212, 0.182, 0.183, 0.304, 0.525],
    [0.432, 0.291, 0.612, 0.139, 0.292, 0.366],
    [0.456, 0.785, 0.2,   0.514, 0.592, 0.046],
    [0.608, 0.171, 0.065, 0.949, 0.966, 0.808]
])

print(f"A = \n{A}")
print("--------------------")
# This means that when computing norm of A/N, it will be equal to 1
n = np.linalg.norm(A, ord=2, axis=0)
B = A / n

print(f"norm of A = {n}")
print(f"Normalized A = \n{B}")
print("--------------------")

# If we compute the norm on axis=0 (columns) => features are on each column 
# => Transpose B to do matmul on the first feature 
print(f"Normalized dot product: \n{B.T @ B}")
# They are all in the diagonal because they are normalized
print(f"Indices of maximum value = {np.argmax(B.T @ B, axis=1)}")

print("--------------------")
print(f"Unormaliazed Dot product: \n{A.T @ A}")
print(f"Indices of maximum value = {np.argmax(A.T @ A, axis=1)}")

np.set_printoptions()

A = 
[[0.375 0.951 0.732 0.599 0.156 0.156]
 [0.058 0.866 0.601 0.708 0.021 0.97 ]
 [0.832 0.212 0.182 0.183 0.304 0.525]
 [0.432 0.291 0.612 0.139 0.292 0.366]
 [0.456 0.785 0.2   0.514 0.592 0.046]
 [0.608 0.171 0.065 0.949 0.966 0.808]]
--------------------
norm of A = [1.265 1.559 1.161 1.441 1.219 1.425]
Normalized A = 
[[0.296 0.61  0.63  0.416 0.128 0.109]
 [0.046 0.556 0.517 0.491 0.017 0.681]
 [0.658 0.136 0.157 0.127 0.249 0.368]
 [0.341 0.187 0.527 0.096 0.24  0.257]
 [0.36  0.504 0.172 0.357 0.486 0.032]
 [0.481 0.11  0.056 0.658 0.792 0.567]]
--------------------
Normalized dot product: 
[[1.    0.594 0.583 0.707 0.84  0.678]
 [0.594 1.    0.885 0.814 0.498 0.622]
 [0.583 0.885 1.    0.685 0.383 0.652]
 [0.707 0.814 0.685 1.    0.811 0.836]
 [0.84  0.498 0.383 0.811 1.    0.644]
 [0.678 0.622 0.652 0.836 0.644 1.   ]]
Indices of maximum value = [0 1 2 3 4 5]
--------------------
Unormaliazed Dot product: 
[[1.6   1.171 0.856 1.289 1.296 1.222]
 [1.171 2.429 1.601 1.828 0.9

- As we can see, the matrix multiplication is not properly reflecting the notion of “similarity”. One reason could be that matrix multiplication can be easily parallelized, engineers may have favor speed instead of “similarity precision” ? (maybe normalizing gives extra overhead ?)

## B) Bahdanau attention

<p align="center"> <img src="./assets/lily-bahdanau.png" height="500" width="900" /></p> 

- Attention mechanism (Bahdanau):
    - **Goal**: help memorize long source sentences in neural machine translation (NMT).
    - **Structure**: At different steps, let a model "focus" on different parts of the input. At each decoder step, it decides which source parts are more important. In this setting, the encoder does not have to compress the whole source into a single vector - it gives representations for all source tokens (for example, all RNN states instead of the last one).
- The whole process looks like this:
    - **Decoder `Hidden layer Nth`**:
        - Init hidden state with last encoder output
        - Compute **attention score**: use all encoder hidden states and decoder `hidden layer 1` state
        - Compute **attention weights**: apply softmax to attention score
        - Compute **attention output**: weighted sum between attention weights and all encoder states
        - Pass **attention output** and **`decoder hidden state Nth`** to compute get **`decoder hidden state Nth+1`** (i.e `self.lstm(attention_output, hidden_nth)`)
    <p align="center"> <img src="./assets/bahdanau.png" height="500" width="900" /></p>
- So we can see that Bahdanau computes the score through a 1 single layer feed forward neural network 


In [7]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from attention_utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    pairs = pairs[:100]
    print("Sampled %s sentence pairs (for faster training)" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)

    dataset = {}
    dataset["input_lang"] = input_lang
    dataset["output_lang"] = output_lang
    dataset["pairs"] = pairs
    return dataset

dataset = prepareData('eng', 'fra', True)
print(random.choice(dataset["pairs"]))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Reading lines...
Read 135842 sentence pairs
Trimmed to 10599 sentence pairs
Sampled 100 sentence pairs (for faster training)
Counting words...
Counted words:
fra 92
eng 62
['je suis certain .', 'i am sure .']


In [2]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        # https://stackoverflow.com/a/48305882/8623609
        # encoder_output: give you the hidden layer outputs of the network for each time-step, but only for the final layer ("top")
        # encoder_hidden: give you the hidden layer outputs of the network for the last time-step only, but for all layers ("last right column")
        last_layer_encoder_hidden_states, last_time_step_encoder_hidden_states = self.gru(embedded, hidden)
        return last_layer_encoder_hidden_states, last_time_step_encoder_hidden_states

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

class DecoderAttentionRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(DecoderAttentionRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        
        self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
        self.attn_proj = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)        
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, decoder_input, decoder_hidden, last_layer_encoder_hidden_states_foreach_input):
        embedded = self.embedding(decoder_input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        x = torch.tanh(self.fc_hidden(decoder_hidden)+self.fc_encoder(last_layer_encoder_hidden_states_foreach_input))
        alignment_scores = x.bmm(self.weight.unsqueeze(2))
        attn_weights = F.softmax(alignment_scores.squeeze(2), dim=1)
        context_vector = torch.bmm(attn_weights.unsqueeze(0), last_layer_encoder_hidden_states_foreach_input.unsqueeze(0))

        output = torch.cat((embedded, context_vector), -1).squeeze(0)
        output = self.attn_proj(output).unsqueeze(0)
        output = F.relu(output)
        last_layer_decoder_hidden_states, last_time_step_decoder_hidden_states = self.gru(output, decoder_hidden)
        last_layer_decoder_hidden_states = F.log_softmax(self.out(last_layer_decoder_hidden_states[0]), dim=1)
        return last_layer_decoder_hidden_states, last_time_step_decoder_hidden_states, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [3]:
teacher_forcing_ratio = 0.5

def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    last_layer_encoder_hidden_states_foreach_input = torch.zeros(max_length, encoder.hidden_size, device=device)
    last_time_step_encoder_hidden_states = encoder.initHidden()
    loss = 0

    # https://stackoverflow.com/a/48305882/8623609
    # encoder_output: give you the hidden layer outputs of the network for each time-step, but only for the final layer ("top")
    # encoder_hidden: give you the hidden layer outputs of the network for the last time-step only, but for all layers ("last right column")
    for i in range(input_length):
        last_layer_encoder_hidden_states, last_time_step_encoder_hidden_states = encoder(input_tensor[i], last_time_step_encoder_hidden_states)
        last_layer_encoder_hidden_states_foreach_input[i] = last_layer_encoder_hidden_states.squeeze()

    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden = last_time_step_encoder_hidden_states

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for i in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, last_layer_encoder_hidden_states_foreach_input)
            loss += criterion(decoder_output, target_tensor[i])
            decoder_input = target_tensor[i]  # Teacher forcing
    else:
        # Without teacher forcing: use its own predictions as the next input
        for i in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, last_layer_encoder_hidden_states_foreach_input)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target_tensor[i])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

def trainIters(encoder, decoder, dataset, n_iters, print_every=100):
    print_loss_total = 0  # Reset every print_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)
    training_pairs = [tensorsFromPair(dataset["input_lang"], dataset["output_lang"], random.choice(dataset["pairs"]))
                      for i in range(n_iters)]

    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, encoder,
                     decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print(f"Step: {iter} ({iter / n_iters * 100}%) Loss: {print_loss_avg}")

hidden_size = 256
encoder1 = EncoderRNN(dataset["input_lang"].n_words, hidden_size).to(device)
attn_decoder1 = DecoderAttentionRNN(hidden_size, dataset["output_lang"].n_words, dropout_p=0.1).to(device)

trainIters(encoder1, attn_decoder1, dataset, 1000)

Step: 100 (10.0%) Loss: 1.462863874832789
Step: 200 (20.0%) Loss: 0.9862939871946975
Step: 300 (30.0%) Loss: 0.69884103957812
Step: 400 (40.0%) Loss: 0.5858553817470871
Step: 500 (50.0%) Loss: 0.32183215591311454
Step: 600 (60.0%) Loss: 0.23454589036107062
Step: 700 (70.0%) Loss: 0.21993039281914625
Step: 800 (80.0%) Loss: 0.13522455321003998
Step: 900 (90.0%) Loss: 0.11338057126601536
Step: 1000 (100.0%) Loss: 0.08471504014978805


In [6]:
def evaluate(encoder, decoder, dataset, sentence, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence(dataset["input_lang"], sentence)
        input_length = input_tensor.size()[0]
        
        last_time_step_encoder_hidden_states = encoder.initHidden()
        last_layer_encoder_hidden_states_foreach_input = torch.zeros(max_length, encoder.hidden_size, device=device)

        for i in range(input_length):
            last_layer_encoder_hidden_states, last_time_step_encoder_hidden_states = encoder(input_tensor[i], last_time_step_encoder_hidden_states)
            last_layer_encoder_hidden_states_foreach_input[i] = last_layer_encoder_hidden_states.squeeze()

        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS
        decoder_hidden = last_time_step_encoder_hidden_states

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for i in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, last_layer_encoder_hidden_states_foreach_input)
            
            decoder_attentions[i] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(dataset["output_lang"].index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:i + 1]

def evaluateRandomly(encoder, decoder, dataset, n=3):
    for i in range(n):
        pair = random.choice(dataset["pairs"])
        print("Input: ", pair[0])
        print("Expected:", pair[1])
        output_words, attentions = evaluate(encoder, decoder, dataset, pair[0])
        output_sentence = ' '.join(output_words)
        print("Pred: ", output_sentence)
        print('')

evaluateRandomly(encoder1, attn_decoder1, dataset)

Input:  il est mouille .
Expected: he s wet .
Pred:  he s wet . <EOS>

Input:  je suis tatillonne .
Expected: i m fussy .
Pred:  i m fussy . <EOS>

Input:  je suis repu !
Expected: i m full .
Pred:  i m full . <EOS>



- Bahdanau attention (also known as additive attention or concat attention) is defined as [follow](https://paperswithcode.com/method/additive-attention): $f_{att}\left(\textbf{h}_{i}, \textbf{s}_{j}\right) = w_{a}^{T}\tanh\left(\textbf{W}_{a}\left[\textbf{h}_{i};\textbf{s}_{j}\right]\right)$ (1)
- Sometimes we also see written as sum: $f_{att}\left(\textbf{h}_{i}, \textbf{s}_{j}\right) = w_{a}^{T}\tanh\left(\textbf{W}_{a}\textbf{h}_{i} + \textbf{U}_{a}\textbf{s}_{j}\right)$ (2)
- This is because the projection (matmul) of 2 concatenated vectors <=> the sum of the projections of respective vectors ! ([source](https://stats.stackexchange.com/a/524729))
    > - Note: the $\textbf{W}_{a}$ in eq (1) and (2) are differents, it should be better to rewrite (2) as $f_{att}\left(\textbf{h}_{i}, \textbf{s}_{j}\right) = w_{a}^{T}\tanh\left(\textbf{T}_{a}\textbf{h}_{i} + \textbf{B}_{a}\textbf{s}_{j}\right)$ with $\textbf{T}$ being the "Top part" and $\textbf{B}$, the "Bottom part" of the same $\textbf{W}$
    > <p align="center"> <img src="./assets/concat-add-bahdanau.png" height="500" width="900" /></p>
    
    > - That's why they have different names (additive or concat attention)! 

----
- Summary: <p align="center"> <img src="assets/part2-summary.png" height="400" width="700" /></p> 