In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle

from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed
)

from dataclasses import dataclass, field
from transformers import DistilBertForQuestionAnswering
from transformers import BertConfig
import os
import logging

In [3]:
logger = logging.getLogger(__name__)

## Load questions

In [4]:
with open(r"answers.pkl", "rb") as input_file:
    question_answers = pickle.load(input_file)

In [5]:
assert len(question_answers) > 5000

In [6]:
question_answers[:10]

[('when is the last episode of season 8 of the walking dead?',
  'March 18 , 2018',
  {'answer_start': 193, 'answer_end': 208},
  "109 10 `` The Lost and the Plunderers '' TBA TBA March 4 , 2018 ( 2018 - 03 - 04 ) TBD 110 11 `` Dead or Alive Or '' TBA TBA March 11 , 2018 ( 2018 - 03 - 11 ) TBD 111 12 `` The Key '' TBA TBA March 18 , 2018 ( 2018 - 03 - 18 ) TBD Release ( edit ) The first trailer for the season was released on July 21 , 2017 , at San Diego Comic - Con"),
 ('in greek mythology who was the goddess of spring growth?',
  "Persephone ( / pərˈsɛfəni / ; Greek : Περσεφόνη ) , also called Kore ( / ˈkɔːriː / ; `` the maiden '' )",
  {'answer_start': 1437, 'answer_end': 1540},
  "Part of a series on Ancient Greek religion Features ( show ) Greek mythology Ancient Greek philosophy Hellenistic philosophy Ancient Greek religion Polytheism Henosis Monism Pantheism Orthopraxy Godheads ( show ) Olympians Aphrodite Apollo Ares Artemis Athena Demeter Dionysus Hades Hephaestus Hera Hermes 

### Split train / valid sets

In [7]:
val_size = len(question_answers) // 10

In [8]:
train_data = question_answers[:-val_size]
valid_data = question_answers[-val_size:]

In [9]:
assert(len(train_data) + len(valid_data) == len(question_answers))

In [10]:
train_data[:1]

[('when is the last episode of season 8 of the walking dead?',
  'March 18 , 2018',
  {'answer_start': 193, 'answer_end': 208},
  "109 10 `` The Lost and the Plunderers '' TBA TBA March 4 , 2018 ( 2018 - 03 - 04 ) TBD 110 11 `` Dead or Alive Or '' TBA TBA March 11 , 2018 ( 2018 - 03 - 11 ) TBD 111 12 `` The Key '' TBA TBA March 18 , 2018 ( 2018 - 03 - 18 ) TBD Release ( edit ) The first trailer for the season was released on July 21 , 2017 , at San Diego Comic - Con")]

### Prepare Tokenizers

In [11]:
def transform_dataframe(data):
    return [d[0] for d in data], [d[2] for d in data], [d[3] for d in data]

In [12]:
train_questions, train_answers, train_contexts = transform_dataframe(train_data)
valid_questions, valid_answers, valid_contexts = transform_dataframe(valid_data)

In [13]:
train_answers[:5]

[{'answer_start': 193, 'answer_end': 208},
 {'answer_start': 1437, 'answer_end': 1540},
 {'answer_start': 128, 'answer_end': 146},
 {'answer_start': 86, 'answer_end': 97},
 {'answer_start': 138, 'answer_end': 142}]

In [14]:
train_contexts[:2]

["109 10 `` The Lost and the Plunderers '' TBA TBA March 4 , 2018 ( 2018 - 03 - 04 ) TBD 110 11 `` Dead or Alive Or '' TBA TBA March 11 , 2018 ( 2018 - 03 - 11 ) TBD 111 12 `` The Key '' TBA TBA March 18 , 2018 ( 2018 - 03 - 18 ) TBD Release ( edit ) The first trailer for the season was released on July 21 , 2017 , at San Diego Comic - Con",
 "Part of a series on Ancient Greek religion Features ( show ) Greek mythology Ancient Greek philosophy Hellenistic philosophy Ancient Greek religion Polytheism Henosis Monism Pantheism Orthopraxy Godheads ( show ) Olympians Aphrodite Apollo Ares Artemis Athena Demeter Dionysus Hades Hephaestus Hera Hermes Hestia Poseidon Zeus Primordial deities Aether Aion Ananke Chaos Chronos Erebus Eros Gaia Hemera Nyx Phanes Pontus Thalassa Tartarus Uranus Lesser deities Alpheus Amphitrite Asclepius Bia Circe Deimos Eileithyia Enyo Eos Eris Harmonia Hebe Hecate Helios Heracles Iris Kratos Leto Metis Momus Morpheus Nemesis Nike Pan Persephone Phantasos Phobos Pr

In [15]:
train_questions[:5]

['when is the last episode of season 8 of the walking dead?',
 'in greek mythology who was the goddess of spring growth?',
 'what is the name of the most important jewish text?',
 "what is the name of spain's most famous soccer team?",
 'when was the first robot used in surgery?']

In [16]:
from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = tokenizer(valid_contexts, valid_questions, truncation=True, padding=True)

In [17]:
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
        # if None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, valid_answers)

### Prepare dataset

In [18]:
PRE_TRAINED_MODEL_NAME = "distilbert-base-uncased"

In [19]:
import torch

class NaturalQuestionsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = NaturalQuestionsDataset(train_encodings)
val_dataset = NaturalQuestionsDataset(val_encodings)

In [20]:
config = BertConfig.from_pretrained( PRE_TRAINED_MODEL_NAME, output_hidden_states=True)
model = DistilBertForQuestionAnswering.from_pretrained(PRE_TRAINED_MODEL_NAME)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode

### Prepare Training

In [21]:
EPOCHS=1
MODEL_PATH = f'./google_nq_{PRE_TRAINED_MODEL_NAME}'

In [22]:
training_args = TrainingArguments(
    output_dir=MODEL_PATH,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_gpu_train_batch_size=8,
    per_gpu_eval_batch_size=8,
    num_train_epochs=EPOCHS,
    logging_first_step=True,
    save_steps=5000,
    evaluate_during_training=True,
    fp16=True
)

In [23]:
set_seed(42)

In [24]:
from nlp import load_metric

metric = load_metric("squad_v2")

In [25]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=metric,
    prediction_loss_only=True
)

[34m[1mwandb[0m: Wandb version 0.9.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [26]:
eval_data_loader = trainer.get_eval_dataloader()

Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future version. Using `--per_device_eval_batch_size` is preferred.


In [27]:
%%time
trainer.train()

Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.
Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.


Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=1.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2938.0, style=ProgressStyle(description_w…

[34m[1mwandb[0m: Wandb version 0.9.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0


Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future version. Using `--per_device_eval_batch_size` is preferred.


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=327.0, style=ProgressStyle(description_w…




Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future version. Using `--per_device_eval_batch_size` is preferred.


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=327.0, style=ProgressStyle(description_w…


Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0


CPU times: user 6min 45s, sys: 1.71 s, total: 6min 47s
Wall time: 6min 44s


TrainOutput(global_step=2938, training_loss=1.872358276445984)

### Evaluate

In [28]:
vl_dl = trainer.get_eval_dataloader(val_dataset)

Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future version. Using `--per_device_eval_batch_size` is preferred.


In [29]:
sample_tensor = (next(iter(vl_dl)))
sample_tensor.keys()

dict_keys(['input_ids', 'attention_mask', 'start_positions', 'end_positions'])

In [30]:
result = trainer.evaluate()

output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
if trainer.is_world_master():
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results *****")
        for key, value in result.items():
            logger.info("  %s = %s", key, value)
            writer.write("%s = %s\n" % (key, value))

Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future version. Using `--per_device_eval_batch_size` is preferred.


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=327.0, style=ProgressStyle(description_w…




#### Save the model

In [31]:
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_master():
    tokenizer.save_pretrained(training_args.output_dir)