Skip to content

Commit

Permalink
Add support for multidimensional spaces.MultiBinary observations (#…
Browse files Browse the repository at this point in the history
…1179)

* Fix `get_obs_shape` for multidimensi onnal Multibinary space

* Update changelog

* more tests

* fix multidiscrete one-hot encoding

* refactor tests

* Update changelog.rst

* Update changelog.rst

* batched obs and revert preprocess_obs changes

* Add support for multidimensional ``spaces.MultiBinary`` observations

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
3 people committed Dec 8, 2022
1 parent 6763a86 commit e39bc3d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 7 deletions.
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added ``with_bias`` argument to ``create_mlp``
- Added support for multidimensional ``spaces.MultiBinary`` observations

SB3-Contrib
^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
- Fixed return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
- Fixed type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
- Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde)
- Fix type annotation of ``model`` in ``evaluate_policy``
Expand Down
5 changes: 4 additions & 1 deletion stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ def get_obs_shape(
return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return (int(observation_space.n),)
if type(observation_space.n) in [tuple, list, np.ndarray]:
return tuple(observation_space.n)
else:
return (int(observation_space.n),)
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}

Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,14 @@ def is_vectorized_multibinary_observation(observation: np.ndarray, observation_s
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == (observation_space.n,):
if observation.shape == observation_space.shape:
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use ({observation_space.n},) or "
+ f"environment, please use {observation_space.shape} or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)

Expand Down
66 changes: 66 additions & 0 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
from gym import spaces

from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs


def test_get_obs_shape_discrete():
assert get_obs_shape(spaces.Discrete(3)) == (1,)


def test_get_obs_shape_multidiscrete():
assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,)


def test_get_obs_shape_multibinary():
assert get_obs_shape(spaces.MultiBinary(3)) == (3,)


def test_get_obs_shape_multidimensional_multibinary():
assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2)


def test_get_obs_shape_box():
assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,)


def test_get_obs_shape_multidimensional_box():
assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2)


def test_preprocess_obs_discrete():
actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3))
expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)


def test_preprocess_obs_multidiscrete():
actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2]))
expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)


def test_preprocess_obs_multibinary():
actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3))
expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)


def test_preprocess_obs_multidimensional_multibinary():
actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2]))
expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)


def test_preprocess_obs_box():
actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,)))
expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)


def test_preprocess_obs_multidimensional_box():
actual = preprocess_obs(
torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2))
)
expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
2 changes: 1 addition & 1 deletion tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def step(self, action):


@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))])
def test_identity_spaces(model_class, env):
"""
Additional tests for DQ/SAC/TD3 to check observation space support
Expand Down

0 comments on commit e39bc3d

Please sign in to comment.