##**Library Imports**

In [None]:
import pandas as pd
import numpy as np
import math
import string
import re

from collections import namedtuple, Counter
import json

from IPython.display import Markdown, display
from matplotlib import pyplot as plt
import seaborn as sns
from itertools import cycle

from tqdm import tqdm
import time
from os import path
import sys

import nltk
nltk.download('punkt')

if 'transformers' not in sys.modules:
  !{sys.executable} -m pip install transformers

from transformers import BertTokenizerFast, BertForQuestionAnswering, BatchEncoding, get_linear_schedule_with_warmup

import torch
import torch.nn as nn


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 3.8 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 43.1 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 5.0 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 44.6 MB/s 
Collecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 29.3 MB/s 
Installing collected packa

## **Configurations for using Google Drive and Hugging Face Hub**

All the below **datasets** are **cached/saved in Google Drive**. Also the (best) **models** are uploaded to a personal repository in the **Hugging Face Model Hub**. By enabling these options will help with **reproducibility** and **execution time reduction**

In [None]:
#@title Google Drive and Hugging Face Model Hub options
load_cached_datasets = True #@param {type:"boolean"}
load_cached_models = True #@param {type:"boolean"}

## **Popular QA datasets download and conversion to SQuAD 2.0**

Dataset creation for both **train** and **dev** sets of SQuAD 2.0  by parsing their **JSON** representations. Each dataset contains the **context text**, the **question text**, the **answer start and end indices of the characters** in the context (they are fixed in case of misalignment) and the **text for every valid answer** (the main answer is considered to be the first one and is seperated from the other answers)

In [None]:
# Start and end positions for some answers are not aligned to the real context by one to three characters
def fixAnswerAlignment(context, answer, answer_start_idx, answer_end_idx):
  for i in range(4):
    if context[(answer_start_idx - i):(answer_end_idx + 1 - i)] == answer:
      return answer_start_idx - i, answer_end_idx - i
    if context[(answer_start_idx + i):(answer_end_idx + 1 + i)] == answer:
      return answer_start_idx + i, answer_end_idx + i

def parseSquad(squad_file):
  with open(squad_file, 'rb') as f:
    squad_json = json.load(f)

  train_mode = 'train' in squad_file
  squad_list = []

  for section in squad_json['data']:
    for paragraph in section['paragraphs']:
      context = paragraph['context'].strip()
      for qa in paragraph['qas']:
          question = qa['question'].strip()
          answers = qa['answers'] if not qa['is_impossible'] else None
          if answers is not None and len(answers) > 0:
            # Always the first answer is selected as the main answer
            answer = answers[0]['text'].strip()
            answer_start_idx = answers[0]['answer_start']
            answer_end_idx = answer_start_idx + len(answer) - 1
            answer_start_idx, answer_end_idx = fixAnswerAlignment(context, answer, answer_start_idx, answer_end_idx)
            # The other answers are also included only for the evaluation/dev dataset and not for the training set
            other_answers = [answers[i]['text'].strip() for i in range(1, len(answers))] if not train_mode else None
            squad_list.append({'context': context, 'question': question, 'answer': answer, 'answer_start_idx': answer_start_idx, 'answer_end_idx': answer_end_idx, 'other_answers': other_answers})
          else:
            squad_list.append({'context': context, 'question': question, 'answer': None, 'answer_start_idx': -1, 'answer_end_idx': -1, 'other_answers': None})

  return pd.DataFrame(squad_list, columns = ['context', 'question', 'answer', 'answer_start_idx', 'answer_end_idx', 'other_answers']).dropna(axis = 1, how = 'all') 


Dataset selection among five QA datasets: **SQuAD2.0**, **TriviaQA**, **NQ**, **QuAC**, **NewsQA**

In [None]:
#@title Datasets for fine tuning
train_squad = True #@param {type:"boolean"}
train_triviaqa = True  #@param {type:"boolean"}
train_nq = True  #@param {type:"boolean"}
train_quac = True  #@param {type:"boolean"}
train_newsqa = True  #@param {type:"boolean"}

In [None]:
#@title Datasets for evaluation
eval_squad = True #@param {type:"boolean"}
eval_triviaqa = True  #@param {type:"boolean"}
eval_nq = True #@param {type:"boolean"}
eval_quac = True #@param {type:"boolean"}
eval_newsqa = True  #@param {type:"boolean"}

