/
document_qa_eval.py
320 lines (275 loc) · 13.4 KB
/
document_qa_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import argparse
import json
from os.path import join
from typing import List
import numpy as np
import pandas as pd
from tqdm import tqdm
from docqa import trainer
from docqa.data_processing.document_splitter import MergeParagraphs, TopTfIdf, ShallowOpenWebRanker, FirstN
from docqa.data_processing.preprocessed_corpus import preprocess_par
from docqa.data_processing.qa_training_data import ParagraphAndQuestionDataset
from docqa.data_processing.span_data import TokenSpans
from docqa.data_processing.text_utils import NltkPlusStopWords
from docqa.dataset import FixedOrderBatcher
from docqa.eval.ranked_scores import compute_ranked_scores
from docqa.evaluator import Evaluator, Evaluation
from docqa.model_dir import ModelDir
from build_span_corpus import XQADataset
from docqa.triviaqa.read_data import normalize_wiki_filename
from docqa.triviaqa.training_data import DocumentParagraphQuestion, ExtractMultiParagraphs, \
ExtractMultiParagraphsPerQuestion
from docqa.triviaqa.trivia_qa_eval import exact_match_score as trivia_em_score
from docqa.triviaqa.trivia_qa_eval import f1_score as trivia_f1_score
from docqa.utils import ResourceLoader, print_table
"""
Evaluate on XQA data
Modified from docqa/eval/triviaqa_full_document_eval.py
"""
class RecordParagraphSpanPrediction(Evaluator):
def __init__(self, bound: int, record_text_ans: bool):
self.bound = bound
self.record_text_ans = record_text_ans
def tensors_needed(self, prediction):
span, score = prediction.get_best_span(self.bound)
needed = dict(spans=span, model_scores=score)
return needed
def evaluate(self, data: List[DocumentParagraphQuestion], true_len, **kargs):
spans, model_scores = np.array(kargs["spans"]), np.array(kargs["model_scores"])
pred_f1s = np.zeros(len(data))
pred_em = np.zeros(len(data))
text_answers = []
for i in tqdm(range(len(data)), total=len(data), ncols=80, desc="scoring"):
point = data[i]
if point.answer is None and not self.record_text_ans:
continue
text = point.get_context()
pred_span = spans[i]
pred_text = " ".join(text[pred_span[0]:pred_span[1] + 1])
if self.record_text_ans:
text_answers.append(pred_text)
if point.answer is None:
continue
f1 = 0
em = False
for answer in data[i].answer.answer_text:
ans = answer
f1 = max(f1, trivia_f1_score(pred_text, ans))
if not em:
em = trivia_em_score(pred_text, ans)
pred_f1s[i] = f1
pred_em[i] = em
results = {}
results["n_answers"] = [0 if x.answer is None else len(x.answer.answer_spans) for x in data]
if self.record_text_ans:
results["text_answer"] = text_answers
results["predicted_score"] = model_scores
results["predicted_start"] = spans[:, 0]
results["predicted_end"] = spans[:, 1]
results["text_f1"] = pred_f1s
results["rank"] = [x.rank for x in data]
results["text_em"] = pred_em
results["para_start"] = [x.para_range[0] for x in data]
results["para_end"] = [x.para_range[1] for x in data]
results["question_id"] = [x.question_id for x in data]
results["doc_id"] = [x.doc_id for x in data]
return Evaluation({}, results)
def main():
parser = argparse.ArgumentParser(description='Evaluate a model on TriviaQA data')
parser.add_argument('model', help='model directory')
parser.add_argument('-p', '--paragraph_output', type=str,
help="Save fine grained results for each paragraph in csv format")
parser.add_argument('-o', '--official_output', type=str, help="Build an offical output file with the model's"
" most confident span for each (question, doc) pair")
parser.add_argument('--no_ema', action="store_true", help="Don't use EMA weights even if they exist")
parser.add_argument('--n_processes', type=int, default=None,
help="Number of processes to do the preprocessing (selecting paragraphs+loading context) with")
parser.add_argument('-i', '--step', type=int, default=None, help="checkpoint to load, default to latest")
parser.add_argument('-n', '--n_sample', type=int, default=None, help="Number of questions to evaluate on")
parser.add_argument('-a', '--async', type=int, default=10)
parser.add_argument('-t', '--tokens', type=int, default=400,
help="Max tokens per a paragraph")
parser.add_argument('-g', '--n_paragraphs', type=int, default=15,
help="Number of paragraphs to run the model on")
parser.add_argument('-f', '--filter', type=str, default=None, choices=["tfidf", "truncate", "linear"],
help="How to select paragraphs")
parser.add_argument('-b', '--batch_size', type=int, default=200,
help="Batch size, larger sizes might be faster but wll take more memory")
parser.add_argument('--max_answer_len', type=int, default=8,
help="Max answer span to select")
parser.add_argument('-c', '--corpus',
choices=["en_dev",
"en_test",
"fr_dev",
"fr_test",
"de_dev",
"de_test",
"ru_dev",
"ru_test",
"pt_dev",
"pt_test",
"zh_dev",
"zh_test",
"pl_dev",
"pl_test",
"uk_dev",
"uk_test",
"ta_dev",
"ta_test",
"fr_trans_en_dev",
"fr_trans_en_test",
"de_trans_en_dev",
"de_trans_en_test",
"ru_trans_en_dev",
"ru_trans_en_test",
"pt_trans_en_dev",
"pt_trans_en_test",
"zh_trans_en_dev",
"zh_trans_en_test",
"pl_trans_en_dev",
"pl_trans_en_test",
"uk_trans_en_dev",
"uk_trans_en_test",
"ta_trans_en_dev",
"ta_trans_en_test"],
required=True)
parser.add_argument('--dump_data_pickle_only', action="store_true", default=False)
args = parser.parse_args()
model_dir = ModelDir(args.model)
model = model_dir.get_model()
corpus_name = args.corpus[:args.corpus.rfind("_")]
eval_set = args.corpus[args.corpus.rfind("_")+1:]
dataset = XQADataset(corpus_name)
if eval_set == "dev":
test_questions = dataset.get_dev()
elif eval_set == "test":
test_questions = dataset.get_test()
else:
raise AssertionError()
corpus = dataset.evidence
splitter = MergeParagraphs(args.tokens)
per_document = args.corpus.startswith("web") # wiki and web are both multi-document
filter_name = args.filter
if filter_name is None:
# Pick default depending on the kind of data we are using
if per_document:
filter_name = "tfidf"
else:
filter_name = "linear"
print("Selecting %d paragraphs using method \"%s\" per %s" % (
args.n_paragraphs, filter_name, ("question-document pair" if per_document else "question")))
if filter_name == "tfidf":
para_filter = TopTfIdf(NltkPlusStopWords(punctuation=True), args.n_paragraphs)
elif filter_name == "truncate":
para_filter = FirstN(args.n_paragraphs)
elif filter_name == "linear":
para_filter = ShallowOpenWebRanker(args.n_paragraphs)
else:
raise ValueError()
n_questions = args.n_sample
if n_questions is not None:
test_questions.sort(key=lambda x:x.question_id)
np.random.RandomState(0).shuffle(test_questions)
test_questions = test_questions[:n_questions]
print("Building question/paragraph pairs...")
# Loads the relevant questions/documents, selects the right paragraphs, and runs the model's preprocessor
if per_document:
prep = ExtractMultiParagraphs(splitter, para_filter, model.preprocessor, require_an_answer=False)
else:
prep = ExtractMultiParagraphsPerQuestion(splitter, para_filter, model.preprocessor, require_an_answer=False)
prepped_data = preprocess_par(test_questions, corpus, prep, args.n_processes, 1000)
data = []
for q in prepped_data.data:
for i, p in enumerate(q.paragraphs):
if q.answer_text is None:
ans = None
else:
ans = TokenSpans(q.answer_text, p.answer_spans)
data.append(DocumentParagraphQuestion(q.question_id, p.doc_id,
(p.start, p.end), q.question, p.text,
ans, i))
# Reverse so our first batch will be the largest (so OOMs happen early)
questions = sorted(data, key=lambda x: (x.n_context_words, len(x.question)), reverse=True)
if args.dump_data_pickle_only:
# dump eval data for bert
import pickle
pickle.dump(questions, open("%s_%d.pkl" % (args.corpus, args.n_paragraphs), "wb"))
return
print("Done, starting eval")
if args.step is not None:
if args.step == "latest":
checkpoint = model_dir.get_latest_checkpoint()
else:
checkpoint = model_dir.get_checkpoint(int(args.step))
else:
checkpoint = model_dir.get_best_weights()
if checkpoint is not None:
print("Using best weights")
else:
print("Using latest checkpoint")
checkpoint = model_dir.get_latest_checkpoint()
test_questions = ParagraphAndQuestionDataset(questions, FixedOrderBatcher(args.batch_size, True))
evaluation = trainer.test(model,
[RecordParagraphSpanPrediction(args.max_answer_len, True)],
{args.corpus:test_questions}, ResourceLoader(), checkpoint, not args.no_ema, args.async)[args.corpus]
if not all(len(x) == len(data) for x in evaluation.per_sample.values()):
raise RuntimeError()
df = pd.DataFrame(evaluation.per_sample)
if args.official_output is not None:
print("Saving question result")
fns = {}
if per_document:
# I didn't store the unormalized filenames exactly, so unfortunately we have to reload
# the source data to get exact filename to output an official test script
print("Loading proper filenames")
if args.corpus == 'web-test':
source = join(TRIVIA_QA, "qa", "web-test-without-answers.json")
elif args.corpus == "web-dev":
source = join(TRIVIA_QA, "qa", "web-dev.json")
else:
raise AssertionError()
with open(join(source)) as f:
data = json.load(f)["Data"]
for point in data:
for doc in point["EntityPages"]:
filename = doc["Filename"]
fn = join("wikipedia", filename[:filename.rfind(".")])
fn = normalize_wiki_filename(fn)
fns[(point["QuestionId"], fn)] = filename
answers = {}
scores = {}
for q_id, doc_id, start, end, txt, score in df[["question_id", "doc_id", "para_start", "para_end",
"text_answer", "predicted_score"]].itertuples(index=False):
filename = dataset.evidence.file_id_map[doc_id]
if per_document:
if filename.startswith("web"):
true_name = filename[4:] + ".txt"
else:
true_name = fns[(q_id, filename)]
key = q_id + "--" + true_name
else:
key = q_id
prev_score = scores.get(key)
if prev_score is None or prev_score < score:
scores[key] = score
answers[key] = txt
with open(args.official_output, "w") as f:
json.dump(answers, f)
output_file = args.paragraph_output
if output_file is not None:
print("Saving paragraph result")
df.to_csv(output_file, index=False)
print("Computing scores")
if per_document:
group_by = ["question_id", "doc_id"]
else:
group_by = ["question_id"]
# Print a table of scores as more paragraphs are used
df.sort_values(group_by + ["rank"], inplace=True)
f1 = compute_ranked_scores(df, "predicted_score", "text_f1", group_by)
em = compute_ranked_scores(df, "predicted_score", "text_em", group_by)
table = [["N Paragraphs", "EM", "F1"]]
table += list([str(i+1), "%.4f" % e, "%.4f" % f] for i, (e, f) in enumerate(zip(em, f1)))
print_table(table)
if __name__ == "__main__":
main()