/
spec_util.py
263 lines (229 loc) · 9.05 KB
/
spec_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
# The spec module
# Manages specification to run things in lab
import itertools
import json
import os
from string import Template
import pydash as ps
from convlab.lib import logger, util
SPEC_DIR = 'convlab/spec'
'''
All spec values are already param, inferred automatically.
To change from a value into param range, e.g.
- single: "explore_anneal_epi": 50
- continuous param: "explore_anneal_epi": {"min": 50, "max": 100, "dist": "uniform"}
- discrete range: "explore_anneal_epi": {"values": [50, 75, 100]}
'''
SPEC_FORMAT = {
"agent": [{
"name": str,
"algorithm": dict,
# "memory": dict,
# "net": dict,
}],
"env": [{
"name": str,
"max_t": (type(None), int, float),
# "max_frame": (int, float),
}],
# "body": {
# "product": ["outer", "inner", "custom"],
# "num": (int, list),
# },
"meta": {
"eval_frequency": (int, float),
"max_session": int,
"max_trial": (type(None), int),
},
"name": str,
}
logger = logger.get_logger(__name__)
def check_comp_spec(comp_spec, comp_spec_format):
'''Base method to check component spec'''
for spec_k, spec_format_v in comp_spec_format.items():
comp_spec_v = comp_spec[spec_k]
if ps.is_list(spec_format_v):
v_set = spec_format_v
assert comp_spec_v in v_set, f'Component spec value {ps.pick(comp_spec, spec_k)} needs to be one of {util.to_json(v_set)}'
else:
v_type = spec_format_v
assert isinstance(comp_spec_v, v_type), f'Component spec {ps.pick(comp_spec, spec_k)} needs to be of type: {v_type}'
if isinstance(v_type, tuple) and int in v_type and isinstance(comp_spec_v, float):
# cast if it can be int
comp_spec[spec_k] = int(comp_spec_v)
def check_body_spec(spec):
'''Base method to check body spec for multi-agent multi-env'''
ae_product = ps.get(spec, 'body.product')
body_num = ps.get(spec, 'body.num')
if ae_product == 'outer':
pass
elif ae_product == 'inner':
agent_num = len(spec['agent'])
env_num = len(spec['env'])
assert agent_num == env_num, 'Agent and Env spec length must be equal for body `inner` product. Given {agent_num}, {env_num}'
else: # custom
assert ps.is_list(body_num)
def check_compatibility(spec):
'''Check compatibility among spec setups'''
# TODO expand to be more comprehensive
if spec['meta'].get('distributed') == 'synced':
assert ps.get(spec, 'agent.0.net.gpu') == False, f'Distributed mode "synced" works with CPU only. Set gpu: false.'
def check(spec):
'''Check a single spec for validity'''
try:
spec_name = spec.get('name')
assert set(spec.keys()) >= set(SPEC_FORMAT.keys()), f'Spec needs to follow spec.SPEC_FORMAT. Given \n {spec_name}: {util.to_json(spec)}'
for agent_spec in spec['agent']:
check_comp_spec(agent_spec, SPEC_FORMAT['agent'][0])
for env_spec in spec['env']:
check_comp_spec(env_spec, SPEC_FORMAT['env'][0])
# check_comp_spec(spec['body'], SPEC_FORMAT['body'])
check_comp_spec(spec['meta'], SPEC_FORMAT['meta'])
# check_body_spec(spec)
check_compatibility(spec)
except Exception as e:
logger.exception(f'spec {spec_name} fails spec check')
raise e
return True
def check_all():
'''Check all spec files, all specs.'''
spec_files = ps.filter_(os.listdir(SPEC_DIR), lambda f: f.endswith('.json') and not f.startswith('_'))
for spec_file in spec_files:
spec_dict = util.read(f'{SPEC_DIR}/{spec_file}')
for spec_name, spec in spec_dict.items():
# fill-in info at runtime
spec['name'] = spec_name
spec = extend_meta_spec(spec)
try:
check(spec)
except Exception as e:
logger.exception(f'spec_file {spec_file} fails spec check')
raise e
logger.info(f'Checked all specs from: {ps.join(spec_files, ",")}')
return True
def extend_meta_spec(spec):
'''Extend meta spec with information for lab functions'''
extended_meta_spec = {
# reset lab indices to -1 so that they tick to 0
'experiment': -1,
'trial': -1,
'session': -1,
'cuda_offset': int(os.environ.get('CUDA_OFFSET', 0)),
'experiment_ts': util.get_ts(),
'prepath': None,
# ckpt extends prepath, e.g. ckpt_str = ckpt-epi10-totalt1000
'ckpt': None,
'git_sha': util.get_git_sha(),
'random_seed': None
}
spec['meta'].update(extended_meta_spec)
if 'eval_model_prepath' not in spec['meta']:
spec['meta']['eval_model_prepath'] = None
return spec
def get(spec_file, spec_name):
'''
Get an experiment spec from spec_file, spec_name.
Auto-check spec.
@example
spec = spec_util.get('base.json', 'base_case_openai')
'''
spec_file = spec_file.replace(SPEC_DIR, '') # cleanup
if 'data/' in spec_file:
assert spec_name in spec_file, 'spec_file in data/ must be lab-generated and contains spec_name'
spec = util.read(spec_file)
else:
spec_file = f'{SPEC_DIR}/{spec_file}' # allow direct filename
spec_dict = util.read(spec_file)
assert spec_name in spec_dict, f'spec_name {spec_name} is not in spec_file {spec_file}. Choose from:\n {ps.join(spec_dict.keys(), ",")}'
spec = spec_dict[spec_name]
# fill-in info at runtime
spec['name'] = spec_name
spec = extend_meta_spec(spec)
check(spec)
return spec
def get_eval_spec(spec_file, spec_name, prename=None):
'''Get spec for eval mode'''
spec = get(spec_file, spec_name)
spec['meta']['ckpt'] = 'eval'
spec['meta']['eval_model_prepath'] = prename
return spec
def get_param_specs(spec):
'''Return a list of specs with substituted spec_params'''
assert 'spec_params' in spec, 'Parametrized spec needs a spec_params key'
spec_params = spec.pop('spec_params')
spec_template = Template(json.dumps(spec))
keys = spec_params.keys()
specs = []
for idx, vals in enumerate(itertools.product(*spec_params.values())):
spec_str = spec_template.substitute(dict(zip(keys, vals)))
spec = json.loads(spec_str)
spec['name'] += f'_{"_".join(vals)}'
# offset to prevent parallel-run GPU competition, to mod in util.set_cuda_id
cuda_id_gap = int(spec['meta']['max_session'] / spec['meta']['param_spec_process'])
spec['meta']['cuda_offset'] += idx * cuda_id_gap
specs.append(spec)
return specs
def override_dev_spec(spec):
spec['meta']['max_session'] = 1
spec['meta']['max_trial'] = 2
return spec
#def override_enjoy_spec(spec):
# spec['meta']['max_session'] = 1
# return spec
def override_eval_spec(spec):
spec['meta']['max_session'] = 1
# evaluate by episode is set in env clock init in env/base.py
return spec
def override_test_spec(spec):
for agent_spec in spec['agent']:
# onpolicy freq is episodic
freq = 1 if agent_spec['memory']['name'] == 'OnPolicyReplay' else 8
agent_spec['algorithm']['training_frequency'] = freq
agent_spec['algorithm']['training_start_step'] = 1
agent_spec['algorithm']['training_iter'] = 1
agent_spec['algorithm']['training_batch_iter'] = 1
for env_spec in spec['env']:
env_spec['max_frame'] = 40
env_spec['max_t'] = 12
spec['meta']['log_frequency'] = 10
spec['meta']['eval_frequency'] = 10
spec['meta']['max_session'] = 1
spec['meta']['max_trial'] = 2
return spec
def save(spec, unit='experiment'):
'''Save spec to proper path. Called at Experiment or Trial init.'''
prepath = util.get_prepath(spec, unit)
util.write(spec, f'{prepath}_spec.json')
def tick(spec, unit):
'''
Method to tick lab unit (experiment, trial, session) in meta spec to advance their indices
Reset lower lab indices to -1 so that they tick to 0
spec_util.tick(spec, 'session')
session = Session(spec)
'''
meta_spec = spec['meta']
if unit == 'experiment':
meta_spec['experiment_ts'] = util.get_ts()
meta_spec['experiment'] += 1
meta_spec['trial'] = -1
meta_spec['session'] = -1
elif unit == 'trial':
if meta_spec['experiment'] == -1:
meta_spec['experiment'] += 1
meta_spec['trial'] += 1
meta_spec['session'] = -1
elif unit == 'session':
if meta_spec['experiment'] == -1:
meta_spec['experiment'] += 1
if meta_spec['trial'] == -1:
meta_spec['trial'] += 1
meta_spec['session'] += 1
else:
raise ValueError(f'Unrecognized lab unit to tick: {unit}')
# set prepath since it is determined at this point
meta_spec['prepath'] = prepath = util.get_prepath(spec, unit)
for folder in ('graph', 'info', 'log', 'model'):
folder_prepath = util.insert_folder(prepath, folder)
os.makedirs(os.path.dirname(util.smart_path(folder_prepath)), exist_ok=True)
meta_spec[f'{folder}_prepath'] = folder_prepath
return spec