-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_expert_rollouts.py
259 lines (218 loc) · 9.29 KB
/
gen_expert_rollouts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import logging
import os
import os.path as osp
from typing import Optional
from sacred.observers import FileStorageObserver
from stable_baselines3.common import logger as sb_logger
from stable_baselines3.common.vec_env import VecNormalize
import imitation.util.sacred as sacred_util
from imitation.data import rollout
from imitation.policies import serialize
from imitation.rewards.serialize import load_reward
from imitation.scripts.config.expert_demos import expert_demos_ex
from imitation.util import logger, util
from imitation.util.reward_wrapper import RewardVecEnvWrapper
@expert_demos_ex.main
def rollouts_and_policy(
_run,
_seed: int,
env_name: str,
total_timesteps: int,
*,
log_dir: str,
num_vec: int,
parallel: bool,
max_episode_steps: Optional[int],
normalize: bool,
normalize_kwargs: dict,
init_rl_kwargs: dict,
n_episodes_eval: int,
reward_type: Optional[str],
reward_path: Optional[str],
rollout_save_interval: int,
rollout_save_final: bool,
rollout_save_n_timesteps: Optional[int],
rollout_save_n_episodes: Optional[int],
policy_save_interval: int,
policy_save_final: bool,
init_tensorboard: bool,
) -> dict:
"""Trains an expert policy from scratch and saves the rollouts and policy.
Checkpoints:
At applicable training steps `step` (where step is either an integer or
"final"):
- Policies are saved to `{log_dir}/policies/{step}/`.
- Rollouts are saved to `{log_dir}/rollouts/{step}.pkl`.
Args:
env_name: The gym.Env name. Loaded as VecEnv.
total_timesteps: Number of training timesteps in `model.learn()`.
log_dir: The root directory to save metrics and checkpoints to.
num_vec: Number of environments in VecEnv.
parallel: If True, then use DummyVecEnv. Otherwise use SubprocVecEnv.
max_episode_steps: If not None, then environments are wrapped by
TimeLimit so that they have at most `max_episode_steps` steps per
episode.
normalize: If True, then rescale observations and reward.
normalize_kwargs: kwargs for `VecNormalize`.
init_rl_kwargs: kwargs for `init_rl`.
n_episodes_eval: The number of episodes to average over when calculating
the average ground truth reward return of the final policy.
reward_type: If provided, then load the serialized reward of this type,
wrapping the environment in this reward. This is useful to test
whether a reward model transfers. For more information, see
`imitation.rewards.serialize.load_reward`.
reward_path: A specifier, such as a path to a file on disk, used by
reward_type to load the reward model. For more information, see
`imitation.rewards.serialize.load_reward`.
rollout_save_interval: The number of training updates in between
intermediate rollout saves. If the argument is nonpositive, then
don't save intermediate updates.
rollout_save_final: If True, then save rollouts right after training is
finished.
rollout_save_n_timesteps: The minimum number of timesteps saved in every
file. Could be more than `rollout_save_n_timesteps` because
trajectories are saved by episode rather than by transition.
Must set exactly one of `rollout_save_n_timesteps`
and `rollout_save_n_episodes`.
rollout_save_n_episodes: The number of episodes saved in every
file. Must set exactly one of `rollout_save_n_timesteps` and
`rollout_save_n_episodes`.
policy_save_interval: The number of training updates between saves. Has
the same semantics are `rollout_save_interval`.
policy_save_final: If True, then save the policy right after training is
finished.
init_tensorboard: If True, then write tensorboard logs to {log_dir}/sb_tb
and "output/summary/...".
Returns:
The return value of `rollout_stats()` using the final policy.
"""
os.makedirs(log_dir, exist_ok=True)
sacred_util.build_sacred_symlink(log_dir, _run)
sample_until = rollout.make_sample_until(
rollout_save_n_timesteps, rollout_save_n_episodes
)
eval_sample_until = rollout.min_episodes(n_episodes_eval)
logging.basicConfig(level=logging.INFO)
logger.configure(
folder=osp.join(log_dir, "rl"), format_strs=["tensorboard", "stdout"]
)
rollout_dir = osp.join(log_dir, "rollouts")
policy_dir = osp.join(log_dir, "policies")
os.makedirs(rollout_dir, exist_ok=True)
os.makedirs(policy_dir, exist_ok=True)
if init_tensorboard:
# sb_tensorboard_dir = osp.join(log_dir, "sb_tb")
# Convert sacred's ReadOnlyDict to dict so we can modify on next line.
init_rl_kwargs = dict(init_rl_kwargs)
# init_rl_kwargs["tensorboard_log"] = sb_tensorboard_dir
# FIXME(sam): this is another hack to prevent SB3 from configuring the
# logger on the first .learn() call. Remove it once SB3 issue #109 is
# fixed.
init_rl_kwargs["tensorboard_log"] = None
# init_rl_kwargs["tensorboard_log"] = "{}/tensorboard".format(log_dir)
venv = util.make_vec_env(
env_name,
num_vec,
seed=_seed,
parallel=parallel,
log_dir=log_dir,
max_episode_steps=max_episode_steps,
)
log_callbacks = []
if reward_type is not None:
reward_fn = load_reward(reward_type, reward_path, venv)
venv = RewardVecEnvWrapper(venv, reward_fn)
log_callbacks.append(venv.log_callback)
logging.info(f"Wrapped env in reward {reward_type} from {reward_path}.")
vec_normalize = None
if normalize:
venv = vec_normalize = VecNormalize(venv, **normalize_kwargs)
policy = util.init_rl(venv, verbose=1, **init_rl_kwargs)
# Make callback to save intermediate artifacts during training.
step = 0
def callback(locals_: dict, _) -> bool:
nonlocal step
step += 1
policy = locals_["self"]
# TODO(adam): make logging frequency configurable
for callback in log_callbacks:
callback(sb_logger)
if rollout_save_interval > 0 and step % rollout_save_interval == 0:
save_path = osp.join(rollout_dir, f"{step}.pkl")
rollout.rollout_and_save(save_path, policy, venv, sample_until)
if policy_save_interval > 0 and step % policy_save_interval == 0:
print("step:{}".format(step))
output_dir = os.path.join(policy_dir, f"{step:05d}")
serialize.save_stable_model(output_dir, policy, vec_normalize)
policy.learn(total_timesteps, callback=callback)
# Save final artifacts after training is complete.
if rollout_save_final:
save_path = osp.join(rollout_dir, "final.pkl")
rollout.rollout_and_save(save_path, policy, venv, sample_until)
if policy_save_final:
output_dir = os.path.join(policy_dir, "final")
serialize.save_stable_model(output_dir, policy, vec_normalize)
# Final evaluation of expert policy.
trajs = rollout.generate_trajectories(policy, venv, eval_sample_until)
stats = rollout.rollout_stats(trajs)
return stats
@expert_demos_ex.command
def rollouts_from_policy(
_run,
_seed: int,
*,
num_vec: int,
rollout_save_n_timesteps: int,
rollout_save_n_episodes: int,
log_dir: str,
policy_path: str,
policy_type: str,
env_name: str,
parallel: bool,
rollout_save_path: str,
max_episode_steps: Optional[int],
) -> None:
"""Loads a saved policy and generates rollouts.
Unlisted arguments are the same as in `rollouts_and_policy()`.
Args:
policy_type: Argument to `imitation.policies.serialize.load_policy`.
policy_path: Argument to `imitation.policies.serialize.load_policy`.
rollout_save_path: Rollout pickle is saved to this path.
"""
os.makedirs(log_dir, exist_ok=True)
sacred_util.build_sacred_symlink(log_dir, _run)
sample_until = rollout.make_sample_until(
rollout_save_n_timesteps, rollout_save_n_episodes
)
venv = util.make_vec_env(
env_name,
num_vec,
seed=_seed,
parallel=parallel,
log_dir=log_dir,
max_episode_steps=max_episode_steps,
)
policy = serialize.load_policy(policy_type, policy_path, venv)
rollout.rollout_and_save(rollout_save_path, policy, venv, sample_until)
def main_console(log_root, config_updates, named_configs):
# observer = FileStorageObserver(osp.join("output", "sacred", "expert_demos"))
observer = FileStorageObserver(log_root)
expert_demos_ex.observers.append(observer)
# expert_demos_ex.run_commandline()
expert_demos_ex.run(config_updates=config_updates, named_configs=named_configs)
if __name__ == "__main__": # pragma: no cover
policy_type = 'ppo'
env_name = 'CartPole-v1'
log_root = 'expert/{}_{}/'.format(policy_type, env_name)
config_updates = {
'log_dir': log_root,
'policy_type' : policy_type,
'init_tensorboard' : True,
'parallel' : False,
'num_vec' : 1,
}
dict_named_configs = {
'CartPole-v1' : ['cartpole'],
}
named_configs = dict_named_configs[env_name]
main_console("{}expert_demos".format(log_root), config_updates, named_configs)