# Attention Mechanism

## RNN recap

In a basic RNN, each recurrent neuron receives inputs from all neurons from the previous time step, as well as the inputs from the current time step, hence the term 'recurrent'.

In [3]:
### This cell should be hidden in the final version

import tensorflow as tf
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from keras.src.utils import pad_sequences
from jupyterquiz import display_quiz
from sklearn.metrics import accuracy_score
from keras.datasets import imdb


git_path="https://raw.githubusercontent.com/ChaosTheLegend/ML-Book/main/Quizes/"

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=10000)

max_len = 200
x_train = pad_sequences(x_train, maxlen=max_len, truncating='post')
x_test = pad_sequences(x_test, maxlen=max_len, truncating='post')
num_words = 10000

embedding_dim = 100
hidden_dim = 256
output_dim = 1
dropout_rate = 0.5


# model = tf.keras.Sequential([
#    tf.keras.layers.Embedding(input_dim=num_words, output_dim=embedding_dim, input_length=max_len),
#    tf.keras.layers.SimpleRNN(hidden_dim),
#    tf.keras.layers.Dense(output_dim, activation='sigmoid')
# ])

# model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Model is loaded from the file to save time

simpleRNN = tf.keras.models.load_model('simpleRNN.keras')

In [None]:
# make predictions and calculate accuracy

y_pred = simpleRNN.predict(x_test)
y_pred = np.round(y_pred)

simple_accuracy = accuracy_score(y_test, y_pred)

In [None]:
accuracy_epochs = pd.read_csv('simplernn_accuracy.csv')

plt.plot(accuracy_epochs['epoch'], accuracy_epochs['accuracy'])

plt.title('Accuracy of Simple RNN')

plt.xlabel('Epoch')

plt.ylabel('Accuracy')

# add final accuracy to the plot

plt.text(17, 0.52, 'Final Accuracy: ' + str(round(simple_accuracy, 2)))

## The Need for Attention Mechanism

The problem with basic RNNs is that they are not very good at handling long sequences. 

Even when using more epochs, the accuracy of the model does not improve much. This is because the model is not able to learn the long-term dependencies in the data.

This is known as the vanishing gradient problem.

### Vanishing Gradient Problem

The vanishing gradient problem occurs when the gradients of the loss function become increasingly smaller as the model learns to associate inputs and outputs that are further apart in time.

This leads to the model "forgetting" the information from the earlier inputs, which makes it difficult to learn long-term dependencies.

![Simple RNN](https://raw.githubusercontent.com/ChaosTheLegend/ML-Book/main/Images/SimpleRNN.png)

### Math Behind Vanishing Gradient Problem

The vanishing gradient problem occurs because of the way gradients are computed in RNNs:

$$
\frac{\partial L}{\partial W} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial h} \frac{\partial h}{\partial W}
$$

The gradient is computed by multiplying the gradients of the loss function with respect to the output, the output with respect to the hidden state, and the hidden state with respect to the weights.

Since gradients are multiplied together, if the gradients at each time step are less than 1 (e.g., due to using activation functions like sigmoid or tanh), this multiplication leads to a compounding effect. As you go further back in time, the gradients become increasingly smaller.


![Simple RNN](https://raw.githubusercontent.com/ChaosTheLegend/ML-Book/main/Images/SimpleRNNProblem.png)

## Attention Mechanism

To combat the vanishing gradient problem, we can use an attention mechanism.

An attention mechanism is a way to help RNNs learn long-term dependencies by allowing the model to focus on the most relevant parts of the input sequence when producing a given output.

We do this by adding a context vector to the model, which is a weighted sum of the encoder's hidden states. The weights are computed using an alignment score function, which measures how well the inputs around a given position and the output at that position match.

![Attention Mechanism](https://raw.githubusercontent.com/ChaosTheLegend/ML-Book/main/Images/Attention.png)

In [None]:
from keras.layers import Input, Embedding, LSTM, Dense, Attention, Bidirectional, Dropout, SimpleRNN
import os

inputs = Input(shape=(max_len,))
embedding = Embedding(input_dim=num_words, output_dim=embedding_dim, input_length=max_len)(inputs)
rnn = SimpleRNN(hidden_dim, return_sequences=True)(embedding)
attention = Attention()([rnn, rnn])
context = tf.reduce_sum(attention * rnn, axis=1)
outputs = Dense(output_dim, activation='sigmoid')(context)

model = tf.keras.Model(inputs, outputs)

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
# load the model

model = tf.keras.models.load_model('simpleRNN_attention.keras')

## Math Behind Attention Mechanism

The context vector is computed as follows:

$$
c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
$$

where $c_i$ is the context vector at position $i$, $T_x$ is the length of the input sequence, $\alpha_{ij}$ is the alignment score between the output at position $i$ and the input at position $j$, and $h_j$ is the hidden state at position $j$.

![Attention Mechanism](https://raw.githubusercontent.com/ChaosTheLegend/ML-Book/main/Images/AttentionLive.png)

In [None]:
model.summary()

attention_layer = model.layers[2]


weights = attention_layer.get_weights()[1]

# sum all the columns in the weights matrix

weights = np.sum(weights, axis=1)

weights = weights[:20]

plt.bar(range(weights.shape[0]), weights)

plt.title('Attention Weights')

plt.xlabel('Embedded word index')

plt.ylabel('Word Weight')

# make x axis use whole numbers

plt.xticks(range(weights.shape[0]), range(weights.shape[0]))

# color all negative weight bars red and positive weights blue

for i in range(weights.shape[0]):
    if weights[i] < 0:
        plt.gca().get_children()[i].set_color('red')
    else:
        plt.gca().get_children()[i].set_color('green')

plt.show()

The Lower the absolute value of the weight, the less relevant the word is to the output

Green bars show positive weights, indicating that the word is positive

Red bars show negative weights, indicating that the word is negative

In [None]:
y_pred = model.predict(x_test)

y_pred = np.round(y_pred)

attention_accuracy = accuracy_score(y_test, y_pred)

attention_accuracy

By adding an attention mechanism, our model performs way better even when using the low number of epochs.

In [None]:
# draw a bar chart to compare the accuracy of the two models

accuracy = [attention_accuracy, simple_accuracy];


plt.bar(['Attention (5 epochs)', 'Simple RNN (50 epochs)'], accuracy)

# add a title to the plot

plt.title('Accuracy of Simple RNN vs Attention Mechanism')

# add a label to the y-axis

plt.ylabel('Accuracy')