At this point the datasets are **created and transformed to have a SQuAD 2.0 format**. The transfomation is achieved using the converters that can be found in this GitHub [repository](https://github.com/amazon-research/qa-dataset-converter). The steps are written in a seperate **bash script** for each dataset and the commands can also be executed either inside the notebook or locally. Also there is one more **bash script** for each dataset in order to **load the already transformed datasets** (in json files) from Google Drive.

### **SQuAD 2.0**

In [None]:
%%bash -s "$train_squad" "$eval_squad"
#squad
if [[ $1 = "True" && ! -f squad2_train.json ]]
      then
        wget -nv https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O squad2_train.json
fi
if [[ $2 = "True" && ! -f squad2_dev.json ]]
      then
        wget -nv https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O squad2_dev.json
fi

2022-03-12 12:55:10 URL:https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json [42123633/42123633] -> "squad2_train.json" [1]
2022-03-12 12:55:10 URL:https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json [4370528/4370528] -> "squad2_dev.json" [1]


In [None]:
if train_squad:
  squad_training_set = parseSquad('squad2_train.json') #.sample(frac = 0.8, random_state = 1).sort_index()
  display(squad_training_set.head(10))
  print(f'SQuAD training set size: {len(squad_training_set)}')
if eval_squad:
  squad_validation_set = parseSquad('squad2_dev.json')
  display(squad_validation_set.head(10))
  print(f'SQuAD validation set size: {len(squad_validation_set)}')

Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx
0,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce start becoming popular?,in the late 1990s,269,285
1,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,What areas did Beyonce compete in when she was...,singing and dancing,207,225
2,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce leave Destiny's Child and bec...,2003,526,529
3,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In what city and state did Beyonce grow up?,"Houston, Texas",166,179
4,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In which decade did Beyonce become famous?,late 1990s,276,285
5,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In what R&B group was she the lead singer?,Destiny's Child,320,334
6,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,What album made her a worldwide known artist?,Dangerously in Love,505,523
7,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,Who managed the Destiny's Child group?,Mathew Knowles,360,373
8,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyoncé rise to fame?,late 1990s,276,285
9,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,What role did Beyoncé have in Destiny's Child?,lead singer,290,300


SQuAD training set size: 130319


Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx,other_answers
0,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,France,159,164,"[France, France, France]"
1,The Normans (Norman: Nourmands; French: Norman...,When were the Normans in Normandy?,10th and 11th centuries,94,116,"[in the 10th and 11th centuries, 10th and 11th..."
2,The Normans (Norman: Nourmands; French: Norman...,From which countries did the Norse originate?,"Denmark, Iceland and Norway",256,282,"[Denmark, Iceland and Norway, Denmark, Iceland..."
3,The Normans (Norman: Nourmands; French: Norman...,Who was the Norse leader?,Rollo,308,312,"[Rollo, Rollo, Rollo]"
4,The Normans (Norman: Nourmands; French: Norman...,What century did the Normans first gain their ...,10th century,671,682,"[the first half of the 10th century, 10th, 10th]"
5,The Normans (Norman: Nourmands; French: Norman...,Who gave their name to Normandy in the 1000's ...,,-1,-1,
6,The Normans (Norman: Nourmands; French: Norman...,What is France a region of?,,-1,-1,
7,The Normans (Norman: Nourmands; French: Norman...,Who did King Charles III swear fealty to?,,-1,-1,
8,The Normans (Norman: Nourmands; French: Norman...,When did the Frankish identity emerge?,,-1,-1,
9,"The Norman dynasty had a major political, cult...",Who was the duke in the battle of Hastings?,William the Conqueror,1022,1042,"[William the Conqueror, William the Conqueror]"


SQuAD validation set size: 11873


### **TriviaQA to SQuAD 2.0**

In [None]:
%%bash -s "$train_triviaqa" "$eval_triviaqa" "$load_cached_datasets"
# triviaqa from google drive
if [[ $3 = "True" ]] && ([[ $1 = "True" && ! -f triviaqa_train.json ]] || [[ $2 = "True" && ! -f triviaqa_dev.json ]])
  then
    wget -nv --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1ZSNKeqmajCoDU54ZbcEMrilxBzkMCKbN' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1ZSNKeqmajCoDU54ZbcEMrilxBzkMCKbN" -O triviaqa_datasets.tar.gz  && rm -rf /tmp/cookies.txt 
    if [[ $1 = "True" && ! -f triviaqa_train.json ]]
      then
        tar -zxf triviaqa_datasets.tar.gz TriviaQA/triviaqa_train.json --strip-components=1
    fi
    if [[ $2 = "True" && ! -f triviaqa_dev.json ]]
      then
      tar -zxf triviaqa_datasets.tar.gz TriviaQA/triviaqa_dev.json --strip-components=1
    fi
    rm triviaqa_datasets.tar.gz
fi

2022-03-12 12:55:18 URL:https://doc-00-00-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/doorul15regfgiimmk2nembil14krv00/1647089700000/10578725460718190134/*/1ZSNKeqmajCoDU54ZbcEMrilxBzkMCKbN?e=download [191712585/191712585] -> "triviaqa_datasets.tar.gz" [1]


In [None]:
%%bash -s "$train_triviaqa" "$eval_triviaqa"
# triviaqa from scratch
if [[ $1 = "True" && ! -f triviaqa_train.json ]] || [[ $2 = "True" && ! -f triviaqa_dev.json ]]
  then
    if [[ ! -d qa-dataset-converter ]]
      then
        git clone https://github.com/amazon-research/qa-dataset-converter.git
    fi
    git clone https://github.com/mandarjoshi90/triviaqa
    cp qa-dataset-converter/triviaqa/triviaqa_to_squad.py triviaqa
    wget https://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz -P triviaqa
    tar -xf triviaqa/triviaqa-rc.tar.gz -C triviaqa
    if [[ $1 = "True" && ! -f triviaqa_train.json ]]
      then
        python triviaqa/triviaqa_to_squad.py --triviaqa_file triviaqa/qa/wikipedia-train.json --data_dir triviaqa/evidence/wikipedia/ --output_file triviaqa_train.json
    fi
    if [[ $2 = "True" && ! -f triviaqa_dev.json ]]
      then
        python triviaqa/triviaqa_to_squad.py --triviaqa_file triviaqa/qa/wikipedia-dev.json --data_dir triviaqa/evidence/wikipedia/ --output_file triviaqa_dev.json
    fi
    rm -rf triviaqa
fi

In [None]:
if train_triviaqa:
  triviaqa_training_set = parseSquad('triviaqa_train.json') #.sample(frac = 0.8, random_state = 1).sort_index()
  display(triviaqa_training_set.head(10))
  print(f'TriviaQA training set size: {len(triviaqa_training_set)}')
if eval_triviaqa:
  triviaqa_validation_set = parseSquad('triviaqa_dev.json')
  display(triviaqa_validation_set.head(10))
  print(f'TriviaQA validation set size: {len(triviaqa_validation_set)}')

Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx
0,England is a country that is part of the Unite...,Where in England was Dame Judi Dench born?,York,1569,1572
1,"Dame Judith Olivia `` Judi '' Dench , ( born 9...",Where in England was Dame Judi Dench born?,York,2248,2251
2,A nation state is a type of state that conjoin...,From which country did Angola achieve independ...,Portuga,2899,2905
3,"Angola , officially the Republic of Angola ( ;...",From which country did Angola achieve independ...,,-1,-1
4,The Angolan Civil War ( ) was a major civil co...,From which country did Angola achieve independ...,Portuga,199,205
5,"David Soul ( born August 28 , 1943 ) is an Ame...",Which city does David Soul come from?,Chicago,309,315
6,Super Bowl XX was an American football game be...,Who won Super Bowl XX?,Chicago Bears,102,114
7,The ethnic groups in Europe are the focus of E...,Which was the first European country to abolis...,,-1,-1
8,"Capital punishment , also known as the death p...",Which was the first European country to abolis...,,-1,-1
9,Integrated Services Digital Network ( ISDN ) i...,In which country did he widespread use of ISDN...,Japan,4124,4128


TriviaQA training set size: 110647


Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx,other_answers
0,"Andrew Lloyd Webber , Baron Lloyd-Webber ( bor...",Which Lloyd Webber musical premiered in the US...,,-1,-1,
1,The Prime Minister of the United Kingdom of Gr...,Who was the next British Prime Minister after ...,,-1,-1,
2,"Arthur James Balfour , 1st Earl of Balfour , (...",Who was the next British Prime Minister after ...,,-1,-1,
3,`` Kiss You All Over '' is a 1978 song perform...,Who had a 70s No 1 hit with Kiss You All Over?,Exile,62,66,[]
4,"Kathleen Mary Ferrier , CBE ( 22 April 1912 - ...",What claimed the life of singer Kathleen Ferrier?,Cancer,2488,2493,[]
5,"Lauren Bacall ( , born Betty Joan Perske ; Sep...",Which actress was voted Miss Greenwich Village...,Bacall,7,12,[]
6,"Michael Joseph Jackson ( August 29 , 1958 – Ju...",What was the name of Michael Jackson's autobio...,moonwalk,1508,1515,[]
7,"Tanzania , This approximates the Kiswahili pro...",Which volcano in Tanzania is the highest mount...,Kilimanjaro,478,488,[]
8,"Mount Kilimanjaro , with its three volcanic co...",Which volcano in Tanzania is the highest mount...,Kilimanjaro,6,16,[]
9,The flag of Libya was originally introduced in...,The flag of Libya is a plain rectangle of whic...,green,1411,1415,[]


TriviaQA validation set size: 14229


### **NQ to SQuAD 2.0**

In [None]:
%%bash -s "$train_nq" "$eval_nq" "$load_cached_datasets"
# nq from google drive
if [[ $3 = "True" ]] && ([[ $1 = "True" && ! -f nq_train.json ]] || [[ $2 = "True" && ! -f nq_dev.json ]])
  then
    wget -nv --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1fD4Lc4_iDsaU77XZ0VNKnpLj_NNufq0-' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1fD4Lc4_iDsaU77XZ0VNKnpLj_NNufq0-" -O nq_datasets.tar.gz  && rm -rf /tmp/cookies.txt 
    if [[ $1 = "True" && ! -f nq_train.json ]]
      then
        tar -zxf nq_datasets.tar.gz NQ/nq_train.json --strip-components=1
    fi
    if [[ $2 = "True" && ! -f nq_dev.json ]]
      then
      tar -zxf nq_datasets.tar.gz NQ/nq_dev.json --strip-components=1
    fi
    rm nq_datasets.tar.gz
fi

2022-03-12 12:55:50 URL:https://doc-10-00-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/sr303b23696aecgcptptv1rgig00f80f/1647089700000/10578725460718190134/*/1fD4Lc4_iDsaU77XZ0VNKnpLj_NNufq0-?e=download [34372360/34372360] -> "nq_datasets.tar.gz" [1]


In [None]:
%%bash -s "$train_nq" "$eval_nq"
# nq from scratch
if [[ $1 = "True" && ! -f nq_train.json ]] || [[ $2 = "True" && ! -f nq_dev.json ]]
  then
    if [[ ! -d qa-dataset-converter ]]
      then
        git clone https://github.com/amazon-research/qa-dataset-converter.git
    fi
    mkdir nq
    cp qa-dataset-converter/nq/nq_to_squad.py nq
    if [[ $1 = "True" && ! -f nq_train.json ]]
      then
        #gsutil ls gs://natural_questions/v1.0/train | head -n 41 | gsutil -m cp -I nq/train
        gsutil -m cp -r gs://natural_questions/v1.0/train nq
        python nq/nq_to_squad.py --nq_dir nq/train/ --output_file nq_train.json
    fi
    if [[ $2 = "True" && ! -f nq_dev.json ]]
      then
        #gsutil ls gs://natural_questions/v1.0/dev | head -n 5 | gsutil -m cp -I nq/dev
        gsutil -m cp -r gs://natural_questions/v1.0/dev nq
        python nq/nq_to_squad.py --nq_dir nq/dev/ --output_file nq_dev.json
    fi
    rm -rf nq
fi

In [None]:
if train_nq:
  nq_training_set = parseSquad('nq_train.json') #.sample(frac = 0.8, random_state = 1).sort_index()
  display(nq_training_set.head(10))
  print(f'NQ training set size: {len(nq_training_set)}')
if eval_nq:
  nq_validation_set = parseSquad('nq_dev.json')
  display(nq_validation_set.head(10))
  print(f'NQ validation set size: {len(nq_validation_set)}')

Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx
0,"""It's Gonna Be Me"" is a song by American boy b...",who sings it's going to be me,NSYNC,50,54
1,It was released in North America on September ...,when does the clown movie it come out,,-1,-1
2,Francis Albert Sinatra (/sɪˈnɑːtrə/; Italian: ...,who sang the song come fly with me,Francis Albert Sinatra,0,21
3,The brightest star is the magnitude 2.2 Alpha ...,what is the brightest star in corona borealis,the magnitude 2.2 Alpha Coronae Borealis,22,61
4,In determining the contours of riparian rights...,who owns the land under this navigable river,state,188,192
5,The photic sneeze reflex (also known as photop...,why do you sneeze if you look at the sun,,-1,-1
6,Red Sea is a direct translation of the Greek E...,where does the red sea get its name,,-1,-1
7,The Battle of Hastings[a] was fought on 14 Oct...,who fought who in the battle of hastings,,-1,-1
8,"At 1,564,116 km2 (603,909 sq mi), Mongolia is ...",where is mongolia located on the world map,,-1,-1
9,Harry Potter and the Goblet of Fire is the fou...,when does harry potter and the goblet of fire ...,,-1,-1


NQ training set size: 110865


Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx,other_answers
0,Mandalay Bay is a 43-story luxury resort and c...,who is the owner of the mandalay bay in vegas,MGM Resorts International,124,148,[]
1,"A kick-off is used to start each half of play,...",who kicks the ball first to start a football game,awarded to the team that lost the pre-game coi...,127,178,[the team that lost the pre-game coin toss]
2,"In 1865, Everest was given its official Englis...",mount everest how did it get its name,,-1,-1,
3,Players are currently inducted into the Hall o...,who votes in the baseball hall of fame,either the Baseball Writers' Association of Am...,73,162,[the Baseball Writers' Association of America ...
4,Taylor Hayes is a fictional character from the...,who played taylor on the bold and beautiful,Hunter Tylo,112,122,[]
5,An acronym is a word or name formed as an abbr...,what do you call initials that stand for somet...,acronym,3,9,[An acronym]
6,"Louis XIII's successor, Louis XIV, had a great...",who expanded the palace of versailles to its p...,Louis XIV,24,32,[]
7,The direct involvement of the Netherlands in W...,when did holland become involved in world war 2,10 May 1940,101,111,"[15 May 1940, with its invasion by Nazi German..."
8,"In chemistry, Roman numerals are often used to...",when do you use the roman numerals in chemistry,to denote the groups of the periodic table,44,85,[in the IUPAC nomenclature of inorganic chemis...
9,"In 1975, the F-150 was introduced in between t...",when was the first ford f 150 made,1975,3,6,[]


NQ validation set size: 3369


### **QuAC to SQuAD 2.0**

In [None]:
%%bash -s "$train_quac" "$eval_quac" "$load_cached_datasets"
# quac from google drive
if [[ $3 = "True" ]] && ([[ $1 = "True" && ! -f quac_train.json ]] || [[ $2 = "True" && ! -f quac_dev.json ]])
  then
    wget -nv --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1aLD34ve_fWNJbjlHfqU9s1CZF9B0J2Ua' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1aLD34ve_fWNJbjlHfqU9s1CZF9B0J2Ua" -O quac_datasets.tar.gz  && rm -rf /tmp/cookies.txt 
    if [[ $1 = "True" && ! -f quac_train.json ]]
      then
        tar -zxf quac_datasets.tar.gz QuAC/quac_train.json --strip-components=1
    fi
    if [[ $2 = "True" && ! -f quac_dev.json ]]
      then
      tar -zxf quac_datasets.tar.gz QuAC/quac_dev.json --strip-components=1
    fi
    rm quac_datasets.tar.gz
fi

2022-03-12 12:56:02 URL:https://doc-0c-00-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/chb7qu4rfgqvitv97p0tr3qfr6fgjath/1647089700000/10578725460718190134/*/1aLD34ve_fWNJbjlHfqU9s1CZF9B0J2Ua?e=download [18228196/18228196] -> "quac_datasets.tar.gz" [1]


In [None]:
%%bash -s "$train_quac" "$eval_quac"
# quac from scratch
if [[ $1 = "True" && ! -f quac_train.json ]] || [[ $2 = "True" && ! -f quac_dev.json ]]
  then
    if [[ ! -d qa-dataset-converter ]]
      then
        git clone https://github.com/amazon-research/qa-dataset-converter.git
    fi
    mkdir quac
    cp qa-dataset-converter/quac/quac_to_squad.py quac
    if [[ $1 = "True" && ! -f quac_train.json ]]
      then
        wget https://s3.amazonaws.com/my89public/quac/train_v0.2.json -P quac
        python quac/quac_to_squad.py --quac_file quac/train_v0.2.json --output_file quac_train.json
    fi
    if [[ $2 = "True" && ! -f quac_dev.json ]]
      then
        wget https://s3.amazonaws.com/my89public/quac/val_v0.2.json -P quac
        python quac/quac_to_squad.py --quac_file quac/val_v0.2.json --output_file quac_dev.json
    fi
    rm -rf quac
fi

In [None]:
if train_quac:
  quac_training_set = parseSquad('quac_train.json') #.sample(frac = 0.8, random_state = 1).sort_index()
  display(quac_training_set.head(10))
  print(f'QuAC training set size: {len(quac_training_set)}')
if eval_quac:
  quac_validation_set = parseSquad('quac_dev.json')
  display(quac_validation_set.head(10))
  print(f'QuAC validation set size: {len(quac_validation_set)}')

Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx
0,"According to the Indian census of 2001, there ...",Where is Malayali located?,"30,803,747 speakers of Malayalam in Kerala, ma...",51,161
1,"According to the Indian census of 2001, there ...",What other languages are spoken there?,"33,015,420 spoke the standard dialects, 19,643...",640,775
2,"According to the Indian census of 2001, there ...",What else is this place known for?,"World Malayalee Council, the organisation work...",1862,2022
3,"According to the Indian census of 2001, there ...",Were they ever successful in doing this?,,-1,-1
4,"According to the Indian census of 2001, there ...",Do they produce anything from here?,,-1,-1
5,"According to the Indian census of 2001, there ...",Is this population still growing?,"In 2010, the Census of Population of Singapore...",1461,1563
6,"According to the Indian census of 2001, there ...",Is the country thriving?,,-1,-1
7,Malayalam is the language spoken by the Malaya...,what language do they speak?,Malayalam is the language spoken by the Malaya...,0,49
8,Malayalam is the language spoken by the Malaya...,Do they speak any other languages?,Malayalam is derived from old Tamil and Sanskr...,51,118
9,Malayalam is the language spoken by the Malaya...,any literary items of interest?,Malayalam literature is ancient in origin. The...,478,596


QuAC training set size: 83568


Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx,other_answers
0,"In May 1983, she married Nikos Karvelas, a com...",what happened in 1983?,"In May 1983, she married Nikos Karvelas, a com...",0,51,[]
1,"In May 1983, she married Nikos Karvelas, a com...",did they have any children?,in November she gave birth to her daughter Sofia.,92,140,[]
2,"In May 1983, she married Nikos Karvelas, a com...",did she have any other children?,,-1,-1,
3,"In May 1983, she married Nikos Karvelas, a com...",what collaborations did she do with nikos?,"Since 1975, all her releases have become gold ...",213,306,[]
4,"In May 1983, she married Nikos Karvelas, a com...",what influences does he have in her music?,,-1,-1,
5,"In May 1983, she married Nikos Karvelas, a com...",what were some of the songs?,"one of her most famous songs, titled ""Dodeka"" ...",879,944,[]
6,"In May 1983, she married Nikos Karvelas, a com...",how famous was it?,reached gold status selling 80.000 units.,950,990,[]
7,"In May 1983, she married Nikos Karvelas, a com...",did she have any other famous songs?,"The album included the hit Pragmata (""Things"")...",1049,1113,[]
8,"In September 2016 Vladimir Markin, official sp...",Did they have any clues?,probably FSB) are known to have targeted the w...,1908,2022,[]
9,"In September 2016 Vladimir Markin, official sp...",How did they target her email?,"On 5 December 2005, RFIS initiated an attack a...",2024,2150,[]


QuAC validation set size: 7354


### **NewsQA to SQuAD 2.0**

In [None]:
%%bash -s "$train_newsqa" "$eval_newsqa" "$load_cached_datasets"
# newsqa from google drive
if [[ $3 = "True" ]] && ([[ $1 = "True" && ! -f newsqa_train.json ]] || [[ $2 = "True" && ! -f newsqa_dev.json ]])
  then
    wget -nv --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1qgtRKYYJNuY6fr_fz86ZLh3daN91aB8i' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1qgtRKYYJNuY6fr_fz86ZLh3daN91aB8i" -O newsqa_datasets.tar.gz  && rm -rf /tmp/cookies.txt 
    if [[ $1 = "True" && ! -f newsqa_train.json ]]
      then
        tar -zxf newsqa_datasets.tar.gz NewsQA/newsqa_train.json --strip-components=1
    fi
    if [[ $2 = "True" && ! -f newsqa_dev.json ]]
      then
      tar -zxf newsqa_datasets.tar.gz NewsQA/newsqa_dev.json --strip-components=1
    fi
    rm newsqa_datasets.tar.gz
fi

2022-03-12 12:56:12 URL:https://doc-0o-00-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/ve0e0mh4jvcmmlrcra3ji3ku3hsauii2/1647089700000/10578725460718190134/*/1qgtRKYYJNuY6fr_fz86ZLh3daN91aB8i?e=download [141199792/141199792] -> "newsqa_datasets.tar.gz" [1]


In [None]:
%%bash -s "$train_newsqa" "$eval_newsqa"
# newsqa from scratch
if [[ $1 = "True" && ! -f newsqa_train.json ]] || [[ $2 = "True" && ! -f newsqa_dev.json ]]
  then
    if [[ ! -d qa-dataset-converter ]]
      then
        git clone https://github.com/amazon-research/qa-dataset-converter.git
    fi
    mkdir newsqa
    cp qa-dataset-converter/newsqa/newsqa_to_squad.py newsqa
    # split_data has been generated locally as shown in https://github.com/Maluuba/newsqa and uploaded to google drive
    wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=18-N4PQHEcxn464o6au6THteBmUGkqEzY' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=18-N4PQHEcxn464o6au6THteBmUGkqEzY" -O newsqa/split_data.tar.gz  && rm -rf /tmp/cookies.txt
    tar -xf newsqa/split_data.tar.gz -C newsqa
    if [[ $1 = "True" && ! -f newsqa_train.json ]]
      then
        python newsqa/newsqa_to_squad.py --newsqa_file newsqa/split_data/train.csv --output_file newsqa_train.json
    fi
    if [[ $2 = "True" && ! -f newsqa_dev.json ]]
      then
        python newsqa/newsqa_to_squad.py --newsqa_file newsqa/split_data/dev.csv --output_file newsqa_dev.json
    fi
    rm -rf newsqa
fi

In [None]:
if train_newsqa:
  newsqa_training_set = parseSquad('newsqa_train.json') #.sample(frac = 0.8, random_state = 1).sort_index()
  display(newsqa_training_set.head(10))
  print(f'NewsQA training set size: {len(newsqa_training_set)}')
if eval_newsqa:
  newsqa_validation_set = parseSquad('newsqa_dev.json')
  display(newsqa_validation_set.head(10))
  print(f'NewsQA validation set size: {len(newsqa_validation_set)}')

Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx
0,"NEW DELHI , India -LRB- CNN -RRB- -- A high co...",What was the amount of children murdered ?,19,305,306
1,-LRB- CNN -RRB- -- Fighting in the volatile Su...,Where was one employee killed ?,Sudanese region of Darfur,44,68
2,Johannesburg -LRB- CNN -RRB- -- Miffed by a vi...,who did say South Africa did not issue a visa ...,Archbishop Desmond Tutu,114,136
3,-LRB- CNN -RRB- -- England international footb...,How many years old was the businessman ?,29-year-old,540,550
4,"BAGHDAD , Iraq -LRB- CNN -RRB- -- At least 6,0...",What frightened the families ?,a series of killings and threats by Muslim ext...,697,797
5,-LRB- CNN -RRB- -- Pope John Paul II used to b...,what Pope used to beat himself ?,John Paul II,24,35
6,CNN affiliates report on where job seekers are...,Who is hiring ?,the federal government,307,328
7,WASHINGTON -LRB- CNN -RRB- -- One of the Marin...,What war was the Iwo Jima battle a part of ?,World War II,67,78
8,-LRB- CNN -RRB- -- Jewish organizations called...,Who is Radu Mazare ?,mayor of the town of Constanta,203,232
9,-LRB- CNN -RRB- -- A phone hacking scandal may...,How many followers does Rupert have ?,45000,339,344


NewsQA training set size: 92549


Unnamed: 0,context,question,answer,answer_start_idx,answer_end_idx,other_answers
0,"TEHRAN , Iran -LRB- CNN -RRB- -- Iran 's parli...",Iran criticizes who ?,U.S. President-elect Barack Obama,75,107,[]
1,"LONDON , England -LRB- CNN -RRB- -- Israeli mi...",What happened to the U.N. compound ?,hit and set on fire,3246,3264,[]
2,WASHINGTON -LRB- CNN -RRB- -- There are no imm...,Who said there is no immediate plans for deplo...,President Obama,122,136,[]
3,"LOS ANGELES , California -LRB- CNN -RRB- -- Fo...",Will Lieberman investigate further ?,intends to follow up with,1980,2004,[]
4,-LRB- CNN -RRB- -- A Colorado prosecutor Frida...,Who spent nine years in prison ?,Tim Masters,112,122,[]
5,-LRB- CNN -RRB- -- Women in Somalia 's third-l...,Women who do n't conform will risk spending ho...,12 hours in,1040,1050,[]
6,-LRB- Mental Floss -RRB- -- If you think comic...,Who did Superman battle in `` Clan of the Fier...,Ku Klux,190,196,[]
7,-LRB- CNN -RRB- -- It was just after midday on...,Who was greeted in Seoul ?,the announcement,246,261,[]
8,-LRB- CNN -RRB- -- Dr. Rajiv Shah President Ob...,Where did the deadly earthquake happen ?,Haiti,218,222,[]
9,"LONDON , England -LRB- CNN -RRB- -- After a we...",What did Steve Bruce describe Amire Zaki as ?,unprofessional,886,899,[]


NewsQA validation set size: 5166


##**Fine tuning Bert for question answering in SQuAD 2.0**

Checking for **GPU availability**. GPU is a very powerful resource for running similar tasks which use deep neural architectures like transformers

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Running on GPU/Cuda' if torch.cuda.is_available() else 'Running on CPU')

Running on GPU/Cuda


The two most common **metrics** to evaluate the question answering performance in SQuAD are the **Exact Match score** and the **f1 score**. The implementation of those scores is inspired by the squad_metrics.py of the Hugging Face library

In [None]:
def normalize_text(text):
    # Lowercasing and punctuation removal
    normalized_text = "".join(ch.lower() for ch in text if ch not in set(string.punctuation))
    # Article removal
    article_regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    normalized_text = re.sub(article_regex, " ", normalized_text)
    # Whitespace removal
    normalized_text = " ".join(normalized_text.split())

    return normalized_text

def exact_match_score_metric(pred_answer, gold_answer):
    return int(normalize_text(pred_answer) == normalize_text(gold_answer))

def f1_score_metric(pred_answer, gold_answer):
  pred_tokens = normalize_text(pred_answer).split()
  gold_tokens = normalize_text(gold_answer).split()
  
  # if either the prediction or the ground truth is no-answer then f1 = 1 if they agree and 0 otherwise
  if len(pred_tokens) == 0 or len(gold_tokens) == 0:
      return int(pred_tokens == gold_tokens)
  
  common = Counter(gold_tokens) & Counter(pred_tokens)
  num_same = sum(common.values())
  
  # If there are no common tokens then f1 = 0
  if num_same == 0:
    return 0
  
  precision = 1.0 * num_same / len(pred_tokens)
  recall = 1.0 * num_same / len(gold_tokens)
  
  return 2 * (precision * recall) / (precision + recall)

Useful function for  **perfomance visualization**. This function is used for plotting **learning curves** (loss vs epochs)

In [None]:
def plotLearningCurves(x_axis, y_axis, x_label, y_label, curve_ids, legend_loc = 'lower right', x_best = None, y_best = None):
  fig = plt.figure(figsize = (7, 7))
  if x_best is not None and y_best is not None:
    ax = fig.add_subplot(111)
    ax.annotate('Best score ' + str(x_best), xy = (x_best, y_best), arrowprops = dict(facecolor = 'black', shrink = 0.05))
  for y, id in zip(y_axis, curve_ids):
    plt.plot(x_axis, y, label = id,  linewidth = 3)
  plt.title('Learning Curves')
  plt.xlabel(x_label)
  plt.ylabel(y_label)
  plt.legend(loc = legend_loc)
  plt.show()

 Popular **learning rate scheldulers**. **Linear warmup** schelduler is the most used by Bert

In [None]:
scheldulers = {
  'Exponential': lambda optimizer, step: torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1 ** (epoch / step)), # Initial lr* 0.1^(t/step)
  'Power': lambda optimizer, step: torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1 / (1 + epoch / step)), # Initial lr/ (1 + t/step)
  'Piecewise': lambda optimizer, milestones:  torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = milestones, gamma = 0.1), # lr for 1 <= epoch < m1, lr*0.1 for m1 <= epoch < m2 ...
  'Performance': lambda optimizer, patience, factor : torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', patience = patience, factor = factor),
  'OneCycle': lambda optimizer, steps_per_epoch, epochs, max_lr : torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = max_lr, anneal_strategy = 'linear', steps_per_epoch = steps_per_epoch, epochs = epochs),
  'LinearWarmup': get_linear_schedule_with_warmup
}

Some useful **building blocks** are specified in order to help with the development of the actual Bert model. At first, the initial SQuAD datasets must be **transformed** to a format that Bert "understands" (e.g token ids, segment ids, attention mask). For that reason a new **custom dataset** is created that **keeps the reference** to the initial dataset and stores all the SQuAD records of that initial dataset into some encoded Bert based records. The encoding is achieved using the encode for Bert method. This method performs the **tokenization** and the **truncation** to produce the **input tokens** and also creates the **segment ids** and the **attention mask** from the extracted tokens. Meanwhile, it **translates/maps** the above character based answer spans of the context to **token based spans** in order to indicate the start and the end positions to the Bert model. If and only if this mapping information is necessary and is going to be used later (for memory usage reduction reasons) then it is stored in a seperate list. All the rest information is eventually stored in a SQuAD record and this record is used as input to Bert. Each record keeps a unique id as well to perform a functionality that also requires a mapping and is **related to the initial dataset** (e.g building an answer span from predictions for a specific record of the dataset). Finally, one problem is that the input lengths might differ and some padding must also be applied. In order to reduce **memory space** and **excecution time** (and better results) some external functions are specified to apply the **padding dynamically**. Dynamic padding adds some 0s to all the above encodings up to the **longest length** in the currently processed by the model batch.

In [None]:
# Custom dataset for SQuAD
class BertSquadDataset(torch.utils.data.Dataset):
  # A SQuAD record factory/generator 
  BertSquadRecord = namedtuple('BertSquadRecord', ['input_ids', 'token_type_ids', 'attention_mask', 'start_position', 'end_position', 'squad_id'])
  
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

  def __init__(self, squad_df, max_length = None, preserve_token_mappings = True):
    self.squad_records = []
    self.offset_mappings = [] if preserve_token_mappings else None
    self.max_length = max_length
    self.squad_df = squad_df

    for row in squad_df.itertuples():
      self.add_item(row.context, row.question, row.answer_start_idx, row.answer_end_idx, row.Index)
      
  def __len__(self):
    return len(self.squad_records)

  def __getitem__(self, idx):
    return self.squad_records[idx]

  def add_item(self, context, question, answer_start_idx, answer_end_idx, squad_id = None):
    encoding = BertSquadDataset.encodeSquadForBert(context, question, answer_start_idx, answer_end_idx, self.max_length)
    offset_mapping = encoding.pop('offset_mapping')
    if self.offset_mappings is not None:
      self.offset_mappings.append(offset_mapping)
    self.squad_records.append(BertSquadDataset.BertSquadRecord(**encoding, squad_id = squad_id))
  
  # Constructs an answer for a specific record in the dataset given a span (it is mostly used to construct the predicted answer of Bert)
  def buildBertAnswer(self, idx, start_position, end_position):
    if self.offset_mappings is None:
      raise Exception('This dataset does not preserve any mapping to build an answer. Please create a dataset with preserve_token_mappings = True')

    # Start position cannot be after the end position
    if start_position > end_position:
      return None
    
    # No answer
    if start_position == 0 and end_position == 0:
      return ''

    # The character based offset mappings and the ids of the tokens that are already calculated by the tokenizer and stored for the specific squad record
    offset_mapping = self.offset_mappings[idx]
    input_ids = self.squad_records[idx].input_ids

    # The first context token is located right after the first occurance of the SEP token
    context_start_token_idx = input_ids.index(BertSquadDataset.tokenizer.sep_token_id) + 1
    # The last context token is located right before the last SEP token
    context_end_token_idx = len(input_ids) - 2

    # If the start of the span is out of the context then there is no answer (it is not acceptable for an answer to start in a question or a padding area)
    if start_position < context_start_token_idx or start_position > context_end_token_idx:
      return None
    
    # If the end of the span is out of the context then it can be truncated to the end of the context (i.e whenever the answer contains some padding it can be removed)
    end_position = min(end_position, context_end_token_idx)

    # Calculation of the answer start and end characters according to the offset mapping
    answer_start_idx = offset_mapping[start_position][0]
    answer_end_idx = offset_mapping[end_position][1]

    context = self.squad_df['context'][idx]
    answer_in_context = context[answer_start_idx:answer_end_idx]

    return answer_in_context
  
  @staticmethod
  def encodeSquadForBert(context, question, answer_start_idx, answer_end_idx, max_length = None):
    # Whenever max length is given then truncation is applied only to context/passage
    truncation = 'only_second' if max_length else False

    encoding = BertSquadDataset.tokenizer(question, context, max_length = max_length, truncation = truncation, return_offsets_mapping = True)
    # Initialization of the start and end positions ((0, 0) indicates the CLS token which is used when an answer cannot be found)
    encoding.update({'start_position': 0, 'end_position': 0})
    # This mapping will be used instead of the char_to_token and token_to_char functions of the encoding for more flexibility
    offset_mapping = encoding['offset_mapping']

    # If either the start or the end of the answer is missing then start and end positions cannot be found 
    if answer_start_idx < 0 or answer_end_idx < 0:
      return encoding
    
    # The first context token is located right after the first occurance of the SEP token
    context_start_token_idx = encoding['input_ids'].index(BertSquadDataset.tokenizer.sep_token_id) + 1
    # The last context token is located right before the last SEP token
    context_end_token_idx = len(encoding['input_ids']) - 2
    
    context_start_idx = offset_mapping[context_start_token_idx][0]
    context_end_idx = offset_mapping[context_end_token_idx][1] - 1

    # If the answer is out of the context bounds (e.g has been truncated) there is no reason to search for its start and end positions
    if context_start_idx > answer_start_idx or context_end_idx < answer_end_idx:
      return encoding
    
    # Start and end positions are calculated in such a way to cover the whole answer without losing information
    # But sometimes a bad tokenization can add a little noise if some tokens include both some characters of the answer and some characters out of the answer

    start_position = context_start_token_idx
    while offset_mapping[start_position][1] - 1 < answer_start_idx:
      start_position += 1
    
    end_position = context_end_token_idx
    while offset_mapping[end_position][0] > answer_end_idx:
      end_position -= 1
    
    encoding.update({'start_position': start_position, 'end_position': end_position})

    return encoding

def padSequence(batch, padding_value = 0): 
  max_len = max([len(x) for x in batch])
  batch_padded = []
  for x in batch:
    x_padded = [padding_value] * max_len
    x_padded[:len(x)] = x
    batch_padded.append(x_padded)
    
  return batch_padded

# Collate function that is going to be used by the Dataloader in order to add padding to each batch
def padCollate(batch, padding_value = 0):
  (batch_input_ids, batch_token_type_ids, batch_attention_masks, batch_start_positions, batch_end_positions, batch_ids) = zip(*batch)

  batch_labels = {
      'start_positions': torch.tensor(batch_start_positions, device = device),
      'end_positions': torch.tensor(batch_end_positions, device = device)
  } 
  
  # Padding all the sequences to have equal lengths (the length is equal to the length of the longest sequence in the batch)
  # Padding is applied to the segment ids and the attention masks as well
  batch_padded_features = {'input_ids': padSequence(batch_input_ids), 
                           'token_type_ids': padSequence(batch_token_type_ids),
                           'attention_mask': padSequence(batch_attention_masks)
                          }
  
  return BatchEncoding(batch_padded_features, tensor_type = 'pt').to(device), batch_labels, batch_ids
          

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

This is the actual class of the **Bert for SQuAD Question Answering**. This class is a wrapper of the **BertForQuestionAnswering** which is already implemented and provided by the Hugging Face library and has been extended with some extra functionalities in order to automate all the training and evaluation process as much as possible. Especially, at the evaluation stage the above metrics (Extact Match and f1 score) are incoorporated for **calculating the scores** for the every single record in the validation/dev dataset. In order the model to predict better answers there is a kind of a **mini exhaustive search** between the **top k** (k is predefined and it can be an extra hyperparameter) possible and fisible spans. Furthermore, there are two functions for **training**, both supporting **gradient clipping**, but the one with the more options (stats, early stopping etc...) it is going to be used below. Moreover, these functions support **learning rate schelduling** and they call in every epoch the **batch** training submethod to train all the batches. Finally, there is one method for making predictions for any given unseen input but using a naive method to extract them and one method for saving the model (config & weights) locally

In [None]:
class BertForSquad(nn.Module):

    def __init__(self, bert_model = 'bert-base-uncased'):
        super().__init__()
        
        # Instantiate Bert for QA
        self.bert = BertForQuestionAnswering.from_pretrained(bert_model)

    def forward(self, bertInputConfig, start_positions = None, end_positions = None):

        out = self.bert(**bertInputConfig, start_positions = start_positions, end_positions = end_positions)
            
        return out
    
    def batchTrain(self, batch_loader, optimizer, schelduler = None, max_clip_norm = None):
      total_loss = 0.0

      for x_batch, y_batch, _ in tqdm(batch_loader, 'Batch Training'):
        # Gradients are restored to zero
        optimizer.zero_grad()
        
        pred = self(x_batch, **y_batch)

        loss = pred.loss

        total_loss += loss.item()
        
        # Backpropagation starting from the loss calculated in this epoch
        loss.backward()
        # Gradient clipping
        if max_clip_norm is not None:
          nn.utils.clip_grad_norm_(self.parameters(), max_clip_norm)
        # Model's weights update based on the gradients calculated during backprop
        optimizer.step()

        # Some scheldulers can make steps inside the batches such as OneCycle schelduler
        if schelduler is not None:
          schelduler.step()
          #print(schelduler.state_dict())

      return total_loss / len(batch_loader) # Average loss
    
    def batchEval(self, batch_loader, search_best_span_limit = 20):
          total_loss = 0.0
          total = 0.0
          em_scores = 0.0
          f1_scores = 0.0

          with torch.no_grad():
            for x_batch, y_batch, batch_ids in tqdm(batch_loader, 'Batch Evaluation'):
              pred = self(x_batch, **y_batch)

              loss = pred.loss
              start_logits = pred.start_logits
              end_logits = pred.end_logits
                
              total_loss += loss.item() 

              # Instead of using `pred_start_positions = torch.argmax(start_logits, dim = 1)` to find the best start that might be invalid, the range is extended to the top k
              pred_start_indices = torch.topk(start_logits, search_best_span_limit, dim = 1).indices.cpu().detach().numpy()
              # Instead of using `pred_end_positions = torch.argmax(end_logits, dim = 1)` to find the best end that might be invalid, the range is extended to the top k
              pred_end_indices = torch.topk(end_logits, search_best_span_limit, dim = 1).indices.cpu().detach().numpy()

              #pred_spans = torch.cat([pred_start_positions.unsqueeze(1), pred_end_positions.unsqueeze(1)], dim = 1).cpu().detach().numpy()

              for i, (starts, ends) in enumerate(zip(pred_start_indices, pred_end_indices)):
                id = batch_ids[i]
                pred_answer = None
                best_span_score = float('-inf')
                # Searching for a fisible span with the best score
                for start in starts: 
                  for end in ends:
                    sum = start_logits[i][start] + end_logits[i][end]
                    if sum.item() > best_span_score:
                      span_text = batch_loader.dataset.buildBertAnswer(id, start, end)
                      if span_text is not None: # If answer is valid
                        pred_answer = span_text
                        best_span_score = sum
                
                if pred_answer is None: # If no valid answer found at all
                  pred_answer = '' 

                answer = batch_loader.dataset.squad_df['answer'][id]
                if answer is not None:
                  other_answers = batch_loader.dataset.squad_df['other_answers'][id] or []
                  all_answers = [*other_answers, answer]
                else:
                  all_answers = ['']
                # Calculation of the Exact Match score and f1 score between the predicted answer and all the gold/ground truth answers  
                em_scores += max((exact_match_score_metric(pred_answer, answer)) for answer in all_answers)
                f1_scores += max((f1_score_metric(pred_answer, answer)) for answer in all_answers)
                total += 1
              
          return [total_loss / len(batch_loader), em_scores / total, f1_scores / total]
        
    def fineTuning(self, training_set, optimizer, lr_schelduler = None, max_clip_norm = None, epochs = 3, batch_size = 32):
      # Switching to training mode
      self.train()

      loss_scores = []

      # Training batches
      batch_loader = torch.utils.data.DataLoader(training_set, batch_size = batch_size, shuffle = True, collate_fn = padCollate)
      
      # Creating the right schelduling function
      if lr_schelduler is not None:
        if lr_schelduler['name'] != 'Performance':
          schelduler = lr_schelduler['schelduler']
        else:
          raise Exception('Performance schelduling cannot be applied')

      for epoch in range(1, epochs + 1):
        print(f'\nEpoch: {epoch}')
        if lr_schelduler is not None and not lr_schelduler['epoch_step']:  # For OneCycle schelduler for example is better to make steps inside the batches
          loss_score = self.batchTrain(batch_loader, optimizer, schelduler, max_clip_norm)
        else:
          loss_score = self.batchTrain(batch_loader, optimizer, None, max_clip_norm)
    
        # Learning rate modification for the next epoch
        if lr_schelduler is not None and lr_schelduler['epoch_step']:
          schelduler.step()
          #print(schelduler.state_dict())

        loss_scores.append(loss_score)
        print(f'Loss = {loss_score:.5f}')
        
      return {'train_losses': loss_scores}

    def fineTuningWithOptions(self, training_set, optimizer, validation_set = None, lr_schelduler = None, max_clip_norm = None, epochs = 3, batch_size = 32, stats = False, early_stopping_params = None):
      if validation_set is not None and (stats != False or early_stopping_params is not None or (lr_schelduler is not None and lr_schelduler['name'] == 'Performance')): 
        train_losses = []
        valid_losses = []
        valid_em_scores = []
        valid_f1_scores = []
        best_score = float('inf')
        best_iter = 1
        counter = 0
        epoch = 1

        # Training batches
        training_batch_loader = torch.utils.data.DataLoader(training_set, batch_size = batch_size, shuffle = True, collate_fn = padCollate)

        # Validation batches
        validation_batch_loader = torch.utils.data.DataLoader(validation_set, batch_size = batch_size, collate_fn = padCollate)

        # Creating the right schelduling function
        if lr_schelduler is not None:
          schelduler = lr_schelduler['schelduler']
        
        for epoch in range(1, epochs + 1): 
          if stats: 
            print(f'\nEpoch: {epoch}')
          # Switching to training mode
          self.train()
          # Batch training and calculation of the average training batch scores (accuracy score, loss, f1 score)
          if lr_schelduler is not None and not lr_schelduler['epoch_step']:  # For OneCycle schelduler for example is better to make steps inside the batches
            train_loss = self.batchTrain(training_batch_loader, optimizer, schelduler, max_clip_norm)
          else:
            train_loss = self.batchTrain(training_batch_loader, optimizer, None, max_clip_norm)

          train_losses.append(train_loss)

          # Switching to evaluation mode
          self.eval()
          # Calculation of validation batch scores
          valid_loss, valid_em_score, valid_f1_score =  self.batchEval(validation_batch_loader)

          valid_losses.append(valid_loss)
          valid_em_scores.append(valid_em_score)
          valid_f1_scores.append(valid_f1_score)

          if valid_loss < best_score:
            best_score = valid_loss
            best_iter = epoch
            counter = 0
          else: 
            counter += 1

          # If validation set starts increasing by reaching a small patience then maybe the model starts to overfit. So model training has to stop as soon as possible
          if early_stopping_params is not None and counter == early_stopping_params['patience']:
            break
          
          # Learning rate modification for the next epoch
          if lr_schelduler is not None:
            if lr_schelduler['name'] == 'Performance': # Validation loss must be passed to the performance schelduler
              schelduler.step(valid_loss)
            elif lr_schelduler['epoch_step']:
              schelduler.step()
            #print(schelduler.state_dict())

          if stats: 
            print(f'Training: Loss = {train_loss:.5f}')
            print(f'Validation: Loss = {valid_loss:.5f}, Exact match = {valid_em_score:.5f}, F1 score = {valid_f1_score:.5f}')
        
        if stats:
          # Learning curve plot
          plotLearningCurves(list(range(1, epoch + 1)), [valid_losses, train_losses], 'Number of Iterations', 'Loss', ['validation', 'training'], 'upper right', best_iter, best_score)
          
          return {'train_losses': train_losses, 'valid_losses': valid_losses, 'valid_em_scores': valid_em_scores, 'valid_f1_scores': valid_f1_scores}
          
      elif validation_set is None and (stats != False or early_stopping_params is not None or (lr_schelduler is not None and lr_schelduler['name'] == 'Performance')): 
        raise Exception('Validation set must not be None in order to calculate stats or to do early stopping or to do performance schelduling')
      else: 
        return self.fineTuning(training_set, optimizer, lr_schelduler, max_clip_norm, epochs, batch_size)
    
    def predict(self, X):
      # Switching to evaluation mode
      self.eval()
      with torch.no_grad():
          # Model computes the span of each answer
          pred = self(X)
          
          start_logits = pred.start_logits
          end_logits = pred.end_logits
          
          # A naive solution for extracting the answer
          pred_start_positions = torch.argmax(start_logits, dim = 1)
          pred_end_positions = torch.argmax(end_logits, dim = 1)

          pred_spans = torch.cat([pred_start_positions.unsqueeze(1), pred_end_positions.unsqueeze(1)], dim = 1)

          return pred_spans

    def save(self, dir = 'model'):
      self.bert.save_pretrained(dir)

Functions for **fine tuning** and **evaluation** on SQuAD 2.0

In [None]:
def SquadEvaluation(model, test_set, batch_size = 8):
  # Testing batches
  test_batch_loader = torch.utils.data.DataLoader(test_set, batch_size = batch_size, collate_fn = padCollate)
  model.eval()
  loss, exact_match, f1_score = model.batchEval(test_batch_loader)
  return {'Loss': loss, 'EM': exact_match, 'F1': f1_score}

def SquadFineTuning(params, training_set, validation_set = None):
  
  model = BertForSquad('bert-base-uncased').to(device)
      
  learning_rate = params['learning_rate']
  # AdamW is an improved version of Adam to better handle the weight decay factor (thats why it is prefered against Adam)
  # Default values of AdamW are the same with the values used at the pretraining phase of Bert
  optimizer = torch.optim.AdamW(model.parameters(), learning_rate['val'], weight_decay = 1e-2)

  schelduler = None
  # Creation of the schelduler with some default values (e.g steps, etc..)
  if learning_rate['schelduler']:
    total_steps = math.ceil(len(training_set) / params['batch_size']) * params['epochs']
    if learning_rate['schelduler'] == 'Exponential':
      schelduler = {'name':  'Exponential', 'schelduler': scheldulers['Exponential'](optimizer, 5), 'epoch_step': True}
    if learning_rate['schelduler'] == 'Power':
      schelduler = {'name':  'Power', 'schelduler': scheldulers['Power'](optimizer, 5), 'epoch_step': True}
    if learning_rate['schelduler'] == 'Piecewise':
      schelduler = {'name':  'Piecewise', 'schelduler': scheldulers['Piecewise'](optimizer, [0.2 * params['epochs'], 0.5 * params['epochs'], 0.8 * params['epochs']]), 'epoch_step': True} 
    if learning_rate['schelduler'] == 'Performance':
      schelduler = {'name':  'Performance', 'schelduler': scheldulers['Performance'](optimizer, 3, 0.5), 'epoch_step': True}
    if learning_rate['schelduler'] == 'OneCycle':
      schelduler = {'name':  'OneCycle', 'schelduler': scheldulers['OneCycle'](optimizer, total_steps / params['epochs'], params['epochs'], learning_rate['val'] * 100), 'epoch_step': False}
    if learning_rate['schelduler'] == 'Linear': # Linear learning rate decay (thats why the warmup steps are 0) is the most popular schelduler and most used in many common tasks
      schelduler = {'name':  'Linear', 'schelduler': scheldulers['LinearWarmup'](optimizer, num_warmup_steps = 0, num_training_steps = total_steps), 'epoch_step': False}
    if learning_rate['schelduler'] == 'LinearWarmup': # Linear warmup 20% of the total steps 
      schelduler = {'name':  'LinearWarmup', 'schelduler': scheldulers['LinearWarmup'](optimizer, num_warmup_steps = 0.2 * total_steps, num_training_steps = total_steps), 'epoch_step': False} 

  param_list = f'Epochs: {params["epochs"]}, Batch size: {params["batch_size"]}, Learning rate: {learning_rate["val"]}, Schelduler: {learning_rate["schelduler"] if schelduler is not None else None}, Max clipping: {params["max_clip_norm"]}'
  display(Markdown('**' + param_list + '**'))
  
  # Timekeeping the training process
  timestamp_start = time.time()
  if validation_set is None:
    scores = model.fineTuning(training_set, optimizer, schelduler, params['max_clip_norm'], params['epochs'], params['batch_size'])
  else:
    scores = model.fineTuningWithOptions(training_set, optimizer, validation_set, schelduler, params['max_clip_norm'], params['epochs'], params['batch_size'], True, None)
  timestamp_end = time.time()
  print(f'\nTotal time: {timestamp_end - timestamp_start:.2f} seconds')

  return model, scores

## **Model fine tuning and evaluation section using the above datasets**

In this section the goal is to **reproduce** the results as shown in the Table 3 of this [paper](https://arxiv.org/pdf/2004.03490.pdf)

**Training** and **validation** dataset split for all datasets including all the **encoded Bert inputs** of the records. Also a max length value is defined for **length truncation**. This value is specified properly in order not to lose much information bearing in mind the memory limits as well 

In [None]:
max_length = 384

###### SQuAD #######
bert_squad_training_set = BertSquadDataset(squad_training_set, max_length, False) if train_squad and not load_cached_models else None
bert_squad_validation_set = BertSquadDataset(squad_validation_set, max_length) if eval_squad else None

###### TriviaQA #######
bert_triviaqa_training_set = BertSquadDataset(triviaqa_training_set, max_length, False) if train_triviaqa and not load_cached_models else None
bert_triviaqa_validation_set = BertSquadDataset(triviaqa_validation_set, max_length) if eval_triviaqa else None

###### NQ #######
bert_nq_training_set = BertSquadDataset(nq_training_set, max_length, False) if train_nq and not load_cached_models else None
bert_nq_validation_set = BertSquadDataset(nq_validation_set, max_length) if eval_nq else None

###### QuAC #######
bert_quac_training_set = BertSquadDataset(quac_training_set, max_length, False) if train_quac and not load_cached_models else None
bert_quac_validation_set = BertSquadDataset(quac_validation_set, max_length) if eval_quac else None

###### NewsQA #######
bert_newsqa_training_set = BertSquadDataset(newsqa_training_set, max_length, False) if train_newsqa and not load_cached_models else None
bert_newsqa_validation_set = BertSquadDataset(newsqa_validation_set, max_length) if eval_newsqa else None

Creation of a seperate list for every dataset in order to store the calculated **f1 scores** for a fine tuned model on a single dataset (that is also stored)

In [None]:
fine_tuned_on = []
f1_scores_on_squad = []
f1_scores_on_triviaqa = []
f1_scores_on_nq = []
f1_scores_on_quac = []
f1_scores_on_newsqa = []

### **Fine tuning on SQuAD 2.0**

In [None]:
if train_squad:
  if not load_cached_models:
    params = {
      'batch_size': 8,
      'learning_rate': {'val': 5e-5, 'schelduler': 'Linear'},
      'max_clip_norm': None,
      'epochs': 2,
    }
    squad_model, scores = SquadFineTuning(params, bert_squad_training_set, bert_squad_validation_set)
    # F1 score of the last epoch
    f1_score = scores['valid_f1_scores'][-1] if eval_squad else None
    squad_model.save('squad_model')
  else:
    squad_model = BertForSquad('OrfeasTsk/bert-base-uncased-finetuned-squadv2').to(device)
    f1_score = SquadEvaluation(squad_model, bert_squad_validation_set)['F1'] if eval_squad else None
  
  fine_tuned_on.append('SQuAD')
  f1_scores_on_squad.append(f1_score)

Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

Batch Evaluation: 100%|██████████| 1485/1485 [11:55<00:00,  2.07it/s]


#### **SQuAD 2.0 evaluation on TriviaQA**

In [None]:
if train_squad:
  f1_score = SquadEvaluation(squad_model, bert_triviaqa_validation_set)['F1'] if eval_triviaqa else None
  f1_scores_on_triviaqa.append(f1_score)

Batch Evaluation: 100%|██████████| 1779/1779 [19:04<00:00,  1.55it/s]


#### **SQuAD 2.0 evaluation on NQ**

In [None]:
if train_squad:
  f1_score = SquadEvaluation(squad_model, bert_nq_validation_set)['F1'] if eval_nq else None
  f1_scores_on_nq.append(f1_score)

Batch Evaluation: 100%|██████████| 422/422 [03:41<00:00,  1.91it/s]


#### **SQuAD 2.0 evaluation on QuAC**

In [None]:
if train_squad:
  f1_score = SquadEvaluation(squad_model, bert_quac_validation_set)['F1'] if eval_quac else None
  f1_scores_on_quac.append(f1_score)

Batch Evaluation: 100%|██████████| 920/920 [09:51<00:00,  1.56it/s]


#### **SQuAD 2.0 evaluation on NewsQA**

In [None]:
if train_squad:
  f1_score = SquadEvaluation(squad_model, bert_newsqa_validation_set)['F1'] if eval_newsqa else None
  f1_scores_on_newsqa.append(f1_score)

Batch Evaluation: 100%|██████████| 646/646 [06:52<00:00,  1.57it/s]


### **Fine tuning on TriviaQA**

In [None]:
if train_triviaqa:
  if not load_cached_models:
    params = {
      'batch_size': 8,
      'learning_rate': {'val': 5e-5, 'schelduler': 'Linear'},
      'max_clip_norm': None,
      'epochs': 2,
    }
    triviaqa_model, scores = SquadFineTuning(params, bert_triviaqa_training_set, bert_triviaqa_validation_set)
    # F1 score of the last epoch
    f1_score = scores['valid_f1_scores'][-1] if eval_triviaqa else None
    triviaqa_model.save('triviaqa_model')
  else:
    triviaqa_model = BertForSquad('OrfeasTsk/bert-base-uncased-finetuned-triviaqa').to(device)
    f1_score = SquadEvaluation(triviaqa_model, bert_triviaqa_validation_set)['F1'] if eval_triviaqa else None
  
  fine_tuned_on.append('TriviaQA')
  f1_scores_on_triviaqa.append(f1_score)

Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

Batch Evaluation: 100%|██████████| 1779/1779 [19:08<00:00,  1.55it/s]


#### **TriviaQA evaluation on SQuAD 2.0**

In [None]:
if train_triviaqa:
  f1_score = SquadEvaluation(triviaqa_model, bert_squad_validation_set)['F1'] if eval_squad else None
  f1_scores_on_squad.append(f1_score)

Batch Evaluation: 100%|██████████| 1485/1485 [11:38<00:00,  2.13it/s]


#### **TriviaQA evaluation on NQ**

In [None]:
if train_triviaqa:
  f1_score = SquadEvaluation(triviaqa_model, bert_nq_validation_set)['F1'] if eval_nq else None
  f1_scores_on_nq.append(f1_score)

Batch Evaluation: 100%|██████████| 422/422 [03:39<00:00,  1.92it/s]


#### **TriviaQA evaluation on QuAC**

In [None]:
if train_triviaqa:
  f1_score = SquadEvaluation(triviaqa_model, bert_quac_validation_set)['F1'] if eval_quac else None
  f1_scores_on_quac.append(f1_score)

Batch Evaluation: 100%|██████████| 920/920 [09:47<00:00,  1.57it/s]


#### **TriviaQA evaluation on NewsQA**

In [None]:
if train_triviaqa:
  f1_score = SquadEvaluation(triviaqa_model, bert_newsqa_validation_set)['F1'] if eval_newsqa else None
  f1_scores_on_newsqa.append(f1_score)

Batch Evaluation: 100%|██████████| 646/646 [06:54<00:00,  1.56it/s]


### **Fine tuning on NQ**

In [None]:
if train_nq:
  if not load_cached_models:
    params = {
      'batch_size': 24,
      'learning_rate': {'val': 3e-5, 'schelduler': 'Linear'},
      'max_clip_norm': None,
      'epochs': 2,
    }
    nq_model, scores = SquadFineTuning(params, bert_nq_training_set, bert_nq_validation_set)
    # F1 score of the last epoch
    f1_score = scores['valid_f1_scores'][-1] if eval_nq else None
    nq_model.save('nq_model')
  else:
    nq_model = BertForSquad('OrfeasTsk/bert-base-uncased-finetuned-nq-large-batch').to(device)
    f1_score = SquadEvaluation(nq_model, bert_nq_validation_set)['F1'] if eval_nq else None
  
  fine_tuned_on.append('NQ')
  f1_scores_on_nq.append(f1_score)

Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

Batch Evaluation: 100%|██████████| 422/422 [03:41<00:00,  1.91it/s]


#### **NQ evaluation on SQuAD 2.0**

In [None]:
if train_nq:
  f1_score = SquadEvaluation(nq_model, bert_squad_validation_set)['F1'] if eval_squad else None
  f1_scores_on_squad.append(f1_score)

Batch Evaluation: 100%|██████████| 1485/1485 [11:43<00:00,  2.11it/s]


#### **NQ evaluation on TriviaQA**

In [None]:
if train_nq:
  f1_score = SquadEvaluation(nq_model, bert_triviaqa_validation_set)['F1'] if eval_triviaqa else None
  f1_scores_on_triviaqa.append(f1_score)

Batch Evaluation: 100%|██████████| 1779/1779 [19:07<00:00,  1.55it/s]


#### **NQ evaluation on QuAC**

In [None]:
if train_nq:
  f1_score = SquadEvaluation(nq_model, bert_quac_validation_set)['F1'] if eval_quac else None
  f1_scores_on_quac.append(f1_score)

Batch Evaluation: 100%|██████████| 920/920 [09:53<00:00,  1.55it/s]


#### **NQ evaluation on NewsQA**

In [None]:
if train_nq:
  f1_score = SquadEvaluation(nq_model, bert_newsqa_validation_set)['F1'] if eval_newsqa else None
  f1_scores_on_newsqa.append(f1_score)

Batch Evaluation: 100%|██████████| 646/646 [06:56<00:00,  1.55it/s]


### **Fine tuning on QuAC**

In [None]:
if train_quac:
  if not load_cached_models:
    params = {
      'batch_size': 8,
      'learning_rate': {'val': 5e-5, 'schelduler': 'Linear'},
      'max_clip_norm': None,
      'epochs': 2,
    }
    quac_model, scores = SquadFineTuning(params, bert_quac_training_set, bert_quac_validation_set)
    # F1 score of the last epoch
    f1_score = scores['valid_f1_scores'][-1] if eval_quac else None
    quac_model.save('quac_model')
  else:
    quac_model = BertForSquad('OrfeasTsk/bert-base-uncased-finetuned-quac').to(device)
    f1_score = SquadEvaluation(quac_model, bert_quac_validation_set)['F1'] if eval_quac else None
  
  fine_tuned_on.append('QuAC')
  f1_scores_on_quac.append(f1_score)

Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

Batch Evaluation: 100%|██████████| 920/920 [09:55<00:00,  1.54it/s]


#### **QuAC evaluation on SQuAD 2.0**

In [None]:
if train_quac:
  f1_score = SquadEvaluation(quac_model, bert_squad_validation_set)['F1'] if eval_squad else None
  f1_scores_on_squad.append(f1_score)

Batch Evaluation: 100%|██████████| 1485/1485 [11:47<00:00,  2.10it/s]


#### **QuAC evaluation on TriviaQA**

In [None]:
if train_quac:
  f1_score = SquadEvaluation(quac_model, bert_triviaqa_validation_set)['F1'] if eval_triviaqa else None
  f1_scores_on_triviaqa.append(f1_score)

Batch Evaluation: 100%|██████████| 1779/1779 [18:58<00:00,  1.56it/s]


#### **QuAC evaluation on NQ**

In [None]:
if train_quac:
  f1_score = SquadEvaluation(quac_model, bert_nq_validation_set)['F1'] if eval_nq else None
  f1_scores_on_nq.append(f1_score)

Batch Evaluation: 100%|██████████| 422/422 [03:38<00:00,  1.93it/s]


#### **QuAC evaluation on NewsQA**

In [None]:
if train_quac:
  f1_score = SquadEvaluation(quac_model, bert_newsqa_validation_set)['F1'] if eval_newsqa else None
  f1_scores_on_newsqa.append(f1_score)

Batch Evaluation: 100%|██████████| 646/646 [06:53<00:00,  1.56it/s]


### **Fine tuning on NewsQA**

In [None]:
if train_newsqa:
  if not load_cached_models:
    params = {
      'batch_size': 24,
      'learning_rate': {'val': 3e-5, 'schelduler': 'Linear'},
      'max_clip_norm': None,
      'epochs': 2,
    }
    newsqa_model, scores = SquadFineTuning(params, bert_newsqa_training_set, bert_newsqa_validation_set)
    # F1 score of the last epoch
    f1_score = scores['valid_f1_scores'][-1] if eval_newsqa else None
    newsqa_model.save('newsqa_model')
  else:
    newsqa_model = BertForSquad('OrfeasTsk/bert-base-uncased-finetuned-newsqa-large-batch').to(device)
    f1_score = SquadEvaluation(newsqa_model, bert_newsqa_validation_set)['F1'] if eval_newsqa else None
  
  fine_tuned_on.append('NewsQA')
  f1_scores_on_newsqa.append(f1_score)

Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

Batch Evaluation: 100%|██████████| 646/646 [06:50<00:00,  1.57it/s]


#### **NewsQA evaluation on SQuAD 2.0**

In [None]:
if train_newsqa:
  f1_score = SquadEvaluation(newsqa_model, bert_squad_validation_set)['F1'] if eval_squad else None
  f1_scores_on_squad.append(f1_score)

Batch Evaluation: 100%|██████████| 1485/1485 [11:32<00:00,  2.15it/s]


#### **NewsQA evaluation on TriviaQA**

In [None]:
if train_newsqa:
  f1_score = SquadEvaluation(newsqa_model, bert_triviaqa_validation_set)['F1'] if eval_triviaqa else None
  f1_scores_on_triviaqa.append(f1_score)

Batch Evaluation: 100%|██████████| 1779/1779 [18:48<00:00,  1.58it/s]


#### **NewsQA evaluation on NQ**

In [None]:
if train_newsqa:
  f1_score = SquadEvaluation(newsqa_model, bert_nq_validation_set)['F1'] if eval_nq else None
  f1_scores_on_nq.append(f1_score)

Batch Evaluation: 100%|██████████| 422/422 [03:36<00:00,  1.95it/s]


#### **NewsQA evaluation on QuAC**

In [None]:
if train_newsqa:
  f1_score = SquadEvaluation(newsqa_model, bert_quac_validation_set)['F1'] if eval_quac else None
  f1_scores_on_quac.append(f1_score)

Batch Evaluation: 100%|██████████| 920/920 [09:41<00:00,  1.58it/s]


### **All Scores**

In [None]:
display(pd.DataFrame({'SQuAD': f1_scores_on_squad,
              'TriviaQA': f1_scores_on_triviaqa,
              'NQ': f1_scores_on_nq,
              'QuAC': f1_scores_on_quac,
              'NewsQA': f1_scores_on_newsqa},
              index = fine_tuned_on))

Unnamed: 0,SQuAD,TriviaQA,NQ,QuAC,NewsQA
SQuAD,0.768707,0.453962,0.47152,0.208758,0.34917
TriviaQA,0.406863,0.616261,0.469505,0.198951,0.22297
NQ,0.523388,0.463316,0.741013,0.214633,0.187447
QuAC,0.309943,0.354544,0.336785,0.349126,0.210997
NewsQA,0.373489,0.497677,0.483716,0.201628,0.530305
