Skip to content

Commit

Permalink
Add 'terminal_observation' info key to VecEnv objects (#412)
Browse files Browse the repository at this point in the history
* Add 'terminal_observation' info key to vecenvs

Fixes #400

* Review changes to terminal_observation fix

* Fix typo
  • Loading branch information
qxcv authored and AdamGleave committed Jul 22, 2019
1 parent 5cbdd31 commit edfc767
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
.. note::

When using vectorized environments, the environments are automatically reset at the end of each episode.
Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.

.. warning::

Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ New Features:
policy in addition to the existing support for categorical stochastic policies.
- Add flag to `action_probability` to return log-probabilities.
- Added support for python lists and numpy arrays in ``logger.writekvs``. (@dwiel)
- The info dicts returned by VecEnvs now include a ``terminal_observation`` key providing access to the last observation in a trajectory. (@qxcv)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -408,4 +409,4 @@ In random order...
Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass
@Miffyli @dwiel @miguelrass @qxcv
2 changes: 2 additions & 0 deletions stable_baselines/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def step_wait(self):
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\
self.envs[env_idx].step(self.actions[env_idx])
if self.buf_dones[env_idx]:
# save final observation where user can get it, then reset
self.buf_infos[env_idx]['terminal_observation'] = obs
obs = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
Expand Down
2 changes: 2 additions & 0 deletions stable_baselines/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def _worker(remote, parent_remote, env_fn_wrapper):
if cmd == 'step':
observation, reward, done, info = env.step(data)
if done:
# save final observation where user can get it, then reset
info['terminal_observation'] = observation
observation = env.reset()
remote.send((observation, reward, done, info))
elif cmd == 'reset':
Expand Down
15 changes: 13 additions & 2 deletions stable_baselines/common/vec_env/vec_frame_stack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
from gym import spaces

Expand All @@ -11,7 +13,7 @@ class VecFrameStack(VecEnvWrapper):
:param venv: (VecEnv) the vectorized environment to wrap
:param n_stack: (int) Number of frames to stack
"""

def __init__(self, venv, n_stack):
self.venv = venv
self.n_stack = n_stack
Expand All @@ -24,9 +26,18 @@ def __init__(self, venv, n_stack):

def step_wait(self):
observations, rewards, dones, infos = self.venv.step_wait()
self.stackedobs = np.roll(self.stackedobs, shift=-observations.shape[-1], axis=-1)
last_ax_size = observations.shape[-1]
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
for i, done in enumerate(dones):
if done:
if 'terminal_observation' in infos[i]:
old_terminal = infos[i]['terminal_observation']
new_terminal = np.concatenate(
(self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
infos[i]['terminal_observation'] = new_terminal
else:
warnings.warn(
"VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stackedobs[i] = 0
self.stackedobs[..., -observations.shape[-1]:] = observations
return self.stackedobs, rewards, dones, infos
Expand Down
66 changes: 66 additions & 0 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,72 @@ def make_env():
vec_env.close()


class StepEnv(gym.Env):
def __init__(self, max_steps):
"""Gym environment for testing that terminal observation is inserted
correctly."""
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]),
dtype='int')
self.max_steps = max_steps
self.current_step = 0

def reset(self):
self.current_step = 0
return np.array([self.current_step], dtype='int')

def step(self, action):
prev_step = self.current_step
self.current_step += 1
done = self.current_step >= self.max_steps
return np.array([prev_step], dtype='int'), 0.0, done, {}


@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
"""Test that 'terminal_observation' gets added to info dict upon
termination."""
step_nums = [i + 5 for i in range(N_ENVS)]
vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums])

if vec_env_wrapper is not None:
if vec_env_wrapper == VecFrameStack:
vec_env = vec_env_wrapper(vec_env, n_stack=2)
else:
vec_env = vec_env_wrapper(vec_env)

zero_acts = np.zeros((N_ENVS,), dtype='int')
prev_obs_b = vec_env.reset()
for step_num in range(1, max(step_nums) + 1):
obs_b, _, done_b, info_b = vec_env.step(zero_acts)
assert len(obs_b) == N_ENVS
assert len(done_b) == N_ENVS
assert len(info_b) == N_ENVS
env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums)
for prev_obs, obs, done, info, final_step_num in env_iter:
assert done == (step_num == final_step_num)
if not done:
assert 'terminal_observation' not in info
else:
terminal_obs = info['terminal_observation']

# do some rough ordering checks that should work for all
# wrappers, including VecNormalize
assert np.all(prev_obs < terminal_obs)
assert np.all(obs < prev_obs)

if not isinstance(vec_env, VecNormalize):
# more precise tests that we can't do with VecNormalize
# (which changes observation values)
assert np.all(prev_obs + 1 == terminal_obs)
assert np.all(obs == 0)

prev_obs_b = obs_b

vec_env.close()


SPACES = collections.OrderedDict([
('discrete', gym.spaces.Discrete(2)),
('multidiscrete', gym.spaces.MultiDiscrete([2, 3])),
Expand Down

0 comments on commit edfc767

Please sign in to comment.