In [None]:
import datasets
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM


In [None]:
input_dataset = datasets.load_from_disk("./")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model_name = "vblagoje/bart_lfqa"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = model.to(device)

In [None]:
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

In [None]:
questions = [
    "Is the {} involved in the development or progression of {}?",
    "Does the {} have a known association with the {}?",
    "Are there any studies that suggest a connection between the {} and the {}?"
]
positive_candidates = [
    "{} is strongly implicated in the development or progression of {}",
    "{} has a moderate association with the {}",
]
negative_candidates = [
    "The relationship between {} and {} is uncertain or unclear",
    "{} has no known connection to the {}",
]

In [None]:
def get_answer(model, tokenizer, question, context):
    conditioned_doc = "<P> " + " <P> ".join([d for d in [context]])
    query_and_docs = "question: {} context: {}".format(question, conditioned_doc)

    model_input = tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")

    generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
                                           attention_mask=model_input["attention_mask"].to(device),
                                           min_length=64,
                                           max_length=256,
                                           do_sample=False, 
                                           early_stopping=True,
                                           num_beams=8,
                                           temperature=1.0,
                                           top_k=None,
                                           top_p=None,
                                           eos_token_id=tokenizer.eos_token_id,
                                           no_repeat_ngram_size=3,
                                           num_return_sequences=1)
    return tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)

In [None]:
def process_candidates(positive_candidates, negative_candidates, zh_res):
    positive_prob, negative_prob = 0, 0
    for label, score in zip(zh_res["labels"], zh_res["scores"]):
        if label in positive_candidates:
            positive_prob += score
        elif label in negative_candidates:
            negative_prob += score
    return [
        positive_prob, 
        negative_prob,
        ]

In [None]:
def classify_answer(dataset, i, pipe, clf, questions, positive_candidates, negative_candidates):
    # store preds for rels
    y_preds, ys = [], []

    # create candidate_labels
    candidate_labels = positive_candidates + negative_candidates

    # store norm to name - type mapping
    norm2text = {}
    for norm, span, type in zip(dataset['test'][i]['ner_norms'], dataset['test'][i]['spans'], dataset['test'][i]['ner_labels']):
        norm2text[norm] = f"{dataset['test'][i]['text'][span[0]:span[1]]} {type}"
    
    # iterate over relation pairs with label
    for rel, label in zip(dataset['test'][i]['relations'], dataset['test'][i]['relations_labels']):
        # create candidates for zs
        candidates = []
        for c in range(len(candidate_labels)):
            candidates.append(candidate_labels[c].format(norm2text[rel[0]], norm2text[rel[1]]))
        # make predictions per question
        probs_per_question = []
        # iterate over questions
        for question in questions:
            # create QA input and get QA model's output
            qa_res = get_answer(
                pipe, 
                tokenizer, 
                question.format(norm2text[rel[0]], norm2text[rel[1]]),
                dataset['test'][i]['text']
            )

            # get classification results
            zh_res = clf(qa_res[0], candidates, multi_label=False)
            # combine probabilities
            y_pred = process_candidates(
                [c.format(norm2text[rel[0]], norm2text[rel[1]]) for c in positive_candidates], 
                [c.format(norm2text[rel[0]], norm2text[rel[1]]) for c in negative_candidates], 
                zh_res)
            probs_per_question.append(y_pred)
        y_preds.append(probs_per_question)
        ys.append(int(label))
    return y_preds, ys

In [None]:
def predict_data(dataset, num, pipe, clf, questions, positive_candidates, negative_candidates):
    y_preds, ys = [], []
    for i in tqdm(range(num)):
        y_pred, y = classify_answer(dataset, i, pipe, clf, questions, positive_candidates, negative_candidates)
        y_preds.extend(y_pred)
        ys.extend(y)
    return y_preds, ys

In [None]:
y_preds, ys = predict_data(
    input_dataset, len(input_dataset), 
    model, classifier, 
    questions, 
    positive_candidates, negative_candidates
)

In [None]:
print(classification_report(ys, [1 if (y[0][0]+y[1][0]) > 1 else 0 for y in y_preds]))