Skip to content

Commit

Permalink
Standardize the use of from gym import spaces (#1240)
Browse files Browse the repository at this point in the history
* generalize the use of `from gym import spaces`

* command line get system info

* Documentation line length for doc

* update changelog

* add space before os plateform to avoid ref to other issue

* format

* get_system_info update in changelog

* fix type check error

* fix get system info

* add comment about regex

* update version
  • Loading branch information
qgallouedec committed Jan 2, 2023
1 parent 2bb8ef5 commit 4fa17dc
Show file tree
Hide file tree
Showing 34 changed files with 219 additions and 196 deletions.
5 changes: 2 additions & 3 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ body:
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms
Expand Down
30 changes: 15 additions & 15 deletions .github/ISSUE_TEMPLATE/custom_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ body:
```python
import gym
import numpy as np
from gym import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
class CustomEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))
def __init__(self):
super().__init__()
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
def reset(self):
return self.observation_space.sample()
def reset(self):
return self.observation_space.sample()
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
done = False
info = {}
return obs, reward, done, info
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
done = False
info = {}
return obs, reward, done, info
env = CustomEnv()
check_env(env)
Expand Down Expand Up @@ -86,9 +87,8 @@ body:
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms
Expand Down
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ pip install -e .[docs,tests,extra]

## Codestyle

We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
For the documentation, we use the default line length of 88 characters per line.

**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.

Expand Down
12 changes: 9 additions & 3 deletions docs/guide/custom_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ That is to say, your environment must implement the following methods (and inher
.. code-block:: python
import gym
import numpy as np
from gym import spaces
class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface"""
"""Custom Environment that follows gym interface."""
metadata = {"render.modes": ["human"]}
def __init__(self, arg1, arg2, ...):
super(CustomEnv, self).__init__()
super().__init__()
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
Expand All @@ -46,12 +49,15 @@ That is to say, your environment must implement the following methods (and inher
def step(self, action):
...
return observation, reward, done, info
def reset(self):
...
return observation # reward, done, info can't be included
def render(self, mode="human"):
...
def close (self):
def close(self):
...
Expand Down
12 changes: 6 additions & 6 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t

.. code-block:: python
import gym
import torch as th
import torch.nn as nn
from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
Expand All @@ -140,7 +140,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
Expand Down Expand Up @@ -199,7 +199,7 @@ downsampling and "vector" with a single linear layer.
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCombinedExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Dict):
def __init__(self, observation_space: spaces.Dict):
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.__init__ before adding modules
Expand Down Expand Up @@ -310,7 +310,7 @@ If your task requires even more granular control over the policy/value architect
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import gym
from gym import spaces
import torch as th
from torch import nn
Expand Down Expand Up @@ -367,8 +367,8 @@ If your task requires even more granular control over the policy/value architect
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
Expand Down
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.7.0a10 (WIP)
Release 1.7.0a11 (WIP)
--------------------------

.. note::
Expand Down Expand Up @@ -71,6 +71,8 @@ Others:
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Set tensors construction directly on the device (~8% speed boost on GPU)
- Monkey-patched ``np.bool = bool`` so gym 0.21 is compatible with NumPy 1.24+
- Standardized the use of ``from gym import spaces``
- Modified ``get_system_info`` to avoid issue linked to copy-pasting on GitHub issue

Documentation:
^^^^^^^^^^^^^^
Expand Down
15 changes: 8 additions & 7 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import gym
import numpy as np
import torch as th
from gym import spaces

from stable_baselines3.common import utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
Expand Down Expand Up @@ -101,7 +102,7 @@ def __init__(
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
Expand All @@ -117,8 +118,8 @@ def __init__(
self._vec_normalize_env = unwrap_vec_normalize(env)
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space = None # type: Optional[gym.spaces.Space]
self.action_space = None # type: Optional[gym.spaces.Space]
self.observation_space = None # type: Optional[spaces.Space]
self.action_space = None # type: Optional[spaces.Space]
self.n_envs = None
self.num_timesteps = 0
# Used for updating schedules
Expand Down Expand Up @@ -175,13 +176,13 @@ def __init__(
)

# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, spaces.Dict):
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")

if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
if self.use_sde and not isinstance(self.action_space, spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")

if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
assert np.all(
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
Expand Down Expand Up @@ -212,7 +213,7 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve

if not is_vecenv_wrapped(env, VecTransposeImage):
wrap_with_vectranspose = False
if isinstance(env.observation_space, gym.spaces.Dict):
if isinstance(env.observation_space, spaces.Dict):
# If even one of the keys is a image-space in need of transpose, apply transpose
# If the image spaces are not consistent (for instance one is channel first,
# the other channel last), VecTransposeImage will throw an error
Expand Down
3 changes: 1 addition & 2 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union

import gym
import numpy as np
import torch as th
from gym import spaces
Expand Down Expand Up @@ -659,7 +658,7 @@ def log_prob_correction(self, x: th.Tensor) -> th.Tensor:


def make_proba_distribution(
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action

def _check_spaces(env: gym.Env) -> None:
"""
Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For
Check that the observation and action spaces are defined and inherit from spaces.Space. For
envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
the observation space is gym.spaces.Dict
"""
Expand Down
40 changes: 21 additions & 19 deletions stable_baselines3/common/envs/identity_env.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional, Union
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union

import gym
import numpy as np
from gym import Env, Space
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
from gym import spaces

from stable_baselines3.common.type_aliases import GymObs, GymStepReturn

T = TypeVar("T", int, np.ndarray)

class IdentityEnv(Env):
def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100):

class IdentityEnv(gym.Env, Generic[T]):
def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = None, ep_length: int = 100):
"""
Identity environment for testing purposes
Expand All @@ -22,7 +24,7 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_
if space is None:
if dim is None:
dim = 1
space = Discrete(dim)
space = spaces.Discrete(dim)
else:
assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"

Expand All @@ -32,13 +34,13 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_
self.num_resets = -1 # Becomes 0 after __init__ exits.
self.reset()

def reset(self) -> GymObs:
def reset(self) -> T:
self.current_step = 0
self.num_resets += 1
self._choose_next_state()
return self.state

def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
Expand All @@ -48,14 +50,14 @@ def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()

def _get_reward(self, action: Union[int, np.ndarray]) -> float:
def _get_reward(self, action: T) -> float:
return 1.0 if np.all(self.state == action) else 0.0

def render(self, mode: str = "human") -> None:
pass


class IdentityEnvBox(IdentityEnv):
class IdentityEnvBox(IdentityEnv[np.ndarray]):
def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
"""
Identity environment for testing purposes
Expand All @@ -65,7 +67,7 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l
:param eps: the epsilon bound for correct value
:param ep_length: the length of each episode in timesteps
"""
space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32)
super().__init__(ep_length=ep_length, space=space)
self.eps = eps

Expand All @@ -80,31 +82,31 @@ def _get_reward(self, action: np.ndarray) -> float:
return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0


class IdentityEnvMultiDiscrete(IdentityEnv):
class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
def __init__(self, dim: int = 1, ep_length: int = 100):
"""
Identity environment for testing purposes
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = MultiDiscrete([dim, dim])
space = spaces.MultiDiscrete([dim, dim])
super().__init__(ep_length=ep_length, space=space)


class IdentityEnvMultiBinary(IdentityEnv):
class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
def __init__(self, dim: int = 1, ep_length: int = 100):
"""
Identity environment for testing purposes
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = MultiBinary(dim)
space = spaces.MultiBinary(dim)
super().__init__(ep_length=ep_length, space=space)


class FakeImageEnv(Env):
class FakeImageEnv(gym.Env):
"""
Fake image environment for testing purposes, it mimics Atari games.
Expand All @@ -128,11 +130,11 @@ def __init__(
self.observation_shape = (screen_height, screen_width, n_channels)
if channel_first:
self.observation_shape = (n_channels, screen_height, screen_width)
self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
if discrete:
self.action_space = Discrete(action_dim)
self.action_space = spaces.Discrete(action_dim)
else:
self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32)
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
self.ep_length = 10
self.current_step = 0

Expand Down

0 comments on commit 4fa17dc

Please sign in to comment.