This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
squad.py
85 lines (63 loc) · 2.62 KB
/
squad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Functions taken from [the official evaluation script]
(https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
for SQuAD version 2.0.
"""
import collections
import re
import string
from typing import Callable, Sequence, TypeVar, Tuple
def make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid_to_has_ans[qa["id"]] = bool(qa["answers"])
return qid_to_has_ans
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_exact(a_pred: str, a_gold: str) -> int:
return int(normalize_answer(a_pred) == normalize_answer(a_gold))
def compute_f1(a_pred: str, a_gold: str) -> float:
pred_toks = get_tokens(a_pred)
gold_toks = get_tokens(a_gold)
common = collections.Counter(pred_toks) & collections.Counter(gold_toks) # type: ignore[var-annotated]
num_same = sum(common.values())
if len(pred_toks) == 0 or len(gold_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return float(pred_toks == gold_toks)
if num_same == 0:
return 0.0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
_P = TypeVar("_P")
_G = TypeVar("_G")
_T = TypeVar("_T", int, float, Tuple[int, ...], Tuple[float, ...])
def metric_max_over_ground_truths(
metric_fn: Callable[[_P, _G], _T], prediction: _P, ground_truths: Sequence[_G]
) -> _T:
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def get_metric_score(prediction: str, gold_answers: Sequence[str]) -> Tuple[int, float]:
exact_scores = metric_max_over_ground_truths(compute_exact, prediction, gold_answers)
f1_scores = metric_max_over_ground_truths(compute_f1, prediction, gold_answers)
return exact_scores, f1_scores