# Download TriviaQA

In [None]:
!wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz
!tar xf triviaqa-rc.tar.gz

--2022-03-13 17:39:27--  http://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz
Resolving nlp.cs.washington.edu (nlp.cs.washington.edu)... 128.208.3.120, 2607:4000:200:12::78
Connecting to nlp.cs.washington.edu (nlp.cs.washington.edu)|128.208.3.120|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2665779500 (2.5G) [application/x-gzip]
Saving to: ‘triviaqa-rc.tar.gz’


2022-03-13 17:42:02 (16.4 MB/s) - ‘triviaqa-rc.tar.gz’ saved [2665779500/2665779500]



# Useful functions for converter

In [None]:
import json


def write_json_to_file(json_object, json_file, mode='w', encoding='utf-8'):
    with open(json_file, mode, encoding=encoding) as outfile:
        json.dump(json_object, outfile, indent=4, sort_keys=True, ensure_ascii=False)


def get_file_contents(filename, encoding='utf-8'):
    with open(filename, encoding=encoding) as f:
        content = f.read()
    return content


def read_json(filename, encoding='utf-8'):
    contents = get_file_contents(filename, encoding=encoding)
    return json.loads(contents)


def get_file_contents_as_list(file_path, encoding='utf-8', ignore_blanks=True):
    contents = get_file_contents(file_path, encoding=encoding)
    lines = contents.split('\n')
    lines = [line for line in lines if line != ''] if ignore_blanks else lines
    return lines

In [None]:
import os
from tqdm import tqdm
import random
import nltk
import argparse


def get_text(qad, domain):
    local_file = os.path.join(args.web_dir, qad['Filename']) if domain == 'SearchResults' else os.path.join(args.wikipedia_dir, qad['Filename'])
    return get_file_contents(local_file, encoding='utf-8')


def select_relevant_portion(text):
    paras = text.split('\n')
    selected = []
    done = False
    for para in paras:
        sents = sent_tokenize.tokenize(para)
        for sent in sents:
            words = nltk.word_tokenize(sent)
            for word in words:
                selected.append(word)
                if len(selected) >= args.max_num_tokens:
                    done = True
                    break
            if done:
                break
        if done:
            break
        selected.append('\n')
    st = ' '.join(selected).strip()
    return st


def add_triple_data(datum, page, domain):
    qad = {'Source': domain}
    for key in ['QuestionId', 'Question', 'Answer']:
        qad[key] = datum[key]
    for key in page:
        qad[key] = page[key]
    return qad


def get_qad_triples(data):
    qad_triples = []
    for datum in data['Data']:
        for key in ['EntityPages', 'SearchResults']:
            for page in datum.get(key, []):
                qad = add_triple_data(datum, page, key)
                qad_triples.append(qad)
    return qad_triples


def convert_to_squad_format(qa_json_file, squad_file):
    qa_json = read_triviaqa_data(qa_json_file)
    qad_triples = get_qad_triples(qa_json)

    random.seed(args.seed)
    random.shuffle(qad_triples)

    data = []
    for qad in tqdm(qad_triples):
        qid = qad['QuestionId']

        text = get_text(qad, qad['Source'])
        selected_text = select_relevant_portion(text)

        question = qad['Question']
        para = {'context': selected_text, 'qas': [{'question': question, 'answers': []}]}
        data.append({'paragraphs': [para]})
        qa = para['qas'][0]
        qa['id'] = get_question_doc_string(qid, qad['Filename'])
        qa['qid'] = qid

        ans_string, index = answer_index_in_document(qad['Answer'], selected_text)
        if index == -1:
            if qa_json['Split'] == 'train':
                continue
        else:
            qa['answers'].append({'text': ans_string, 'answer_start': index})

        if qa_json['Split'] == 'train' and len(data) >= args.sample_size and qa_json['Domain'] == 'Web':
            break

    squad = {'data': data, 'version': qa_json['Version']}
    write_json_to_file(squad, squad_file)
    print ('Added', len(data))


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--triviaqa_file', help='Triviaqa file')
    parser.add_argument('--squad_file', help='Squad file')
    parser.add_argument('--wikipedia_dir', help='Wikipedia doc dir')
    parser.add_argument('--web_dir', help='Web doc dir')

    parser.add_argument('--seed', default=10, type=int, help='Random seed')
    parser.add_argument('--max_num_tokens', default=800, type=int, help='Maximum number of tokens from a document')
    parser.add_argument('--sample_size', default=80000, type=int, help='Random seed')
    parser.add_argument('--tokenizer', default='tokenizers/punkt/english.pickle', help='Sentence tokenizer')
    args = parser.parse_args()
    return args

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
def get_key_to_ground_truth(data):
    if data['Domain'] == 'Wikipedia':
        return {datum['QuestionId']: datum['Answer'] for datum in data['Data']}
    else:
        return get_qd_to_answer(data)


