# Deep Learning
## Exercise 9 - Attention Mechanism


### 1. Sentiment Classification with Soft Attention
Recall that we have implemented a simple LSTM based sentimental classification model in the 5th exercise. Now we want to use all hidden states of LSTM layer via attention scores. We will use the same data and task setup:

In [None]:
# Similarly get the data
import random
import re
import torch
from torchtext import data, datasets, vocab
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from collections import Counter, OrderedDict
from sklearn.model_selection import train_test_split

random_seed = 0
data_directory = './data'
debugging = False #This can be set to True, if you want to test your implementation on a smaller subset


random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

max_length = 200   # we want the maximum words in each text instance to be 200.
max_vocab = 20000  # We want the vocabulary size not to exceed 20000.

# define a function to preprocess and tokenize raw text input
tokenizer = data.get_tokenizer('basic_english')
def text_tokenizer(entry):
    entry = re.sub('<\w{1,2} />', ' ', entry) #replace <br /> and similar
    entry = re.sub(r'[^\w\s]', ' ', entry) #remove any non-space or non-word characters
    entry = re.sub(r'\s+', ' ', entry) #replace multiple spaces by one space
    tokens = tokenizer(entry)
    return tokens

# read the dataset, the first call also downloads the dataset. Split the training_data into training and validation
train_set, test_set = datasets.IMDB(root=data_directory)
train_set = list(train_set)
test_set = list(test_set)


if debugging: 
    train_labels = [l for l,t in train_set]
    test_labels = [l for l, t in test_set]
    train_set, _ = train_test_split(train_set, train_size=0.2, stratify=train_labels, random_state=random_seed)
    test_set, _ = train_test_split(test_set, train_size=0.2, stratify=test_labels, random_state=random_seed)

train_labels = [l for l,t in train_set]
train_set, val_set = train_test_split(train_set, train_size=0.7, stratify=train_labels, random_state=random_seed)

# build the vocabulary from the training data
counter = Counter()
for label, text in train_set:
    tokens = text_tokenizer(text)
    counter.update(tokens)
    
vocabulary = vocab.vocab(OrderedDict(counter.most_common()[:max_vocab]))
special_tokens = ['<unk>', '<pad>']
for i, tok in enumerate(special_tokens):
    vocabulary.insert_token(tok, i)
vocabulary.set_default_index(vocabulary['<unk>'])

# build the train, validation and test dataloaders
def collate_fn(batch):
    labels, indexes = [], []
    for label, text in batch:
        labels.append(1 if label =='pos' else 0)
        
        tokens = text_tokenizer(text)
        indexes += [torch.tensor([vocabulary[t] for t in tokens][:max_length])]
    labels = torch.tensor(labels)
    padded_indices = pad_sequence(indexes, padding_value=vocabulary['<pad>'], batch_first=True)
        
    return labels, padded_indices

train_dataloader = DataLoader(train_set, batch_size=32, shuffle=True, 
                              collate_fn=collate_fn, drop_last=True)

val_dataloader = DataLoader(val_set, batch_size=32, shuffle=True, 
                              collate_fn=collate_fn, drop_last=True)

test_dataloader = DataLoader(test_set, batch_size=32, shuffle=False, 
                              collate_fn=collate_fn, drop_last=False)

#### 1. Build an LSTM with attention

Reuse your code for the 5th exercise and add an attention mechanism. Specifically, you need to:
* Compute the soft attention scores by the dot product of each hidden state and the last hidden state. 
* Compute a context vector by the weighted sum of all hidden states.
* Apply a dropout layer with $p=0.3$ on the context vector.
* Use the context vector to predict the label, instead of the original final state.

$$Context = \sum_{i=0}^{L}\alpha_{i}  h_i, \; \text{where} \\
\alpha_{i} = \frac{e_{sim(h_{L}, h_{i})}}{\sum_{j}(e^{sim(h_{L}, h_{j})})} \\
sim(a, b) = a \cdot b    
$$

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#ToDo: Implement an LSTM with Attention

