# Question answering on the SQuAD dataset

In [3]:
import sys
import random
from functools import partial

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
import transformers
import tokenizers

from transformers.trainer_utils import set_seed
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.normalizers import Sequence, StripAccents, Lowercase, Strip
from tokenizers.pre_tokenizers import Sequence as PreSequence
from tokenizers.pre_tokenizers import Whitespace, Punctuation
from tokenizers import BertWordPieceTokenizer

import dataset
import model
import training
import utils

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
plt.rcParams['figure.figsize'] = [8, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.xmargin'] = .05
plt.rcParams['axes.ymargin'] = .05
plt.style.use('ggplot')

In [5]:
WANDB_PROJECT = "squad-qa"
WANDB_ENTITY = "wadaboa"
WANDB_MODE = "online"
WANDB_RESUME = "never"

init_wandb = partial(
    wandb.init,
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    mode=WANDB_MODE,
    resume=WANDB_RESUME,
)

In [6]:
!wandb enabled

W&B enabled.


In [7]:
RANDOM_SEED = 42
set_seed(RANDOM_SEED)

In [8]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  return torch._C._cuda_getDeviceCount() > 0


## Preliminaries

### Raw data loading

In [9]:
squad_dataset = dataset.SquadDataset(subset=0.01)

In [10]:
squad_dataset.raw_train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56de4d9ecffd8e1900b4b7e2,What year was the Banská Akadémia founded?,Institute_of_technology,1860,The world's first institution of technology or...,1735,167,171
1,572674a05951b619008f7319,What is another speed that can also be reporte...,Film_speed,9354,The standard specifies how speed ratings shoul...,SOS-based speed,793,808
2,5730bb058ab72b1400f9c72c,Where were the use of advanced materials and t...,Sumer,17505,The most impressive and famous of Sumerian bui...,Sumerian temples and palaces,421,449
3,572781a5f1498d1400e8fa1f,Who is elected every even numbered year?,"Ann_Arbor,_Michigan",10585,Ann Arbor has a council-manager form of govern...,mayor,192,197
4,572843ce4b864d190016485c,What was the purpose of top secret ICBM commit...,John_von_Neumann,11497,"Shortly before his death, when he was already ...",decide on the feasibility of building an ICBM ...,194,284
...,...,...,...,...,...,...,...,...
870,5725cad2ec44d21400f3d5a0,It was an early backer of what,Bill_%26_Melinda_Gates_Foundation,8521,A key aspect of the Gates Foundation's U.S. ef...,The foundation was the biggest early backer of...,449,539
871,570c0fedec8fbc190045bc47,What event in February did Barcelona qualify f...,FC_Barcelona,6319,"On 4 January 2016, Barcelona's transfer ban en...",Copa del Rey final,267,285
872,57327ed206a3a419008aca8d,What caused a large migration of Greek refuges...,Humanism,18662,The humanists' close study of Latin literary t...,Greek manuscripts,1262,1279
873,5727c05eff5b5019007d9449,What league are the Toronto Bluejays in?,Exhibition_game,13088,Several MLB teams used to play regular exhibit...,American League,246,261


In [11]:
squad_dataset.raw_test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,57379ed81c456719005744d7,In what way do idea strings transmit tesion fo...,Force,2058,Tension forces can be modeled using ideal stri...,instantaneously in action-reaction pairs,250,290
1,56e0ed557aa994140058e7dd,What was Tesla's device called?,Nikola_Tesla,181,Tesla also explained the principles of the rot...,Egg of Columbus,187,202
2,56e1febfe3433e140042323a,What is the example of another problem charact...,Computational_complexity_theory,282,What intractability means in practice is open ...,NP-complete Boolean satisfiability problem,539,581
3,57269bb8708984140094cb98,What are EU Regulations essentially the same a...,European_Union_law,759,Although it is generally accepted that EU law ...,Treaty provisions,1332,1349
4,56f884cba6d7ea1400e17708,What theologian differed in views about the so...,Martin_Luther,412,The Lutheran theologian Franz Pieper observed ...,Gerhard,181,188
...,...,...,...,...,...,...,...,...
177,5711475ca58dae1900cd6d8b,"In a 4-cylinder compound engine, what degree w...",Steam_engine,587,With two-cylinder compounds used in railway wo...,180,313,316
178,572fdd03a23a5019007fcaa0,What party had a victory in the 2015 UK election?,Scottish_Parliament,1846,A procedural consequence of the establishment ...,Conservative,465,477
179,572ff56304bcaa1900d76f30,What does the Nederrijn change it's name to?,Rhine,1781,The other third of the water flows through the...,the Lek,402,409
180,5737a0acc3c5551400e51f4a,What may a force on one part of an object affect?,Force,2059,Newton's laws and Newtonian mechanics in gener...,other parts,276,287


### Embeddings

In [12]:
UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"

- FastText: 
    - _fasttext-wiki-news-subwords_ (dimensions: 300)
- GloVe:
    - _glove-twitter_ (dimensions: 25. 50, 100, 200)
    - _glove-wiki-gigaword_ (dimensions: 50, 100, 200, 300)
- Word2Vec:
    - _word2vec-google-news_ (dimensions: 300)
    - _word2vec-ruscorpora_ (dimensions: 300)

In [9]:
# See https://github.com/RaRe-Technologies/gensim-data
GLOVE_EMBEDDING_DIMENSION = 50
GLOVE_MODEL_NAME = "glove-twitter"
glove_embedding_model = utils.load_embedding_model(
    GLOVE_MODEL_NAME, embedding_dimension=GLOVE_EMBEDDING_DIMENSION
)

In [10]:
glove_unk = np.mean(glove_embedding_model.vectors, axis=0)
glove_embedding_model.add(UNK_TOKEN, glove_unk)

In [11]:
glove_embedding_model[UNK_TOKEN]

array([-0.21896735,  0.17269313, -0.05617283,  0.06307325,  0.00960657,
       -0.23461065, -0.16731773, -0.25613925,  0.12990713, -0.34179848,
       -0.07411992,  0.00533567,  0.7090377 , -0.1139018 ,  0.10613882,
        0.09186497,  0.15880948,  0.03158554,  0.2241412 ,  0.20387109,
        0.05305386,  0.04961218,  0.11807557, -0.10199773, -0.18345806,
        0.56560194,  0.07183363,  0.04322447, -0.39442873,  0.06828266,
        0.39542177,  0.08794834,  0.41605434, -0.27820984, -0.5106833 ,
       -0.16443801,  0.0973425 ,  0.02233286,  0.19346187,  0.15909852,
        0.886585  , -0.01498107,  0.10211241, -0.12959567, -0.328366  ,
        0.13014658, -0.02061043,  0.05735753,  0.14008364,  0.22588447],
      dtype=float32)

In [12]:
list(glove_embedding_model.vocab.keys())[-1]

'[UNK]'

In [13]:
any(np.all(glove_embedding_model.vectors == 0, axis=1))

False

In [14]:
glove_embedding_model.add(PAD_TOKEN, np.zeros((1, GLOVE_EMBEDDING_DIMENSION)))
glove_embedding_model[PAD_TOKEN]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)