def get_question_doc_string(qid, doc_name):
    return '{}--{}'.format(qid, doc_name)

def get_qd_to_answer(data):
    key_to_answer = {}
    for datum in data['Data']:
        for page in datum.get('EntityPages', []) + datum.get('SearchResults', []):
            qd_tuple = get_question_doc_string(datum['QuestionId'], page['Filename'])
            key_to_answer[qd_tuple] = datum['Answer']
    return key_to_answer


def read_clean_part(datum):
    for key in ['EntityPages', 'SearchResults']:
        new_page_list = []
        for page in datum.get(key, []):
            if page['DocPartOfVerifiedEval']:
                new_page_list.append(page)
        datum[key] = new_page_list
    assert len(datum['EntityPages']) + len(datum['SearchResults']) > 0
    return datum


def read_triviaqa_data(qajson):
    data = read_json(qajson)
    # read only documents and questions that are a part of clean data set
    if data['VerifiedEval']:
        clean_data = []
        for datum in data['Data']:
            if datum['QuestionPartOfVerifiedEval']:
                if data['Domain'] == 'Web':
                    datum = read_clean_part(datum)
                clean_data.append(datum)
        data['Data'] = clean_data
    return data


def answer_index_in_document(answer, document):
    answer_list = answer['NormalizedAliases']
    for answer_string_in_doc in answer_list:
        index = document.lower().find(answer_string_in_doc)
        if index != -1:
            return answer_string_in_doc, index
    return answer['NormalizedValue'], -1

In [None]:
import os
import argparse
import json
import nltk
nltk.download('punkt')


def answer_index_in_document(answer, document):
    answer_list = answer['Aliases'] + answer['NormalizedAliases']
    for answer_string_in_doc in answer_list:
        index = document.find(answer_string_in_doc)
        if index != -1:
            return answer_string_in_doc, index
    return answer['NormalizedValue'], -1


def select_relevant_portion(text):
    paras = text.split('\n')
    selected = []
    done = False
    for para in paras:
        sents = sent_tokenize.tokenize(para)
        for sent in sents:
            words = nltk.word_tokenize(sent)
            for word in words:
                selected.append(word)
                if len(selected) >= 800:
                    done = True
                    break
            if done:
                break
        if done:
            break
        selected.append('\n')
    st = ' '.join(selected).strip()
    return st


def triviaqa_to_squad_format(triviaqa_file, data_dir, output_file):
    triviaqa_json = read_triviaqa_data(triviaqa_file)
    qad_triples = get_qad_triples(triviaqa_json)

    data = []

    for triviaqa_example in qad_triples:
        question_text = triviaqa_example['Question']
        text = get_file_contents(os.path.join(data_dir, triviaqa_example['Filename']), encoding='utf-8')
        context = select_relevant_portion(text)

        para = {'context': context, 'qas': [{'question': question_text, 'answers': []}]}
        data.append({'paragraphs': [para]})
        qa = para['qas'][0]
        qa['id'] = get_question_doc_string(triviaqa_example['QuestionId'], triviaqa_example['Filename'])
        qa['is_impossible'] = True
        ans_string, index = answer_index_in_document(triviaqa_example['Answer'], context)

        if index != -1:
            qa['answers'].append({'text': ans_string, 'answer_start': index})
            qa['is_impossible'] = False

    triviaqa_as_squad = {'data': data, 'version': '2.0'}

    with open(output_file, 'w', encoding='utf-8') as outfile:
        json.dump(triviaqa_as_squad, outfile, indent=2, sort_keys=True, ensure_ascii=False)



sent_tokenize = nltk.data.load('tokenizers/punkt/english.pickle')

triviaqa_to_squad_format("qa/wikipedia-train.json", "evidence/wikipedia/", "triviaqa_train.json")

triviaqa_to_squad_format("qa/wikipedia-dev.json", "evidence/wikipedia/", "triviaqa_dev.json")

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


# Imports, Downloads & Cuda availability | MUST-RUN

In [None]:
!pip install transformers

