-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_wikiqa.py
292 lines (224 loc) · 10.4 KB
/
eval_wikiqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""Script to evaluate models on the WikiQA dataset
This script will train the specified model on the WikiQA train set and provide predictions on the WikiQA test set in the TREC format.
The TREC format includes a "qrels" and a "pred" file
It can then be evaluated using
./trec_eval qrels pred
(Get trec_eval from misc_scripts/)
The script will automatically save the qrels and pred file with a distinguishable name.
For example,
pred_mp_wikiqa : pred file of MatchPyramid model on WikiQA test
pred_bidaf_t_wikiqa : pred file of BiDAF_T model on WikiQA test after pretraining on SQUAD-T
pred_bidaf_t_finetuned_wikiqa : pred file of BiDAF_T on WikiQA test after pretraining on SQUAD-T and finetuning on WikiQA train
Since we have a predefined split of WikiQA into test, train and dev (unlike InsuranceQA) the qrels file will always be the
same for all the models.
Example Usage
-------------
$ python eval_wikiqa.py # evaluates on MatchPyramid, DRMM_TKS, BiDAF_T
$ python eval_wikiqa.py --model_type mp # evaluates on MatchPyramid
model_type : {mp, dtks, bidaf_t}
mp : MatchPyramid
dtks : DRMM_TKS
bidaf_t : BiDirectional Attention Flow (senTence level)
"""
import sys
sys.path.append('../..')
import sys
import os
from sl_eval.models import MatchPyramid, DRMM_TKS, BiDAF_T
from data_readers import WikiReaderIterable, WikiReaderStatic
import gensim.downloader as api
import argparse
def save_qrels(test_data, fname):
"""Saves the WikiQA data `Truth Data`. This remains the same regardless of which model you use.
qrels : query relevance
Format
------
<query_id>\t<0>\t<document_id>\t<relevance>
Note: parameter <0> is ignored by the model
Example
-------
Q1 0 D1-0 0
Q1 0 D1-1 0
Q1 0 D1-2 0
Q1 0 D1-3 1
Q1 0 D1-4 0
Q16 0 D16-0 1
Q16 0 D16-1 0
Q16 0 D16-2 0
Q16 0 D16-3 0
Q16 0 D16-4 0
Parameters
----------
fname : str
File where the qrels should be saved
"""
queries, doc_group, label_group, query_ids, doc_id_group = test_data
with open(fname, 'w') as f:
for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group):
for d, l, d_id in zip(doc, labels, d_ids):
f.write(q_id + '\t' + '0' + '\t' + str(d_id) + '\t' + str(l) + '\n')
print("qrels done. Saved as %s" % fname)
def save_model_pred(test_data, fname, similarity_fn):
"""Goes through all the queries and docs, gets their Similarity score as per the `similarity_fn`
and saves it in the TREC format
Format
------
<query_id>\t<Q0>\t<document_id>\t<rank>\t<model_score>\t<STANDARD>
Note: parameters <Q0>, <rank> and <STANDARD> are ignored by the model and can be kept as anything
I have chose 99 as the rank. It has no meaning.
Example
-------
Q1 Q0 D1-0 99 0.64426434 STANDARD
Q1 Q0 D1-1 99 0.26972288 STANDARD
Q1 Q0 D1-2 99 0.6259719 STANDARD
Q1 Q0 D1-3 99 0.8891963 STANDARD
Q1 Q0 D1-4 99 1.7347554 STANDARD
Q16 Q0 D16-0 99 1.1078827 STANDARD
Q16 Q0 D16-1 99 0.22940424 STANDARD
Q16 Q0 D16-2 99 1.7198141 STANDARD
Q16 Q0 D16-3 99 1.7576259 STANDARD
Q16 Q0 D16-4 99 1.548423 STANDARD
Parameters
----------
fname : str
File where the qrels should be saved
similarity_fn : function
Parameters
- query : list of str
- doc : list of str
Returns
- similarity_score : float
"""
queries, doc_group, label_group, query_ids, doc_id_group = test_data
with open(fname, 'w') as f:
for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group):
for d, l, d_id in zip(doc, labels, d_ids):
my_score = str(similarity_fn(q,d))
f.write(q_id + '\t' + 'Q0' + '\t' + str(d_id) + '\t' + '99' + '\t' + my_score + '\t' + 'STANDARD' + '\n')
print("Prediction done. Saved as %s" % fname)
def dtks_similarity_fn(q, d):
"""Similarity Function for DRMM TKS
Parameters
----------
query : list of str
doc : list of str
Returns
-------
similarity_score : float
"""
return drmm_tks_model.predict([q], [[d]])[0][0]
def mp_similarity_fn(q, d):
"""Similarity Function for DRMM TKS
Parameters
----------
query : list of str
doc : list of str
Returns
-------
similarity_score : float
"""
return mp_model.predict([q], [[d]])[0][0]
if __name__ == '__main__':
wikiqa_folder = os.path.join('..', '..', 'data', 'WikiQACorpus')
squad_t_path = os.path.join('..', '..', 'data', 'SQUAD-T-QA.tsv')
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', required=False, help='the model to be evaluated (mp, dtks, bidaf_t)')
args = parser.parse_args()
model_type = args.model_type
do_bidaf_t, do_mp, do_dtks = False, False, False
if model_type == 'dtks':
do_bidaf_t = True
elif model_type == 'mp':
do_mp = True
elif model_type == 'bidaf_t':
do_bidaf_t = True
else: # Evaluate all
do_bidaf_t, do_mp, do_dtks = True, True, True
q_iterable = WikiReaderIterable('query', os.path.join(wikiqa_folder, 'WikiQA-train.tsv'))
d_iterable = WikiReaderIterable('doc', os.path.join(wikiqa_folder, 'WikiQA-train.tsv'))
l_iterable = WikiReaderIterable('label', os.path.join(wikiqa_folder, 'WikiQA-train.tsv'))
q_val_iterable = WikiReaderIterable('query', os.path.join(wikiqa_folder, 'WikiQA-dev.tsv'))
d_val_iterable = WikiReaderIterable('doc', os.path.join(wikiqa_folder, 'WikiQA-dev.tsv'))
l_val_iterable = WikiReaderIterable('label', os.path.join(wikiqa_folder, 'WikiQA-dev.tsv'))
q_test_iterable = WikiReaderIterable('query', os.path.join(wikiqa_folder, 'WikiQA-test.tsv'))
d_test_iterable = WikiReaderIterable('doc', os.path.join(wikiqa_folder, 'WikiQA-test.tsv'))
l_test_iterable = WikiReaderIterable('label', os.path.join(wikiqa_folder, 'WikiQA-test.tsv'))
test_data = WikiReaderStatic(os.path.join(wikiqa_folder, 'WikiQA-test.tsv')).get_data()
num_samples_wikiqa = 9000
num_embedding_dims = 300
qrels_save_path = 'qrels_wikiqa'
mp_pred_save_path = 'pred_mp_wikiqa'
dtks_pred_save_path = 'pred_dtks_wikiqa'
bidaf_t_pred_save_path = 'pred_bidaf_t_wikiqa'
bidaf_t_finetuned_pred_save_path = 'pred_bidaf_t_finetuned_wikiqa'
print('Saving qrels for WikiQA test data')
save_qrels(test_data, qrels_save_path)
kv_model = api.load('glove-wiki-gigaword-' + str(num_embedding_dims))
if do_bidaf_t:
q_squad = WikiReaderIterable('query', squad_t_path)
d_squad = WikiReaderIterable('doc', squad_t_path)
l_squad = WikiReaderIterable('label', squad_t_path)
num_squad_samples = 53968
n_epochs = 20
batch_size = 100
text_maxlen = 100
steps_per_epoch_squad = num_squad_samples // batch_size
print('Pretraining on SQUAD-T dataset')
bidaf_t_model = BiDAF_T(q_iterable, d_iterable, l_iterable, kv_model, n_epochs=n_epochs,
steps_per_epoch=steps_per_epoch_squad)
print('Testing on WikiQA-test')
queries, doc_group, label_group, query_ids, doc_id_group = test_data
i=0
with open(bidaf_t_pred_save_path, 'w') as f:
for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group):
batch_score = bidaf_t_model.batch_predict(q, doc)
for d, l, d_id, bscore in zip(doc, labels, d_ids, batch_score):
my_score = bscore[1]
i += 1
f.write(q_id + '\t' + 'Q0' + '\t' + str(d_id) + '\t' + '99' + '\t' + str(my_score) + '\t' + 'STANDARD' + '\n')
print("Prediction done. Saved as %s" % bidaf_t_pred_save_path)
print('FineTuning on WikiQA-train set')
finetune_epochs = 1
finetune_batch_size = 100
steps_per_epoch = num_samples_wikiqa // finetune_batch_size
bidaf_t_model.train(queries=q_iterable, docs=d_iterable, labels=l_iterable, batch_size=finetune_batch_size,
steps_per_epoch=steps_per_epoch)
print('Testing on WikiQA-test after finetuning')
queries, doc_group, label_group, query_ids, doc_id_group = test_data
i=0
with open(bidaf_t_finetuned_pred_save_path, 'w') as f:
for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group):
batch_score = bidaf_t_model.batch_predict(q, doc)
for d, l, d_id, bscore in zip(doc, labels, d_ids, batch_score):
my_score = bscore[1]
i += 1
f.write(q_id + '\t' + 'Q0' + '\t' + str(d_id) + '\t' + '99' + '\t' + str(my_score) + '\t' + 'STANDARD' + '\n')
print("Prediction done. Saved as %s" % bidaf_t_finetuned_pred_save_path)
if do_mp:
n_epochs = 2
batch_size = 10
text_maxlen = 100
steps_per_epoch = num_samples_wikiqa // batch_size
# Train the model
mp_model = MatchPyramid(
queries=q_iterable, docs=d_iterable, labels=l_iterable, word_embedding=kv_model,
epochs=n_epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, text_maxlen=text_maxlen,
unk_handle_method='zero'
)
print('Test set results')
mp_model.evaluate(q_test_iterable, d_test_iterable, l_test_iterable)
print('Saving prediction on test data in TREC format')
save_model_pred(test_data, mp_pred_save_path, mp_similarity_fn)
if do_dtks:
batch_size = 10
steps_per_epoch = num_samples_wikiqa // batch_size
n_epochs = 6
# Train the model
drmm_tks_model = DRMM_TKS(
queries=q_iterable, docs=d_iterable, labels=l_iterable, word_embedding=kv_model, epochs=n_epochs,
topk=20, steps_per_epoch=steps_per_epoch, batch_size=batch_size
)
print('Test set results')
drmm_tks_model.evaluate(q_test_iterable, d_test_iterable, l_test_iterable)
print('Saving prediction on test data in TREC format')
save_model_pred(test_data, dtks_pred_save_path, dtks_similarity_fn)