-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
125 lines (79 loc) · 3.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from argparse import ArgumentParser
import json
import torch
import random
import numpy as np
from data import EACDataset
from util import generate_vocabs, score_graphs_gold_AI
import os
from model import Model
from pprint import pprint
from joblib import Parallel, delayed
import pickle as pkl
from tqdm import tqdm
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
def main():
parser = ArgumentParser()
parser.add_argument('--input_dir', type=str, required=True,
help='path to the input directory')
parser.add_argument('--output_dir', type=str, required=True,
help='path to the output directory')
parser.add_argument('--dataset_name', type=str, required=True,
help='name of the dataset')
parser.add_argument('--split', type=str, required=True,
help='split of the dataset')
parser.add_argument('--lm_name', type=str, required=True,
help='name of language model')
parser.add_argument('--model_name', type=str, required=True,
help='name of model')
parser.add_argument('--mode', type=str, required=True,
help='mode of program (prompting or predict)', choices=['prompting', 'predict'])
args = parser.parse_args()
#build model
model = Model(arg_role_map_path = 'src/arg_role_map.json',
arg_type_map_path= 'src/arg_type_map.json',
arg_role_entity_type_constraint_map_path = \
'src/arg_role_entity_type_constraint_map.json',
time_expression_lexicon_path = 'src/time_expression_lexicon.pickle')
input_file = os.path.join(args.input_dir,
'{}.event.json'.format(args.split))
input_dataset = EACDataset(input_file)
vocabs = generate_vocabs([input_dataset])
input_dataset.numberize(vocabs)
if(args.mode == 'prompting'):
arg_score_list_dataset = []
for instance in tqdm(input_dataset):
arg_score_list_instance = []
for event in instance.events:
arg_score_list_event = model.prompting(instance.sentence, event)
arg_score_list_instance.append(arg_score_list_event)
arg_score_list_dataset.append(arg_score_list_instance)
pkl.dump(arg_score_list_dataset, open(os.path.join(args.output_dir, '{}_{}_{}_arg_score_list_dataset.pickle'.format(args.lm_name, args.dataset_name, args.split)), 'wb'))
elif(args.mode == 'predict'):
output_file = os.path.join(args.output_dir,
"{}_{}_output.event.json".format(args.model_name,
args.dataset_name))
arg_score_list_dataset = pkl.load(open(os.path.join(args.output_dir, '{}_{}_{}_arg_score_list_dataset.pickle'.format(args.lm_name, args.dataset_name, args.split)), 'rb'))
with open(output_file, 'w') as fw:
for instance, arg_score_list_instance in \
tqdm(zip(input_dataset, arg_score_list_dataset)):
output = model.predict(instance, arg_score_list_instance)
fw.write(json.dumps(output) + '\n')
fw.flush()
## Evaluate
gold_dataset = EACDataset(input_file)
pred_dataset = EACDataset(output_file)
vocabs = generate_vocabs([gold_dataset, pred_dataset])
gold_dataset.numberize(vocabs)
pred_dataset.numberize(vocabs)
gold_graphs, pred_graphs = [], []
i = 0
for inst1, inst2 in zip(gold_dataset, pred_dataset):
i += 1
gold_graphs.append(inst1.graph)
pred_graphs.append(inst2.graph)
score_graphs_gold_AI(gold_graphs, pred_graphs)
if __name__ == "__main__":
main()