# Question answering on the SQuAD dataset

In [50]:
import sys
import datetime
import random
from functools import partial

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn
import torch
import torch.nn as nn
import torch.optim as optim
import gensim
import gensim.downloader as gloader
import wandb
import transformers
import tokenizers
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.normalizers import Sequence, StripAccents, Lowercase, Strip
from tokenizers.pre_tokenizers import Sequence as PreSequence
from tokenizers.pre_tokenizers import Whitespace, Punctuation
from IPython.display import display, HTML

import dataset
import model

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
plt.rcParams['figure.figsize'] = [8, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.xmargin'] = .05
plt.rcParams['axes.ymargin'] = .05
plt.style.use('ggplot')

In [70]:
WANDB_PROJECT = "squad-qa"
WANDB_ENTITY = "wadaboa"
WANDB_MODE = "online"
WANDB_RESUME = "never"

init_wandb = partial(
    wandb.init,
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    mode=WANDB_MODE,
    resume=WANDB_RESUME,
)

## Preliminaries

### Data loading/preparation

In [194]:
squad_dataset = dataset.SquadDataset()

In [195]:
squad_dataset.dataframe

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,5733be284776f41900661182,To whom did the Virgin Mary allegedly appear i...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",Saint Bernadette Soubirous,515,541
1,5733be284776f4190066117f,What is in front of the Notre Dame Main Building?,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a copper statue of Christ,188,213
2,5733be284776f41900661180,The Basilica of the Sacred heart at Notre Dame...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",the Main Building,279,296
3,5733be284776f41900661181,What is the Grotto at Notre Dame?,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a Marian place of prayer and reflection,381,420
4,5733be284776f4190066117e,What sits on top of the Main Building at Notre...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a golden statue of the Virgin Mary,92,126
...,...,...,...,...,...,...,...,...
87594,5735d259012e2f140011a09d,In what US state did Kathmandu first establish...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Oregon,229,235
87595,5735d259012e2f140011a09e,What was Yangon previously known as?,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Rangoon,414,421
87596,5735d259012e2f140011a09f,With what Belorussian city does Kathmandu have...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Minsk,476,481
87597,5735d259012e2f140011a0a0,In what year did Kathmandu create its initial ...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",1975,199,203


In [196]:
squad_dataset.train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,5733be284776f41900661182,To whom did the Virgin Mary allegedly appear i...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",Saint Bernadette Soubirous,515,541
1,5733be284776f4190066117f,What is in front of the Notre Dame Main Building?,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a copper statue of Christ,188,213
2,5733be284776f41900661180,The Basilica of the Sacred heart at Notre Dame...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",the Main Building,279,296
3,5733be284776f41900661181,What is the Grotto at Notre Dame?,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a Marian place of prayer and reflection,381,420
4,5733be284776f4190066117e,What sits on top of the Main Building at Notre...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a golden statue of the Virgin Mary,92,126
...,...,...,...,...,...,...,...,...
70185,5735d259012e2f140011a09d,In what US state did Kathmandu first establish...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Oregon,229,235
70186,5735d259012e2f140011a09e,What was Yangon previously known as?,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Rangoon,414,421
70187,5735d259012e2f140011a09f,With what Belorussian city does Kathmandu have...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Minsk,476,481
70188,5735d259012e2f140011a0a0,In what year did Kathmandu create its initial ...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",1975,199,203


In [197]:
squad_dataset.val_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be85543aeaaa14008c9063,When did Beyonce start becoming popular?,Beyoncé,55,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,in the late 1990s,269,286
1,56be85543aeaaa14008c9065,What areas did Beyonce compete in when she was...,Beyoncé,55,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,singing and dancing,207,226
2,56be85543aeaaa14008c9066,When did Beyonce leave Destiny's Child and bec...,Beyoncé,55,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,2003,526,530
3,56bf6b0f3aeaaa14008c9601,In what city and state did Beyonce grow up?,Beyoncé,55,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,"Houston, Texas",166,180
4,56bf6b0f3aeaaa14008c9602,In which decade did Beyonce become famous?,Beyoncé,55,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,late 1990s,276,286
...,...,...,...,...,...,...,...,...
17404,5732868bb3a91d1900202e0f,At what Augusta hole was the Eisenhower Pine l...,Dwight_D._Eisenhower,18592,"A loblolly pine, known as the ""Eisenhower Pine...",17th,74,78
17405,5732868bb3a91d1900202e10,How many meters away from the Masters tee on A...,Dwight_D._Eisenhower,18592,"A loblolly pine, known as the ""Eisenhower Pine...",192,110,113
17406,5732868bb3a91d1900202e11,What did Eisenhower want to be done to the Eis...,Dwight_D._Eisenhower,18592,"A loblolly pine, known as the ""Eisenhower Pine...",cut down,279,287
17407,5732868bb3a91d1900202e12,What damaged the Eisenhower Pine in February 2...,Dwight_D._Eisenhower,18592,"A loblolly pine, known as the ""Eisenhower Pine...",ice storm,478,487


In [198]:
squad_dataset.test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be4db0acb8001400a502ec,Which NFL team represented the AFC at Super Bo...,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,Denver Broncos,177,191
3,56be4db0acb8001400a502ed,Which NFL team represented the NFC at Super Bo...,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,Carolina Panthers,249,266
6,56be4db0acb8001400a502ee,Where did Super Bowl 50 take place?,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,"Santa Clara, California",403,426
7,56be4db0acb8001400a502ee,Where did Super Bowl 50 take place?,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,Levi's Stadium,355,369
8,56be4db0acb8001400a502ee,Where did Super Bowl 50 take place?,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,Levi's Stadium in the San Francisco Bay Area a...,355,427
...,...,...,...,...,...,...,...,...
34711,5737aafd1c456719005744fd,What is a very seldom used unit of mass in the...,Force,2066,"The pound-force has a metric counterpart, less...",slug,274,278
34712,5737aafd1c456719005744fd,What is a very seldom used unit of mass in the...,Force,2066,"The pound-force has a metric counterpart, less...",metric slug,267,278
34715,5737aafd1c456719005744fd,What is a very seldom used unit of mass in the...,Force,2066,"The pound-force has a metric counterpart, less...",the metric slug,263,278
34716,5737aafd1c456719005744fe,What seldom used term of a unit of force equal...,Force,2066,"The pound-force has a metric counterpart, less...",kip,712,715


In [95]:
squad_dataset.train_dataset[0]

('To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 515,
 541)

### GloVe embeddings

In [8]:
UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"

- FastText: 
    - _fasttext-wiki-news-subwords_ (dimensions: 300)
- GloVe:
    - _glove-twitter_ (dimensions: 25. 50, 100, 200)
    - _glove-wiki-gigaword_ (dimensions: 50, 100, 200, 300)
- Word2Vec:
    - _word2vec-google-news_ (dimensions: 300)
    - _word2vec-ruscorpora_ (dimensions: 300)

In [9]:
def load_embedding_model(model_name, embedding_dimension=50):
    """
    Loads a pre-trained word embedding model via gensim library
    """
    model = f"{model_name}-{embedding_dimension}"
    try:
        return gloader.load(model)
    except Exception as e:
        print("Invalid embedding model name.")
        raise e


# See https://github.com/RaRe-Technologies/gensim-data
GLOVE_EMBEDDING_DIMENSION = 50
GLOVE_MODEL_NAME = "glove-twitter"
glove_embedding_model = load_embedding_model(
    GLOVE_MODEL_NAME, embedding_dimension=GLOVE_EMBEDDING_DIMENSION
)

In [10]:
glove_unk = np.mean(glove_embedding_model.vectors, axis=0)
glove_embedding_model.add(UNK_TOKEN, glove_unk)

In [11]:
glove_embedding_model[UNK_TOKEN]

array([-0.21896735,  0.17269313, -0.05617283,  0.06307325,  0.00960657,
       -0.23461065, -0.16731773, -0.25613925,  0.12990713, -0.34179848,
       -0.07411992,  0.00533567,  0.7090377 , -0.1139018 ,  0.10613882,
        0.09186497,  0.15880948,  0.03158554,  0.2241412 ,  0.20387109,
        0.05305386,  0.04961218,  0.11807557, -0.10199773, -0.18345806,
        0.56560194,  0.07183363,  0.04322447, -0.39442873,  0.06828266,
        0.39542177,  0.08794834,  0.41605434, -0.27820984, -0.5106833 ,
       -0.16443801,  0.0973425 ,  0.02233286,  0.19346187,  0.15909852,
        0.886585  , -0.01498107,  0.10211241, -0.12959567, -0.328366  ,
        0.13014658, -0.02061043,  0.05735753,  0.14008364,  0.22588447],
      dtype=float32)

In [12]:
list(glove_embedding_model.vocab.keys())[-1]

'[UNK]'

In [13]:
any(np.all(glove_embedding_model.vectors == 0, axis=1))

False

In [14]:
glove_embedding_model.add(PAD_TOKEN, np.zeros((1, GLOVE_EMBEDDING_DIMENSION)))
glove_embedding_model[PAD_TOKEN]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)

