Skip to content

Commit

Permalink
Add methods for calling env methods/setting attributes inside a VecEn…
Browse files Browse the repository at this point in the history
…v (#71)

* added utility method for calling custom env methods

* style compliance; docstring

* added env_method() in dummy

* fixed docstring formatting; looping over dummy envs

* added get/set_attr feature to vectorized envs

* added note about picklability

* added 'indices' arg to set_attr() to target specific sub_envs

* fixed iteration over unspecified number of indices

* made test and added 'indices' arg to DummyVecEnv.set_attr()

* removed import numpy

* Fix VecEnv test

* Style fixes

* [ci skip] Update changelog

* Remove unused 'self'

* Attempt to fix memory issue for travis
  • Loading branch information
bjmuld authored and araffin committed Nov 4, 2018
1 parent 08de556 commit 7c95b74
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 13 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ For download links, please look at `Github release page <https://github.com/hill
Pre-Release 2.1.2a (WIP)
-----------------------

- Add ``async_eigen_decomp`` parameter for ACKTR and set it to ``False`` by default (remove deprecation warnings)
- added ``async_eigen_decomp`` parameter for ACKTR and set it to ``False`` by default (remove deprecation warnings)
- added methods for calling env methods/setting attributes inside a VecEnv (thanks to @bjmuld)


Release 2.1.1 (2018-10-20)
Expand Down
14 changes: 7 additions & 7 deletions stable_baselines/common/cmd_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0, allow_early_resets=True):
"""
Create a wrapped, monitored SubprocVecEnv for Atari.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environment you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
Expand All @@ -45,7 +45,7 @@ def _thunk():
def make_mujoco_env(env_id, seed, allow_early_resets=True):
"""
Create a wrapped, monitored gym.Env for MuJoCo.
:param env_id: (str) the environment ID
:param seed: (int) the inital seed for RNG
:param allow_early_resets: (bool) allows early reset of the environment
Expand All @@ -62,7 +62,7 @@ def make_mujoco_env(env_id, seed, allow_early_resets=True):
def make_robotics_env(env_id, seed, rank=0, allow_early_resets=True):
"""
Create a wrapped, monitored gym.Env for MuJoCo.
:param env_id: (str) the environment ID
:param seed: (int) the inital seed for RNG
:param rank: (int) the rank of the environment (for logging)
Expand All @@ -82,7 +82,7 @@ def make_robotics_env(env_id, seed, rank=0, allow_early_resets=True):
def arg_parser():
"""
Create an empty argparse.ArgumentParser.
:return: (ArgumentParser)
"""
import argparse
Expand All @@ -92,7 +92,7 @@ def arg_parser():
def atari_arg_parser():
"""
Create an argparse.ArgumentParser for run_atari.py.
:return: (ArgumentParser) parser {'--env': 'BreakoutNoFrameskip-v4', '--seed': 0, '--num-timesteps': int(1e7)}
"""
parser = arg_parser()
Expand All @@ -105,7 +105,7 @@ def atari_arg_parser():
def mujoco_arg_parser():
"""
Create an argparse.ArgumentParser for run_mujoco.py.
:return: (ArgumentParser) parser {'--env': 'Reacher-v2', '--seed': 0, '--num-timesteps': int(1e6), '--play': False}
"""
parser = arg_parser()
Expand All @@ -119,7 +119,7 @@ def mujoco_arg_parser():
def robotics_arg_parser():
"""
Create an argparse.ArgumentParser for run_mujoco.py.
:return: (ArgumentParser) parser {'--env': 'FetchReach-v0', '--seed': 0, '--num-timesteps': int(1e6)}
"""
parser = arg_parser()
Expand Down
35 changes: 35 additions & 0 deletions stable_baselines/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,38 @@ def _obs_from_buf(self):
return self.buf_obs[None]
else:
return self.buf_obs

def env_method(self, method_name, *method_args, **method_kwargs):
"""
Provides an interface to call arbitrary class methods of vectorized environments
:param method_name: (str) The name of the env class method to invoke
:param method_args: (tuple) Any positional arguments to provide in the call
:param method_kwargs: (dict) Any keyword arguments to provide in the call
:return: (list) List of items retured by the environment's method call
"""
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in self.envs]

def get_attr(self, attr_name):
"""
Provides a mechanism for getting class attribues from vectorized environments
:param attr_name: (str) The name of the attribute whose value to return
:return: (list) List of values of 'attr_name' in all environments
"""
return [getattr(env_i, attr_name) for env_i in self.envs]

def set_attr(self, attr_name, value, indices=None):
"""
Provides a mechanism for setting arbitrary class attributes inside vectorized environments
:param attr_name: (str) Name of attribute to assign new value
:param value: (obj) Value to assign to 'attr_name'
:param indices: (list,int) Indices of envs to assign value
:return: (list) in case env access methods might return something, they will be returned in a list
"""
if indices is None:
indices = range(len(self.envs))
elif isinstance(indices, int):
indices = [indices]
return [setattr(env_i, attr_name, value) for env_i in [self.envs[i] for i in indices]]
54 changes: 54 additions & 0 deletions stable_baselines/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def _worker(remote, parent_remote, env_fn_wrapper):
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'env_method':
method = getattr(env, data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == 'get_attr':
remote.send(getattr(env, data))
elif cmd == 'set_attr':
remote.send(setattr(env, data[0], data[1]))
else:
raise NotImplementedError
except EOFError:
Expand Down Expand Up @@ -107,3 +114,50 @@ def get_images(self):
pipe.send(('render', {"mode": 'rgb_array'}))
imgs = [pipe.recv() for pipe in self.remotes]
return imgs

def env_method(self, method_name, *method_args, **method_kwargs):
"""
Provides an interface to call arbitrary class methods of vectorized environments
:param method_name: (str) The name of the env class method to invoke
:param method_args: (tuple) Any positional arguments to provide in the call
:param method_kwargs: (dict) Any keyword arguments to provide in the call
:return: (list) List of items retured by each environment's method call
"""

for remote in self.remotes:
remote.send(('env_method', (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in self.remotes]

def get_attr(self, attr_name):
"""
Provides a mechanism for getting class attribues from vectorized environments
(note: attribute value returned must be picklable)
:param attr_name: (str) The name of the attribute whose value to return
:return: (list) List of values of 'attr_name' in all environments
"""

for remote in self.remotes:
remote.send(('get_attr', attr_name))
return [remote.recv() for remote in self.remotes]

def set_attr(self, attr_name, value, indices=None):
"""
Provides a mechanism for setting arbitrary class attributes inside vectorized environments
(note: this is a broadcast of a single value to all instances)
(note: the value must be picklable)
:param attr_name: (str) Name of attribute to assign new value
:param value: (obj) Value to assign to 'attr_name'
:param indices: (list,tuple) Iterable containing indices of envs whose attr to set
:return: (list) in case env access methods might return something, they will be returned in a list
"""

if indices is None:
indices = range(len(self.remotes))
elif isinstance(indices, int):
indices = [indices]
for remote in [self.remotes[i] for i in indices]:
remote.send(('set_attr', (attr_name, value)))
return [remote.recv() for remote in [self.remotes[i] for i in indices]]
15 changes: 11 additions & 4 deletions stable_baselines/ppo2/run_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,29 @@
from stable_baselines.common.policies import CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy, MlpPolicy


def train(env_id, num_timesteps, seed, policy):
def train(env_id, num_timesteps, seed, policy,
n_envs=8, nminibatches=4, n_steps=128):
"""
Train PPO2 model for atari environment, for testing purposes
:param env_id: (str) the environment id string
:param num_timesteps: (int) the number of timesteps to run
:param seed: (int) Used to seed the random generator.
:param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
:param n_envs: (int) Number of parallel environments
:param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
the number of environments run in parallel should be a multiple of nminibatches.
:param n_steps: (int) The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
"""

env = VecFrameStack(make_atari_env(env_id, 8, seed), 4)
env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4)
policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy]
model = PPO2(policy=policy, env=env, n_steps=128, nminibatches=4, lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches,
lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1)
model.learn(total_timesteps=num_timesteps)

del model

def main():
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/test_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def test_ppo2(policy):
:param policy: (str) the policy to test for PPO2
"""
ppo2_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, policy=policy)
ppo2_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS,
seed=SEED, policy=policy, n_envs=NUM_CPU,
nminibatches=NUM_CPU, n_steps=16)


@pytest.mark.slow
Expand Down
90 changes: 90 additions & 0 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
import gym
import numpy as np

from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv

N_ENVS = 3


class CustomGymEnv(gym.Env):
def __init__(self):
"""
Custom gym environment for testing purposes
"""
self.action_space = gym.spaces.Discrete(2)
self.observation_space = self.action_space
self.current_step = 0
self.ep_length = 4

def reset(self):
self.current_step = 0
self._choose_next_state()
return self.state

def step(self, action):
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
done = self.current_step >= self.ep_length
return self.state, reward, done, {}

def _choose_next_state(self):
self.state = self.action_space.sample()

def render(self, mode='human'):
pass

@staticmethod
def custom_method(dim_0=1, dim_1=1):
"""
Dummy method to test call to custom method
from VecEnv
:param dim_0: (int)
:param dim_1: (int)
:return: (np.ndarray)
"""
return np.ones((dim_0, dim_1))


@pytest.mark.parametrize("vec_env_class", [DummyVecEnv, SubprocVecEnv])
def test_vecenv_custom_calls(vec_env_class):
"""Test access to methods/attributes of vectorized environments"""
vec_env = vec_env_class([CustomGymEnv for _ in range(N_ENVS)])
env_method_results = vec_env.env_method('custom_method', 1, dim_1=2)
setattr_results = []
# Set current_step to an arbitrary value
for env_idx in range(N_ENVS):
setattr_results.append(vec_env.set_attr('current_step', env_idx, indices=env_idx))
# Retrieve the value for each environment
getattr_results = vec_env.get_attr('current_step')

assert len(env_method_results) == N_ENVS
assert len(setattr_results) == N_ENVS
assert len(getattr_results) == N_ENVS

for env_idx in range(N_ENVS):
assert (env_method_results[env_idx] == np.ones((1, 2))).all()
assert setattr_results[env_idx][0] is None
assert getattr_results[env_idx] == env_idx

# Test to change value for all the environments
setattr_result = vec_env.set_attr('current_step', 42, indices=None)
getattr_result = vec_env.get_attr('current_step')
assert setattr_result == [None for _ in range(N_ENVS)]
assert getattr_result == [42 for _ in range(N_ENVS)]

# Additional tests for setattr that does not affect all the environments
vec_env.reset()
setattr_result = vec_env.set_attr('current_step', 12, indices=[0, 1])
getattr_result = vec_env.get_attr('current_step')
assert setattr_result == [None for _ in range(2)]
assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)]

vec_env.reset()
# Change value only for first and last environment
setattr_result = vec_env.set_attr('current_step', 12, indices=[0, -1])
getattr_result = vec_env.get_attr('current_step')
assert setattr_result == [None for _ in range(2)]
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]

0 comments on commit 7c95b74

Please sign in to comment.