In [None]:
!pip install sentence-transformers
!pip install nltk

import nltk
nltk.download('punkt')

## Bart QA

In [None]:
import json
import os
from transformers import BartTokenizer, BartForQuestionAnswering
import torch

from tqdm.notebook import tqdm, trange

import sys
sys.path.append(".")
sys.path.append("..") # Adds higher directory to python modules path.
from eval.eval import ClickbaitResolverEvaluator

In [None]:
ENTRY_SETS = ['train', 'dev']
DATA_PATH = "../data/"
RESULT_PATH = "../data/baseline_results/bart_qa/"

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load model and tokenizer
tokenizer = BartTokenizer.from_pretrained('a-ware/bart-squadv2')
model = BartForQuestionAnswering.from_pretrained('a-ware/bart-squadv2')

model     = model.to(device)

In [None]:
def compute_bart_answer(entries, name):
    results = []
    for entry in tqdm(entries, desc=name):
        text = entry["text"]
        if text[0] == ".":
            text = text[1:].strip()
        if len(text) > 1024:
            text = text[:1023]
        
        question = entry["title"]
            
        encoding = tokenizer(question, text, return_tensors='pt').to(device)
        input_ids = encoding['input_ids']
        attention_mask = encoding['attention_mask']

        start_scores, end_scores = model(input_ids, attention_mask=attention_mask, output_attentions=False)[:2]

        all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
        answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
        answer = tokenizer.convert_tokens_to_ids(answer.split())
        answer = tokenizer.decode(answer)
        
        answer = answer.replace("<s>", " ").replace("</s>", " ").replace(".", "").strip()
        if answer == "":
            answer = '-'

        print(f"{question} -> {answer}")
        results.append({"id": entry["id"], "answer": answer})
    return results

In [None]:
os.makedirs(RESULT_PATH, exist_ok=True)

for s in ENTRY_SETS:
    with open(f"{DATA_PATH}final_{s}.json", "r") as entry_file:
        results = compute_bart_answer(json.load(entry_file), s)
        
    with open(f"{RESULT_PATH}{s}.json", "w") as result_file:
        json.dump(results, result_file, indent=2, ensure_ascii=False)

In [None]:
evaluator = ClickbaitResolverEvaluator()

for s in ENTRY_SETS:
    agg_results, results = evaluator.run_file(f"{RESULT_PATH}{s}.json", f"{DATA_PATH}final_{s}.json")
    evaluator.print_results(agg_results, results, False)