In [15]:
list(glove_embedding_model.vocab.keys())[-1]

'[PAD]'

In [16]:
glove_embedding_model.vectors.shape

(1193516, 50)

In [17]:
glove_vocab = dict(
    zip(glove_embedding_model.index2word, range(len(glove_embedding_model.index2word)))
)

In [18]:
glove_embedding_layer = nn.Embedding(
    glove_embedding_model.vectors.shape[0],
    GLOVE_EMBEDDING_DIMENSION,
    padding_idx=glove_vocab[PAD_TOKEN],
)
glove_embedding_layer.weight = nn.Parameter(
    torch.from_numpy(glove_embedding_model.vectors)
)
glove_embedding_layer.weight.requires_grad = False

### Baseline tokenizer

In [19]:
MAX_CONTEXT_TOKENS = 300

In [96]:
baseline_question_tokenizer = Tokenizer(WordLevel(glove_vocab, unk_token=UNK_TOKEN))
baseline_question_tokenizer.normalizer = Sequence([StripAccents(), Lowercase(), Strip()])
baseline_question_tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
baseline_question_tokenizer.enable_padding(
    direction="right",
    pad_id=glove_vocab[PAD_TOKEN],
    pad_type_id=1,
    pad_token=PAD_TOKEN
)

baseline_context_tokenizer = Tokenizer(WordLevel(glove_vocab, unk_token=UNK_TOKEN))
baseline_context_tokenizer.normalizer = Sequence([StripAccents(), Lowercase(), Strip()])
baseline_context_tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
baseline_context_tokenizer.enable_padding(
    direction="right",
    pad_id=glove_vocab[PAD_TOKEN],
    pad_type_id=1,
    pad_token=PAD_TOKEN,
    length=MAX_CONTEXT_TOKENS,
)
baseline_context_tokenizer.enable_truncation(MAX_CONTEXT_TOKENS)

