Skip to content

Commit

Permalink
Fix: Reshape action in DictRolloutBuffer (#1395)
Browse files Browse the repository at this point in the history
* reshape action in DictRolloutBuffer

* improve buffer test

* update changelog

* add comment

* Update comments and version

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
younik and araffin committed Mar 29, 2023
1 parent b6aa507 commit a60b017
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.8.0a11 (WIP)
Release 1.8.0a12 (WIP)
--------------------------

.. warning::
Expand Down Expand Up @@ -46,6 +46,7 @@ Bug Fixes:
- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio)
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
- Fixed loading of normalized image-based environments
- Fixed `DictRolloutBuffer.add` with multidimensional action space (@younik)

Deprecations:
^^^^^^^^^^^^^
Expand Down
9 changes: 6 additions & 3 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def add(
obs = obs.reshape((self.n_envs, *self.obs_shape))
next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))

# Same, for actions
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

# Copy to avoid modification by reference
Expand Down Expand Up @@ -430,7 +430,7 @@ def add(
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))

# Same reshape, for actions
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.observations[self.pos] = np.array(obs).copy()
Expand Down Expand Up @@ -588,7 +588,7 @@ def add(
next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()

# Same reshape, for actions
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos] = np.array(action).copy()
Expand Down Expand Up @@ -741,6 +741,9 @@ def add(
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = obs_

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0a11
1.8.0a12
3 changes: 2 additions & 1 deletion tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class DummyDictEnv(gym.Env):
"""

def __init__(self):
self.action_space = spaces.Box(1, 5, (1,))
# Test for multi-dim action space
self.action_space = spaces.Box(1, 5, shape=(10, 7))
space = spaces.Box(1, 5, (1,))
self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space})
self._observations = [1, 2, 3, 4, 5]
Expand Down

0 comments on commit a60b017

Please sign in to comment.