In [15]:
list(glove_embedding_model.vocab.keys())[-1]

'[PAD]'

In [16]:
glove_embedding_model.vectors.shape

(1193516, 50)

In [17]:
glove_vocab = dict(
    zip(glove_embedding_model.index2word, range(len(glove_embedding_model.index2word)))
)

In [18]:
glove_embedding_layer = nn.Embedding(
    glove_embedding_model.vectors.shape[0],
    GLOVE_EMBEDDING_DIMENSION,
    padding_idx=glove_vocab[PAD_TOKEN],
)
glove_embedding_layer.weight = nn.Parameter(
    torch.from_numpy(glove_embedding_model.vectors)
)
glove_embedding_layer.weight.requires_grad = False

### Standard tokenizer and preprocessing

In [185]:
MAX_CONTEXT_TOKENS = 300

In [223]:
standard_question_tokenizer = Tokenizer(WordLevel(glove_vocab, unk_token=UNK_TOKEN))
standard_question_tokenizer.normalizer = Sequence(
    [StripAccents(), Lowercase(), Strip()]
)
standard_question_tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
standard_question_tokenizer.enable_padding(
    direction="right", pad_id=glove_vocab[PAD_TOKEN], pad_type_id=1, pad_token=PAD_TOKEN
)

standard_context_tokenizer = Tokenizer(WordLevel(glove_vocab, unk_token=UNK_TOKEN))
standard_context_tokenizer.normalizer = Sequence([StripAccents(), Lowercase(), Strip()])
standard_context_tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
standard_context_tokenizer.enable_padding(
    direction="right",
    pad_id=glove_vocab[PAD_TOKEN],
    pad_type_id=1,
    pad_token=PAD_TOKEN,
    length=MAX_CONTEXT_TOKENS,
)
standard_context_tokenizer.enable_truncation(MAX_CONTEXT_TOKENS)

