# Question answering on the SQuAD dataset

## Colab requirements

Before restarting runtime (remeber to select GPU runtime)$\dots$

In [None]:
!git clone https://github.com/Wadaboa/squad-question-answering.git
!pip install -r squad-question-answering/init/base_requirements.txt

After restarting runtime$\dots$

In [None]:
import os, sys

sys.path.insert(0, "/content/squad-question-answering")
os.chdir("/content/squad-question-answering")

## Imports

In [1]:
import os
from functools import partial

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
import transformers
from transformers.trainer_utils import set_seed

import dataset
import model
import training
import tokenizer
import utils

%load_ext autoreload
%autoreload 2

## Initialization

### Matplotlib

In [2]:
%matplotlib inline
plt.rcParams['figure.figsize'] = [8, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.xmargin'] = .05
plt.rcParams['axes.ymargin'] = .05
plt.style.use('ggplot')

### Weights & biases

In [3]:
%env WANDB_PROJECT=squad-qa
%env WANDB_ENTITY=wadaboa
%env WANDB_MODE=online
%env WANDB_RESUME=never
%env WANDB_WATCH=false
%env WANDB_SILENT=true

env: WANDB_PROJECT=squad-qa
env: WANDB_ENTITY=wadaboa
env: WANDB_MODE=online
env: WANDB_RESUME=never
env: WANDB_WATCH=false
env: WANDB_SILENT=true


In [4]:
!wandb login

In [5]:
!wandb enabled

W&B enabled.


### PyTorch and numpy

In [6]:
RANDOM_SEED = 42
set_seed(RANDOM_SEED)

In [7]:
DEVICE = utils.get_device()
DEVICE

  return torch._C._cuda_getDeviceCount() > 0


device(type='cpu')

## Preliminaries

### Raw data loading

In [132]:
DATA_FOLDER = os.path.join(os.getcwd(), "data")
TRAIN_DATA_FOLDER = os.path.join(DATA_FOLDER, "training")
TRAIN_SET_PATH = os.path.join(TRAIN_DATA_FOLDER, "training_set.json")
TEST_DATA_FOLDER = os.path.join(DATA_FOLDER, "testing")
TEST_SET_PATH = os.path.join(TEST_DATA_FOLDER, "test_set.json")

In [133]:
squad_dataset = dataset.SquadDataset(
    train_set_path=TRAIN_SET_PATH, test_set_path=TEST_SET_PATH, subset=0.01
)

In [134]:
squad_dataset.raw_train_df

Unnamed: 0,answer_start,answer,title,context,question_id,question,context_id,answer_end
0,421,3000,Dissolution_of_the_Soviet_Union,"On the next day, December 18, protests turned ...",5727891b708984140094e032,How many people were estimated by authorities ...,12157,426
1,0,Roman Catholicism,Germans,Roman Catholicism was the sole established rel...,57295cd51d04691400779321,"Until the Reformation, what was the establishe...",9466,17
2,379,March 1969,Gamal_Abdel_Nasser,Israel retaliated against Egyptian shelling wi...,572832e92ca10214002da07f,When did the war start up again?,13342,389
3,0,Zen,Buddhism,Zen Buddhist teaching is often full of paradox...,56d249a4b329da140004ecf1,What Buddhist teachings are often full of para...,1180,3
4,642,Encyclopaedia of Islam,Quran,Sahih al-Bukhari narrates Muhammad describing ...,572eda2cdfa6aa1500f8d453,In which work did Welch express his belief tha...,16097,664
...,...,...,...,...,...,...,...,...
870,162,1954,Department_store,After World War II Hudson's realized that the ...,57267be6708984140094c79b,In what year did Hudson's address the issue of...,9620,166
871,10,Yellow Emperor,Humanism,"In China, Yellow Emperor is regarded as the hu...",57327bd90fdd8d15006c6b01,Who was known as being a founder of humanism t...,18657,24
872,204,Economic Partnership Agreement,Tuvalu,In July 2013 Tuvalu signed the Memorandum of U...,5730e32fb7151e1900c015b0,What does the trade agreement encompass?,17534,234
873,627,Neal Purvis and Robert Wade,Spectre_(2015_film),"Despite being an original story, Spectre draws...",56cdd97f62d2951400fa68f2,Who were the writers of Spectre?,474,654


In [135]:
squad_dataset.raw_test_df

Unnamed: 0,answer_start,answer,title,context,question_id,question,context_id,answer_end
0,243,"over $2,000,",Nikola_Tesla,"Near the end of his life, Tesla walked to the ...",56e11ba9cd28a01900c675d8,How much did Tesla spend on the injured pigeon?,220,255
1,375,protective radiation shield,Oxygen,Trioxygen (O\n3) is usually known as ozone and...,571c9074dd7acb1400e4c103,What function does ozone perform for the planet?,634,402
2,239,a stroke,Nikola_Tesla,"On 24 March 1879, Tesla was returned to Gospić...",56e0cd33231d4119001ac3c2,What was one of theories as to what caused Tes...,159,247
3,290,Chartered,Teacher,"Teaching may be carried out informally, within...",56e74bf937bdd419002c3e35,What is another type of accountant other than ...,292,299
4,1213,the contemporary Orient,Imperialism,Some have described the internal strife betwee...,5730bb522461fd1900a9d015,Who does Edward Said say is being attacked by ...,1922,1236
...,...,...,...,...,...,...,...,...
177,207,immunization,Immune_system,Long-term active memory is acquired following ...,5729ffda1d046914007796b0,What is the process of vaccination also known as?,1257,219
178,309,chlorophyll b,Chloroplast,"The chloroplastidan chloroplasts, or green chl...",57295b5b1d04691400779319,What do green chloroplasts have instead of phy...,1678,322
179,359,the Miller–Rabin primality test,Prime_number,The property of being prime (or not) is called...,57296f293f37b319004783a5,What is the name of one algorithm useful for c...,1734,390
180,624,lack of understanding,Civil_disobedience,Many of the same decisions and principles that...,5728e8212ca10214002daa6f,What reasons cause failure of the disobedience...,1290,645


### Embeddings

In [26]:
UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"

- FastText: 
    - _fasttext-wiki-news-subwords_ (dimensions: 300)
- GloVe:
    - _glove-twitter_ (dimensions: 25. 50, 100, 200)
    - _glove-wiki-gigaword_ (dimensions: 50, 100, 200, 300)
- Word2Vec:
    - _word2vec-google-news_ (dimensions: 300)
    - _word2vec-ruscorpora_ (dimensions: 300)

In [27]:
# See https://github.com/RaRe-Technologies/gensim-data
GLOVE_EMBEDDING_DIMENSION = 25
GLOVE_MODEL_NAME = "glove-twitter"
glove_embedding_model, glove_vocab = utils.load_embedding_model(
    GLOVE_MODEL_NAME,
    embedding_dimension=GLOVE_EMBEDDING_DIMENSION,
    unk_token=UNK_TOKEN,
    pad_token=PAD_TOKEN,
)

In [28]:
glove_embedding_layer = model.get_embedding_module(
    glove_embedding_model, pad_id=glove_vocab[PAD_TOKEN]
)

### Standard tokenizer and preprocessing

In [136]:
MAX_CONTEXT_TOKENS = 300

In [137]:
standard_tokenizer = tokenizer.get_standard_tokenizer(
    glove_vocab,
    MAX_CONTEXT_TOKENS,
    unk_token=UNK_TOKEN,
    pad_token=PAD_TOKEN,
    device=DEVICE,
)

In [138]:
standard_dm = dataset.SquadDataManager(squad_dataset, standard_tokenizer, device=DEVICE)

In [139]:
standard_dm.train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56cd8b5862d2951400fa66a4,When is the HD version of Twilight Princess sl...,The_Legend_of_Zelda:_Twilight_Princess,433,"At the time of its release, Twilight Princess ...",[March 2016],[530],[540]
1,56cdd97f62d2951400fa68f2,Who were the writers of Spectre?,Spectre_(2015_film),474,"Despite being an original story, Spectre draws...",[Neal Purvis and Robert Wade],[627],[654]
2,56ce42afaab44d1400b88620,Who was given the grandiose title?,Sino-Tibetan_relations_during_the_Ming_dynasty,362,"Sonam Gyatso, after being granted the grandios...",[Sonam Gyatso],[0],[12]
3,56ce9211aab44d1400b88895,What is an example of a passive solar technique?,Solar_energy,833,It is an important source of renewable energy ...,[orienting a building to the Sun],[397],[428]
4,56ce9464aab44d1400b8889a,What is solar energy's yearly potential?,Solar_energy,834,The large magnitude of solar energy available ...,"[1,575–49,837 exajoules (EJ)]",[226],[253]
...,...,...,...,...,...,...,...,...
694,57359fece853931400426a84,How many cities were present in the Kathmandu ...,Kathmandu,18842,"During the later part of the Malla era, Kathma...",[four],[67],[71]
695,5735c92f012e2f140011a046,What are done with the bodies of Kirants after...,Kathmandu,18882,The Bagmati River which flows through Kathmand...,[buried],[275],[281]
696,5735d1a86c16ec1900b92834,From what city does Arkefly offer nonstop flig...,Kathmandu,18889,The main international airport serving Kathman...,[Amsterdam],[698],[707]
697,57361c88012e2f140011a1a7,Who are federal excise taxes are distributed to?,Hunting,18818,"Each year, nearly $200 million in hunters' fed...",[state agencies],[83],[97]


In [140]:
standard_dm.val_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be8e353aeaaa14008c90c7,"How many weeks did their single ""Independent W...",Beyoncé,63,"The remaining band members recorded ""Independe...",[eleven],[216],[222]
1,56beb2a43aeaaa14008c9239,Beyonce along with Jay Z met with whom's famil...,Beyoncé,87,"Following the death of Freddie Gray, Beyoncé a...",[Freddie Gray],[23],[35]
2,56bec3303aeaaa14008c9391,What characteristics has Beyonce received accl...,Beyoncé,97,Beyoncé has received praise for her stage pres...,[stage presence and voice],[36],[60]
3,56bf940da10cfb140055118b,How high did ''Deja Vu'' climb on the Billboar...,Beyoncé,67,Beyoncé's second solo album B'Day was released...,[top five],[342],[350]
4,56bfc281a10cfb14005512b6,Who chose her as number one on his list of Bes...,Beyoncé,97,Beyoncé has received praise for her stage pres...,[Jarett Wieselman],[87],[103]
...,...,...,...,...,...,...,...,...
169,5732b3a5328d981900602019,What does CRA stand for?,Financial_crisis_of_2007%E2%80%9308,1646,A 2000 United States Department of the Treasur...,[Community Reinvestment Act],[162],[188]
170,573334094776f41900660786,"Per Bernanke, how much did the U.S. current ac...",Financial_crisis_of_2007%E2%80%9308,1596,"Bernanke explained that between 1996 and 2004,...",[$650 billion],[93],[105]
171,57335c77d058e614000b5909,Who founded the Atlanta-based Intercontinental...,Financial_crisis_of_2007%E2%80%9308,1619,In testimony before the Senate Committee on Co...,"[Goldman Sachs, Morgan Stanley and BP]",[281],[317]
172,5733703c4776f41900660ad9,What do money market funds frequently invest in?,Financial_crisis_of_2007%E2%80%9308,1578,"In September 2008, the crisis hit its most cri...",[commercial paper issued by corporations],[152],[191]


In [141]:
standard_dm.test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56bec8243aeaaa14008c942f,What Super Bowl was the last where a fumble wa...,Super_Bowl_50,46,"After each team punted, Panthers quarterback C...",[XXVIII],[635],[641]
1,56bec9133aeaaa14008c9445,Who tackled Mike Tolbert and caused a fumble?,Super_Bowl_50,48,On Carolina's next possession fullback Mike To...,[Stewart],[103],[110]
2,56bf49993aeaaa14008c95b8,Where did the Broncos practice at for Super Bo...,Super_Bowl_50,23,The Panthers used the San Jose State practice ...,[Stanford],[117],[125]
3,56d602631c85041400946eda,Who headlined the halftime show for Super Bowl...,Super_Bowl_50,3,"CBS broadcast Super Bowl 50 in the U.S., and c...",[Coldplay],[194],[202]
4,56d6ee6e0d65d21400198257,What was the third city that was considered?,Super_Bowl_50,5,The league eventually narrowed the bids to thr...,[San Francisco Bay Area's],[128],[152]
...,...,...,...,...,...,...,...,...
177,57373d0cc3c5551400e51e86,What was the concept of force an integral part...,Force,2025,Aristotle provided a philosophical discussion ...,[Aristotelian cosmology],[95],[117]
178,57376a1bc3c5551400e51ec4,What do forces have with regard to additive qu...,Force,2035,"Historically, forces were first quantitatively...",[magnitude and direction],[248],[271]
179,57376a1bc3c5551400e51ec6,"When forces are acting on an extended body, wh...",Force,2035,"Historically, forces were first quantitatively...",[respective lines of application],[863],[894]
180,57377083c3c5551400e51ee2,Objects of constant density are proportional t...,Force,2038,A static equilibrium between two forces is the...,[force of gravity],[430],[446]


## Utils

In [142]:
TRAINER_ARGS = utils.get_default_trainer_args()

## Baseline model

In [171]:
baseline_model = model.QABaselineModel(
    glove_embedding_layer, MAX_CONTEXT_TOKENS, device=DEVICE
)
print(f"The baseline model has {baseline_model.count_parameters()} parameters")

The baseline model has 41000 parameters


In [172]:
baseline_optimizer = optim.Adam(baseline_model.parameters(), lr=1e-3)
baseline_lr_scheduler = transformers.get_constant_schedule(baseline_optimizer)

In [173]:
%env WANDB_RUN_GROUP=baseline
baseline_run_name = utils.get_run_name()
baseline_args = TRAINER_ARGS(
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{baseline_run_name}",
    run_name=baseline_run_name,
    num_train_epochs=30,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
)
baseline_trainer = training.SquadTrainer(
    model=baseline_model,
    args=baseline_args,
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.train_dataset,
    eval_dataset=standard_dm.val_dataset,
    optimizers=(baseline_optimizer, baseline_lr_scheduler),
)

env: WANDB_RUN_GROUP=baseline


In [None]:
baseline_trainer.train()

In [39]:
baseline_test_output = baseline_trainer.predict(standard_dm.test_dataset)
baseline_test_output.metrics

{'test_loss': 5.292969703674316,
 'test_f1': 0.0,
 'test_accuracy': 0.0,
 'test_em': 0.0,
 'test_runtime': 0.0829,
 'test_samples_per_second': 12.058}

In [40]:
baseline_answers_path = "results/answers/baseline.json"
utils.save_answers(baseline_answers_path, baseline_test_output.predictions[-1])
wandb.save(baseline_answers_path);
wandb.finish()

## BiDAF

In [41]:
bidaf_model = model.QABiDAFModel(glove_embedding_layer, device=DEVICE)
print(f"The BiDAF model has {bidaf_model.count_parameters()} parameters")

The BiDAF model has 79675 parameters


In [42]:
bidaf_optimizer = optim.Adadelta(bidaf_model.parameters(), lr=0.5)
bidaf_lr_scheduler = transformers.get_constant_schedule(bidaf_optimizer)

In [43]:
%env WANDB_RUN_GROUP=bidaf
bidaf_run_name = utils.get_run_name()
bidaf_args = TRAINER_ARGS(
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{bidaf_run_name}",
    run_name=bidaf_run_name,
    num_train_epochs=12,
    per_device_train_batch_size=60,
    per_device_eval_batch_size=60,
)
bidaf_trainer = training.SquadTrainer(
    model=bidaf_model,
    args=bidaf_args,
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.train_dataset,
    eval_dataset=standard_dm.val_dataset,
    optimizers=(bidaf_optimizer, bidaf_lr_scheduler),
)

env: WANDB_RUN_GROUP=bidaf


In [44]:
bidaf_trainer.train()

Epoch,Training Loss,Validation Loss,F1,Accuracy,Em,Runtime,Samples Per Second
1,6.6194,5.197407,0.0,0.0,0.0,0.3057,6.542
2,6.6194,5.198681,0.0,0.0,0.0,0.3101,6.45
3,6.6194,6.93373,0.0,0.0,0.0,0.3094,6.464
4,6.6194,6.934238,0.0,0.0,0.0,0.3092,6.469
5,5.9164,6.934334,0.0,0.0,0.0,0.3059,6.538
6,5.9164,6.93491,0.0,0.0,0.0,0.3039,6.581
7,5.9164,6.9356,0.0,0.0,0.0,0.3126,6.397
8,5.9164,6.935662,0.0,0.0,0.0,0.3026,6.61
9,5.9164,6.935297,0.0,0.0,0.0,0.3155,6.339
10,6.6018,6.632412,0.0,0.0,0.0,0.3157,6.335


TrainOutput(global_step=12, training_loss=6.486756801605225, metrics={'train_runtime': 26.8579, 'train_samples_per_second': 0.447, 'total_flos': 0, 'epoch': 12.0})

In [45]:
bidaf_test_output = bidaf_trainer.predict(standard_dm.test_dataset)
bidaf_test_output.metrics

{'test_loss': 8.894977569580078,
 'test_f1': 0.0,
 'test_accuracy': 0.0,
 'test_em': 0.0,
 'test_runtime': 0.2739,
 'test_samples_per_second': 3.651}

In [46]:
bidaf_answers_path = "results/answers/bidaf.json"
utils.save_answers(bidaf_answers_path, bidaf_test_output.predictions[-1])
wandb.save(bidaf_answers_path);
wandb.finish()

## BERT

In [14]:
MAX_BERT_TOKENS = 512

In [15]:
bert_tokenizer = tokenizer.get_bert_tokenizer(device=DEVICE)

In [16]:
bert_dm = dataset.SquadDataManager(squad_dataset, bert_tokenizer, device=DEVICE)

In [17]:
bert_dm.train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56cdd28562d2951400fa68bd,Who does M fight with?,Spectre_(2015_film),470,Bond and Swann return to London where they mee...,[C],[105],[106]
1,56de4d9ecffd8e1900b4b7e2,What year was the Banská Akadémia founded?,Institute_of_technology,1860,The world's first institution of technology or...,[1735],[167],[171]
2,572674a05951b619008f7319,What is another speed that can also be reporte...,Film_speed,9354,The standard specifies how speed ratings shoul...,[SOS-based speed],[793],[808]
3,5726ef98708984140094d66e,What conferences became a requirement after Va...,Pope_Paul_VI,10862,Some critiqued Paul VI's decision; the newly c...,[National Bishop Conferences],[347],[374]
4,572843ce4b864d190016485c,What was the purpose of top secret ICBM commit...,John_von_Neumann,11497,"Shortly before his death, when he was already ...",[decide on the feasibility of building an ICBM...,[194],[284]
5,5730bb058ab72b1400f9c72c,Where were the use of advanced materials and t...,Sumer,17505,The most impressive and famous of Sumerian bui...,[Sumerian temples and palaces],[421],[449]


In [18]:
bert_dm.val_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,570e1a2a0dc6ce1900204dbf,How many species of fungi have been found on A...,Antarctica,6902,About 1150 species of fungi have been recorded...,[1150],[6],[10]
1,572781a5f1498d1400e8fa1f,Who is elected every even numbered year?,"Ann_Arbor,_Michigan",10585,Ann Arbor has a council-manager form of govern...,[mayor],[192],[197]


In [19]:
bert_dm.test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,57379ed81c456719005744d7,In what way do idea strings transmit tesion fo...,Force,2058,Tension forces can be modeled using ideal stri...,[instantaneously in action-reaction pairs],[250],[290]


In [20]:
bert_model = model.QABertModel(device=DEVICE)

In [21]:
bert_optimizer = optim.Adam(bert_model.parameters(), lr=5e-5)
bert_lr_scheduler = transformers.get_constant_schedule(bert_optimizer)

In [22]:
%env WANDB_RUN_GROUP=bert
bert_run_name = utils.get_run_name()
bert_args = TRAINER_ARGS(
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{bert_run_name}",
    run_name=bert_run_name,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
)
bert_trainer = training.SquadTrainer(
    model=bert_model,
    args=bert_args,
    data_collator=bert_dm.tokenizer,
    train_dataset=bert_dm.train_dataset,
    eval_dataset=bert_dm.val_dataset,
    optimizers=(bert_optimizer, bert_lr_scheduler),
)

env: WANDB_RUN_GROUP=bert


In [23]:
bert_trainer.train()

Epoch,Training Loss,Validation Loss,F1,Accuracy,Em,Runtime,Samples Per Second
1,6.5272,6.653205,0.166667,0.25,0.0,1.0424,1.919
2,6.5272,7.183872,0.166667,0.25,0.0,1.0671,1.874
3,6.5272,7.202362,0.166667,0.25,0.0,0.9527,2.099


TrainOutput(global_step=3, training_loss=5.523826281229655, metrics={'train_runtime': 45.0591, 'train_samples_per_second': 0.067, 'total_flos': 0, 'epoch': 3.0})

In [24]:
bert_test_output = bert_trainer.predict(bert_dm.test_dataset)
bert_test_output.metrics

{'test_loss': 9.000775337219238,
 'test_f1': 0.0,
 'test_accuracy': 0.0,
 'test_em': 0.0,
 'test_runtime': 0.6044,
 'test_samples_per_second': 1.655}

In [25]:
bert_answers_path = "results/answers/bert.json"
utils.save_answers(bert_answers_path, bert_test_output.predictions[-1])
wandb.save(bert_answers_path);
wandb.finish()