In [11]:
import json
from fuzzywuzzy import fuzz

In [17]:
import sys
import re
import string
from collections import Counter
import pickle

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

In [2]:
with open('pred_codalab.json', 'r') as f:
    pred = json.load(f)

In [3]:
with open('../../data/external/hotpot_dev_distractor_v1.json', 'r') as f:
    dev = json.load(f)

In [4]:
dict_id2ans = dict()
for ins in dev:
    _id = ins['_id']
    ans = ins['answer']
    dict_id2ans[_id] = ans

In [22]:
cnt = 0
em_cnt = 0
for _id in pred['answer'].keys():
    lbl = dict_id2ans[_id]
    ans = pred['answer'][_id]
    if exact_match_score(ans, lbl):
        em_cnt += 1
    if fuzz.ratio(lbl, ans) >= 90 and not exact_match_score(ans, lbl):
        print(lbl, "###", ans)
        cnt += 1

9,984 ### 9, 984
Robert Erskine Childers DSC ### Robert Erskine Childers
1,462 ### 1, 462
35,124 ### 35, 124
$10.5 million ### $ 10. 5 million
right-hand ### right - hand
super-regional shopping mall ### super - regional shopping mall
Slaughterhouse-Five ### Slaughterhouse - Five
3,384,569 ### 3, 384, 569
Pakistani ### Pakistan
76,416 ### 76, 416
728,000 ft² ### 728, 000 ft²
51,271 ### 51, 271
1.95 m ### 1. 95 m
4,613 ### 4, 613
The R-8 Human Rhythm Composer ### The R - 8 Human Rhythm Composer
Queen In-hyun's Man ### Queen In - hyun's Man
Symphony No. 7 ### Symphony No. 1
Jean-Loup Jacques Marie Chrétien ### Jean - Loup Jacques Marie Chrétien
natural-ingredients-only personal care products ### natural - ingredients - only personal care products
Hessians ### Hessian
video game ### videogame
gull-wing doors ### gull - wing doors
1,800 ### 1, 800
Regional Rural Bank ### Regional Rural Banks
media for the 65.8 million ### media for the 65. 8 million
Thomas Warburton ### Tom Warburton
26–30

In [21]:
cnt

250

In [24]:
em_cnt/len(dev)

0.5665091154625254

In [25]:
(em_cnt+cnt)/len(dev)

0.600270087778528

In [10]:
lbl

'Norwood, Massachusetts'

In [29]:
cnt = 0
for _id in pred['answer'].keys():
    lbl = dict_id2ans[_id]
    ans = pred['answer'][_id]
    if ans == "":
        cnt += 1

In [30]:
cnt

151

In [31]:
(em_cnt+cnt)/len(dev)

0.5869007427413909