# Question answering on the SQuAD dataset

## Colab requirements

Before restarting runtime (remeber to select GPU runtime)$\dots$

In [None]:
!git clone https://github.com/Wadaboa/squad-question-answering.git
!pip install -r squad-question-answering/init/base_requirements.txt

After restarting runtime$\dots$

In [None]:
import os, sys

sys.path.insert(0, "/content/squad-question-answering")
os.chdir("/content/squad-question-answering")

## Imports

In [1]:
import os
from functools import partial

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
import transformers
from transformers.trainer_utils import set_seed

import dataset
import model
import training
import tokenizer
import utils

%load_ext autoreload
%autoreload 2

## Initialization

### Matplotlib

In [2]:
%matplotlib inline
plt.rcParams['figure.figsize'] = [8, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.xmargin'] = .05
plt.rcParams['axes.ymargin'] = .05
plt.style.use('ggplot')

### Weights & biases

In [3]:
%env WANDB_PROJECT=squad-qa
%env WANDB_ENTITY=wadaboa
%env WANDB_MODE=online
%env WANDB_RESUME=never
%env WANDB_WATCH=false
%env WANDB_SILENT=true

env: WANDB_PROJECT=squad-qa
env: WANDB_ENTITY=wadaboa
env: WANDB_MODE=online
env: WANDB_RESUME=never
env: WANDB_WATCH=false
env: WANDB_SILENT=true


In [4]:
!wandb login

In [5]:
!wandb enabled

W&B enabled.


### PyTorch and numpy

In [6]:
RANDOM_SEED = 42
set_seed(RANDOM_SEED)

In [7]:
DEVICE = utils.get_device()
DEVICE

  return torch._C._cuda_getDeviceCount() > 0


device(type='cpu')

## Preliminaries

### Raw data loading

In [8]:
DATA_FOLDER = os.path.join(os.getcwd(), "data")
TRAIN_DATA_FOLDER = os.path.join(DATA_FOLDER, "training")
TRAIN_SET_PATH = os.path.join(TRAIN_DATA_FOLDER, "training_set.json")
TEST_DATA_FOLDER = os.path.join(DATA_FOLDER, "testing")
TEST_SET_PATH = os.path.join(TEST_DATA_FOLDER, "test_set.json")

In [9]:
squad_dataset = dataset.SquadDataset(
    train_set_path=TRAIN_SET_PATH, test_set_path=TEST_SET_PATH, subset=0.01
)

In [10]:
squad_dataset.raw_train_df

Unnamed: 0,answer_start,answer,title,context,question_id,question,context_id,answer_end
0,167,1735,Institute_of_technology,The world's first institution of technology or...,56de4d9ecffd8e1900b4b7e2,What year was the Banská Akadémia founded?,1860,171
1,793,SOS-based speed,Film_speed,The standard specifies how speed ratings shoul...,572674a05951b619008f7319,What is another speed that can also be reporte...,9354,808
2,421,Sumerian temples and palaces,Sumer,The most impressive and famous of Sumerian bui...,5730bb058ab72b1400f9c72c,Where were the use of advanced materials and t...,17505,449
3,192,mayor,"Ann_Arbor,_Michigan",Ann Arbor has a council-manager form of govern...,572781a5f1498d1400e8fa1f,Who is elected every even numbered year?,10585,197
4,194,decide on the feasibility of building an ICBM ...,John_von_Neumann,"Shortly before his death, when he was already ...",572843ce4b864d190016485c,What was the purpose of top secret ICBM commit...,11497,284
...,...,...,...,...,...,...,...,...
870,449,The foundation was the biggest early backer of...,Bill_%26_Melinda_Gates_Foundation,A key aspect of the Gates Foundation's U.S. ef...,5725cad2ec44d21400f3d5a0,It was an early backer of what,8521,539
871,267,Copa del Rey final,FC_Barcelona,"On 4 January 2016, Barcelona's transfer ban en...",570c0fedec8fbc190045bc47,What event in February did Barcelona qualify f...,6319,285
872,1262,Greek manuscripts,Humanism,The humanists' close study of Latin literary t...,57327ed206a3a419008aca8d,What caused a large migration of Greek refuges...,18662,1279
873,246,American League,Exhibition_game,Several MLB teams used to play regular exhibit...,5727c05eff5b5019007d9449,What league are the Toronto Bluejays in?,13088,261


In [11]:
squad_dataset.raw_test_df

Unnamed: 0,answer_start,answer,title,context,question_id,question,context_id,answer_end
0,250,instantaneously in action-reaction pairs,Force,Tension forces can be modeled using ideal stri...,57379ed81c456719005744d7,In what way do idea strings transmit tesion fo...,2058,290
1,187,Egg of Columbus,Nikola_Tesla,Tesla also explained the principles of the rot...,56e0ed557aa994140058e7dd,What was Tesla's device called?,181,202
2,539,NP-complete Boolean satisfiability problem,Computational_complexity_theory,What intractability means in practice is open ...,56e1febfe3433e140042323a,What is the example of another problem charact...,282,581
3,1332,Treaty provisions,European_Union_law,Although it is generally accepted that EU law ...,57269bb8708984140094cb98,What are EU Regulations essentially the same a...,759,1349
4,181,Gerhard,Martin_Luther,The Lutheran theologian Franz Pieper observed ...,56f884cba6d7ea1400e17708,What theologian differed in views about the so...,412,188
...,...,...,...,...,...,...,...,...
177,313,180,Steam_engine,With two-cylinder compounds used in railway wo...,5711475ca58dae1900cd6d8b,"In a 4-cylinder compound engine, what degree w...",587,316
178,465,Conservative,Scottish_Parliament,A procedural consequence of the establishment ...,572fdd03a23a5019007fcaa0,What party had a victory in the 2015 UK election?,1846,477
179,402,the Lek,Rhine,The other third of the water flows through the...,572ff56304bcaa1900d76f30,What does the Nederrijn change it's name to?,1781,409
180,276,other parts,Force,Newton's laws and Newtonian mechanics in gener...,5737a0acc3c5551400e51f4a,What may a force on one part of an object affect?,2059,287


### Embeddings

In [12]:
UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"

- FastText: 
    - _fasttext-wiki-news-subwords_ (dimensions: 300)
- GloVe:
    - _glove-twitter_ (dimensions: 25. 50, 100, 200)
    - _glove-wiki-gigaword_ (dimensions: 50, 100, 200, 300)
- Word2Vec:
    - _word2vec-google-news_ (dimensions: 300)
    - _word2vec-ruscorpora_ (dimensions: 300)

In [13]:
# See https://github.com/RaRe-Technologies/gensim-data
GLOVE_EMBEDDING_DIMENSION = 25
GLOVE_MODEL_NAME = "glove-twitter"
glove_embedding_model, glove_vocab = utils.load_embedding_model(
    GLOVE_MODEL_NAME,
    embedding_dimension=GLOVE_EMBEDDING_DIMENSION,
    unk_token=UNK_TOKEN,
    pad_token=PAD_TOKEN,
)

In [14]:
glove_embedding_layer = model.get_embedding_module(
    glove_embedding_model, pad_id=glove_vocab[PAD_TOKEN]
)

## Utils

In [54]:
TRAINER_ARGS = utils.get_default_trainer_args()

### Standard data loading

In [48]:
MAX_CONTEXT_TOKENS = 300

In [49]:
standard_tokenizer = tokenizer.get_standard_tokenizer(
    glove_vocab,
    MAX_CONTEXT_TOKENS,
    unk_token=UNK_TOKEN,
    pad_token=PAD_TOKEN,
    device=DEVICE,
)

In [50]:
standard_dm = dataset.SquadDataManager(squad_dataset, standard_tokenizer, device=DEVICE)

In [51]:
standard_dm.train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56cd788162d2951400fa65e7,What was Apple's highest quarterly profit as o...,IPod,415,"On January 22, 2008, Apple reported the best q...",[$1.58 billion],[184],[197]
1,56cda41362d2951400fa679b,How many people would be in the orchestra Kond...,The_Legend_of_Zelda:_Twilight_Princess,453,Media requests at the trade show prompted Kond...,[50],[217],[219]
2,56cda8a662d2951400fa67d6,Where did Twilight Princess place among Wii ti...,The_Legend_of_Zelda:_Twilight_Princess,460,Twilight Princess received the awards for Best...,[4th],[777],[780]
3,56cdcfee62d2951400fa687a,Who performs research for Bond?,Spectre_(2015_film),467,Bond disobeys M's order and travels to Rome to...,[Moneypenny],[388],[398]
4,56cdd28562d2951400fa68bd,Who does M fight with?,Spectre_(2015_film),470,Bond and Swann return to London where they mee...,[C],[105],[106]
...,...,...,...,...,...,...,...,...
692,5734580c879d6814001ca544,What official is in charge of Richmond's execu...,"Richmond,_Virginia",17037,Richmond city government consists of a city co...,[mayor],[183],[188]
693,57359c97e853931400426a3f,What is the English translation of Mandap?,Kathmandu,18835,The city of Kathmandu is named after Kasthaman...,[covered shelter],[155],[170]
694,57359f64e853931400426a7e,What notable Nepali figure died in a Kathmandu...,Kathmandu,18841,The Licchavi era was followed by the Malla era...,[Abhaya Malla],[427],[439]
695,5735a721e853931400426aa9,"What are Madhyapur Thimi, Kirtipur and Bhaktapur?",Kathmandu,18846,The agglomeration of Kathmandu has not yet bee...,[municipalities],[441],[455]


In [52]:
standard_dm.val_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be8e353aeaaa14008c90c6,"""Charlie's Angels"" featured which single from ...",Beyoncé,63,"The remaining band members recorded ""Independe...",[Independent Women Part I],[37],[61]
1,56be9eea3aeaaa14008c9184,How many awards did Beyonce take home with her...,Beyoncé,79,At the 57th Annual Grammy Awards in February 2...,[three],[108],[113]
2,56bf89cfa10cfb1400551163,What album caused a lawsuit to be filed in 2001?,Beyoncé,63,"The remaining band members recorded ""Independe...",[Survivor],[593],[601]
3,56bf9c70a10cfb14005511bb,In what year did Beyonce have her hiatus?,Beyoncé,72,Beyoncé announced a hiatus from her music care...,[2010],[60],[64]
4,56cc27346d243a140015eebb,During what years did the Mongol leader Kublai...,Sino-Tibetan_relations_during_the_Ming_dynasty,300,Some scholars note that Tibetan leaders during...,[1402–1424],[739],[748]
...,...,...,...,...,...,...,...,...
167,57351496879d6814001cab11,Where can safari hunters go which are uninviti...,Hunting,18825,A variety of industries benefit from hunting a...,[remote areas],[367],[379]
168,5735fcb96c16ec1900b928c8,What forbid hunting in the woods with hounds a...,Hunting,18806,"From early Christian times, hunting has been f...",[Corpus Juris Canonici],[98],[119]
169,5735ffb96c16ec1900b928e1,Game animals were introduced here by whom?,Hunting,18808,New Zealand has a strong hunting culture. The ...,[acclimatisation societies],[189],[214]
170,573605726c16ec1900b92906,Why are assistants used?,Hunting,18812,"Shooting as practised in Britain, as opposed t...",[help load shotguns],[305],[323]


In [53]:
standard_dm.test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56be5438acb8001400a5031c,Which California venue was one of three consid...,Super_Bowl_50,5,The league eventually narrowed the bids to thr...,[San Francisco Bay Area's Levi's Stadium],[128],[167]
1,56be5523acb8001400a5032e,When did Lev's Stadium open?,Super_Bowl_50,7,"On May 21, 2013, NFL owners at their spring me...",[2014],[144],[148]
2,56beab833aeaaa14008c91d2,Which Denver linebacker was named Super Bowl MVP?,Super_Bowl_50,2,The Broncos took an early lead in Super Bowl 5...,[linebacker Von Miller],[237],[258]
3,56beb6533aeaaa14008c9290,Who was first on the team in total tackles?,Super_Bowl_50,14,The Broncos' defense ranked first in the NFL y...,[Brandon Marshall],[458],[474]
4,56bebc383aeaaa14008c9321,What yard marker on the field was painted gold?,Super_Bowl_50,25,Various gold-themed promotions and initiatives...,[50],[232],[234]
...,...,...,...,...,...,...,...,...
177,57379829c3c5551400e51f3d,What does the W and Z boson exchange create?,Force,2056,The weak force is due to the exchange of the h...,[weak force],[4],[14]
178,57379a4b1c456719005744d0,What is the force that causes rigid strength i...,Force,2057,The normal force is due to repulsive forces of...,[normal],[298],[304]
179,57379ed81c456719005744d7,In what way do idea strings transmit tesion fo...,Force,2058,Tension forces can be modeled using ideal stri...,[instantaneously in action-reaction pairs],[250],[290]
180,5737a0acc3c5551400e51f4a,What may a force on one part of an object affect?,Force,2059,Newton's laws and Newtonian mechanics in gener...,[other parts],[276],[287]


## Baseline model

In [60]:
baseline_optimizer = optim.Adam(baseline_model.parameters(), lr=1e-3)
baseline_lr_scheduler = transformers.get_constant_schedule(baseline_optimizer)

In [61]:
%env WANDB_RUN_GROUP=baseline
baseline_run_name = utils.get_run_name()
baseline_args = partial(
    TRAINER_ARGS,
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{baseline_run_name}",
    num_train_epochs=30,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
)

env: WANDB_RUN_GROUP=baseline


### Training and validation

In [41]:
baseline_model = model.QABaselineModel(
    glove_embedding_layer, MAX_CONTEXT_TOKENS, device=DEVICE
)
print(f"The baseline model has {baseline_model.count_parameters()} parameters")

The baseline model has 41000 parameters


In [43]:
baseline_trainer = training.SquadTrainer(
    model=baseline_model,
    args=baseline_args(run_name=baseline_run_name),
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.train_dataset,
    eval_dataset=standard_dm.val_dataset,
    optimizers=(baseline_optimizer, baseline_lr_scheduler),
)

env: WANDB_RUN_GROUP=baseline


In [44]:
baseline_trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Em,Runtime,Samples Per Second
1,7.321,7.646826,0.048099,0.061143,0.092297,0.065407,0.005814,1.7094,100.619
2,7.4226,7.61481,0.048099,0.061143,0.092297,0.065407,0.005814,1.7229,99.829
3,7.424,7.547461,0.048099,0.061143,0.092297,0.065407,0.005814,2.7028,63.638
4,7.3784,7.057219,0.031507,0.038092,0.125975,0.046912,0.0,1.7108,100.537
5,7.0541,6.916319,0.018038,0.032375,0.085083,0.026914,0.0,1.7301,99.418


TrainOutput(global_step=30, training_loss=7.2997305234273275, metrics={'train_runtime': 68.643, 'train_samples_per_second': 0.437, 'total_flos': 0, 'epoch': 5.0})

### Training only

In [62]:
baseline_model = model.QABaselineModel(
    glove_embedding_layer, MAX_CONTEXT_TOKENS, device=DEVICE
)
print(f"The baseline model has {baseline_model.count_parameters()} parameters")

The baseline model has 41000 parameters


In [63]:
baseline_trainer = training.SquadTrainer(
    model=baseline_model,
    args=baseline_args(run_name=f"{baseline_run_name}-whole", evaluation_strategy="no"),
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.whole_dataset,
    optimizers=(baseline_optimizer, baseline_lr_scheduler),
)

In [64]:
baseline_trainer.train()

Step,Training Loss
1,5.9843
5,6.3311
10,6.1952
15,6.2592
20,6.2147
25,6.1515
30,6.2805
35,6.2771


TrainOutput(global_step=35, training_loss=6.234272003173828, metrics={'train_runtime': 79.2375, 'train_samples_per_second': 0.442, 'total_flos': 0, 'epoch': 5.0})

### Testing

In [65]:
baseline_test_output = baseline_trainer.predict(standard_dm.test_dataset)
baseline_test_output.metrics

{'test_loss': 6.0533928871154785,
 'test_accuracy': 0.048774934275635486,
 'test_precision': 0.04997091627749208,
 'test_recall': 0.539648141845944,
 'test_f1': 0.08496900937695989,
 'test_em': 0.0,
 'test_runtime': 1.8771,
 'test_samples_per_second': 96.959}

In [66]:
baseline_answers_path = "results/answers/baseline.json"
utils.save_answers(baseline_answers_path, baseline_test_output.predictions[-1])
wandb.save(baseline_answers_path);
wandb.finish()

## BiDAF

In [42]:
bidaf_optimizer = optim.Adadelta(bidaf_model.parameters(), lr=0.5)
bidaf_lr_scheduler = transformers.get_constant_schedule(bidaf_optimizer)

In [43]:
%env WANDB_RUN_GROUP=bidaf
bidaf_run_name = utils.get_run_name()
bidaf_args = partial(
    TRAINER_ARGS,
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{bidaf_run_name}",
    num_train_epochs=12,
    per_device_train_batch_size=60,
    per_device_eval_batch_size=60,
)

env: WANDB_RUN_GROUP=bidaf


### Training and validation

In [41]:
bidaf_model = model.QABiDAFModel(glove_embedding_layer, device=DEVICE)
print(f"The BiDAF model has {bidaf_model.count_parameters()} parameters")

The BiDAF model has 79675 parameters


In [None]:
bidaf_trainer = training.SquadTrainer(
    model=bidaf_model,
    args=bidaf_args(run_name=bidaf_run_name),
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.train_dataset,
    eval_dataset=standard_dm.val_dataset,
    optimizers=(bidaf_optimizer, bidaf_lr_scheduler),
)

In [44]:
bidaf_trainer.train()

Epoch,Training Loss,Validation Loss,F1,Accuracy,Em,Runtime,Samples Per Second
1,6.6194,5.197407,0.0,0.0,0.0,0.3057,6.542
2,6.6194,5.198681,0.0,0.0,0.0,0.3101,6.45
3,6.6194,6.93373,0.0,0.0,0.0,0.3094,6.464
4,6.6194,6.934238,0.0,0.0,0.0,0.3092,6.469
5,5.9164,6.934334,0.0,0.0,0.0,0.3059,6.538
6,5.9164,6.93491,0.0,0.0,0.0,0.3039,6.581
7,5.9164,6.9356,0.0,0.0,0.0,0.3126,6.397
8,5.9164,6.935662,0.0,0.0,0.0,0.3026,6.61
9,5.9164,6.935297,0.0,0.0,0.0,0.3155,6.339
10,6.6018,6.632412,0.0,0.0,0.0,0.3157,6.335


TrainOutput(global_step=12, training_loss=6.486756801605225, metrics={'train_runtime': 26.8579, 'train_samples_per_second': 0.447, 'total_flos': 0, 'epoch': 12.0})

### Training only

In [None]:
bidaf_model = model.QABiDAFModel(glove_embedding_layer, device=DEVICE)
print(f"The BiDAF model has {bidaf_model.count_parameters()} parameters")

In [None]:
bidaf_trainer = training.SquadTrainer(
    model=bidaf_model,
    args=bidaf_args(run_name=f"{bidaf_run_name}-whole", evaluation_strategy="no"),
    data_collator=standard_dm.tokenizer,
    train_dataset=standard_dm.whole_dataset,
    optimizers=(bidaf_optimizer, bidaf_lr_scheduler),
)

In [None]:
bidaf_trainer.train()

### Testing

In [45]:
bidaf_test_output = bidaf_trainer.predict(standard_dm.test_dataset)
bidaf_test_output.metrics

{'test_loss': 8.894977569580078,
 'test_f1': 0.0,
 'test_accuracy': 0.0,
 'test_em': 0.0,
 'test_runtime': 0.2739,
 'test_samples_per_second': 3.651}

In [46]:
bidaf_answers_path = "results/answers/bidaf.json"
utils.save_answers(bidaf_answers_path, bidaf_test_output.predictions[-1])
wandb.save(bidaf_answers_path);
wandb.finish()

## Transformers data loading

In [14]:
MAX_BERT_TOKENS = 512

In [15]:
bert_tokenizer = tokenizer.get_bert_tokenizer(max_tokens=MAX_BERT_TOKENS, device=DEVICE)

In [16]:
bert_dm = dataset.SquadDataManager(squad_dataset, bert_tokenizer, device=DEVICE)

In [17]:
bert_dm.train_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,56cdd28562d2951400fa68bd,Who does M fight with?,Spectre_(2015_film),470,Bond and Swann return to London where they mee...,[C],[105],[106]
1,56de4d9ecffd8e1900b4b7e2,What year was the Banská Akadémia founded?,Institute_of_technology,1860,The world's first institution of technology or...,[1735],[167],[171]
2,572674a05951b619008f7319,What is another speed that can also be reporte...,Film_speed,9354,The standard specifies how speed ratings shoul...,[SOS-based speed],[793],[808]
3,5726ef98708984140094d66e,What conferences became a requirement after Va...,Pope_Paul_VI,10862,Some critiqued Paul VI's decision; the newly c...,[National Bishop Conferences],[347],[374]
4,572843ce4b864d190016485c,What was the purpose of top secret ICBM commit...,John_von_Neumann,11497,"Shortly before his death, when he was already ...",[decide on the feasibility of building an ICBM...,[194],[284]
5,5730bb058ab72b1400f9c72c,Where were the use of advanced materials and t...,Sumer,17505,The most impressive and famous of Sumerian bui...,[Sumerian temples and palaces],[421],[449]


In [18]:
bert_dm.val_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,570e1a2a0dc6ce1900204dbf,How many species of fungi have been found on A...,Antarctica,6902,About 1150 species of fungi have been recorded...,[1150],[6],[10]
1,572781a5f1498d1400e8fa1f,Who is elected every even numbered year?,"Ann_Arbor,_Michigan",10585,Ann Arbor has a council-manager form of govern...,[mayor],[192],[197]


In [19]:
bert_dm.test_df

Unnamed: 0,question_id,question,title,context_id,context,answer,answer_start,answer_end
0,57379ed81c456719005744d7,In what way do idea strings transmit tesion fo...,Force,2058,Tension forces can be modeled using ideal stri...,[instantaneously in action-reaction pairs],[250],[290]


## BERT

In [21]:
bert_optimizer = optim.Adam(bert_model.parameters(), lr=5e-5)
bert_lr_scheduler = transformers.get_constant_schedule(bert_optimizer)

In [22]:
%env WANDB_RUN_GROUP=bert
bert_run_name = utils.get_run_name()
bert_args = partial(
    TRAINER_ARGS,
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{bert_run_name}",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
)

env: WANDB_RUN_GROUP=bert


### Training and validation

In [20]:
bert_model = model.QABertModel(device=DEVICE)
print(f"The BERT model has {bert_model.count_parameters()} parameters")

In [None]:
bert_trainer = training.SquadTrainer(
    model=bert_model,
    args=bert_args(run_name=bert_run_name),
    data_collator=bert_dm.tokenizer,
    train_dataset=bert_dm.train_dataset,
    eval_dataset=bert_dm.val_dataset,
    optimizers=(bert_optimizer, bert_lr_scheduler),
)

In [23]:
bert_trainer.train()

Epoch,Training Loss,Validation Loss,F1,Accuracy,Em,Runtime,Samples Per Second
1,6.5272,6.653205,0.166667,0.25,0.0,1.0424,1.919
2,6.5272,7.183872,0.166667,0.25,0.0,1.0671,1.874
3,6.5272,7.202362,0.166667,0.25,0.0,0.9527,2.099


TrainOutput(global_step=3, training_loss=5.523826281229655, metrics={'train_runtime': 45.0591, 'train_samples_per_second': 0.067, 'total_flos': 0, 'epoch': 3.0})

### Training only

In [None]:
bert_model = model.QABertModel(device=DEVICE)
print(f"The BERT model has {bert_model.count_parameters()} parameters")

In [None]:
bert_trainer = training.SquadTrainer(
    model=bert_model,
    args=bert_args(run_name=f"{bert_run_name}-whole", evaluation_strategy="no"),
    data_collator=bert_dm.tokenizer,
    train_dataset=bert_dm.whole_dataset,
    optimizers=(bert_optimizer, bert_lr_scheduler),
)

In [None]:
bert_trainer.train()

### Testing

In [24]:
bert_test_output = bert_trainer.predict(bert_dm.test_dataset)
bert_test_output.metrics

{'test_loss': 9.000775337219238,
 'test_f1': 0.0,
 'test_accuracy': 0.0,
 'test_em': 0.0,
 'test_runtime': 0.6044,
 'test_samples_per_second': 1.655}

In [25]:
bert_answers_path = "results/answers/bert.json"
utils.save_answers(bert_answers_path, bert_test_output.predictions[-1])
wandb.save(bert_answers_path);
wandb.finish()