### Pre-processing

In [97]:
def lost_answers_indexes(df, tokenized_contexts):
    whole_answers_start = df["answer_start"].tolist()
    whole_answers_end = df["answer_end"].tolist()
    lost_dirty, lost_truncated = [], []
    for i, (c, s, e) in enumerate(
        zip(tokenized_contexts, whole_answers_start, whole_answers_end)
    ):
        offsets = torch.tensor(c.offsets)[torch.tensor(c.attention_mask).bool()]
        start_index = torch.nonzero(offsets[:, 0] == s)
        end_index = torch.nonzero(offsets[:, 1] == e)
        if len(start_index) == 0 or len(end_index) == 0:
            if s > offsets[-1, 0] or e > offsets[-1, 1]:
                lost_truncated.append(i)
            else:
                lost_dirty.append(i)

    return lost_truncated, lost_dirty


tokenized_contexts = baseline_context_tokenizer.encode_batch(
    squad_dataset.dataframe["context"].tolist()
)
lost_truncated, lost_dirty = lost_answers_indexes(
    squad_dataset.dataframe, tokenized_contexts
)

In [98]:
lost_truncated_perc = len(lost_truncated) / len(squad_dataset.dataframe)
lost_dirty_perc = len(lost_dirty) / len(squad_dataset.dataframe)
print(
    f"Lost {len(lost_truncated)}/{len(squad_dataset.dataframe)} ({lost_truncated_perc:.2%}) answers, because of truncation"
)
print(
    f"Lost {len(lost_dirty)}/{len(squad_dataset.dataframe)} ({lost_dirty_perc:.2%}) answers, because of dirtyness"
)

