Skip to content

Commit

Permalink
Merge branch 'master' into sde
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Jun 4, 2020
2 parents 45fd302 + 353ea81 commit 0fa3733
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 19 deletions.
28 changes: 28 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,34 @@
Changelog
==========


Pre-Release 0.7.0a0 (WIP)
------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``render()`` method of ``VecEnvs`` now only accept one argument: ``mode``

New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^
- Fixed ``render()`` method for ``VecEnvs``
- Fixed ``seed()``` method for ``SubprocVecEnv``

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Re-enable unsafe ``fork`` start method in the tests (was causing a deadlock with tensorflow)
- Added a test for seeding ``SubprocVecEnv``` and rendering

Documentation:
^^^^^^^^^^^^^^


Pre-Release 0.6.0 (2020-06-01)
------------------------------

Expand Down
14 changes: 7 additions & 7 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,22 +162,22 @@ def step(self, actions):
self.step_async(actions)
return self.step_wait()

def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]:
def get_images(self) -> Sequence[np.ndarray]:
"""
Return RGB images from each environment
"""
raise NotImplementedError

def render(self, *args, mode: str = 'human', **kwargs):
def render(self, mode: str = 'human'):
"""
Gym environment rendering
:param mode: the rendering type
"""
try:
imgs = self.get_images(*args, **kwargs)
imgs = self.get_images()
except NotImplementedError:
logger.warn('Render not defined for {}'.format(self))
logger.warn(f'Render not defined for {self}')
return

# Create a big image by tiling images from subprocesses
Expand All @@ -189,7 +189,7 @@ def render(self, *args, mode: str = 'human', **kwargs):
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
raise NotImplementedError(f'Render mode {mode} is not supported by VecEnvs')

@abstractmethod
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
Expand Down Expand Up @@ -268,8 +268,8 @@ def seed(self, seed=None):
def close(self):
return self.venv.close()

def render(self, *args, **kwargs):
return self.venv.render(*args, **kwargs)
def render(self, mode: str = 'human'):
return self.venv.render(mode=mode)

def get_images(self):
return self.venv.get_images()
Expand Down
10 changes: 5 additions & 5 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def close(self):
for env in self.envs:
env.close()

def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]:
return [env.render(*args, mode='rgb_array', **kwargs) for env in self.envs]
def get_images(self) -> Sequence[np.ndarray]:
return [env.render(mode='rgb_array') for env in self.envs]

def render(self, *args, **kwargs):
def render(self, mode: str = 'human'):
"""
Gym environment rendering. If there are multiple environments then
they are tiled together in one image via ``BaseVecEnv.render()``.
Expand All @@ -82,9 +82,9 @@ def render(self, *args, **kwargs):
:param mode: The rendering type.
"""
if self.num_envs == 1:
return self.envs[0].render(*args, **kwargs)
return self.envs[0].render(mode=mode)
else:
return super().render(*args, **kwargs)
return super().render(mode=mode)

def _save_obs(self, env_idx, obs):
for key in self.keys:
Expand Down
10 changes: 6 additions & 4 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ def _worker(remote, parent_remote, env_fn_wrapper):
info['terminal_observation'] = observation
observation = env.reset()
remote.send((observation, reward, done, info))
elif cmd == 'seed':
remote.send(env.seed(data))
elif cmd == 'reset':
observation = env.reset()
remote.send(observation)
elif cmd == 'render':
remote.send(env.render(*data[0], **data[1]))
remote.send(env.render(data))
elif cmd == 'close':
remote.close()
break
Expand All @@ -39,7 +41,7 @@ def _worker(remote, parent_remote, env_fn_wrapper):
elif cmd == 'set_attr':
remote.send(setattr(env, data[0], data[1]))
else:
raise NotImplementedError
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
break

Expand Down Expand Up @@ -129,11 +131,11 @@ def close(self):
process.join()
self.closed = True

def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]:
def get_images(self) -> Sequence[np.ndarray]:
for pipe in self.remotes:
# gather images from subprocesses
# `mode` will be taken into account later
pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs})))
pipe.send(('render', 'rgb_array'))
imgs = [pipe.recv() for pipe in self.remotes]
return imgs

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.6.0
0.7.0a0
14 changes: 12 additions & 2 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def _choose_next_state(self):
self.state = self.observation_space.sample()

def render(self, mode='human'):
if mode == 'rgb_array':
return np.zeros((4, 4, 3))

def seed(self, seed=None):
pass

@staticmethod
Expand Down Expand Up @@ -71,6 +75,11 @@ def make_env():
else:
vec_env = vec_env_wrapper(vec_env)

# Test seed method
vec_env.seed(0)
# Test render method call
# vec_env.render() # we need a X server to test the "human" mode
vec_env.render(mode='rgb_array')
env_method_results = vec_env.env_method('custom_method', 1, indices=None, dim_1=2)
setattr_results = []
# Set current_step to an arbitrary value
Expand Down Expand Up @@ -271,9 +280,10 @@ def obs_assert(obs):
def test_subproc_start_method():
start_methods = [None]
# Only test thread-safe methods. Others may deadlock tests! (gh/428)
safe_methods = {'forkserver', 'spawn'}
# Note: adding unsafe `fork` method as we are now using PyTorch
all_methods = {'forkserver', 'spawn', 'fork'}
available_methods = multiprocessing.get_all_start_methods()
start_methods += list(safe_methods.intersection(available_methods))
start_methods += list(all_methods.intersection(available_methods))
space = gym.spaces.Discrete(2)

def obs_assert(obs):
Expand Down

0 comments on commit 0fa3733

Please sign in to comment.