In [None]:
class AttentionSentimentClassifier(nn.Module):
    def __init__(self):
        super(AttentionSentimentClassifier, self).__init__()
        self.embedding = nn.Embedding(len(vocabulary), 100, padding_idx=vocabulary['<pad>'])
        self.rnn = nn.LSTM(100, 200, num_layers=1, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(400, 1)
        self.dropout = nn.Dropout(0.3) 
        self.classify = nn.Sigmoid()
        self.softm = nn.Softmax(1)
    def forward(self, text):
        embedded = self.embedding(text)
        hidden, (last_hid, last_cell) = self.rnn(embedded)
        last_hid = torch.cat((last_hid[0], last_hid[1]), -1) #Since we are using a bidirectional LSTM
        last_hid = last_hid.unsqueeze(-1)                    #we need to concatenate these two

        sims = torch.bmm(hidden,last_hid).squeeze(-1)
        attention = self.softm(sims)
        attention = attention.unsqueeze(-2)
        context = torch.bmm(attention, hidden)  
        drop_context = self.dropout(context)
        linear = self.fc(drop_context)
        output = self.classify(linear)
        
        return output.flatten(), attention


#### 2. Train your model

Train your model as in the 5th exercise (use binary cross-entropy loss and the adam optimizer, train for a maximum of 20 epochs and implement early stopping)

In [None]:
#ToDo: Train your model

In [None]:
from tqdm import tqdm

def train(num_epochs, model, loss_funtion, optimizer, train_loader, val_loader, break_criterium, model_name):
    best_val_loss = 100000
    no_improve=0
    for epoch in range(num_epochs):
        model.train()
        for labels, indices in tqdm(train_loader, desc='Train Iter', ascii=True):
            output, _ = model(indices)
            loss = loss_function(output, labels.float())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        acc, val_loss = evaluate(model, val_loader, loss_function)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve = 0
            torch.save(model.state_dict(), f'{model_name}.pt')
        else:
            no_improve += 1
        print(f"Epoch {epoch} \t Loss {val_loss:.5f} \t Accuracy {acc:.5f}")
        if no_improve >= break_criterium:
            model.load_state_dict(torch.load(f'{model_name}.pt'))
            break
                
def evaluate(model, test_loader, loss_function):
    model.eval()
    correct = 0
    total_entries = 0
    cum_loss = 0
    with torch.no_grad():
        for labels, indices in tqdm(test_loader, desc='Test Iter', ascii=True):
            output, _ = model(indices)
            preds = (output>0.5).int()
            correct += (preds == labels).sum()
            total_entries += labels.shape[0]
            cum_loss += loss_function(output, labels.float()).item()
    return correct/total_entries, cum_loss/len(test_loader)
    
ASC = AttentionSentimentClassifier()
loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(ASC.parameters())
train(15, ASC, loss_function, optimizer, train_dataloader, val_dataloader, 5, 'ASC') 
print(evaluate(ASC, test_dataloader, loss_function))  



#### 3. Visualize or print the attention

Visualize or print the attention scores of a test instance.

In [None]:
#ToDo: Visualize the attention scores

In [None]:
#Find a sentence with a reasonable length
sample_sents = [entry for entry in test_set if 5 < len(text_tokenizer(entry[1])) < 25 ]
print(sample_sents)
sentence = sample_sents[0]
print(sentence)
tokens = text_tokenizer(sentence[1])
token_ids = torch.tensor([[vocabulary[token] for token in tokens]])
print(tokens)
print(token_ids)

#Get the attention scores
output, attention = ASC(token_ids)
print(attention.shape)

print(f'Prediction: {output.item():.4f}')
from matplotlib import pyplot as plt

plt.bar(range(len(tokens)), attention.detach().flatten(), tick_label=tokens)
plt.xticks(rotation=45, ha='right')
plt.show()

In [None]:
from IPython.display import display_html
import numpy as np
def to_html(word, importance):
  def _get_color(attr):
    # clip values to prevent CSS errors (Values should be from [-1,1])
    attr = max(-1, min(1, attr))
    if attr > 0:
        hue = 120
        sat = 75
        lig = 100 - int(50 * attr)
    else:
        hue = 0
        sat = 75
        lig = 100 - int(-40 * attr)
    return "hsl({}, {}%, {}%)".format(hue, sat, lig)
  color = _get_color(importance)
  tag = '<mark style="background-color: {color}; opacity:1.0; \
                    line-height:1.75"><font color="black"> {word}\
                    </font></mark>'.format(
            color=color, word=word)
  return tag
res = ''.join(to_html(word, att) for word, att in zip(tokens, attention.detach().flatten()))
display_html(res, raw=True)

### 2. Sentiment Classification with Transformer
Recall that we have implemented a very simple LSTM based sentimental classification model in the 5th exercise. Now it is the time to replace the LSTM module with the Transformer. Again we will use the same IMDB dataset and the transformer implementation from `torch.nn`. 

#### 1. Define the Model

Your model should consists of:
* An Embedding layer with 200 embedding dimension. Use `nn.Embedding`.
* A Positional embedding layer to include position information for each token. Use the given `PositionalEncoding` Module we took from [this tutorial](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) (recommend reading).
* A self-attention layer with 4 heads and 0.3 dropout rate. Use `nn.MultiheadAttention`.
* A fully connected layer to map the attention based representation to 400 dimension hidden states. Use `nn.Linear`.
* A fully connected layer to map the ***first step*** of the hidden states to the prediction output, use the sigmoid activation for the output. 


In [None]:
from torch import nn

class PositionalEncoding(nn.Module):
    # from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    def __init__(self, embedding_dim, dropout=0.1, max_len=max_length):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
#ToDo: Fill the __init__() and forward() functions. Add arguments if needed.
class TransformerClassifier(nn.Module):
    def __init__(self, ):
        super(TransformerClassifier, self).__init__()
        

    def forward(self, ):



In [None]:
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, max_len ):
        super(TransformerClassifier, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim=200)
        self.pos_encoding = PositionalEncoding(embedding_dim=200, max_len = max_len)
        self.self_attention = nn.MultiheadAttention(embed_dim=200, num_heads=4, dropout=0.3, batch_first=True)
        self.full_1 = nn.Linear(in_features=200, out_features=400)
        self.full_2 = nn.Linear(in_features=400, out_features=1)
        self.sigm = nn.Sigmoid()
        

    def forward(self, x):
        embs = self.embed(x)
        pos_encs = self.pos_encoding(embs)
        outp, weights = self.self_attention(pos_encs, pos_encs, pos_encs)
        full_1 = self.full_1(outp)
        full_2 = self.full_2(full_1[:,0,:])
        pred = self.sigm(full_2).flatten()
        return pred, weights
    
