# Final Evaluation

**Student Credentials:** sdi1800119, Vissarion Moutafis

In this notebook we will use the trained models, saved in hugging-face hub and the datasets, stored in kaggle-datasets repos, in our attempt to replicate the results of the given paper. The results are quite interesting.

We will use the same preprocessing as in training process and remove all samples where the question is larger than the doc stride.

In [None]:
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt
import seaborn
import re
seaborn.set_style("ticks")

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, classification_report

import torch
import torch.nn as nn
import torchtext
from torch.utils.data import SubsetRandomSampler
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

!pip install transformers datasets
!apt install git-lfs
import transformers
import datasets
from transformers import AutoModel, BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoModelForQuestionAnswering
from datasets import load_dataset, load_metric

!pip install tqdm
from tqdm import tqdm, trange

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Collecting datasets
  Downloading datasets-1.18.4-py3-none-any.whl (312 kB)
     |████████████████████████████████| 312 kB 779 kB/s            
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
     |████████████████████████████████| 212 kB 10.3 MB/s            
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Installing collected packages: xxhash, responses, datasets
Successfully installed datasets-1.18.4 responses-0.18.0 xxhash-3.0.0



The following NEW packages will be installed:
  git-lfs
0 upgraded, 1 newly installed, 0 to remove and 42 not upgraded.
Need to get 3316 kB of archives.
After this operation, 11.1 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu focal/universe amd64 git-lfs amd64 2.9.2-1 [3316 kB]
Fetched 3316 kB in 2s (2155 kB/s)

