Skip to content

Commit

Permalink
Gym Env Checker (#615)
Browse files Browse the repository at this point in the history
* Add Gym Env checker

* Test common failures

* Declare param as unused

* Update tests/test_envs.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update docs/guide/rl_tips.rst

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update docs/guide/rl_tips.rst

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Split checks

* Split tests

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Update stable_baselines/common/env_checker.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Reformat files
  • Loading branch information
araffin committed Dec 16, 2019
1 parent ea93850 commit ba51e25
Show file tree
Hide file tree
Showing 10 changed files with 404 additions and 5 deletions.
10 changes: 10 additions & 0 deletions .github/ISSUE_TEMPLATE/issue-template.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ If you have any questions, feel free to create an issue with the tag [question].
If you wish to suggest an enhancement or feature request, add the tag [feature request].
If you are submitting a bug report, please fill in the following details.

If your issue is related to a custom gym environment, please check it first using:

```python
from stable_baselines.common.env_checker import check_env

env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```

**Describe the bug**
A clear and concise description of what the bug is.

Expand Down
7 changes: 7 additions & 0 deletions docs/common/env_checker.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _env_checker:

Gym Environment Checker
========================

.. automodule:: stable_baselines.common.env_checker
:members:
18 changes: 15 additions & 3 deletions docs/guide/custom_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,22 @@ Then you can define and train a RL agent with:

.. code-block:: python
# Instantiate and wrap the env
env = DummyVecEnv([lambda: CustomEnv(arg1, ...)])
# Instantiate the env
env = CustomEnv(arg1, ...)
# Define and Train the agent
model = A2C(CnnPolicy, env).learn(total_timesteps=1000)
model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
To check that your environment follows the gym interface, please use:

.. code-block:: python
from stable_baselines.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for
Expand Down
13 changes: 12 additions & 1 deletion docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,18 @@ Some basic advice:
- debug with random actions to check that your environment works and follows the gym interface:


Here is a code snippet to check that your environment runs without error.
We provide a helper to check that your environment runs without error:

.. code-block:: python
from stable_baselines.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
If you want to quickly try a random agent on your environment, you can also do:

.. code-block:: python
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
common/cmd_utils
common/schedules
common/evaluation
common/env_checker

.. toctree::
:maxdepth: 1
Expand Down
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ New Features:
- Add type checking and PEP 561 compliance.
Note: most functions are still not annotated, this will be a gradual process.
- DDPG, TD3 and SAC accept non-symmetric action spaces. (@Antymon)
- Add `check_env` util to check if a custom environment follows the gym interface (@araffin and @justinkterry)

Bug Fixes:
^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,8 @@ Piecewise
csv
nvidia
visdom
tensorboard
preprocessed
namespace
sklearn
GoalEnv
222 changes: 222 additions & 0 deletions stable_baselines/common/env_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import warnings
from typing import Union

import gym
from gym import spaces
import numpy as np

from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan


def _enforce_array_obs(observation_space: spaces.Space) -> bool:
"""
Whether to check that the returned observation is a numpy array
it is not mandatory for `Dict` and `Tuple` spaces.
"""
return not isinstance(observation_space, (spaces.Dict, spaces.Tuple))


def _check_image_input(observation_space: spaces.Box) -> None:
"""
Check that the input will be compatible with Stable-Baselines
when the observation is apparently an image.
"""
if observation_space.dtype != np.uint8:
warnings.warn("It seems that your observation is an image but the `dtype` "
"of your observation_space is not `np.uint8`. "
"If your observation is not an image, we recommend you to flatten the observation "
"to have only a 1D vector")

if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
warnings.warn("It seems that your observation space is an image but the "
"upper and lower bounds are not in [0, 255]. "
"Because the CNN policy normalize automatically the observation "
"you may encounter issue if the values are not in that range."
)

if observation_space.shape[0] < 36 or observation_space.shape[1] < 36:
warnings.warn("The minimal resolution for an image is 36x36 for the default CnnPolicy. "
"You might need to use a custom `cnn_extractor` "
"cf https://stable-baselines.readthedocs.io/en/master/guide/custom_policy.html")


def _check_unsupported_obs_spaces(env: gym.Env, observation_space: spaces.Space) -> None:
"""Emit warnings when the observation space used is not supported by Stable-Baselines."""

if isinstance(observation_space, spaces.Dict) and not isinstance(env, gym.GoalEnv):
warnings.warn("The observation space is a Dict but the environment is not a gym.GoalEnv "
"(cf https://github.com/openai/gym/blob/master/gym/core.py), "
"this is currently not supported by Stable Baselines "
"(cf https://github.com/hill-a/stable-baselines/issues/133), "
"you will need to use a custom policy. "
)

if isinstance(observation_space, spaces.Tuple):
warnings.warn("The observation space is a Tuple,"
"this is currently not supported by Stable Baselines "
"(cf https://github.com/hill-a/stable-baselines/issues/133), "
"you will need to flatten the observation and maybe use a custom policy. "
)


def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
for _ in range(10):
action = [env.action_space.sample()]
_, _, _, _ = vec_env.step(action)


def _check_obs(obs: Union[tuple, dict, np.ndarray, int],
observation_space: spaces.Space,
method_name: str) -> None:
"""
Check that the observation returned by the environment
correspond to the declared one.
"""
if not isinstance(observation_space, spaces.Tuple):
assert not isinstance(obs, tuple), ("The observation returned by the `{}()` "
"method should be a single value, not a tuple".format(method_name))

# The check for a GoalEnv is done by the base class
if isinstance(observation_space, spaces.Discrete):
assert isinstance(obs, int), "The observation returned by `{}()` method must be an int".format(method_name)
elif _enforce_array_obs(observation_space):
assert isinstance(obs, np.ndarray), ("The observation returned by `{}()` "
"method must be a numpy array".format(method_name))

assert observation_space.contains(obs), ("The observation returned by the `{}()` "
"method does not match the given observation space".format(method_name))


def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Check the returned values by the env when calling `.reset()` or `.step()` methods.
"""
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
obs = env.reset()

_check_obs(obs, observation_space, 'reset')

# Sample a random action
action = action_space.sample()
data = env.step(action)

assert len(data) == 4, "The `step()` method must return four values: obs, reward, done, info"

# Unpack
obs, reward, done, info = data

_check_obs(obs, observation_space, 'step')

# We also allow int because the reward will be cast to float
assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float"
assert isinstance(done, bool), "The `done` signal must be a boolean"
assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary"

if isinstance(env, gym.GoalEnv):
# For a GoalEnv, the keys are checked at reset
assert reward == env.compute_reward(obs['achieved_goal'], obs['desired_goal'], info)


def _check_spaces(env: gym.Env) -> None:
"""
Check that the observation and action spaces are defined
and inherit from gym.spaces.Space.
"""
# Helper to link to the code, because gym has no proper documentation
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"

assert hasattr(env, 'observation_space'), "You must specify an observation space (cf gym.spaces)" + gym_spaces
assert hasattr(env, 'action_space'), "You must specify an action space (cf gym.spaces)" + gym_spaces

assert isinstance(env.observation_space,
spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces
assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces


def _check_render(env: gym.Env, warn=True, headless=False) -> None:
"""
Check the declared render modes and the `render()`/`close()`
method of the environment.
:param env: (gym.Env) The environment to check
:param warn: (bool) Whether to output additional warnings
:param headless: (bool) Whether to disable render modes
that require a graphical interface. False by default.
"""
render_modes = env.metadata.get('render.modes')
if render_modes is None:
if warn:
warnings.warn("No render modes was declared in the environment "
" (env.metadata['render.modes'] is None or not defined), "
"you may have trouble when calling `.render()`")

else:
# Don't check render mode that require a
# graphical interface (useful for CI)
if headless and 'human' in render_modes:
render_modes.remove('human')
# Check all declared render modes
for render_mode in render_modes:
env.render(mode=render_mode)
env.close()


def check_env(env: gym.Env, warn=True, skip_render_check=True) -> None:
"""
Check that an environment follows Gym API.
This is particularly useful when using a custom environment.
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
for more information about the API.
It also optionally check that the environment is compatible with Stable-Baselines.
:param env: (gym.Env) The Gym environment that will be checked
:param warn: (bool) Whether to output additional warnings
mainly related to the interaction with Stable Baselines
:param skip_render_check: (bool) Whether to skip the checks for the render method.
True by default (useful for the CI)
"""
assert isinstance(env, gym.Env), ("You environment must inherit from gym.Env class "
" cf https://github.com/openai/gym/blob/master/gym/core.py")

# ============= Check the spaces (observation and action) ================
_check_spaces(env)

# Define aliases for convenience
observation_space = env.observation_space
action_space = env.action_space

# Warn the user if needed.
# A warning means that the environment may run but not work properly with Stable Baselines algorithms
if warn:
_check_unsupported_obs_spaces(env, observation_space)

# If image, check the low and high values, the type and the number of channels
# and the shape (minimal value)
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
_check_image_input(observation_space)

if isinstance(observation_space, spaces.Box) and len(observation_space.shape) not in [1, 3]:
warnings.warn("Your observation has an unconventional shape (neither an image, nor a 1D vector). "
"We recommend you to flatten the observation "
"to have only a 1D vector")

# Check for the action space, it may lead to hard-to-debug issues
if (isinstance(action_space, spaces.Box) and
(np.abs(action_space.low) != np.abs(action_space.high)
or np.abs(action_space.low) > 1 or np.abs(action_space.high) > 1)):
warnings.warn("We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
"cf https://stable-baselines.readthedocs.io/en/master/guide/rl_tips.html")

# ============ Check the returned values ===============
_check_returned_values(env, observation_space, action_space)

# ==== Check the render method and the declared render modes ====
if not skip_render_check:
_check_render(env, warn=warn)

# The check only works with numpy arrays
if _enforce_array_obs(observation_space):
_check_nan(env)
2 changes: 1 addition & 1 deletion stable_baselines/common/identity_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class IdentityEnv(Env):
def __init__(self, dim, ep_length=100):
def __init__(self, dim=1, ep_length=100):
"""
Identity environment for testing purposes
Expand Down

0 comments on commit ba51e25

Please sign in to comment.