In [None]:
import torch
import data_loader
from traineval import train, evaluate
import model as model

import matplotlib.pyplot as plt

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

## Loading the Data

The following line of code invokes data_loader and will automatically download and extract the dataset if needed.
It instantiates the following variables;
* tokens_vocab - the sentence words vocabulary
* y_vocab - the labels (senses) vocabulary
* datasets - a dictionary with train,dev, and test WSDDataset instances.

Use the optional sentence_count kwarg to limit the number of sentences loaded.

In [None]:
train_dataset, tokens_vocab, y_vocab = data_loader.load_train_dataset()
train_dataset

In [None]:
dev_dataset = data_loader.load_dev_dataset(tokens_vocab, y_vocab)
dev_dataset

## Part 1: Query-Based Attention

Implement the model.

Load the model.

In [None]:
dropout = 0.25
D = 300

m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout
).to(device)

Train the model.

In [None]:
lr = 8e-5
batch_size=100
num_epochs=10

optimizer = torch.optim.Adam(m.parameters(), lr=lr)

losses, train_acc, val_acc = train(
    m, optimizer, train_dataset, dev_dataset, num_epochs=num_epochs, batch_size=batch_size)

print(f"Validation accuracy: {val_acc[-1]:.3f}, Training accuracy:{train_acc[-1]:.3f}")

Plot the loss and training/validation accuracy. You should be getting ~54% validation accuracy after 10 epochs.

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

Use the attention vizualization to get a feel of what the model is attending to.

The query token is highlighted in green, and the model's attention with a pink-blue gradient.
In addition, the loss is given a red gradient.

In [None]:
from traineval import higlight_samples

higlight_samples(m, dev_dataset, sample_size=5)

## Part 2: Padding

Implement the padding mask.

Load the model and retrain.

In [None]:
m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout
).to(device)

optimizer = torch.optim.Adam(m.parameters(), lr=lr)

losses, train_acc, val_acc = train(
    m, optimizer, train_dataset, dev_dataset, num_epochs=num_epochs, batch_size=batch_size)

Plot the loss and training/validation accuracy.

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

Use the visualization to verify that the model does not attend on pads.

In [None]:
higlight_samples(m, dev_dataset, sample_size=5)

Examine additional examples, using the API and pandas as demonstrated below.

In [None]:
import pandas as pd
import numpy as np
from traineval import evaluate_verbose, highlight

pd.set_option('max_columns', 100)

eval_df, attention_df = evaluate_verbose(m, dev_dataset, iter_lim=100)

Visualization of 5 incorrectly classified examples.

In [None]:
idxs = np.where(eval_df['y_true'] != eval_df['y_pred'])
idxs = list(idxs[0][:5])
highlight(eval_df, attention_df, idxs)

Visualization of examples with the query word "left".

In [None]:
idxs = np.where(eval_df['query_token'] == 'left')
highlight(eval_df, attention_df, idxs)

## Part 3: Self-Attention

The method below converts the query-based instances in WSDDataset to sentence-level instances in WSDSentencesDataset for self-attention.

Notice how the number of samples now equals number of sentences.

In [None]:
sa_train_dataset = data_loader.WSDSentencesDataset.from_word_dataset(train_dataset)
sa_train_dataset

In [None]:
sa_dev_dataset = data_loader.WSDSentencesDataset.from_word_dataset(dev_dataset)
sa_dev_dataset

Implement self-attention in the model.

Load the model and retrain.

In [None]:
lr=2e-4
dropout = 0.2
D=300
batch_size=100
num_epochs=5

m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout
).to(device)

optimizer = torch.optim.Adam(m.parameters(), lr=lr)

losses, train_acc, val_acc = train(
    m, optimizer, sa_train_dataset, sa_dev_dataset, num_epochs=num_epochs, batch_size=batch_size)

Plot the loss and training/validation accuracy.

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

## Part 4: Position-Sensitive Attention

In [None]:
# TODO: your experiments here

## Part 5: Causal Attention

In [None]:
# TODO: your experiments here