In [224]:
standard_tokenizer = dataset.StandardSquadTokenizer(
    standard_question_tokenizer, standard_context_tokenizer, device=DEVICE
)

In [225]:
standard_dm = dataset.SquadDataManager(squad_dataset, standard_tokenizer, device=DEVICE)

In [226]:
standard_dm.tokenizer = standard_tokenizer

In [227]:
standard_dm.train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56cc27346d243a140015eebe,The Tibetan leaders had a diplomacy with what ...,Sino-Tibetan_relations_during_the_Ming_dynasty,300,Some scholars note that Tibetan leaders during...,[Nepal],[162],[167]
1,56cd8ec662d2951400fa6706,Hostile spirits are also known as what?,The_Legend_of_Zelda:_Twilight_Princess,438,"When Link enters the Twilight Realm, the void ...",[Poes],[1014],[1018]
2,56cda10262d2951400fa6793,Who was in charge of overseeing audio production?,The_Legend_of_Zelda:_Twilight_Princess,452,The game's score was composed by Toru Minegish...,[Koji Kondo],[84],[94]
3,56cda64a62d2951400fa67c0,In what areas is the content of the GameStop b...,The_Legend_of_Zelda:_Twilight_Princess,457,A CD containing 20 musical selections from the...,"[Japan, Europe, and Australia]",[150],[178]
4,56ce3631aab44d1400b885d1,What date did To Kill a Mockingbird begin to c...,To_Kill_a_Mockingbird,738,"Ultimately, Lee spent over two and a half year...","[July 11, 1960]",[105],[118]
...,...,...,...,...,...,...,...,...
694,5733d4c24776f41900661306,What also happened in 1939 besides tyrothricin?,Antibiotics,211,"In 1939, coinciding with the start of World Wa...","[start of World War II,]",[29],[51]
695,57342435d058e614000b69e6,What electronic charge do cellular molecules h...,Infection,18780,Other microscopic procedures may also aid in i...,[negatively charged],[179],[197]
696,573424434776f41900661944,What city saw the largest growth?,Montana,155,The United States Census Bureau estimates that...,[Kalispell],[561],[570]
697,5735f8bc012e2f140011a106,What scriptures describe hunting as and accept...,Hunting,18805,Hindu scriptures describe hunting as an accept...,[Hindu],[0],[5]


In [228]:
standard_dm.val_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be99b53aeaaa14008c913f,Who did Beyonce donate the money to earned fro...,Beyoncé,73,"In 2011, documents obtained by WikiLeaks revea...",[Clinton Bush Haiti Fund],[367],[390]
1,56beb5b23aeaaa14008c9287,Beyonce was coached for her Spanish songs by w...,Beyoncé,90,"Beyoncé's music is generally R&B, but she also...",[Rudy Perez],[516],[526]
2,56bed07e3aeaaa14008c94a9,Which soda company has Beyonce partnered with ...,Beyoncé,110,"Beyoncé has worked with Pepsi since 2002, and ...",[Pepsi],[24],[29]
3,56bfda91a10cfb1400551339,What did she agree to do for 50 million dollar...,Beyoncé,110,"Beyoncé has worked with Pepsi since 2002, and ...",[endorse Pepsi],[191],[204]
4,56cbdfbf6d243a140015edba,What short poem spoke of Frédéric's popularity...,Frédéric_Chopin,225,In 1817 the Saxon Palace was requisitioned by ...,[Nasze Przebiegi],[566],[581]
...,...,...,...,...,...,...,...,...
166,573220fce99e3014001e653a,What means guaranteed the Jews and Judaism in ...,Religion_in_ancient_Rome,18187,For at least a century before the establishmen...,[treaty],[130],[136]
167,5732321ce17f3d1400422717,What edict defined imperial ideas as being tho...,Religion_in_ancient_Rome,18194,Constantine successfully balanced his own role...,[edict of Milan],[238],[252]
168,573409fbd058e614000b684a,When did the mail stagecoaches stop running?,"Tucson,_Arizona",16601,"Arizona, south of the Gila River was legally b...",[August 1861],[611],[622]
169,57340b224776f4190066179d,What was Tucson's population in 2006?,"Tucson,_Arizona",16603,"By 1900, 7,531 people lived in the city. The p...","[535,000]",[608],[615]