7[0;23r8[1ASelecting previously unselected package git-lfs.
(Reading database ... 103272

## Load Datasets

In [None]:
token='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'

In [None]:
squad_dt = load_dataset('squad_v2', split='validation')

Downloading:   0%|          | 0.00/1.87k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading and preparing dataset squad_v2/squad_v2 (download: 44.34 MiB, generated: 122.41 MiB, post-processed: Unknown size, total: 166.75 MiB) to /root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/9.55M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/801k [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset squad_v2 downloaded and prepared to /root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d. Subsequent calls will reuse this data.


In [None]:
triviaqa_dt = load_dataset('../input/squadlikeloader/squad_like.py', data_files={'validation':'../triviaqatosquad/triviaqa_dev.json', 'train':'../triviaqatosquad/triviaqa_dev.json'}, split='validation')

Downloading and preparing dataset squad_like/default to /root/.cache/huggingface/datasets/squad_like/default-7731d89230024ff0/0.0.0/c11bde73ef00f53b085b6a086d13514938f65b80af061fc874ce3e7514c24892...


0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset squad_like downloaded and prepared to /root/.cache/huggingface/datasets/squad_like/default-7731d89230024ff0/0.0.0/c11bde73ef00f53b085b6a086d13514938f65b80af061fc874ce3e7514c24892. Subsequent calls will reuse this data.


In [None]:
nq_dt= load_dataset('../input/squadlikeloader/squad_like.py', data_files={'validation':'../nqtosquad/nq_dev.json', 'train':'../nqtosquad/nq_dev.json'}, split='validation')

Downloading and preparing dataset squad_like/default to /root/.cache/huggingface/datasets/squad_like/default-c80748befeddccb3/0.0.0/c11bde73ef00f53b085b6a086d13514938f65b80af061fc874ce3e7514c24892...


0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset squad_like downloaded and prepared to /root/.cache/huggingface/datasets/squad_like/default-c80748befeddccb3/0.0.0/c11bde73ef00f53b085b6a086d13514938f65b80af061fc874ce3e7514c24892. Subsequent calls will reuse this data.


In [None]:
quac_dt = load_dataset('../input/squadlikeloader/squad_like.py', data_files={'validation':'../quactosquad/quac_val.json', 'train':'../quactosquad/quac_val.json'}, split='validation')

Downloading and preparing dataset squad_like/default to /root/.cache/huggingface/datasets/squad_like/default-6f35545f73d12f1c/0.0.0/c11bde73ef00f53b085b6a086d13514938f65b80af061fc874ce3e7514c24892...


0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset squad_like downloaded and prepared to /root/.cache/huggingface/datasets/squad_like/default-6f35545f73d12f1c/0.0.0/c11bde73ef00f53b085b6a086d13514938f65b80af061fc874ce3e7514c24892. Subsequent calls will reuse this data.


In [None]:
newsqa_dt = load_dataset('../input/squadlikeloader/squad_like.py', data_files={'validation':'../newsqatosquad/newsqa_dev.json', 'train':'../newsqatosquad/newsqa_dev.json'}, split='validation')

Downloading and preparing dataset squad_like/default to /root/.cache/huggingface/datasets/squad_like/default-440ce43c40204aae/0.0.0/c11bde73ef00f53b085b6a086d13514938f65b80af061fc874ce3e7514c24892...


0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset squad_like downloaded and prepared to /root/.cache/huggingface/datasets/squad_like/default-440ce43c40204aae/0.0.0/c11bde73ef00f53b085b6a086d13514938f65b80af061fc874ce3e7514c24892. Subsequent calls will reuse this data.


In [None]:
eval_datasets = {
    'SQuADv2' : squad_dt, 
    'TriviaQA' : triviaqa_dt,
    'NQ' : nq_dt,
    'QuAC' : quac_dt,
    'NewsQA' : newsqa_dt,
}

## Preprocess datasets

In [None]:
from transformers import AutoTokenizer
# initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
MAX_LENGTH = 384
DOC_STRIDE = 128 # multi-context overlapping range for large context'd instances 

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [None]:
def preprocess_squad(examples):
  # get the questions and the context
  questions = [q.strip() for q in examples["question"]]
  context = examples["context"]
  # tokenize questions along with the context 
  inputs = tokenizer(
        questions,
        context,
        max_length=MAX_LENGTH,
        stride=DOC_STRIDE,
        truncation="only_second",
        padding="max_length",
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
    )
  offset_mapping = inputs.pop("offset_mapping")
  sample_mapping = inputs.pop("overflow_to_sample_mapping")
  answers = examples["answers"]
  start_positions = []
  end_positions = []
  
  for i, offset in enumerate(offset_mapping):
    sample_index = sample_mapping[i]
    answer = examples["answers"][sample_index]
    # if there is no answer default to [CLS]
    if not answer["answer_start"]:
      start_positions.append(inputs['input_ids'][i].index(tokenizer.cls_token_id))
      end_positions.append(inputs['input_ids'][i].index(tokenizer.cls_token_id))
      continue 
    
    # get answer start and end positions
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
      idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
      idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label it (0, 0)
    if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
      start_positions.append(inputs['input_ids'][i].index(tokenizer.cls_token_id))
      end_positions.append(inputs['input_ids'][i].index(tokenizer.cls_token_id))
    else:
      # Otherwise it's the start and end token positions
      idx = context_start
      while idx <= context_end and offset[idx][0] <= start_char:
        idx += 1
      start_positions.append(idx - 1)

      idx = context_end
      while idx >= context_start and offset[idx][1] >= end_char:
        idx -= 1
      end_positions.append(idx + 1)

  inputs["start_positions"] = start_positions
  inputs["end_positions"] = end_positions
  return inputs

In [None]:
def appropriate_length(q, c):
    tq = tokenizer(q)['input_ids']
    return len(tq) <= MAX_LENGTH


In [None]:
for dt_id in eval_datasets:
    eval_datasets[dt_id] = eval_datasets[dt_id].filter(appropriate_length, input_columns=['question', 'context']) # make sure that the questions cannot get segmented
    eval_datasets[dt_id] = eval_datasets[dt_id].map(preprocess_squad, batched=True, remove_columns=eval_datasets[dt_id].column_names) # preprocess the dataset

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

  0%|          | 0/6 [00:00<?, ?ba/s]

  0%|          | 0/6 [00:00<?, ?ba/s]

In [None]:
eval_loaders = {
    key : torch.utils.data.DataLoader(value, batch_size=24, shuffle=True) for key, value in eval_datasets.items()
}

## Evaluation Phase

Use the default mertics for squad-formated datasets' evaluation, from the huggingface hub, using the load\_metric() routine.

In [None]:
for model_name in ['bert-on-squad2', 'bert-finetuned-triviaqa', 'bert-finetuned-nq', 'bert-finetuned-quac', 'bert-finetuned-newsqa']:
    model = AutoModelForQuestionAnswering.from_pretrained('vissa/'+model_name, use_auth_token=token)
    model.eval()
    model.to(device)
    metric = load_metric('squad_v2')
    for test_set, test_loader in eval_loaders.items():
        f1_total = []
        acc = []
        test_loss = []
        acc = []
        pbar = tqdm(test_loader)
        for i,batch in enumerate(pbar):
          torch.cuda.empty_cache()
          with torch.no_grad():
            args = {
                  "start_positions" : torch.LongTensor(batch["start_positions"]).to(device),
                  "end_positions" : torch.LongTensor(batch["end_positions"]).to(device),
                  "input_ids" : torch.stack(batch["input_ids"], axis=1).to(device),
                  "attention_mask" : torch.stack(batch["attention_mask"], axis=1).to(device)
                }
            outputs = model(**args)
            test_loss.append(outputs[0].item())
            # find the indices for start and end in every example of the batch
            start_positions = []
            end_positions = []
            start_pred = torch.argmax(outputs['start_logits'], dim=1)
            end_pred = torch.argmax(outputs['end_logits'], dim=1)

            # create tokens lists and estimate f1 score based on common tokens, since its the same as estimating the f1 based on common words
            pred = [tokenizer.decode(input_ids[s:e+1].tolist()) for input_ids, (s, e) in zip(args['input_ids'], zip(start_pred, end_pred))]
            true = [tokenizer.decode(input_ids[s:e+1].tolist()) for input_ids, (s, e) in zip(args['input_ids'], zip(args['start_positions'], args['end_positions']))]
            # calculate exact match
#             acc.append(compute_exact(pred_token_ranges, true_token_ranges))

#             f1_total.append(compute_f1(pred_token_ranges, true_token_ranges))
            metric.add_batch(predictions=[{'id':24*i+id, 'prediction_text':text, 'no_answer_probability':0.0} for id,text in enumerate(pred)], 
                             references=[{'id':24*i+id, 'answers':{'text':[text], 'answer_start':[args['start_positions'][id]]}} for id,text in enumerate(true)])
        cur = test_set,metric.compute()
        print('{} on {}: exact:{}, f1:{}'.format(model_name, test_set, cur[1]['exact'], cur[1]['f1']))

Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.25k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.18k [00:00<?, ?B/s]

100%|██████████| 506/506 [03:06<00:00,  2.71it/s]


bert-on-squad2 on SQuADv2: exact:64.58711059831877, f1:70.54734852507852


100%|██████████| 2222/2222 [13:09<00:00,  2.82it/s]


bert-on-squad2 on TriviaQA: exact:63.72824782478248, f1:64.83328227310953


100%|██████████| 143/143 [00:50<00:00,  2.82it/s]


bert-on-squad2 on NQ: exact:31.673014942865514, f1:38.958448587651986


100%|██████████| 307/307 [01:49<00:00,  2.80it/s]


bert-on-squad2 on QuAC: exact:33.763937992929016, f1:36.02606995687384


100%|██████████| 707/707 [04:11<00:00,  2.81it/s]


bert-on-squad2 on NewsQA: exact:55.569963456324416, f1:58.96617328644258


Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

100%|██████████| 506/506 [03:01<00:00,  2.80it/s]


bert-finetuned-triviaqa on SQuADv2: exact:36.12164166804022, f1:39.30426229875678


100%|██████████| 2222/2222 [13:12<00:00,  2.80it/s]


bert-finetuned-triviaqa on TriviaQA: exact:71.43339333933393, f1:72.28513363883935


100%|██████████| 143/143 [00:51<00:00,  2.80it/s]


bert-finetuned-triviaqa on NQ: exact:33.78259595663639, f1:38.46840835592722


100%|██████████| 307/307 [01:49<00:00,  2.80it/s]


bert-finetuned-triviaqa on QuAC: exact:55.085667663856405, f1:55.7392841931602


100%|██████████| 707/707 [04:12<00:00,  2.80it/s]


bert-finetuned-triviaqa on NewsQA: exact:57.10833431569021, f1:59.07536924605947


Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

100%|██████████| 506/506 [03:01<00:00,  2.79it/s]


bert-finetuned-nq on SQuADv2: exact:40.843909675292565, f1:44.44707749685987


100%|██████████| 2222/2222 [13:17<00:00,  2.79it/s]


bert-finetuned-nq on TriviaQA: exact:54.911116111611165, f1:56.44004324960148


100%|██████████| 143/143 [00:51<00:00,  2.79it/s]


bert-finetuned-nq on NQ: exact:53.472018751831236, f1:61.2643776520415


100%|██████████| 307/307 [01:51<00:00,  2.76it/s]


bert-finetuned-nq on QuAC: exact:59.260266521620885, f1:59.96489244864429


100%|██████████| 707/707 [04:14<00:00,  2.77it/s]


bert-finetuned-nq on NewsQA: exact:57.33820582341153, f1:59.331982155382995


Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

100%|██████████| 506/506 [03:02<00:00,  2.77it/s]


bert-finetuned-quac on SQuADv2: exact:27.385857919894512, f1:28.41310490914049


100%|██████████| 2222/2222 [13:14<00:00,  2.80it/s]


bert-finetuned-quac on TriviaQA: exact:56.85756075607561, f1:56.98147817316199


100%|██████████| 143/143 [00:51<00:00,  2.80it/s]


bert-finetuned-quac on NQ: exact:14.38617052446528, f1:16.59844053954373


100%|██████████| 307/307 [01:50<00:00,  2.79it/s]


bert-finetuned-quac on QuAC: exact:50.53032363339679, f1:53.979574641934704


100%|██████████| 707/707 [04:13<00:00,  2.79it/s]


bert-finetuned-quac on NewsQA: exact:48.85064246139338, f1:49.16474056003159


Downloading:   0%|          | 0.00/673 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/415M [00:00<?, ?B/s]

100%|██████████| 506/506 [03:01<00:00,  2.79it/s]


bert-finetuned-newsqa on SQuADv2: exact:22.03725070051096, f1:31.61951532935127


100%|██████████| 2222/2222 [13:16<00:00,  2.79it/s]


bert-finetuned-newsqa on TriviaQA: exact:5.683693369336933, f1:7.964988250103812


100%|██████████| 143/143 [00:51<00:00,  2.77it/s]


bert-finetuned-newsqa on NQ: exact:27.746850278347495, f1:39.20854602167326


100%|██████████| 307/307 [01:50<00:00,  2.77it/s]


bert-finetuned-newsqa on QuAC: exact:0.013598041881968996, f1:4.149837385465861


100%|██████████| 707/707 [04:13<00:00,  2.79it/s]


bert-finetuned-newsqa on NewsQA: exact:16.45644229635742, f1:22.056876688079935
