Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bypass out-of-sync Gym registry in SubprocVecEnv by resolving EnvSpec #160

Merged
merged 10 commits into from
Jan 22, 2020
17 changes: 16 additions & 1 deletion src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,30 @@ def make_vec_env(env_name: str,
max_episode_steps: If specified, wraps VecEnv in TimeLimit wrapper with
this episode length before returning.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
this episode length before returning.
this episode length before returning. Otherwise, defaults to `max_episode_steps` for `env_name` in the Gym registry.

Copy link
Member Author

@shwang shwang Jan 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expanded this comment a bit more in 902fe96. Wanted to note that the gym registry total timesteps thing is default behavior for gym.make.

"""
# Resolve the spec outside of the subprocess first, so that it is available to
# subprocesses running `make_env` via automatic pickling.
spec = gym.spec(env_name)

def make_env(i, this_seed):
env = gym.make(env_name)
# Previously, we directly called `gym.make(env_name)`, but running
# `imitation.scripts.train_adversarial` within `imitation.scripts.parallel`
# created a weird interaction between Gym and Ray -- `gym.make` would fail
# inside this function for any of our custom environment unless those
# environments were also `gym.register()`ed inside `make_env`. Even
# registering the custom environment in the scope of `make_vec_env` didn't
# work. For more discussion and hypotheses on this issue see PR #160:
# https://github.com/HumanCompatibleAI/imitation/pull/160.
env = spec.make()

# Seed each environment with a different, non-sequential seed for diversity
# (even if caller is passing us sequentially-assigned base seeds). int() is
# necessary to work around gym bug where it chokes on numpy int64s.
env.seed(int(this_seed))

if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
elif (spec.max_episode_steps is not None) and not spec.tags.get('vnc'):
shwang marked this conversation as resolved.
Show resolved Hide resolved
env = TimeLimit(env, max_episode_steps=spec.max_episode_steps)

# Use Monitor to record statistics needed for Baselines algorithms logging
# Optionally, save to disk
Expand Down
36 changes: 34 additions & 2 deletions tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def test_transfer_learning(tmpdir):
dict(
sacred_ex_name="expert_demos",
base_named_configs=["cartpole", "fast"],
n_seeds=2,
search_space={
"config_updates": {
"seed": tune.grid_search([0, 1]),
"init_rl_kwargs": {
"learning_rate": tune.grid_search([3e-4, 1e-4]),
},
Expand All @@ -171,7 +171,6 @@ def test_transfer_learning(tmpdir):
),
]


PARALLEL_CONFIG_LOW_RESOURCE = {
# CI server only has 2 cores.
"init_kwargs": {"num_cpus": 2},
Expand All @@ -194,6 +193,39 @@ def test_parallel(config_updates):
assert run.status == 'COMPLETED'


def _generate_test_rollouts(tmpdir: str, env_named_config: str) -> str:
expert_demos_ex.run(
named_configs=[env_named_config, "fast"],
config_updates=dict(
rollout_save_interval=0,
log_dir=tmpdir,
))
rollout_path = osp.abspath(f"{tmpdir}/rollouts/final.pkl")
return rollout_path


def test_parallel_train_adversarial_custom_env(tmpdir):
env_named_config = "custom_ant"
rollout_path = _generate_test_rollouts(tmpdir, env_named_config)

config_updates = dict(
sacred_ex_name="train_adversarial",
n_seeds=1,
base_named_configs=[env_named_config, "fast"],
base_config_updates=dict(
init_trainer_kwargs=dict(
parallel=True,
num_vec=2,
),
rollout_path=rollout_path,
),
)
config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE)
run = parallel_ex.run(named_configs=["debug_log_root"],
config_updates=config_updates)
assert run.status == 'COMPLETED'


@pytest.mark.parametrize("run_names", ([], list("adab")))
def test_analyze_imitation(tmpdir: str, run_names: List[str]):
sacred_logs_dir = tmpdir
Expand Down