Lost 183/87599 (0.21%) answers, because of truncation
Lost 260/87599 (0.30%) answers, because of dirtyness


In [99]:
def show_df_row(df, index):
    row = df.iloc[index]
    display(HTML(pd.DataFrame([row]).to_html()))

In [100]:
random_trunc_index = random.choice(lost_truncated)
show_df_row(squad_dataset.dataframe, random_trunc_index)

Unnamed: 0,index,question,title,context_id,context,answer,answer_start,answer_end
455,56bf97aba10cfb14005511a1,How much did the second world tour make in dollars?,Beyoncé,69,"On April 4, 2008, Beyoncé married Jay Z. She publicly revealed their marriage in a video montage at the listening party for her third studio album, I Am... Sasha Fierce, in Manhattan's Sony Club on October 22, 2008. I Am... Sasha Fierce was released on November 18, 2008 in the United States. The album formally introduces Beyoncé's alter ego Sasha Fierce, conceived during the making of her 2003 single ""Crazy in Love"", selling 482,000 copies in its first week, debuting atop the Billboard 200, and giving Beyoncé her third consecutive number-one album in the US. The album featured the number-one song ""Single Ladies (Put a Ring on It)"" and the top-five songs ""If I Were a Boy"" and ""Halo"". Achieving the accomplishment of becoming her longest-running Hot 100 single in her career, ""Halo""'s success in the US helped Beyoncé attain more top-ten singles on the list than any other woman during the 2000s. It also included the successful ""Sweet Dreams"", and singles ""Diva"", ""Ego"", ""Broken-Hearted Girl"" and ""Video Phone"". The music video for ""Single Ladies"" has been parodied and imitated around the world, spawning the ""first major dance craze"" of the Internet age according to the Toronto Star. The video has won several awards, including Best Video at the 2009 MTV Europe Music Awards, the 2009 Scottish MOBO Awards, and the 2009 BET Awards. At the 2009 MTV Video Music Awards, the video was nominated for nine awards, ultimately winning three including Video of the Year. Its failure to win the Best Female Video category, which went to American country pop singer Taylor Swift's ""You Belong with Me"", led to Kanye West interrupting the ceremony and Beyoncé improvising a re-presentation of Swift's award during her own acceptance speech. In March 2009, Beyoncé embarked on the I Am... World Tour, her second headlining worldwide concert tour, consisting of 108 shows, grossing $119.5 million.",119.5 million,1881,1894


In [102]:
random_dirty_index = random.choice(lost_dirty)
show_df_row(squad_dataset.dataframe, random_dirty_index)

Unnamed: 0,index,question,title,context_id,context,answer,answer_start,answer_end
14434,56e080dc231d4119001ac20d,How many satellites provide the link to the internet?,Saint_Helena,3142,"Saint Helena has a 10/3.6 Mbit/s internet link via Intelsat 707 provided by SURE. Serving a population of more than 4,000, this single satellite link is considered inadequate in terms of bandwidth.",1,19,20


In [103]:
to_remove = lost_truncated + lost_dirty
df = squad_dataset.dataframe.drop(to_remove)
assert len(df) == len(squad_dataset.dataframe) - len(to_remove), (
    f"Before {len(squad_dataset.dataframe)}, "
    f"after {len(df)}, "
    f"removed {len(to_remove)}"
)

In [104]:
squad_dataset.update_df(df)
squad_dataset.dataframe

Unnamed: 0,index,question,title,context_id,context,answer,answer_start,answer_end
0,5733be284776f41900661182,To whom did the Virgin Mary allegedly appear i...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",Saint Bernadette Soubirous,515,541
1,5733be284776f4190066117f,What is in front of the Notre Dame Main Building?,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a copper statue of Christ,188,213
2,5733be284776f41900661180,The Basilica of the Sacred heart at Notre Dame...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",the Main Building,279,296
3,5733be284776f41900661181,What is the Grotto at Notre Dame?,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a Marian place of prayer and reflection,381,420
4,5733be284776f4190066117e,What sits on top of the Main Building at Notre...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a golden statue of the Virgin Mary,92,126
...,...,...,...,...,...,...,...,...
87594,5735d259012e2f140011a09d,In what US state did Kathmandu first establish...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Oregon,229,235
87595,5735d259012e2f140011a09e,What was Yangon previously known as?,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Rangoon,414,421
87596,5735d259012e2f140011a09f,With what Belorussian city does Kathmandu have...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Minsk,476,481
87597,5735d259012e2f140011a0a0,In what year did Kathmandu create its initial ...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",1975,199,203


### Metrics and logs

In [28]:
def exact_match(labels, preds):
    assert labels.shape == preds.shape
    total = labels.shape[0]
    matches = np.count_nonzero((labels == preds).all(axis=1))
    return matches / total

In [40]:
def compute_metrics(eval_prediction):
    # Labels are stored as a tuple of tensors
    # (one for answer start and one for answer end)
    labels = torch.cat(
        [
            eval_prediction.label_ids[0].view(-1, 1),
            eval_prediction.label_ids[1].view(-1, 1),
        ],
        dim=1,
    ).numpy()
    preds = eval_prediction.predictions.numpy()
    f_labels, f_preds = labels.flatten(), preds.flatten()
    return {
        "f1": sklearn.metrics.f1_score(f_labels, f_preds, average="macro"),
        "accuracy": sklearn.metrics.accuracy_score(f_labels, f_preds),
        "em": exact_match(labels, preds),
    }

In [79]:
class SquadTrainer(transformers.Trainer):
    def __init__(self, *args, **kwargs):
        super(SquadTrainer, self).__init__(*args, **kwargs)

    def compute_loss(self, model, inputs):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        outputs = model(**inputs)

        if (self.state.global_step == 1 and self.args.logging_first_step) or (
            self.args.logging_steps > 0
            and self.state.global_step > 0
            and self.state.global_step % self.args.logging_steps == 0
        ):
            has_labels = all(inputs.get(k) is not None for k in self.label_names)
            if has_labels:
                labels = nested_detach(
                    tuple(inputs.get(name) for name in self.label_names)
                )
                if len(labels) == 1:
                    labels = labels[0]
            else:
                labels = None

            if labels is not None:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=outputs["outputs"], label_ids=labels)
                )
                self.log(metrics)

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        return outputs["loss"] if isinstance(outputs, dict) else outputs[0]

## Baseline

In [80]:
class BaselineDataCollatorWithPadding:
    """
    Data collator that will dynamically pad the inputs received
    """

    def __init__(self, question_tokenizer, context_tokenizer):
        self.question_tokenizer = question_tokenizer
        self.context_tokenizer = context_tokenizer

    def find_tokenized_answer_indexes(self, offsets, query, dim):
        index = torch.nonzero(offsets[:, dim] == query)
        assert len(index) in (0, 1)
        return index.item() if len(index) > 0 else -1

    def __call__(self, inputs):
        (questions, contexts, answers_start, answers_end) = zip(*inputs)
        tokenized_questions = self.question_tokenizer.encode_batch(questions)
        tokenized_contexts = self.context_tokenizer.encode_batch(contexts)
        batch_size = len(tokenized_questions)
        questions_shape = (batch_size, len(tokenized_questions[0].ids))
        contexts_shape = (batch_size, len(tokenized_contexts[0].ids))

        batch = {
            "question_ids": torch.empty(questions_shape, dtype=torch.long),
            "question_type_ids": torch.empty(questions_shape, dtype=torch.long),
            "question_attention_mask": torch.empty(questions_shape, dtype=torch.bool),
            "question_special_tokens_mask": torch.empty(
                questions_shape, dtype=torch.bool
            ),
            "question_lenghts": torch.empty((batch_size,), dtype=torch.long),
            "context_ids": torch.empty(contexts_shape, dtype=torch.long),
            "context_type_ids": torch.empty(contexts_shape, dtype=torch.long),
            "context_attention_mask": torch.empty(contexts_shape, dtype=torch.bool),
            "context_special_tokens_mask": torch.empty(
                contexts_shape, dtype=torch.bool
            ),
            "context_offsets": torch.empty((*contexts_shape, 2), dtype=torch.long),
            "context_lenghts": torch.empty((batch_size,), dtype=torch.long),
            # "answer_start": torch.tensor(answers_start, dtype=torch.long),
            # "answer_end": torch.tensor(answers_end, dtype=torch.long),
            "answer_start": torch.empty((batch_size,), dtype=torch.long),
            "answer_end": torch.empty((batch_size,), dtype=torch.long),
        }
        for i in range(batch_size):
            batch["question_ids"][i] = torch.tensor(tokenized_questions[i].ids)
            batch["question_type_ids"][i] = torch.tensor(
                tokenized_questions[i].type_ids
            )
            batch["question_attention_mask"][i] = torch.tensor(
                tokenized_questions[i].attention_mask
            )
            batch["question_special_tokens_mask"][i] = torch.tensor(
                tokenized_questions[i].special_tokens_mask
            )
            batch["question_lenghts"][i] = torch.count_nonzero(
                ~batch["question_special_tokens_mask"][i]
            )
            batch["context_ids"][i] = torch.tensor(tokenized_contexts[i].ids)
            batch["context_type_ids"][i] = torch.tensor(tokenized_contexts[i].type_ids)
            batch["context_attention_mask"][i] = torch.tensor(
                tokenized_contexts[i].attention_mask
            )
            batch["context_special_tokens_mask"][i] = torch.tensor(
                tokenized_contexts[i].special_tokens_mask
            )
            batch["context_offsets"][i] = torch.tensor(tokenized_contexts[i].offsets)
            batch["context_lenghts"][i] = torch.count_nonzero(
                ~batch["context_special_tokens_mask"][i]
            )
            masked_offsets = batch["context_offsets"][i][batch["context_attention_mask"][i]]
            batch["answer_start"][i] = self.find_tokenized_answer_indexes(
                masked_offsets, answers_start[i], 0
            )
            batch["answer_end"][i] = self.find_tokenized_answer_indexes(
                masked_offsets, answers_end[i], 1
            )

        return batch

### Baseline model

In [81]:
baseline_model = model.QABaselineModel(glove_embedding_layer, MAX_CONTEXT_TOKENS)
baseline_model.count_parameters()

101400

In [82]:
baseline_args = transformers.TrainingArguments(
    output_dir="./checkpoints",
    logging_dir="./runs",
    logging_first_step=True,
    logging_steps=5,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    num_train_epochs=10,
    remove_unused_columns=False,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=64,
    label_names=["answer_start", "answer_end"],
)

In [83]:
baseline_trainer = SquadTrainer(
    model=baseline_model,
    args=baseline_args,
    data_collator=BaselineDataCollatorWithPadding(
        baseline_question_tokenizer, baseline_context_tokenizer
    ),
    train_dataset=squad_dataset.train_dataset,
    eval_dataset=squad_dataset.val_dataset,
    compute_metrics=compute_metrics,
)

In [85]:
run_name = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
wandb_logger = init_wandb(name=run_name, group="baseline", reinit=True,)

In [86]:
baseline_trainer.train()

./checkpoints


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [87]:
wandb_logger.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train/loss,11.31863
train/learning_rate,0.001
train/epoch,0.03663
train/f1,0.0
train/accuracy,0.0
train/em,0.0
_step,10.0
_runtime,118.0
_timestamp,1610467892.0


0,1
train/loss,█▆▁
train/learning_rate,█▅▁
train/epoch,▁▄█
train/f1,█▇▁
train/accuracy,█▃▁
train/em,▁▁▁
_step,▁▄█
_runtime,▁▅█
_timestamp,▁▅█


### BiDAF

In [354]:
bidaf_model = model.BiDAFModel(glove_embedding_layer)
bidaf_model.count_parameters()

314350

In [355]:
bidaf_args = transformers.TrainingArguments(
    output_dir="./checkpoints",
    logging_dir="./runs",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    num_train_epochs=3,
    remove_unused_columns=False,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=64,
    label_names=["answer_start", "answer_end"]
)

In [356]:
bidaf_optimizer = optim.Adadelta(bidaf_model.parameters(), lr=0.5)
bidaf_lr_scheduler = optim.lr_scheduler.ExponentialLR(bidaf_optimizer, gamma=.999)

In [357]:
bidaf_trainer = transformers.Trainer(
    model=bidaf_model,
    args=bidaf_args,
    data_collator=BaselineDataCollatorWithPadding(baseline_question_tokenizer, baseline_context_tokenizer),
    train_dataset=squad_dataset.train_dataset,
    eval_dataset=squad_dataset.val_dataset,
    optimizers=(bidaf_optimizer, bidaf_lr_scheduler),
    #callbacks=[transformers.integrations.WandbCallback]
)

In [358]:
baseline_trainer.train()

torch.Size([256, 300, 1]) torch.Size([256, 300, 1])
tensor(15.0205, grad_fn=<NllLossBackward>) tensor(13.8060, grad_fn=<NllLossBackward>)


KeyboardInterrupt: 

## BERT

In [277]:
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
model = transformers.BertModel.from_pretrained("bert-base-uncased")

In [58]:
args = transformers.TrainingArguments(
    output_dir="./checkpoints",
    logging_dir="./runs",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    remove_unused_columns=False,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    label_names=["answer_start", "answer_end"],
)

In [59]:
class DataCollatorWithPadding:
    """
    Data collator that will dynamically pad the inputs received
    """

    def __init__(self, tokenizer, padding=True, max_lenght=None):
        self.tokenizer = tokenizer
        self.padding = True
        self.max_length = None

    def __call__(self, inputs):
        (questions, contexts, _, _) = zip(*inputs)
        tokenized = self.tokenizer(questions, contexts)
        batch = self.tokenizer.pad(
            tokenized,
            padding=self.padding,
            max_length=self.max_length,
            return_tensors="pt",
        )
        return batch

In [60]:
data_collator = DataCollatorWithPadding(tokenizer)

NameError: name 'tokenizer' is not defined

In [None]:
trainer = transformers.Trainer(
    model=model,
    args=args,
    data_collator=data_collator,
    train_dataset=squad_dataset.train_dataset,
    eval_dataset=squad_dataset.val_dataset
)

In [288]:
trainer.train()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])


KeyError: 'loss'