Skip to content

Commit

Permalink
Fix support of image like normalized inputs (#1214)
Browse files Browse the repository at this point in the history
* Fix support of image like normalized inputs

* Improve docstring and warning message.

* Don't check if obs is image when normalize_images is False (lil opt)

* Comment fix

* Fix normalize_images not passed to parent

* Check for subclasses too

* Remove useless multiline

* Update version and add comment

* Fix some typos

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
araffin and qgallouedec committed Dec 20, 2022
1 parent ca944fe commit 8452106
Show file tree
Hide file tree
Showing 20 changed files with 163 additions and 37 deletions.
2 changes: 1 addition & 1 deletion docs/common/logger.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ train/
- ``loss``: Current total loss value
- ``n_updates``: Number of gradient updates applied so far
- ``policy_gradient_loss``: Current value of the policy gradient loss (its value does not have much meaning)
- ``value_loss``: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carle estimate (or TD(lambda) estimate)
- ``value_loss``: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carlo estimate (or TD(lambda) estimate)
- ``std``: Current standard deviation of the noise when using generalized State-Dependent Exploration (gSDE)


Expand Down
10 changes: 7 additions & 3 deletions docs/guide/custom_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ That is to say, your environment must implement the following methods (and inher


.. note::
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255]
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either
channel-first or channel-last.
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
Images can be either channel-first or channel-last.

If you want to use ``CnnPolicy`` or ``MultiInputPolicy`` with image-like observation (3D tensor) that are already normalized, you must pass ``normalize_images=False``
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
and make sure your image is in the **channel-first** format.


.. note::
Expand Down
6 changes: 3 additions & 3 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
super(CustomCNN, self).__init__(observation_space, features_dim)
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
Expand Down Expand Up @@ -201,7 +201,7 @@ downsampling and "vector" with a single linear layer.
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.__init__ before adding modules
super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1)
super().__init__(observation_space, features_dim=1)
extractors = {}
Expand Down Expand Up @@ -374,7 +374,7 @@ If your task requires even more granular control over the policy/value architect
**kwargs,
):
super(CustomActorCriticPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ You can use environments with dictionary observation spaces. This is useful in t
concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles).
Stable Baselines3 provides ``SimpleMultiObsEnv`` as an example of this kind of of setting.
The environment is a simple grid world but the observations for each cell come in the form of dictionaries.
These dictionaries are randomly initilaized on the creation of the environment and contain a vector observation and an image observation.
These dictionaries are randomly initialized on the creation of the environment and contain a vector observation and an image observation.

.. code-block:: python
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/iss

The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``).

For PPO, assuming a shared feature extactor.
For PPO, assuming a shared feature extractor.

.. warning::

Expand Down
9 changes: 8 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.7.0a6 (WIP)
Release 1.7.0a7 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -13,12 +13,16 @@ Breaking Changes:
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
- Removed ``ret`` attributes in ``VecNormalize``, please use ``returns`` instead
- ``VecNormalize`` now updates the observation space when normalizing images

New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added ``with_bias`` argument to ``create_mlp``
- Added support for multidimensional ``spaces.MultiBinary`` observations
- Features extractors now properly support unnormalized image-like observations (3D tensor)
when passing ``normalize_images=False``
- Added ``normalized_image`` parameter to ``NatureCNN`` and ``CombinedExtractor``

SB3-Contrib
^^^^^^^^^^^
Expand All @@ -31,6 +35,8 @@ Bug Fixes:
- Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde)
- Fix type annotation of ``model`` in ``evaluate_policy``
- Fixed ``Self`` return type using ``TypeVar``
- Fixed the env checker, the key was not passed when checking images from Dict observation space
- Fixed ``normalize_images`` which was not passed to parent class in some cases

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -58,6 +64,7 @@ Documentation:
- Changed ``env`` to ``vec_env`` when environment is vectorized
- Updated custom policy docs to better explain the ``mlp_extractor``'s dimensions (@AlexPasqua)
- Update custom policy documentation (@athatheo)
- Clarify doc when using image-like input

Release 1.6.2 (2022-10-10)
--------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/misc/projects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Driving policies can be trained in different scenarios, and several notebooks us
tactile-gym
-------------------

Suite of RL environments focussed on using a simulated tactile sensor as the primary source of observations. Sim-to-Real results across 4 out of 5 proposed envs.
Suite of RL environments focused on using a simulated tactile sensor as the primary source of observations. Sim-to-Real results across 4 out of 5 proposed envs.

| Author: Alex Church
| GitHub: https://github.com/ac-93/tactile_gym
Expand Down
10 changes: 7 additions & 3 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
when the observation is apparently an image.
:param observation_space: Observation space
:key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if observation_space.dtype != np.uint8:
warnings.warn(
f"It seems that your observation {key} is an image but the `dtype` "
"of your observation_space is not `np.uint8`. "
f"It seems that your observation {key} is an image but its `dtype` "
f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. "
"If your observation is not an image, we recommend you to flatten the observation "
"to have only a 1D vector"
)
Expand Down Expand Up @@ -180,7 +184,7 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
# If image, check the low and high values, the type and the number of channels
# and the shape (minimal value)
if len(observation_space.shape) == 3:
_check_image_input(observation_space)
_check_image_input(observation_space, key)

if len(observation_space.shape) not in [1, 3]:
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str):

