# setup

In [1]:
%%capture
!pip install transformers

In [2]:
import json
import pickle
import torch
import numpy as np
from transformers import pipeline
from tqdm import tqdm

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [4]:
%%capture
nli = pipeline(model="facebook/bart-large-mnli", device=device)

In [5]:
with open("aspects_cache.json", "r") as f:
    aspects_cache = json.load(f)

with open("negations_cache.json", "r") as f:
    negations_cache = json.load(f)

with open("Recipe-MPR.json", "r") as f:
    data = json.load(f)

with open("entailment_cache_2.pkl", "rb") as f:
    entailment_cache_2 = pickle.load(f)

# code

In [6]:
def iter_recipe_mpr():
    with open(f"Recipe-MPR.json", "r") as f:
        data = json.load(f)

    for row in data:
        query = row["query"]
        options = row["options"].values()
        answer = row["options"][row["answer"]]

        yield query, options, answer

In [7]:
def score_entailment(premise, hypothesis):
    cache = entailment_cache_2

    if (premise, hypothesis) in cache:
        return cache[(premise, hypothesis)]


    result = nli(premise, hypothesis)

    print(f"nli {premise} => {hypothesis}")
    cache[(premise, hypothesis)] = result["scores"][0]

    with open("entailment_cache_2.pkl", "wb") as f:
        pickle.dump(cache, f)


    return result["scores"][0]

def negate_aspect(aspect):
    return negations_cache[aspect]

In [8]:
def score_aspect(aspect, option):
    aspect_negated = negate_aspect(aspect)

    pos_score = score_entailment(option, aspect)
    neg_score = score_entailment(option, aspect_negated)

    # softmax
    scores = np.array([pos_score, neg_score])
    scores = np.exp(scores) / np.sum(np.exp(scores))
    pos_score, neg_score = scores

    # score
    odds = pos_score / neg_score
    score = odds / (1 + odds)
    score = np.log(score)

    result = {
        "score": score,
        "raw_score": {
            "pos": {"score": pos_score, "aspect": aspect},
            "neg": {"score": neg_score, "aspect": aspect_negated}
        }
    }

    return result

def score_option(aspects, option):
    score = 0
    aspect_scores = []

    for aspect in aspects:
        aspect_score = score_aspect(aspect, option)

        score += aspect_score["score"]
        aspect_scores.append(aspect_score)

    result = {
        "option": option,
        "score": score,
        "aspect_scores": aspect_scores
    }

    return result

In [9]:
def answer_query(query, options):
    aspects = aspects_cache[query]

    ranking = []

    for option in options:
        option_score = score_option(aspects, option)

        ranking.append(option_score)

    ranking.sort(key=lambda x: x["score"], reverse=True)

    return ranking

In [10]:
def log(query, answer, ranking, f):
    pred = ranking[0]["option"]

    is_correct = (pred == answer)

    f.write(f"  QUERY: {query}\n")
    f.write(f" ANSWER: {answer}\n")
    f.write(f"   PRED: {pred}\n")
    f.write(f"CORRECT: {is_correct}\n\n")

    for option in ranking:

        if option["option"] == answer:
            f.write(f"    ({option['score']:.5f}) {option['option']} **ANSWER**\n")
        else:
            f.write(f"    ({option['score']:.5f}) {option['option']}\n")

        aspect_scores = option["aspect_scores"]

        for aspect_score in aspect_scores:
            raw_aspect_score = aspect_score["raw_score"]

            f.write(f"        {raw_aspect_score['pos']['score']:.5f} => {raw_aspect_score['pos']['aspect']}\n")
            f.write(f"        {raw_aspect_score['neg']['score']:.5f} => {raw_aspect_score['neg']['aspect']}\n")
            f.write(f"        -------\n")
            f.write(f"        {aspect_score['score']:.5f}\n\n")

In [12]:
log_f = "logs.txt"

correct = 0
incorrect = 0

f = open(log_f, "w")

n = 0


for query, options, answer in iter_recipe_mpr():
    ranking = answer_query(query, options)

    prediction = ranking[0]["option"]

    if prediction == answer:
        correct += 1
    else:
        incorrect += 1

    log(query, answer, ranking, f)

    n += 1
    if n == 500:
        break

accuracy = correct / (correct + incorrect)
print(f"accuracy: {accuracy:.5f}")
f.close()


accuracy: 0.82200
