## Studets details

Student1
* Name:
* ID:
* Username:

Student2
* Name:
* ID:
* Username:

Student3
* Name:
* ID:
* Username:

### General tip

While debugging you might want to use:
```python
import importlib
importlib.reload(model)
```

to reload the model module without repeating unnecessary cells.

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

### Import relevant packages - you might need to pip install some 

In [None]:
import torch
import sys
from os.path import dirname

sys.path.append('.')
sys.path.append('..')
sys.path.append('/content/drive/MyDrive/TAU/Advanced NLP/Ex1')

import data_loader
from traineval import train, evaluate
import model as model


import matplotlib.pyplot as plt

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"deviced used is {device}")

import importlib
importlib.reload(model)

%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import random

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
seed = 42

## Loading the Data

The following line of code invokes data_loader and will automatically download and extract the dataset if needed.
It instantiates the following variables;
* tokens_vocab - the sentence words vocabulary
* y_vocab - the labels (senses) vocabulary
* datasets - a dictionary with train,dev, and test WSDDataset instances.

Use the optional sentence_count kwarg to limit the number of sentences loaded.

In [None]:
train_dataset, tokens_vocab, y_vocab = data_loader.load_train_dataset()
train_dataset

In [None]:
dev_dataset = data_loader.load_dev_dataset(tokens_vocab, y_vocab)
dev_dataset

## Part 1: Query-Based Attention

Implement the relevant parts in model.py module. You might to check out this blog post about [gather method](https://medium.com/analytics-vidhya/understanding-indexing-with-pytorch-gather-33717a84ebc4)

Load the model.

In [None]:
dropout = 0.25
D = 300
lr = 8e-5
batch_size=100
num_epochs=5
set_seed(seed)

m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout
).to(device)

optimizer = torch.optim.Adam(m.parameters(), lr=lr)

losses, train_acc, val_acc = train(m, 
                                   optimizer, 
                                   train_dataset, 
                                   dev_dataset, 
                                   num_epochs=num_epochs, 
                                   batch_size=batch_size)

Train the model - you shoud see the loss decreasing and validation acc increasing from epoch to epoch.

In [None]:
print(f"Validation accuracy: {val_acc[-1]:.3f}, Training accuracy:{train_acc[-1]:.3f}")
assert round(val_acc[-1], 3) >= 0.514, "The last validation accuracy should be at least 0.514. Please check your implementation before you continue"

Plot the loss and training/validation accuracy. You should be getting ~54% validation accuracy after 10 epochs.

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

Use the attention vizualization to get a feel of what the model is attending to.

The query token is highlighted in green, and the model's attention with a pink-blue gradient.
In addition, the loss is given a red gradient.

In [None]:
from traineval import higlight_samples

higlight_samples(m, dev_dataset, sample_size=5)

## Part 2: Padding

Implement the padding mask in the attention function in model.py.

Load the model and retrain.

In [None]:
set_seed(seed)

m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout,
    use_padding=True
).to(device)

optimizer = torch.optim.Adam(m.parameters(), lr=lr)

losses, train_acc, val_acc = train(
    m, optimizer, train_dataset, dev_dataset, num_epochs=num_epochs, batch_size=batch_size)

In [None]:
print(f"Validation accuracy: {val_acc[-1]:.3f}, Training accuracy:{train_acc[-1]:.3f}")
assert round(val_acc[-1], 3) >= 0.527, "The last validation accuracy should be at least 0.527. Please check your implementation before you continue"

Plot the loss and training/validation accuracy.

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

In [None]:
higlight_samples(m, dev_dataset, sample_size=5)

Examine additional examples, using the API and pandas as demonstrated below.

In [None]:
import pandas as pd
import numpy as np
from traineval import evaluate_verbose, highlight

pd.set_option('max_columns', 100)

eval_df, attention_df = evaluate_verbose(m, dev_dataset, iter_lim=100)

Visualization of 5 incorrectly classified examples.

In [None]:
idxs = np.where(eval_df['y_true'] != eval_df['y_pred'])
idxs = list(idxs[0][:5])
highlight(eval_df, attention_df, idxs)

