-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for multidimensional
spaces.MultiBinary
observations (#…
…1179) * Fix `get_obs_shape` for multidimensi onnal Multibinary space * Update changelog * more tests * fix multidiscrete one-hot encoding * refactor tests * Update changelog.rst * Update changelog.rst * batched obs and revert preprocess_obs changes * Add support for multidimensional ``spaces.MultiBinary`` observations Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
- Loading branch information
1 parent
6763a86
commit e39bc3d
Showing
5 changed files
with
77 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
from gym import spaces | ||
|
||
from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs | ||
|
||
|
||
def test_get_obs_shape_discrete(): | ||
assert get_obs_shape(spaces.Discrete(3)) == (1,) | ||
|
||
|
||
def test_get_obs_shape_multidiscrete(): | ||
assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,) | ||
|
||
|
||
def test_get_obs_shape_multibinary(): | ||
assert get_obs_shape(spaces.MultiBinary(3)) == (3,) | ||
|
||
|
||
def test_get_obs_shape_multidimensional_multibinary(): | ||
assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2) | ||
|
||
|
||
def test_get_obs_shape_box(): | ||
assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,) | ||
|
||
|
||
def test_get_obs_shape_multidimensional_box(): | ||
assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2) | ||
|
||
|
||
def test_preprocess_obs_discrete(): | ||
actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3)) | ||
expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
|
||
def test_preprocess_obs_multidiscrete(): | ||
actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2])) | ||
expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
|
||
def test_preprocess_obs_multibinary(): | ||
actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3)) | ||
expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
|
||
def test_preprocess_obs_multidimensional_multibinary(): | ||
actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2])) | ||
expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
|
||
def test_preprocess_obs_box(): | ||
actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,))) | ||
expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
|
||
def test_preprocess_obs_multidimensional_box(): | ||
actual = preprocess_obs( | ||
torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2)) | ||
) | ||
expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32) | ||
torch.testing.assert_close(actual, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters