# Question answering on the SQuAD dataset

In [289]:
import sys
import random
from functools import partial

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import gensim
import gensim.downloader as gloader
import wandb
import transformers
import tokenizers
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.normalizers import Sequence, StripAccents, Lowercase, Strip
from tokenizers.pre_tokenizers import Sequence as PreSequence
from tokenizers.pre_tokenizers import Whitespace, Punctuation
from tokenizers import BertWordPieceTokenizer

import dataset
import model
import training

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [290]:
plt.rcParams['figure.figsize'] = [8, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.xmargin'] = .05
plt.rcParams['axes.ymargin'] = .05
plt.style.use('ggplot')

In [291]:
WANDB_PROJECT = "squad-qa"
WANDB_ENTITY = "wadaboa"
WANDB_MODE = "online"
WANDB_RESUME = "never"

init_wandb = partial(
    wandb.init,
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    mode=WANDB_MODE,
    resume=WANDB_RESUME,
)

## Preliminaries

### Raw data loading

In [336]:
squad_dataset = dataset.SquadDataset()

In [337]:
squad_dataset.raw_train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,5733be284776f41900661182,To whom did the Virgin Mary allegedly appear i...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",Saint Bernadette Soubirous,515,541
1,5733be284776f4190066117f,What is in front of the Notre Dame Main Building?,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a copper statue of Christ,188,213
2,5733be284776f41900661180,The Basilica of the Sacred heart at Notre Dame...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",the Main Building,279,296
3,5733be284776f41900661181,What is the Grotto at Notre Dame?,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a Marian place of prayer and reflection,381,420
4,5733be284776f4190066117e,What sits on top of the Main Building at Notre...,University_of_Notre_Dame,0,"Architecturally, the school has a Catholic cha...",a golden statue of the Virgin Mary,92,126
...,...,...,...,...,...,...,...,...
87594,5735d259012e2f140011a09d,In what US state did Kathmandu first establish...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Oregon,229,235
87595,5735d259012e2f140011a09e,What was Yangon previously known as?,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Rangoon,414,421
87596,5735d259012e2f140011a09f,With what Belorussian city does Kathmandu have...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",Minsk,476,481
87597,5735d259012e2f140011a0a0,In what year did Kathmandu create its initial ...,Kathmandu,18890,"Kathmandu Metropolitan City (KMC), in order to...",1975,199,203


In [338]:
squad_dataset.raw_test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be4db0acb8001400a502ec,Which NFL team represented the AFC at Super Bo...,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,Denver Broncos,177,191
1,56be4db0acb8001400a502ed,Which NFL team represented the NFC at Super Bo...,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,Carolina Panthers,249,266
2,56be4db0acb8001400a502ee,Where did Super Bowl 50 take place?,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,"Santa Clara, California",403,426
3,56be4db0acb8001400a502ee,Where did Super Bowl 50 take place?,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,Levi's Stadium,355,369
4,56be4db0acb8001400a502ee,Where did Super Bowl 50 take place?,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,Levi's Stadium in the San Francisco Bay Area a...,355,427
...,...,...,...,...,...,...,...,...
18211,5737aafd1c456719005744fd,What is a very seldom used unit of mass in the...,Force,2066,"The pound-force has a metric counterpart, less...",slug,274,278
18212,5737aafd1c456719005744fd,What is a very seldom used unit of mass in the...,Force,2066,"The pound-force has a metric counterpart, less...",metric slug,267,278
18213,5737aafd1c456719005744fd,What is a very seldom used unit of mass in the...,Force,2066,"The pound-force has a metric counterpart, less...",the metric slug,263,278
18214,5737aafd1c456719005744fe,What seldom used term of a unit of force equal...,Force,2066,"The pound-force has a metric counterpart, less...",kip,712,715


### Embeddings

In [8]:
UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"

- FastText: 
    - _fasttext-wiki-news-subwords_ (dimensions: 300)
- GloVe:
    - _glove-twitter_ (dimensions: 25. 50, 100, 200)
    - _glove-wiki-gigaword_ (dimensions: 50, 100, 200, 300)
- Word2Vec:
    - _word2vec-google-news_ (dimensions: 300)
    - _word2vec-ruscorpora_ (dimensions: 300)

In [9]:
def load_embedding_model(model_name, embedding_dimension=50):
    """
    Loads a pre-trained word embedding model via gensim library
    """
    model = f"{model_name}-{embedding_dimension}"
    try:
        return gloader.load(model)
    except Exception as e:
        print("Invalid embedding model name.")
        raise e


# See https://github.com/RaRe-Technologies/gensim-data
GLOVE_EMBEDDING_DIMENSION = 50
GLOVE_MODEL_NAME = "glove-twitter"
glove_embedding_model = load_embedding_model(
    GLOVE_MODEL_NAME, embedding_dimension=GLOVE_EMBEDDING_DIMENSION
)

In [10]:
glove_unk = np.mean(glove_embedding_model.vectors, axis=0)
glove_embedding_model.add(UNK_TOKEN, glove_unk)

In [11]:
glove_embedding_model[UNK_TOKEN]

array([-0.21896735,  0.17269313, -0.05617283,  0.06307325,  0.00960657,
       -0.23461065, -0.16731773, -0.25613925,  0.12990713, -0.34179848,
       -0.07411992,  0.00533567,  0.7090377 , -0.1139018 ,  0.10613882,
        0.09186497,  0.15880948,  0.03158554,  0.2241412 ,  0.20387109,
        0.05305386,  0.04961218,  0.11807557, -0.10199773, -0.18345806,
        0.56560194,  0.07183363,  0.04322447, -0.39442873,  0.06828266,
        0.39542177,  0.08794834,  0.41605434, -0.27820984, -0.5106833 ,
       -0.16443801,  0.0973425 ,  0.02233286,  0.19346187,  0.15909852,
        0.886585  , -0.01498107,  0.10211241, -0.12959567, -0.328366  ,
        0.13014658, -0.02061043,  0.05735753,  0.14008364,  0.22588447],
      dtype=float32)

In [12]:
list(glove_embedding_model.vocab.keys())[-1]

'[UNK]'

In [13]:
any(np.all(glove_embedding_model.vectors == 0, axis=1))

False

In [14]:
glove_embedding_model.add(PAD_TOKEN, np.zeros((1, GLOVE_EMBEDDING_DIMENSION)))
glove_embedding_model[PAD_TOKEN]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)

In [15]:
list(glove_embedding_model.vocab.keys())[-1]

'[PAD]'

In [16]:
glove_embedding_model.vectors.shape

(1193516, 50)

In [17]:
glove_vocab = dict(
    zip(glove_embedding_model.index2word, range(len(glove_embedding_model.index2word)))
)

In [18]:
glove_embedding_layer = nn.Embedding(
    glove_embedding_model.vectors.shape[0],
    GLOVE_EMBEDDING_DIMENSION,
    padding_idx=glove_vocab[PAD_TOKEN],
)
glove_embedding_layer.weight = nn.Parameter(
    torch.from_numpy(glove_embedding_model.vectors)
)
glove_embedding_layer.weight.requires_grad = False

### Standard tokenizer and preprocessing

In [19]:
MAX_CONTEXT_TOKENS = 300

In [295]:
standard_question_tokenizer = Tokenizer(WordLevel(glove_vocab, unk_token=UNK_TOKEN))
standard_question_tokenizer.normalizer = Sequence(
    [StripAccents(), Lowercase(), Strip()]
)
standard_question_tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
standard_question_tokenizer.enable_padding(
    direction="right", pad_id=glove_vocab[PAD_TOKEN], pad_type_id=1, pad_token=PAD_TOKEN
)

standard_context_tokenizer = Tokenizer(WordLevel(glove_vocab, unk_token=UNK_TOKEN))
standard_context_tokenizer.normalizer = Sequence([StripAccents(), Lowercase(), Strip()])
standard_context_tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
standard_context_tokenizer.enable_padding(
    direction="right",
    pad_id=glove_vocab[PAD_TOKEN],
    pad_type_id=1,
    pad_token=PAD_TOKEN,
    length=MAX_CONTEXT_TOKENS,
)
standard_context_tokenizer.enable_truncation(MAX_CONTEXT_TOKENS)

In [339]:
standard_tokenizer = dataset.StandardSquadTokenizer(
    standard_question_tokenizer, standard_context_tokenizer
)

In [340]:
standard_dm = dataset.SquadDataManager(squad_dataset, standard_tokenizer)

In [341]:
standard_dm.train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56cc239e6d243a140015eeb7,Who were Wang Jiawei and Nyima Gyaincain?,Sino-Tibetan_relations_during_the_Ming_dynasty,299,The exact nature of relations between Tibet an...,[Mainland Chinese scholars],[274],[299]
1,56cc27346d243a140015eeba,What important trade did the Ming Dynasty have...,Sino-Tibetan_relations_during_the_Ming_dynasty,300,Some scholars note that Tibetan leaders during...,[horse trade],[338],[349]
2,56cc27346d243a140015eebb,During what years did the Mongol leader Kublai...,Sino-Tibetan_relations_during_the_Ming_dynasty,300,Some scholars note that Tibetan leaders during...,[1402–1424],[739],[748]
3,56cc27346d243a140015eebc,Who did the Yongle Emperor try to build a reli...,Sino-Tibetan_relations_during_the_Ming_dynasty,300,Some scholars note that Tibetan leaders during...,[Deshin Shekpa],[821],[834]
4,56cc27346d243a140015eebd,Deshin Shekpa was the head of what school?,Sino-Tibetan_relations_during_the_Ming_dynasty,300,Some scholars note that Tibetan leaders during...,[the Karma Kagyu school],[863],[885]
...,...,...,...,...,...,...,...,...
69864,573636bf9c79961900ff7e06,What Botswana was resently forced to do?,Hunting,18832,"In contrast, Botswana has recently been forced...",[ban trophy hunting],[50],[68]
69865,573636bf9c79961900ff7e07,What animal declined across Botswana?,Hunting,18832,"In contrast, Botswana has recently been forced...",[antelope],[126],[134]
69866,573636bf9c79961900ff7e08,What animal numbers have increased in Botswana?,Hunting,18832,"In contrast, Botswana has recently been forced...",[hippopotamus],[251],[263]
69867,573636bf9c79961900ff7e09,What animal numbers remain stable in Botswana?,Hunting,18832,"In contrast, Botswana has recently been forced...",[elephant],[214],[222]


In [342]:
standard_dm.val_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be85543aeaaa14008c9063,When did Beyonce start becoming popular?,Beyoncé,55,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,[in the late 1990s],[269],[286]
1,56be85543aeaaa14008c9065,What areas did Beyonce compete in when she was...,Beyoncé,55,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,[singing and dancing],[207],[226]
2,56be85543aeaaa14008c9066,When did Beyonce leave Destiny's Child and bec...,Beyoncé,55,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,[2003],[526],[530]
3,56be86cf3aeaaa14008c9076,"After her second solo album, what other entert...",Beyoncé,56,Following the disbandment of Destiny's Child i...,[acting],[207],[213]
4,56be86cf3aeaaa14008c9078,Which artist did Beyonce marry?,Beyoncé,56,Following the disbandment of Destiny's Child i...,[Jay Z],[369],[374]
...,...,...,...,...,...,...,...,...
17282,573445bbacc1501500babd6d,Why is cycling popular in Tucson?,"Tucson,_Arizona",16649,Cycling is popular in Tucson due to its flat t...,[its flat terrain and dry climate],[36],[68]
17283,573445bbacc1501500babd6e,What is The Loop?,"Tucson,_Arizona",16649,Cycling is popular in Tucson due to its flat t...,[a network of seven linear parks],[254],[285]
17284,573445bbacc1501500babd6f,How many miles of trails are in The Loop?,"Tucson,_Arizona",16649,Cycling is popular in Tucson due to its flat t...,[over 100],[297],[305]
17285,573445bbacc1501500babd70,What organization advises the Tucson governmen...,"Tucson,_Arizona",16649,Cycling is popular in Tucson due to its flat t...,[Tucson-Pima County Bicycle Advisory Committee],[429],[474]


In [343]:
standard_dm.test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be4db0acb8001400a502ec,Which NFL team represented the AFC at Super Bo...,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,[Denver Broncos],[177],[191]
1,56be4db0acb8001400a502ed,Which NFL team represented the NFC at Super Bo...,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,[Carolina Panthers],[249],[266]
2,56be4db0acb8001400a502ee,Where did Super Bowl 50 take place?,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,"[Santa Clara, California, Levi's Stadium, Levi...","[403, 355, 355]","[426, 369, 427]"
3,56be4db0acb8001400a502ef,Which NFL team won Super Bowl 50?,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,[Denver Broncos],[177],[191]
4,56be4db0acb8001400a502f0,What color was used to emphasize the 50th anni...,Super_Bowl_50,0,Super Bowl 50 was an American football game to...,[gold],[521],[525]
...,...,...,...,...,...,...,...,...
10525,5737aafd1c456719005744fb,What is the metric term less used than the New...,Force,2066,"The pound-force has a metric counterpart, less...","[kilogram-force, pound-force, kilogram-force (...","[82, 4, 82, 78]","[96, 15, 102, 98]"
10526,5737aafd1c456719005744fc,What is the kilogram-force sometimes reffered ...,Force,2066,"The pound-force has a metric counterpart, less...",[kilopond],[114],[122]
10527,5737aafd1c456719005744fd,What is a very seldom used unit of mass in the...,Force,2066,"The pound-force has a metric counterpart, less...","[slug, metric slug, the metric slug]","[274, 267, 263]","[278, 278, 278]"
10528,5737aafd1c456719005744fe,What seldom used term of a unit of force equal...,Force,2066,"The pound-force has a metric counterpart, less...",[kip],[712],[715]


### Metrics and logs

## Baseline model

In [344]:
baseline_model = model.QABaselineModel(glove_embedding_layer, MAX_CONTEXT_TOKENS)
baseline_model.count_parameters()

101400

In [345]:
baseline_args = transformers.TrainingArguments(
    output_dir="./checkpoints",
    logging_dir="./runs",
    logging_first_step=True,
    logging_steps=5,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    num_train_epochs=10,
    remove_unused_columns=False,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=64,
    label_names=["answer_start", "answer_end"],
)

In [346]:
baseline_trainer = training.SquadTrainer(
    model=baseline_model,
    args=baseline_args,
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.train_dataset,
    eval_dataset=standard_dm.val_dataset,
    compute_metrics=training.compute_metrics,
)

In [85]:
baseline_run_name = utils.get_run_name()
baseline_wandb_logger = init_wandb(
    name=baseline_run_name, group="baseline", reinit=True,
)

In [347]:
baseline_trainer.train()

./checkpoints


KeyError: 'ids'

In [87]:
baseline_wandb_logger.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train/loss,11.31863
train/learning_rate,0.001
train/epoch,0.03663
train/f1,0.0
train/accuracy,0.0
train/em,0.0
_step,10.0
_runtime,118.0
_timestamp,1610467892.0


0,1
train/loss,█▆▁
train/learning_rate,█▅▁
train/epoch,▁▄█
train/f1,█▇▁
train/accuracy,█▃▁
train/em,▁▁▁
_step,▁▄█
_runtime,▁▅█
_timestamp,▁▅█


## BiDAF

In [354]:
bidaf_model = model.BiDAFModel(glove_embedding_layer)
bidaf_model.count_parameters()

314350

In [355]:
bidaf_args = transformers.TrainingArguments(
    output_dir="./checkpoints",
    logging_dir="./runs",
    logging_first_step=True,
    logging_steps=5,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    num_train_epochs=10,
    remove_unused_columns=False,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=64,
    label_names=["answer_start", "answer_end"]
)

In [356]:
bidaf_optimizer = optim.Adadelta(bidaf_model.parameters(), lr=0.5)
bidaf_lr_scheduler = optim.lr_scheduler.ExponentialLR(bidaf_optimizer, gamma=.999)

In [357]:
bidaf_trainer = training.SquadTrainer(
    model=bidaf_model,
    args=bidaf_args,
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.train_dataset,
    eval_dataset=standard_dm.val_dataset,
    optimizers=(bidaf_optimizer, bidaf_lr_scheduler),
    compute_metrics=training.compute_metrics,
)

In [None]:
bidaf_run_name = utils.get_run_name()
bidaf_wandb_logger = init_wandb(name=bidaf_run_name, group="bidaf", reinit=True,)

In [358]:
bidaf_trainer.train()

torch.Size([256, 300, 1]) torch.Size([256, 300, 1])
tensor(15.0205, grad_fn=<NllLossBackward>) tensor(13.8060, grad_fn=<NllLossBackward>)


KeyboardInterrupt: 

In [None]:
bidaf_wandb_logger.finish()

## BERT

In [None]:
bert_tokenizer = BertWordPieceTokenizer("data/bert-base-uncased-vocab.txt", lowercase=True)
bert_tokenizer.enable_padding(
    direction="right",
    pad_type_id=1,
    pad_token=PAD_TOKEN,
)

In [None]:
bert_tokenizer = dataset.BertSquadTokenizer(bert_tokenizer)

In [None]:
bert_dm = dataset.SquadDataManager(squad_dataset, bert_tokenizer)

In [216]:
bert_model = transformers.BertModel.from_pretrained("bert-base-uncased")

In [58]:
bert_args = transformers.TrainingArguments(
    output_dir="./checkpoints",
    logging_dir="./runs",
    logging_first_step=True,
    logging_steps=5,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    num_train_epochs=10,
    remove_unused_columns=False,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=64,
    label_names=["answer_start", "answer_end"]
)

In [None]:
bert_trainer = training.SquadTrainer(
    model=bert_model,
    args=bert_args,
    data_collator=bert_dm.tokenizer,
    train_dataset=bert_dm.train_dataset,
    eval_dataset=bert_dm.val_dataset,
    compute_metrics=training.compute_metrics,
)

In [None]:
bert_run_name = utils.get_run_name()
bert_wandb_logger = init_wandb(name=bert_run_name, group="bert", reinit=True,)

In [288]:
bert_trainer.train()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])


KeyError: 'loss'

In [None]:
bert_wandb_logger.finish()