Visualization of examples with the query word "left".

In [None]:
idxs = np.where(eval_df['query_token'] == 'left')
highlight(eval_df, attention_df, idxs)

## Part 3: Self-Attention

The method below converts the query-based instances in WSDDataset to sentence-level instances in WSDSentencesDataset for self-attention.

Notice how the number of samples now equals number of sentences.

In [None]:
sa_train_dataset = data_loader.WSDSentencesDataset.from_word_dataset(train_dataset)
sa_train_dataset

In [None]:
sa_dev_dataset = data_loader.WSDSentencesDataset.from_word_dataset(dev_dataset)
sa_dev_dataset

Implement self-attention in the model.

Load the model and retrain.

In [None]:
lr=2e-4
dropout = 0.2
D=300
batch_size=20
num_epochs=2
set_seed(seed)

m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout,
    use_padding=True
).to(device)

optimizer = torch.optim.Adam(m.parameters(), lr=lr)

losses, train_acc, val_acc = train(
    m, optimizer, sa_train_dataset, sa_dev_dataset, num_epochs=num_epochs, batch_size=batch_size)

Plot the loss and training/validation accuracy.

In [None]:
print(f"Validation accuracy: {val_acc[-1]:.3f}, Training accuracy:{train_acc[-1]:.3f}")
# assert val_acc[-1] >= 0.543, "The last validation accuracy should be at least 0.543. Please check your implementation before you continue"

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

## Part 4: Positional embeddings  &  Part 5: Causal Attention


We do not provide "you code here" comments for this part as you should be familiar with the code by now

In [None]:
from warnings import warn


def train_eval_positional(num_epochs: int,
                          pos_exponent: int,
                          pos_cutoff_position: int,
                          pos_is_causal:bool,
                          pos_normalize_magnitude: bool,
                          DEBUG_dummy_train: bool = False) -> None:
    
    sa_train_dataset = data_loader.WSDSentencesDataset.from_word_dataset(train_dataset)
    sa_dev_dataset = data_loader.WSDSentencesDataset.from_word_dataset(dev_dataset)

    if DEBUG_dummy_train:
        sa_train_dataset = sa_dev_dataset
        warn("using dev set for training")
        print()
    
    lr=2e-4
    dropout = 0.2
    D=300
    batch_size=20
    set_seed(seed)

    m = model.WSDModel(
        tokens_vocab.size(), 
        y_vocab.size(), 
        D=D, 
        dropout_prob=dropout,
        use_padding=True,
        use_positional_encodings=True,
        pos_exponent=pos_exponent,
        pos_cutoff_position=pos_cutoff_position,
        pos_is_causal=pos_is_causal,
        pos_normalize_magnitude=pos_normalize_magnitude
    ).to(device)

    optimizer = torch.optim.Adam(m.parameters(), lr=lr)

    losses, train_acc, val_acc = train(
        m, optimizer, sa_train_dataset, sa_dev_dataset, num_epochs=num_epochs, batch_size=batch_size)
    
    print(f"Validation accuracy: {val_acc[-1]:.3f}, Training accuracy:{train_acc[-1]:.3f}")

    fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

    axs[0].plot(losses, '-', label='Train Loss');
    axs[0].legend()
    axs[1].plot(train_acc, '-o', label='Train Acc');
    axs[1].plot(val_acc, '-o', label='Val Acc');
    axs[1].legend()

    plt.tight_layout()
    plt.show()

In [None]:
for pos_is_causal in [False, True]:
    for pos_normalize_magnitude in [False, True]:
        for pos_exponent in [1, 1.5, 2]:
            print("\n\n")
            print("==================================================================================")
            print("pos_is_causal:", pos_is_causal, " | pos_normalize_magnitude:", pos_normalize_magnitude,
                  " | pos_exponent:", pos_exponent)
            print("==================================================================================")
            train_eval_positional(num_epochs=1,
                                  pos_exponent=pos_exponent,
                                  pos_cutoff_position=10,
                                  pos_is_causal=pos_is_causal,
                                  pos_normalize_magnitude=pos_normalize_magnitude,
                                  DEBUG_dummy_train=True)