In [2]:
import json
import torch
import transformers
import numpy as np
from tqdm import tqdm
from transformers import pipeline
from transformers import AutoTokenizer, BertForQuestionAnswering
transformers.logging.set_verbosity_error()

In [None]:
# Replace this with your own checkpoint
model_checkpoint = "exp/test_0"
# model_checkpoint = "deepset/bert-base-uncased-squad2"
qa_pipeline = pipeline("question-answering", model=model_checkpoint, tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"))

In [3]:
with open("./data/squad_v2/raw/dev-v2.0.json", "r") as source_file:
    raw_data = json.load(source_file)["data"]

In [4]:
model = BertForQuestionAnswering.from_pretrained('deepset/bert-base-uncased-squad2').to('cuda:1')
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-uncased-squad2")

def get_answer(question, context):
    inputs = tokenizer(question, context, max_length=384, padding="max_length", truncation="only_second", return_tensors="pt").to("cuda:1")
    with torch.no_grad():
        outputs = model(**inputs)

    answer_start_index = outputs.start_logits.argmax()
    answer_end_index = outputs.end_logits.argmax()

    predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
    answer = tokenizer.decode(predict_answer_tokens)
    return answer

In [5]:
answer_dict = {}
for article in tqdm(raw_data):
        title = article["title"]
        paragraphs = article["paragraphs"]
        for paragraph in paragraphs:
            context = paragraph["context"]
            qas = paragraph["qas"]
            for qa in qas:
                qid = qa["id"]
                question = qa["question"]
                # answer = qa_pipeline(question=question, context=context, handle_impossible_answer=True)
                # answer_dict[qid] = answer["answer"]
                answer = get_answer(question=question, context=context)
                answer_dict[qid] = answer

100%|██████████| 35/35 [01:46<00:00,  3.05s/it]


In [None]:
with open("./data/squad_v2/processed/pred.json", "w") as wf:
    wf.write(json.dumps(answer_dict))

In [1]:
!python squad_eval.py ./data/squad_v2/raw/dev-v2.0.json ./data/squad_v2/processed/pred_squad.json 

{
  "exact": 22.732249642044977,
  "f1": 27.523669015646554,
  "total": 11873,
  "HasAns_exact": 32.05128205128205,
  "HasAns_f1": 41.647861373612024,
  "HasAns_total": 5928,
  "NoAns_exact": 13.43986543313709,
  "NoAns_f1": 13.43986543313709,
  "NoAns_total": 5945
}