In [229]:
standard_dm.test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56beaa4a3aeaaa14008c91c3,Who did the Broncos prevent from going to the ...,Super_Bowl_50,1,The Panthers finished the regular season with ...,[the New England Patriots],[368],[392]
1,56beb4343aeaaa14008c925e,How many balls did Josh Norman intercept?,Super_Bowl_50,11,"The Panthers defense gave up just 308 points, ...",[four],[1104],[1108]
2,56beb4343aeaaa14008c925f,Who registered the most sacks on the team this...,Super_Bowl_50,11,"The Panthers defense gave up just 308 points, ...",[Short],[199],[204]
3,56beb90c3aeaaa14008c92c7,When was Manning picked #1 in the NFL Draft?,Super_Bowl_50,19,This was the first Super Bowl to feature a qua...,[1998],[149],[153]
4,56bec7a63aeaaa14008c941a,Who kicked a field goal for Denver?,Super_Bowl_50,45,Denver took the opening kickoff and started ou...,[McManus],[544],[551]
...,...,...,...,...,...,...,...,...
173,57373d0cc3c5551400e51e88,Where did Aristotle believe the natural place ...,Force,2025,Aristotle provided a philosophical discussion ...,[ground],[388],[394]
174,573792ee1c456719005744b9,What was dificult to reconcile the photoelectr...,Force,2053,"However, attempting to reconcile electromagnet...",[electromagnetic theory],[33],[55]
175,5737958ac3c5551400e51f2b,What is needed to pack electrons densely toget...,Force,2054,It is a common misconception to ascribe the st...,[energy],[579],[585]
176,5737a4511c456719005744e1,To calculate instant angular acceleration of a...,Force,2061,Torque is the rotation equivalent of force in ...,[Newton's Second Law of Motion],[375],[404]


## Utils

In [23]:
TRAINER_ARGS = partial(
    transformers.TrainingArguments,
    output_dir="./checkpoints",
    logging_dir="./runs",
    logging_first_step=True,
    logging_steps=5,
    overwrite_output_dir=True,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    num_train_epochs=1,
    remove_unused_columns=False,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    label_names=["answers"],
    seed=RANDOM_SEED
)

## Baseline model

In [239]:
baseline_model = model.QABaselineModel(glove_embedding_layer, MAX_CONTEXT_TOKENS, device=DEVICE)
print(f"The baseline model has {baseline_model.count_parameters()} parameters")

The baseline model has 101400 parameters


In [240]:
baseline_args = TRAINER_ARGS()
baseline_trainer = training.SquadTrainer(
    model=baseline_model,
    args=baseline_args,
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.train_dataset,
    eval_dataset=standard_dm.val_dataset,
)

In [198]:
baseline_run_name = utils.get_run_name()
baseline_wandb_logger = init_wandb(
    name=baseline_run_name, group="baseline", reinit=True,
)

[34m[1mwandb[0m: wandb version 0.10.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [241]:
baseline_trainer.train()

Epoch,Training Loss,Validation Loss,F1,Accuracy,Em,Runtime,Samples Per Second
1,13.2616,13.711234,0.003027,0.011696,0.0,2.0848,82.021


TrainOutput(global_step=11, training_loss=13.365315350619229, metrics={'train_runtime': 22.6835, 'train_samples_per_second': 0.485, 'total_flos': 0, 'epoch': 1.0})

In [242]:
baseline_test_output = baseline_trainer.predict(standard_dm.test_dataset)

In [243]:
baseline_test_output.metrics

{'test_loss': 13.67077350616455,
 'test_f1': 0.00019309864064133845,
 'test_accuracy': 0.0056179775280898875,
 'test_em': 0.0,
 'test_runtime': 2.1768,
 'test_samples_per_second': 81.772}

In [205]:
baseline_answers_path = "results/baseline.json"
utils.save_answers(baseline_answers_path, baseline_test_output.predictions[-1])
baseline_wandb_logger.save(baseline_answers_path)

['/root/jupyter/squad-question-answering/wandb/run-20210118_133945-2cbsx7pg/files/results/baseline.json']

In [206]:
baseline_wandb_logger.finish()

## BiDAF

In [244]:
bidaf_model = model.QABiDAFModel(glove_embedding_layer, device=DEVICE)
print(f"The BiDAF model has {bidaf_model.count_parameters()} parameters")

The BiDAF model has 314350 parameters


In [245]:
bidaf_optimizer = optim.Adadelta(bidaf_model.parameters(), lr=0.5)
bidaf_lr_scheduler = optim.lr_scheduler.ExponentialLR(bidaf_optimizer, gamma=.999)

In [246]:
bidaf_args = TRAINER_ARGS()
bidaf_trainer = training.SquadTrainer(
    model=bidaf_model,
    args=bidaf_args,
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.train_dataset,
    eval_dataset=standard_dm.val_dataset,
    optimizers=(bidaf_optimizer, bidaf_lr_scheduler),
)

In [210]:
bidaf_run_name = utils.get_run_name()
bidaf_wandb_logger = init_wandb(name=bidaf_run_name, group="bidaf", reinit=True,)

[34m[1mwandb[0m: Currently logged in as: [33mwadaboa[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [247]:
bidaf_trainer.train()

Epoch,Training Loss,Validation Loss,F1,Accuracy,Em,Runtime,Samples Per Second
1,12.7738,13.520892,0.006754,0.011696,0.0,6.0142,28.433


TrainOutput(global_step=11, training_loss=12.868615410544656, metrics={'train_runtime': 71.9392, 'train_samples_per_second': 0.153, 'total_flos': 0, 'epoch': 1.0})

In [248]:
bidaf_test_output = bidaf_trainer.predict(standard_dm.test_dataset)

In [249]:
bidaf_test_output.metrics

{'test_loss': 13.101156234741211,
 'test_f1': 0.004649616256759114,
 'test_accuracy': 0.011235955056179775,
 'test_em': 0.0,
 'test_runtime': 6.7952,
 'test_samples_per_second': 26.195}

In [None]:
bidaf_answers_path = "results/bidaf.json"
utils.save_answers(bidaf_answers_path, bidaf_test_output.predictions[-1])
bidaf_wandb_logger.save(bidaf_answers_path)

In [601]:
bidaf_wandb_logger.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

## BERT

In [27]:
MAX_BERT_TOKENS = 512

In [28]:
bert_model = model.QABertModel(device=DEVICE)

In [15]:
bert_wp_tokenizer = BertWordPieceTokenizer("data/bert-base-uncased-vocab.txt", lowercase=True)
bert_wp_tokenizer.enable_padding(
    direction="right",
    pad_type_id=1,
)
bert_wp_tokenizer.enable_truncation(MAX_BERT_TOKENS)

In [16]:
bert_tokenizer = dataset.BertSquadTokenizer(bert_wp_tokenizer, device=DEVICE)

In [17]:
bert_dm = dataset.SquadDataManager(squad_dataset, bert_tokenizer, device=DEVICE)

In [18]:
bert_dm.tokenizer = bert_tokenizer

In [19]:
bert_dm.train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56cd788162d2951400fa65e7,What was Apple's highest quarterly profit as o...,IPod,415,"On January 22, 2008, Apple reported the best q...",[$1.58 billion],[184],[197]
1,56cda41362d2951400fa679b,How many people would be in the orchestra Kond...,The_Legend_of_Zelda:_Twilight_Princess,453,Media requests at the trade show prompted Kond...,[50],[217],[219]
2,56cda8a662d2951400fa67d6,Where did Twilight Princess place among Wii ti...,The_Legend_of_Zelda:_Twilight_Princess,460,Twilight Princess received the awards for Best...,[4th],[777],[780]
3,56cdcfee62d2951400fa687a,Who performs research for Bond?,Spectre_(2015_film),467,Bond disobeys M's order and travels to Rome to...,[Moneypenny],[388],[398]
4,56cdd28562d2951400fa68bd,Who does M fight with?,Spectre_(2015_film),470,Bond and Swann return to London where they mee...,[C],[105],[106]
...,...,...,...,...,...,...,...,...
693,5734580c879d6814001ca544,What official is in charge of Richmond's execu...,"Richmond,_Virginia",17037,Richmond city government consists of a city co...,[mayor],[183],[188]
694,57359c97e853931400426a3f,What is the English translation of Mandap?,Kathmandu,18835,The city of Kathmandu is named after Kasthaman...,[covered shelter],[155],[170]
695,57359f64e853931400426a7e,What notable Nepali figure died in a Kathmandu...,Kathmandu,18841,The Licchavi era was followed by the Malla era...,[Abhaya Malla],[427],[439]
696,5735a721e853931400426aa9,"What are Madhyapur Thimi, Kirtipur and Bhaktapur?",Kathmandu,18846,The agglomeration of Kathmandu has not yet bee...,[municipalities],[441],[455]


In [20]:
bert_dm.val_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be8e353aeaaa14008c90c6,"""Charlie's Angels"" featured which single from ...",Beyoncé,63,"The remaining band members recorded ""Independe...",[Independent Women Part I],[37],[61]
1,56be9eea3aeaaa14008c9184,How many awards did Beyonce take home with her...,Beyoncé,79,At the 57th Annual Grammy Awards in February 2...,[three],[108],[113]
2,56bf89cfa10cfb1400551163,What album caused a lawsuit to be filed in 2001?,Beyoncé,63,"The remaining band members recorded ""Independe...",[Survivor],[593],[601]
3,56bf9c70a10cfb14005511bb,In what year did Beyonce have her hiatus?,Beyoncé,72,Beyoncé announced a hiatus from her music care...,[2010],[60],[64]
4,56cc27346d243a140015eebb,During what years did the Mongol leader Kublai...,Sino-Tibetan_relations_during_the_Ming_dynasty,300,Some scholars note that Tibetan leaders during...,[1402–1424],[739],[748]
...,...,...,...,...,...,...,...,...
168,57351496879d6814001cab11,Where can safari hunters go which are uninviti...,Hunting,18825,A variety of industries benefit from hunting a...,[remote areas],[367],[379]
169,5735fcb96c16ec1900b928c8,What forbid hunting in the woods with hounds a...,Hunting,18806,"From early Christian times, hunting has been f...",[Corpus Juris Canonici],[98],[119]
170,5735ffb96c16ec1900b928e1,Game animals were introduced here by whom?,Hunting,18808,New Zealand has a strong hunting culture. The ...,[acclimatisation societies],[189],[214]
171,573605726c16ec1900b92906,Why are assistants used?,Hunting,18812,"Shooting as practised in Britain, as opposed t...",[help load shotguns],[305],[323]


In [21]:
bert_dm.test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be5438acb8001400a5031c,Which California venue was one of three consid...,Super_Bowl_50,5,The league eventually narrowed the bids to thr...,[San Francisco Bay Area's Levi's Stadium],[128],[167]
1,56be5523acb8001400a5032e,When did Lev's Stadium open?,Super_Bowl_50,7,"On May 21, 2013, NFL owners at their spring me...",[2014],[144],[148]
2,56beab833aeaaa14008c91d2,Which Denver linebacker was named Super Bowl MVP?,Super_Bowl_50,2,The Broncos took an early lead in Super Bowl 5...,[linebacker Von Miller],[237],[258]
3,56beb6533aeaaa14008c9290,Who was first on the team in total tackles?,Super_Bowl_50,14,The Broncos' defense ranked first in the NFL y...,[Brandon Marshall],[458],[474]
4,56bebc383aeaaa14008c9321,What yard marker on the field was painted gold?,Super_Bowl_50,25,Various gold-themed promotions and initiatives...,[50],[232],[234]
...,...,...,...,...,...,...,...,...
176,57379829c3c5551400e51f3d,What does the W and Z boson exchange create?,Force,2056,The weak force is due to the exchange of the h...,[weak force],[4],[14]
177,57379a4b1c456719005744d0,What is the force that causes rigid strength i...,Force,2057,The normal force is due to repulsive forces of...,[normal],[298],[304]
178,57379ed81c456719005744d7,In what way do idea strings transmit tesion fo...,Force,2058,Tension forces can be modeled using ideal stri...,[instantaneously in action-reaction pairs],[250],[290]
179,5737a0acc3c5551400e51f4a,What may a force on one part of an object affect?,Force,2059,Newton's laws and Newtonian mechanics in gener...,[other parts],[276],[287]


In [29]:
bert_args = TRAINER_ARGS(per_device_train_batch_size=16, per_device_eval_batch_size=16,)
bert_trainer = training.SquadTrainer(
    model=bert_model,
    args=bert_args,
    data_collator=bert_dm.tokenizer,
    train_dataset=bert_dm.train_dataset,
    eval_dataset=bert_dm.val_dataset,
)

In [25]:
bert_run_name = utils.get_run_name()
bert_wandb_logger = init_wandb(name=bert_run_name, group="bert", reinit=True,)

[34m[1mwandb[0m: Currently logged in as: [33mwadaboa[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [30]:
bert_trainer.train()

Epoch,Training Loss,Validation Loss




KeyboardInterrupt: 

In [None]:
bert_test_output = bert_trainer.predict(bert_dm.test_dataset)

In [None]:
bert_test_output.metrics

In [None]:
bert_answers_path = "results/bert.json"
utils.save_answers(bert_answers_path, bert_test_output.predictions[-1])
bert_wandb_logger.save(bert_answers_path)

In [40]:
bert_wandb_logger.finish()