import torch
from tqdm import tqdm
import pandas as pd
from transformers import BertTokenizer
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
import random
import numpy as np
from sklearn import metrics
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import roc_curve, auc
from transformers import AutoTokenizer,AdamW,BertForQuestionAnswering,DistilBertForQuestionAnswering
from transformers import BertForSequenceClassification, AdamW, BertConfig

import string
import nltk
import unicodedata
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem.porter import PorterStemmer
from nltk.stem import WordNetLemmatizer
nltk.download('punkt')
nltk.download('wordnet')
import re
from itertools import cycle

import matplotlib.pyplot as plt

import time
import datetime


if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('Using gpu')
else:
    device = torch.device("cpu")
    print('Using cpu')

import json
from pathlib import Path

!mkdir squad
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O squad/train-v2.0.json
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O squad/dev-v2.0.json

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
Using gpu
mkdir: cannot create directory ‘squad’: File exists
--2022-03-13 18:28:05--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.111.153, 185.199.110.153, 185.199.108.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.111.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42123633 (40M) [application/json]
Saving to: ‘squad/train-v2.0.json’


2022-03-13 18:28:06 (219 MB/s) - ‘squad/train-v2.0.json’ saved [42123633/42123633]

--2022-03-13 18:28:07--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.111.153, 185.199.108.153, 185.199.109.153, ...

# Preprocessing & Print functions

In [None]:
def remove_punctuation(text):
    punctuationfree="".join([i for i in text if (i not in string.punctuation)])
    return punctuationfree

def add_prefix_NOT_(text):
  neg_array = ["n’t","n't", "not", "no", "never"]
  final_txt = ""
  flg = 0;
  for i in text.split():
    if flg == 1:
      final_txt = final_txt + " NOT_" + i
    else:
      final_txt = final_txt + " " + i
    if i.endswith(tuple(neg_array)):
      flg = 1
    else:
      flg = 0
  return final_txt

def strip_accents(s):
   return ''.join(c for c in unicodedata.normalize('NFD', s)
                  if unicodedata.category(c) != 'Mn')
   
def remove_emoji(string):
    emoji_pattern = re.compile("["
                               u"\U0001F600-\U0001F64F"  # emoticons
                               u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                               u"\U0001F680-\U0001F6FF"  # transport & map symbols
                               u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                               u"\U00002500-\U00002BEF"  # chinese char
                               u"\U00002702-\U000027B0"
                               u"\U00002702-\U000027B0"
                               u"\U000024C2-\U0001F251"
                               u"\U0001f926-\U0001f937"
                               u"\U00010000-\U0010ffff"
                               u"\u2640-\u2642"
                               u"\u2600-\u2B55"
                               u"\u200d"
                               u"\u23cf"
                               u"\u23e9"
                               u"\u231a"
                               u"\ufe0f"  # dingbats
                               u"\u3030"
                               "]+", flags=re.UNICODE)
    return emoji_pattern.sub(r'', string)
stemmer = PorterStemmer()
lemmatizer = WordNetLemmatizer()
def stem(text):
    return [stemmer.stem(word) for word in text]
def lem(text):
  return [lemmatizer.lemmatize(word) for word in text]
def split(text):
  return re.split(' ',text)
def stringify(text):
  return ' '.join(text)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, id):
        return {key: torch.tensor(val[id]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()

    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)

    common_tokens = set(pred_tokens) & set(truth_tokens)

    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0

    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)

    return 2 * (prec * rec) / (prec + rec)

def normalize_text(s):
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

  
def print_function(iters,losses,losses_Val):
  
  plt.title("Learning Curve ")
  plt.plot(iters, losses, label="Train")
  plt.plot(iters, losses_Val, label="Val")
  plt.xlabel("Epoch")
  plt.ylabel("Loss")
  plt.legend(loc='best')
  plt.show()

# Read train & validation files

In [None]:
path_train = Path('triviaqa_train.json')
path_val = Path('triviaqa_dev.json')
path_val_sq = Path('squad/dev-v2.0.json')
with open(path_train, 'rb') as f:
    dictionary_train = json.load(f)
    
con_train, quest_train, ans_train = [],[],[]

for data in dictionary_train['data']:
    for paragraphs in data['paragraphs']:
        context = paragraphs['context']
        for qas in paragraphs['qas']:
            question = qas['question']
            for answers in qas['answers']:
                con_train.append(context)
                quest_train.append(question)
                ans_train.append(answers)


