# Answer Extraction Pipeline for TriviaQA

Contains dataset preprocessing, training and evaluation code for the question answering task using BERT.

In [None]:
!nvidia-smi

Mon Jun  7 23:25:38 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 466.27       Driver Version: 466.27       CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| 18%   32C    P8    18W / 250W |   1223MiB / 11264MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Installs and Imports

In [None]:
!pip install transformers
!pip install datasets



In [None]:
!pip install pickle5
import pickle5 as pickle



In [None]:
import numpy as np
import os
import re
import random

## Run in Google Drive

In [None]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)  

ir_project_drive_folder = "IR Projekt"
data_folder = "/content/drive/My Drive/{}/data/wikipedia".format(ir_project_drive_folder)
models_folder = '/content/drive/My Drive/{}/saved_models'.format(ir_project_drive_folder)

## Run Locally

In [None]:
base_folder = '..'
data_folder = base_folder + '/data'
models_folder = base_folder + '/models'

## Files and Folders

In [None]:
raw_folder = data_folder + '/raw'
preprocessed_folder = data_folder + '/no-pron/preprocessed'

question_answering_models_folder = models_folder + '/question_answering_task'

# qa filenames
qa_wikipedia_verified_dev_filename = raw_folder + '/qa/verified-wikipedia-dev.json'
qa_wikipedia_dev_filename = raw_folder + '/qa/wikipedia-dev.json'
qa_wikipedia_test_without_answers_filename = raw_folder + '/qa/wikipedia-test-without-answers.json'
qa_wikipedia_train_filename = raw_folder + '/qa/wikipedia-train.json'

# evidence files
wikipedia_evidence_file = raw_folder + '/wikipedia_evidence_dict.pkl'

In [None]:
def save_as_pickle(obj, filename):
    """
    save an object in a pickle file dump
    :param obj: object to dump
    :param filename: target file
    :return:
    """
    directory = os.path.dirname(filename)
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    with open(filename, 'wb') as file:
        pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)


def load_pickle(filename):
    """
    load an object from a given pickle file
    :param filename: source file
    :return: loaded object
    """
    with open(filename, 'rb') as file:
        return pickle.load(file)

## Load Data

In [None]:
# Evidence data

documents_dict = load_pickle(wikipedia_evidence_file)
print(len(documents_dict))


73930


# Training & Evaluation 

In [None]:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




In [None]:
pad_to_max_length = True
pad_on_right = False # tokenizer.padding_side == "right"  # Padding side determines if we do (question|context) or (context|question).
doc_stride = 128
max_seq_length = tokenizer.model_max_length

# Postprocessing Config
version_2_with_negative = True
n_best_size = 20
max_answer_length = 50
null_score_diff_threshold = 0

negative_to_positive = 1

## Import framework code adapted from [Huggingface Transformers Question Answering](https://github.com/huggingface/transformers/tree/v4.6.1/examples/pytorch/question-answering)

In [None]:
%run -i -n QuestionAnsweringTrainerFramework.py

## Utils

In [None]:
def read_trivia_entries(groups):
  answers = groups['Answer']
  questions = groups['Question']
  question_ids = groups['QuestionId']
  entity_pages_list = groups['EntityPages']

  result_ids = []
  result_contexts = []
  result_questions = []
  result_answers = []

  for answer, question, question_id, entity_pages in zip(answers, questions, question_ids, entity_pages_list):
    answer_texts = [answer['Value']] + answer['Aliases']
    if 'HumanAnswers' in answer and answer['HumanAnswers'] is not None:
      answer_texts += answer['HumanAnswers']  

    for entity_page in entity_pages:
      document_name = entity_page['Filename']
      document_name_clean = re.sub("[\?:\*\"]", "_", document_name)
      
      if not document_name_clean in documents_dict:
        continue

      context = documents_dict[document_name_clean]

      found_answer_texts = []
      found_answer_starts = []
      for answer_text in answer_texts:
        try:
          answer_start = context.index(answer_text)
          found_answer_texts.append(answer_text)
          found_answer_starts.append(answer_start)

        except ValueError as e:
          pass

      if len(found_answer_texts) == 0:
        continue

      id = question_id + '__' + document_name

      result_ids.append(id)
      result_contexts.append(context)
      result_questions.append(question)
      result_answers.append({'text': found_answer_texts, 'answer_start': found_answer_starts})

  return {"id": result_ids, "context": result_contexts, "question": result_questions, "answer": result_answers}

In [None]:
from datasets import load_dataset

def create_qa_dataset(path):
  dataset = load_dataset('json', data_files=[path], field='Data', keep_in_memory=False)["train"]
  return dataset.map(read_trivia_entries, batched=True, remove_columns=dataset.column_names, batch_size=32)

## Training

### Training Data

In [None]:
dataset_name = "bert/1positive-to-" + str(negative_to_positive) + "negative"
model_name = "bert/full_wikipedia-1to" + str(negative_to_positive) 

In [None]:
train_base_dataset_directory = data_folder + "/qa-datasets/" + dataset_name + "/train/base"
train_encodings_dataset_directory = data_folder + "/qa-datasets/" + dataset_name + "/train/encodings"

