In [1]:
import re
import string
import collections
from pprint import pprint

In [2]:
predictions = [
    {'prediction_text': '', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}, 
    {'prediction_text': 'Beyonce', 'id': '56d2051ce7d4791d0090260b', 'no_answer_probability': 0.}, 
    {'prediction_text': 'climate change in world', 'id': '5733b5344776f419006610e1', 'no_answer_probability': 0.},
    {'prediction_text': 'jakarta', 'id': '5733b5344776f419006610e2', 'no_answer_probability': 0.},
    {'prediction_text': 'bandung', 'id': '5733b5344776f419006610e3', 'no_answer_probability': 0.}
]
references = [
    {'answers': {'answer_start': [891], 'text': ['climate change in other world']}, 'id': '5733b5344776f419006610e1'},
    {'answers': {'answer_start': [891], 'text': ['jakarta']}, 'id': '5733b5344776f419006610e2'},
    {'answers': {'answer_start': [891], 'text': ['bandung']}, 'id': '5733b5344776f419006610e3'},
    {'answers': {'answer_start': [], 'text': []}, 'id': '56e10a3be3433e1400422b22'}, 
    {'answers': {'answer_start': [], 'text': []}, 'id': '56d2051ce7d4791d0090260b'}
]

In [3]:
ARTICLES_REGEX = re.compile(r"\b(a|an|the)\b", re.UNICODE)

In [4]:
def make_qid_to_has_ans(dataset):
    qid_to_has_ans = {}
    for article in dataset:
        for p in article["paragraphs"]:
            for qa in p["qas"]:
                qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"])
    return qid_to_has_ans

In [5]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return ARTICLES_REGEX.sub(" ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

In [6]:
def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()

In [7]:
def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))


def compute_f1(a_gold, a_pred):
    
    print("a_gold", a_gold)
    print("a_pred", a_pred)
    
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    
    print("gold_toks", gold_toks)
    print("pred_toks", pred_toks)
    
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    print("common", common)
    num_same = sum(common.values())
    print("num_same", num_same)
    
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    
    print("precision", precision)
    print("recall", recall)
    print("f1", f1)
    
    return f1

In [8]:
def get_raw_scores(dataset, preds):
    exact_scores = {}
    f1_scores = {}
    for article in dataset:
        for p in article["paragraphs"]:
            for qa in p["qas"]:
                qid = qa["id"]
                gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)]
                
                print("gold_answers")
                print(gold_answers)
                
                if not gold_answers:
                    # For unanswerable questions, only correct answer is empty string
                    gold_answers = [""]
                if qid not in preds:
                    print(f"Missing prediction for {qid}")
                    continue

                a_pred = preds[qid]
                
                # Take max over all gold answers
                exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
                f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
                
    return exact_scores, f1_scores

In [9]:
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
    new_scores = {}
    for qid, s in scores.items():
        pred_na = na_probs[qid] > na_prob_thresh
        if pred_na:
            new_scores[qid] = float(not qid_to_has_ans[qid])
        else:
            new_scores[qid] = s
    return new_scores

In [10]:
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
    
    if not qid_list:
        total = len(exact_scores)
        return collections.OrderedDict(
            [
                ("exact", 100.0 * sum(exact_scores.values()) / total),
                ("f1", 100.0 * sum(f1_scores.values()) / total),
                ("total", total),
            ]
        )
    else:
        total = len(qid_list)
        return collections.OrderedDict(
            [
                ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
                ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
                ("total", total),
            ]
        )

In [11]:
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
    num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
    cur_score = num_no_ans
    best_score = cur_score
    best_thresh = 0.0
    qid_list = sorted(na_probs, key=lambda k: na_probs[k])

    for i, qid in enumerate(qid_list):
        if qid not in scores:
            continue
        if qid_to_has_ans[qid]:
            diff = scores[qid]
        else:
            if preds[qid]:
                diff = -1
            else:
                diff = 0

        cur_score += diff
        
        if cur_score > best_score:
            best_score = cur_score
            best_thresh = na_probs[qid]
    return 100.0 * best_score / len(scores), best_thresh

In [12]:
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
    
    best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
    best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
    
    main_eval["best_exact"] = best_exact
    main_eval["best_exact_thresh"] = exact_thresh
    main_eval["best_f1"] = best_f1
    main_eval["best_f1_thresh"] = f1_thresh

In [13]:
def _compute(predictions, references, no_answer_threshold=1.0):
    no_answer_probabilities = {p["id"]: p["no_answer_probability"] for p in predictions}
    dataset = [{"paragraphs": [{"qas": references}]}]
    predictions = {p["id"]: p["prediction_text"] for p in predictions}
    
    # print("no_answer_probabilities")
    # pprint(no_answer_probabilities)
    # print("dataset")
    # pprint(dataset)
    # print("predictions")
    # pprint(predictions)

    qid_to_has_ans = make_qid_to_has_ans(dataset)  # maps qid to True/False
    has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
    no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
    
    # print("qid_to_has_ans")
    # pprint(qid_to_has_ans)
    # print("has_ans_qids")
    # pprint(has_ans_qids)
    # print("no_ans_qids")
    # pprint(no_ans_qids)

    exact_raw, f1_raw = get_raw_scores(dataset, predictions)
    print("f1_raw")
    pprint(f1_raw)
    
    exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
    f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
    out_eval = make_eval_dict(exact_thresh, f1_thresh)
    
    print(out_eval)
    pprint(out_eval)

#     if has_ans_qids:
#         has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
#         merge_eval(out_eval, has_ans_eval, "HasAns")

#     if no_ans_qids:
#         no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
#         merge_eval(out_eval, no_ans_eval, "NoAns")
    
    find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans)
    
    return dict(out_eval)

In [14]:
_compute(predictions, references)

gold_answers
['climate change in other world']
a_gold climate change in other world
a_pred climate change in world
gold_toks ['climate', 'change', 'in', 'other', 'world']
pred_toks ['climate', 'change', 'in', 'world']
common Counter({'climate': 1, 'change': 1, 'in': 1, 'world': 1})
num_same 4
precision 1.0
recall 0.8
f1 0.888888888888889
gold_answers
['jakarta']
a_gold jakarta
a_pred jakarta
gold_toks ['jakarta']
pred_toks ['jakarta']
common Counter({'jakarta': 1})
num_same 1
precision 1.0
recall 1.0
f1 1.0
gold_answers
['bandung']
a_gold bandung
a_pred bandung
gold_toks ['bandung']
pred_toks ['bandung']
common Counter({'bandung': 1})
num_same 1
precision 1.0
recall 1.0
f1 1.0
gold_answers
[]
a_gold 
a_pred 
gold_toks []
pred_toks []
common Counter()
num_same 0
gold_answers
[]
a_gold 
a_pred Beyonce
gold_toks []
pred_toks ['beyonce']
common Counter()
num_same 0
f1_raw
{'56d2051ce7d4791d0090260b': 0,
 '56e10a3be3433e1400422b22': 1,
 '5733b5344776f419006610e1': 0.888888888888889,
 '5733b

{'exact': 60.0,
 'f1': 77.77777777777777,
 'total': 5,
 'best_exact': 60.0,
 'best_exact_thresh': 0.0,
 'best_f1': 77.77777777777777,
 'best_f1_thresh': 0.0}