## Bahdanau et al. (2016)

https://arxiv.org/abs/1409.0473

https://machinelearningmastery.com/the-bahdanau-attention-mechanism/

Say we have a sequence (usually text) of n elements. So it looks like this -

$$ X = x_1 , x_2 , ...... x_n $$

Which needs to be aligned with another sequence -

$$ Y = y_1 , y_2 , ...... y_n $$

And an encoder decoder network which in the paper is based on a bi-directional RNN. Traditional encoders create a fixed size vector from an input sequence, which limits the amount of information the decoder can use. Furthermore, the encoder has to create this vector from multi-dim hidden states, which may result in loss of information. 

What attention does is - 
- Finds the probability of an element in $X$, $x_j$ to be aligned with an element in the output sequence, $Y$, $y_i$
- The original work was for Machine Translation, hence the alignment modelling

So say the hidden state from the encoder is $h_j$ and from the decoder, $s_{i - 1}$

Then, 

attention probability for $y_i$ to $x_j$ is $\alpha_{ij} = softmax(e_{ij})$ (softmax since probability!)

now what is $e$ here? Well actually, the hidden state of the encoder contains a representations for all elements in $X$, and attention just doesn't look at one element on both sequences and finds a probability. It compares one element in the output against all the others in the input and says, how related is this element at $i$ in the output to $1....j$ elements in the input?

Why from output to input? Machine Translation, you're finding a proper translation for a source sentence. Or in other words, you want your decoder to search for the proper translation. Attention provides the search information for the decoder. So moving on, $e$, or the attention score comes from the attention module, which is defined and trained as a feed forward network.

$e_{ij} = a(s_{i -1}, h_j)$

And then, a context vector $c_i$ is created from $\alpha_{ij}$

$$
 c_i = \sum_{j = 1}^{n} \alpha_{ij}h_j
$$

In layman terms, find how probable $y_i$ is to be an output candidate based on the entire input $X$.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

So to be more specific, 

input goes to encoder -> gives hidden state $h_j$

$h_j$ goes to alignment module with the previous decoder state $s_{i - 1}$ -> gives attention score $e_{ij}$

take softmax of the attention score

get the context vector by weighted sum of annotation score and encoder hidden state


Let's generate some random tensors for the hidden states of encoder and the decoder.

In [2]:
# dummy input of size 50
x = torch.arange(50)

embedding = nn.Embedding(100, 20)
encoder = nn.RNN(20, 30, bidirectional=True)
decoder = nn.RNN(30, 30)

with torch.no_grad():
    x = embedding(x)
    _, h_j = encoder(x)


s_previous = h_j

There are two ways to get the attention score using the hidden states. 

1. Concatenate both and pass to a FFN with a single weight, $W$ (dot product)
2. Or, have two weights for each hidden state, apply the weights individually in a FFN and then add them (additive). 

in both cases the output of the FFN is passed to a tanh activation layer and then, a weight vector is applied for scaling.

The weight vectors are implemented as linear layers.

In [3]:
@torch.no_grad()
def attention(s_previous: torch.Tensor, h_j: torch.Tensor,  mode="dot") -> torch.Tensor:    
    if mode == "dot":
        # concat
        hidden_states = torch.cat([h_j, s_previous], dim=-1)
        
        # init weight
        # actually use a linear layer with no bias
        # output should be the size of h_j
        # well you have to multiply the attention score softmax with h_j!
        W = nn.Linear(hidden_states.size(-1), h_j.size(-1))
        
        # e_ij
        attention_scores = F.tanh(W(hidden_states))
        
        # scale by  weight vector
        V = nn.Linear(attention_scores.size(-1), 1)
        attention_scores = V(attention_scores)
  
    elif mode == "add":
        W1 = nn.Linear(h_j.size(-1), h_j.size(-1))
        W2 = nn.Linear(s_previous.size(-1), s_previous.size(-1))
        
        attention_scores = F.tanh(W1(h_j) + W2(s_previous))
        
        V = nn.Linear(attention_scores.size(-1), 1)
        attention_scores = V(attention_scores)
        
    
    # transpose, since matmul for context vector
    attention_scores = attention_scores.T
    alpha = F.softmax(attention_scores, dim=-1)
    c = alpha @ h_j
    
    return c

In [4]:
# concat
ctx = attention(s_previous, h_j)
print(ctx)
print(ctx.size())

tensor([[-0.1915, -0.7699,  0.0801,  0.0558,  0.6161,  0.0018,  0.2433,  0.1940,
         -0.2244,  0.4273,  0.1298, -0.3667, -0.0507, -0.4133,  0.0799,  0.5204,
          0.2218,  0.0478, -0.0245,  0.6666,  0.2871,  0.4599,  0.4257,  0.6712,
          0.0661, -0.1901, -0.1643, -0.4719,  0.1610, -0.5516]])
torch.Size([1, 30])


In [5]:
# additive
ctx = attention(s_previous, h_j, "add")
print(ctx)
print(ctx.size())

tensor([[-0.2746, -0.7863,  0.0674,  0.0451,  0.6449,  0.0602,  0.3130,  0.1456,
         -0.2343,  0.4156,  0.0650, -0.3128, -0.0150, -0.3874,  0.1069,  0.4908,
          0.2967,  0.0543, -0.0060,  0.6706,  0.3313,  0.4690,  0.3695,  0.6550,
          0.0606, -0.1467, -0.1490, -0.5003,  0.1544, -0.5247]])
torch.Size([1, 30])