# Fine tuning DistilBert for qa in TriviaQA w/ lr=2e-5 eps=5e-9

In [None]:
torch.cuda.empty_cache()
for answer in ans_train:
    answer['answer_end'] = answer['answer_start'] + len(answer['text'])

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

enc_train = tokenizer(con_train, quest_train, truncation=True,max_length = 350, padding='max_length')

spos_list = []
epos_list = []

for i in range(len(ans_train)):
  spos_list.append(enc_train.char_to_token(i, ans_train[i]['answer_start']))
  epos_list.append(enc_train.char_to_token(i, ans_train[i]['answer_end']))
  if spos_list[-1] is None:
    spos_list[-1] = tokenizer.model_max_length
  if epos_list[-1] is None:
    epos_list[-1] = tokenizer.model_max_length

enc_train.update({'start_positions': spos_list, 'end_positions': epos_list})

train_dataset = Dataset(enc_train)

#Batch size & dataloaders for training loop
batch_size = 4

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


#Using bert-base-uncased BertForSequenceClassification for our model
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased').to(device)

optimizer = AdamW(model.parameters(),lr = 2e-5,eps = 5e-9)

#For Bert fine tuning we will need to train the pre-trained model for our dataset
epochs = 3

sch_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0,num_training_steps = sch_steps)

losses = []
losses_Val = []

n_epoch=0
iters = []

for curr_epoch in range(0, epochs):
    batch_lo = []
    val_batch_los = []
    

    print(f"\nEpoch: {curr_epoch}")
    print("Training:")
    model.train()

    for batch in tqdm(train_dataloader):

        #Get batche's data
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)

        #Clearning gradients
        model.zero_grad()

        #Run a forward pass
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)

        #Get loss and logits data
        loss = outputs[0]
        # do a backwards pass 
        loss.backward()

        #Clipping that helps for exploding gradients problem
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        batch_lo.append(loss.item())

        scheduler.step()


    losses.append(sum(batch_lo)/len(train_dataloader))

    print(f" Total Average loss : {sum(batch_lo)/len(train_dataloader)}")
  


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode


Epoch: 0
Training:


100%|██████████| 19278/19278 [29:31<00:00, 10.88it/s]


 Total Average loss : 1.425250962463926

Epoch: 1
Training:


100%|██████████| 19278/19278 [29:30<00:00, 10.89it/s]


 Total Average loss : 0.8337224429889346

Epoch: 2
Training:


100%|██████████| 19278/19278 [29:25<00:00, 10.92it/s]

 Total Average loss : 0.5417537616457181





# Validation in TriviaQA

In [None]:
with open(path_val, 'rb') as f:
    data = json.load(f)
    
model.eval()
f1_list = []
iters = 0

for group in tqdm(data['data']):
    for passage in group['paragraphs']:
        context = passage['context']
        for qa in passage['qas']:
            question = qa['question']
            for answer in qa['answers']:
                enc = tokenizer.encode_plus(context,question,truncation=True, return_tensors = 'pt')
                enc.to(device)
                outputs = model(**enc)
                start = torch.argmax(outputs.start_logits)
                end = torch.argmax(outputs.end_logits) + 1
                answer_pred = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(enc['input_ids'][0][start:end]))
                f1 = compute_f1(answer_pred,answer['text'])
                f1_list.append(f1)
                iters+=1
                    

print(sum(f1_list)/iters)

100%|██████████| 14229/14229 [02:28<00:00, 95.92it/s] 

0.23284294569437827





# Validation in SQuAD 2.0

In [None]:
with open(path_val_sq, 'rb') as f:
    data = json.load(f)
    
model.eval()
f1_list = []
iters = 0

for group in tqdm(data['data']):
    for passage in group['paragraphs']:
        context = passage['context']
        for qa in passage['qas']:
            question = qa['question']
            for answer in qa['answers']:
                enc = tokenizer.encode_plus(context,question,truncation=True, return_tensors = 'pt')
                enc.to(device)
                outputs = model(**enc)
                start = torch.argmax(outputs.start_logits)
                end = torch.argmax(outputs.end_logits) + 1
                answer_pred = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(enc['input_ids'][0][start:end]))
                f1 = compute_f1(answer_pred,answer['text'])
                f1_list.append(f1)
                iters+=1
                    

print(sum(f1_list)/iters)

100%|██████████| 35/35 [02:24<00:00,  4.13s/it]

0.18114353810558886