class HParam:
"""
Hyperparameter data class storing hyperparameters and metrics in dictionnaries
Hyperparameter data class storing hyperparameters and metrics in dictionaries
:param hparam_dict: key-value pairs of hyperparameters to log
:param metric_dict: key-value pairs of metrics to log
Expand Down
5 changes: 4 additions & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __init__(

self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
# Automatically deactivate dtype and bounds checks
if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
self.features_extractor_kwargs.update(dict(normalized_image=True))

def _update_features_extractor(
self,
Expand Down Expand Up @@ -430,6 +433,7 @@ def __init__(
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=squash_output,
normalize_images=normalize_images,
)

# Default network architecture, from stable-baselines
Expand All @@ -446,7 +450,6 @@ def __init__(
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim

self.normalize_images = normalize_images
self.log_std_init = log_std_init
dist_kwargs = None
# Keyword arguments for gSDE distribution
Expand Down
15 changes: 11 additions & 4 deletions stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
def is_image_space(
observation_space: spaces.Space,
check_channels: bool = False,
normalized_image: bool = False,
) -> bool:
"""
Check if a observation space has the shape, limits and dtype
Expand All @@ -38,15 +39,21 @@ def is_image_space(
:param observation_space:
:param check_channels: Whether to do or not the check for the number of channels.
e.g., with frame-stacking, the observation space may have more channels than expected.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
:return:
"""
check_dtype = check_bounds = not normalized_image
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
# Check the type
if observation_space.dtype != np.uint8:
if check_dtype and observation_space.dtype != np.uint8:
return False

# Check the value range
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255)
if check_bounds and incorrect_bounds:
return False

# Skip channels check
Expand All @@ -57,7 +64,7 @@ def is_image_space(
n_channels = observation_space.shape[0]
else:
n_channels = observation_space.shape[-1]
# RGB, RGBD, GrayScale
# GrayScale, RGB, RGBD
return n_channels in [1, 3, 4]
return False

Expand Down Expand Up @@ -99,7 +106,7 @@ def preprocess_obs(
:return:
"""
if isinstance(observation_space, spaces.Box):
if is_image_space(observation_space) and normalize_images:
if normalize_images and is_image_space(observation_space):
return obs.float() / 255.0
return obs.float()

Expand Down
39 changes: 30 additions & 9 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BaseFeaturesExtractor(nn.Module):
:param features_dim: Number of features extracted.
"""

def __init__(self, observation_space: gym.Space, features_dim: int = 0):
def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
super().__init__()
assert features_dim > 0
self._observation_space = observation_space
Expand All @@ -37,7 +37,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
:param observation_space:
"""

def __init__(self, observation_space: gym.Space):
def __init__(self, observation_space: gym.Space) -> None:
super().__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()

Expand All @@ -55,19 +55,31 @@ class NatureCNN(BaseFeaturesExtractor):
:param observation_space:
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""

def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
def __init__(
self,
observation_space: gym.spaces.Box,
features_dim: int = 512,
normalized_image: bool = False,
) -> None:
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space, check_channels=False), (
assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
"You should use NatureCNN "
f"only with images not with {observation_space}\n"
"(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
"If you are using a custom environment,\n"
"please check it using our env checker:\n"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n"
"If you are using `VecNormalize` or already normalized channel-first images "
"you should pass `normalize_images=False`: \n"
"https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
Expand Down Expand Up @@ -167,7 +179,7 @@ def __init__(
net_arch: List[Union[int, Dict[str, List[int]]]],
activation_fn: Type[nn.Module],
device: Union[th.device, str] = "auto",
):
) -> None:
super().__init__()
device = get_device(device)
shared_net: List[nn.Module] = []
Expand Down Expand Up @@ -247,18 +259,27 @@ class CombinedExtractor(BaseFeaturesExtractor):
:param observation_space:
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
256 to avoid exploding network sizes.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""

def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
def __init__(
self,
observation_space: gym.spaces.Dict,
cnn_output_dim: int = 256,
normalized_image: bool = False,
) -> None:
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)

extractors: Dict[str, nn.Module] = {}

total_concat_size = 0
for key, subspace in observation_space.spaces.items():
if is_image_space(subspace):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim)
if is_image_space(subspace, normalized_image=normalized_image):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image)
total_concat_size += cnn_output_dim
else:
# The observation key is a vector, flatten it if needed
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/stacked_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(

def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict:
"""
Returns the stacked verson of a Dict observation space
Returns the stacked version of a Dict observation space
:param observation_space: Dict observation space to stack
:return: stacked observation space
Expand Down
27 changes: 27 additions & 0 deletions stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from stable_baselines3.common import utils
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper

Expand Down Expand Up @@ -50,9 +51,35 @@ def __init__(
if isinstance(self.observation_space, gym.spaces.Dict):
self.obs_spaces = self.observation_space.spaces
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
# Update observation space when using image
# See explanation below and GH #1214
for key in self.obs_rms.keys():
if is_image_space(self.obs_spaces[key]):
self.observation_space.spaces[key] = gym.spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.obs_spaces[key].shape,
dtype=np.float32,
)

else:
self.obs_spaces = None
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
# Update observation space when using image
# See GH #1214
# This is to raise proper error when
# VecNormalize is used with an image-like input and
# normalize_images=True.
# For correctness, we should also update the bounds
# in other cases but this will cause backward-incompatible change
# and break already saved policies.
if is_image_space(self.observation_space):
self.observation_space = gym.spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.observation_space.shape,
dtype=np.float32,
)

self.ret_rms = RunningMeanStd(shape=())
self.clip_obs = clip_obs
Expand Down

0 comments on commit 8452106

Please sign in to comment.