val_verified_base_dataset_directory = data_folder + "/qa-datasets/" + dataset_name + "/val_verified/base"
val_verified_encodings_dataset_directory = data_folder + "/qa-datasets/" + dataset_name + "/val_verified/encodings"

model_directory = question_answering_models_folder + '/' + model_name

#### Create Training Dataset and Store it

In [None]:
train_base_dataset = create_qa_dataset(qa_wikipedia_train_filename)

train_dataset = train_base_dataset.map(
    prepare_train_features,
    batched = True,
    remove_columns = train_base_dataset.column_names,
    batch_size = 32
)

In [None]:
train_base_dataset.save_to_disk(train_base_dataset_directory)
train_dataset.save_to_disk(train_encodings_dataset_directory)

#### Load existing Training Dataset

In [None]:
from datasets import Dataset

train_base_dataset = Dataset.load_from_disk(train_base_dataset_directory)
train_dataset = Dataset.load_from_disk(train_encodings_dataset_directory)

### Evaluation Data

#### Create Evaluation Dataset and Store it

In [None]:
val_verified_base_dataset = create_qa_dataset(qa_wikipedia_verified_dev_filename)

print(len(val_verified_base_dataset))

val_verified_dataset = val_verified_base_dataset.map(
    prepare_validation_features,
    batched = True,
    remove_columns = val_verified_base_dataset.column_names,
    batch_size = 32
)

In [None]:
val_verified_base_dataset.save_to_disk(val_verified_base_dataset_directory)
val_verified_dataset.save_to_disk(val_verified_encodings_dataset_directory)

#### Load existing Evaluation Dataset

In [None]:
from datasets import Dataset

val_verified_base_dataset = Dataset.load_from_disk(val_verified_base_dataset_directory)
val_verified_dataset = Dataset.load_from_disk(val_verified_encodings_dataset_directory)

### Create Metrics

In [None]:
from datasets import load_metric
from transformers import EvalPrediction

metric = load_metric("squad_v2" if version_2_with_negative else "squad")

def compute_metrics(p: EvalPrediction):
    metrics = metric.compute(predictions = p.predictions, references = p.label_ids, no_answer_threshold=.8)
    metrics["eval_exact"] = metrics["exact"]
    return metrics

In [None]:
from transformers import BertForQuestionAnswering

model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir = model_directory,# output directory
    num_train_epochs = 5,              # total number of training epochs
    per_device_train_batch_size = 8,  # batch size per device during training
    per_device_eval_batch_size = 32,   # batch size for evaluation
    warmup_steps = 500,                # number of warmup steps for learning rate scheduler
    weight_decay = 0.01,               # strength of weight decay
    logging_dir = model_directory + '/logs',            # directory for storing logs
    logging_steps = 50,
 #   save_steps = 1000,
    save_total_limit = 3,
    evaluation_strategy = "steps",
    eval_steps = 1000,
    load_best_model_at_end = True,
    metric_for_best_model = "eval_exact"
)

trainer = QuestionAnsweringTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_verified_dataset,
    eval_examples=val_verified_base_dataset,
    tokenizer=tokenizer,
    post_process_function=post_processing_function,
    compute_metrics=compute_metrics
)

### Run Training

In [None]:
continue_existing_training = len(os.listdir(model_directory)) > 0

if continue_existing_training:
  print("Checkpoint directory exists. Continuing existing model...")
  train_result = trainer.train(True)
else:
  print("No existing checkpoint found. Starting fresh...")
  train_result = trainer.train()

No existing checkpoint found. Starting fresh...


Step,Training Loss,Validation Loss,Unnamed: 3,Exact,F1,Total,Exact Thresh,F1 Thresh
1000,2.1375,No log,622,13.344051,16.938468,622,0.0,0.0
2000,1.667,No log,622,30.546624,33.240749,622,0.0,0.0
3000,1.6194,No log,622,32.636656,36.808941,622,0.0,0.0
4000,1.4979,No log,622,32.636656,35.794672,622,0.0,0.0
5000,1.5831,No log,622,35.369775,38.348973,622,0.0,0.0
6000,1.5064,No log,622,38.263666,42.188794,622,0.0,0.0
7000,1.5219,No log,622,39.549839,42.90353,622,0.0,0.0
8000,1.404,No log,622,40.836013,44.905919,622,0.0,0.0
9000,1.4234,No log,622,40.675241,43.657301,622,0.0,0.0
10000,1.3674,No log,622,41.961415,45.730222,622,0.0,0.0


HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=622.0), HTML(value='')))




KeyboardInterrupt: 

### Run Evaluation

In [None]:
from transformers.trainer_utils import get_last_checkpoint

model = BertForQuestionAnswering.from_pretrained(get_last_checkpoint(model_directory))

training_args = TrainingArguments(
   output_dir = model_directory,# output directory
   per_device_eval_batch_size = 32
)

trainer = QuestionAnsweringTrainer(
    model=model,
    args=training_args,
    eval_dataset=val_verified_dataset,
    eval_examples=val_verified_base_dataset,
    tokenizer=tokenizer,
    post_process_function=post_processing_function,
    compute_metrics=compute_metrics
)

In [None]:
metrics = trainer.evaluate()
metrics