-
Notifications
You must be signed in to change notification settings - Fork 2
/
inference_py.py
127 lines (97 loc) · 3.95 KB
/
inference_py.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from utils.exec_py import exec_py
from accelerate.utils import set_seed
from tqdm import tqdm
import argparse
import json
import numpy as np
import openai
import random
import torch
def select_models(args):
generator = None
evaluator = None
if args.generator_name.startswith("gpt"):
from generators.openai_generator import OpenaiGenerator
generator = OpenaiGenerator(args.generator_name)
else:
from generators.hf_generator import HFGenerator
generator = HFGenerator(args.generator_name, device="cuda")
if args.evaluator_name == "oracle":
from evaluators.oracle_evaluator_py import OracleEvaluatorPy
evaluator = OracleEvaluatorPy(args.oracle_prob)
elif args.evaluator_name.startswith("codellama"):
from evaluators.codellama_evaluator_py import CodeLlamaEvaluatorPy
evaluator = CodeLlamaEvaluatorPy(args.evaluator_name, device="cuda")
elif args.evaluator_name.startswith("gpt"):
from evaluators.openai_evaluator_py import OpenaiEvaluatorPy
evaluator = OpenaiEvaluatorPy(args.evaluator_name)
return generator, evaluator
def select_method(method_name):
if method_name == "mctot":
from planning_methods.mc_tot_py import mc_tot_py
return mc_tot_py
elif method_name == "greedy":
from planning_methods.greedy_py import greedy_py
return greedy_py
elif method_name == "rerank":
from planning_methods.rerank_py import rerank_py
return rerank_py
elif method_name == "iter_corr":
from planning_methods.iter_correction_py import iter_correction_py
return iter_correction_py
else:
raise Exception("Invalid method.")
def inference(generator, evaluator, args):
test_data = json.load(open(args.test_fname))
method = select_method(args.method_name)
results = []
log = []
#global answer
for ex in tqdm(test_data):
#answer = None
solution, example_log = method(ex, generator, evaluator, args)
example_log["solution"] = solution
try:
answer = exec_py(solution)
#exec(solution, globals())
except:
results.append(0)
example_log["pred_answer"] = "ERROR"
example_log["acc"] = 0
log.append(example_log)
continue
if answer is not None and answer == ex["answer"]:
results.append(1)
example_log["pred_answer"] = str(answer)
example_log["acc"] = 1
else:
results.append(0)
example_log["pred_answer"] = str(answer)
example_log["acc"] = 0
log.append(example_log)
out = open("log/" + args.log_fname, "w+", encoding="utf-8")
json.dump(log, out, indent=2)
out.close()
print("Accuracy: {:<20.4f}".format(sum(results) / len(results)))
def set_seed_all(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
set_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument('--test_fname', type=str, default='data/spider_dev.json')
args_parser.add_argument('--log_fname', type=str, default='spider_dev.json')
args_parser.add_argument('--method_name', type=str, default='mctot')
args_parser.add_argument('--generator_name', type=str, default='')
args_parser.add_argument('--evaluator_name', type=str, default='') #codellama/CodeLlama-13b-Instruct-hf
args_parser.add_argument('--oracle_prob', type=float, default=1.0)
args_parser.add_argument('--seed', type=int, default=42)
args_parser.add_argument('--generation_config', type=str, default='generation_configs/temp_sampling.json')
args_parser.add_argument('--evaluation_config', type=str, default='')
args = args_parser.parse_args()
set_seed_all(args.seed)
generator, evaluator = select_models(args)
inference(generator, evaluator, args)