In [48]:
import collections
import json
import os
import re
import string
import sys

import numpy as np

In [49]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()


def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))


def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

In [50]:
def get_raw_scores(preds):
    exact_scores = {}
    f1_scores = {}
    for row in preds:
        gold_answers = [a for a in row["answers"]["text"] if normalize_answer(a)]
        if not gold_answers:
            gold_answers = [""]
        a_pred = row["gemini-1.0-pro-latest_answer"]
        if a_pred == "<no_answer>":
            a_pred = ""
        if a_pred == "<API_failed>":
            continue
        # Take max over all gold answers
        exact_scores[row["id"]] = max(compute_exact(a, a_pred) for a in gold_answers)
        f1_scores[row["id"]] = max(compute_f1(a, a_pred) for a in gold_answers)
    return exact_scores, f1_scores

In [51]:
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
    if not qid_list:
        total = len(exact_scores)
        return collections.OrderedDict(
            [
                ("exact", 100.0 * sum(exact_scores.values()) / total),
                ("f1", 100.0 * sum(f1_scores.values()) / total),
                ("total", total),
            ]
        )
    else:
        total = len(qid_list)
        return collections.OrderedDict(
            [
                ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
                ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
                ("total", total),
            ]
        )

In [52]:
with open("../gemini-1.0-pro-latest_squad2.0_validation.json", "r") as fin:
    data = json.load(fin)

In [53]:
exact_scores, f1_scores = get_raw_scores(data)

In [54]:
make_eval_dict(exact_scores, f1_scores, qid_list=None)

OrderedDict([('exact', 76.77725118483413),
             ('f1', 80.81593955483237),
             ('total', 11605)])

In [56]:
# Find ids where Gemini failed.
failed_ids = []
for row in data:
    if row["gemini-1.0-pro-latest_answer"] == "<API_failed>":
        failed_ids.append(row["id"])

print(len(failed_ids))
print(failed_ids)

268
['56de15104396321400ee25b7', '5ad3ee2d604f3c001a3ff7e1', '56de1563cffd8e1900b4b5c3', '5ad3f028604f3c001a3ff825', '5ad3fb01604f3c001a3ffb36', '56de3dbacffd8e1900b4b6d2', '5ad3fb6e604f3c001a3ffb5f', '5ad567055b96ef001a10adeb', '5ad04de377cf76001a686fa6', '570d2c20fed7b91900d45ca7', '57109275b654c5140001f9a1', '571114cfb654c5140001fb0a', '571114cfb654c5140001fb0c', '57111ab8a58dae1900cd6c40', '571144d1a58dae1900cd6d6e', '5ad3d689604f3c001a3ff30d', '5ad3ed37604f3c001a3ff7a4', '5ad415fd604f3c001a40032b', '571c3e8cdd7acb1400e4c0a7', '571a4d1a4faf5e1900b8a95a', '571c4132dd7acb1400e4c0b0', '571cebc05efbb31900334e48', '571cebc05efbb31900334e49', '571cebc05efbb31900334e4a', '571cebc05efbb31900334e4c', '5ad2678ad7d075001a42922c', '5ad2678ad7d075001a42922f', '571ce9bddd7acb1400e4c1a1', '5ad2685dd7d075001a429279', '571c7d55dd7acb1400e4c0c4', '571c8198dd7acb1400e4c0cf', '5ad24180d7d075001a428970', '5ad24ce8d7d075001a428c0f', '571c9348dd7acb1400e4c118', '571cc5c45efbb31900334ddb', '5ad258b4d7d075