for test_item in test_set:
    TC = TransformerClassifier(len(vocabulary), max_length)
    print(TC(torch.tensor([[vocabulary[token] for token in text_tokenizer(test_item[1])]])))
    break
        

#### 2. Train the model and evaluate it

You don't need to re-implement a training loop if you implemented the one above generally enough.

In [None]:
#ToDo: Train the model and evalute it

In [None]:
TC = TransformerClassifier(len(vocabulary), max_length)
loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(TC.parameters())
train(25, TC, loss_function, optimizer, train_dataloader, val_dataloader, 10, 'TC') 
print(evaluate(TC, test_dataloader, loss_function)) 

#### 3. Visualize or print the attention scores

For the sample sentence: `This is a good movie` visualize the attention scores. A heatmap is recommended.

In [None]:
sample = 'This is a good movie'

#ToDo: visualize the attention scores.

In [None]:
sample = 'This is a good movie'
tokens = text_tokenizer(sample)
token_ids = torch.tensor([[vocabulary[token] for token in tokens]])
pred, att = TC(token_ids)
att = att.detach().squeeze(0).numpy()
pred = pred.detach().item()
print(f'Prediction: {pred:.3f}, {"pos" if pred>0.5 else "neg"}')

import matplotlib.pyplot as plt
import matplotlib
# from https://matplotlib.org/stable/gallery/images_contours_and_fields/image_annotated_heatmap.html

fig, ax = plt.subplots()

    # Plot the heatmap
im = ax.imshow(att, cmap='Blues')

    # Create colorbar
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel('Attention', rotation=-90, va="bottom")

    # We want to show all ticks...
ax.set_xticks(np.arange(att.shape[1]))
ax.set_yticks(np.arange(att.shape[0]))
    # ... and label them with the respective list entries.
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)

    # Let the horizontal axes labeling appear on top.
ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
             rotation_mode="anchor")

    # Create white grid.
ax.set_xticks(np.arange(att.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(att.shape[0]+1)-.5, minor=True)
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)


threshold = im.norm(att.max())/2.

# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
texts = []
for i in range(att.shape[0]):
    for j in range(att.shape[1]):

        im.axes.text(j, i, f"{att[i,j]:.2f}", color=('black' if int(im.norm(att[i, j]) < threshold) else 'white'),
                            horizontalalignment='center', verticalalignment='center')

