Skip to content

Commit

Permalink
Add the argument dtype (default to float32) to the noise (#1301)
Browse files Browse the repository at this point in the history
* Fixed noise to return float32

* Updated changelog

* Fixed test to use numpy arrays instead of python floats

* Sorted imports for tests

* Added dtype to constructor

* Removed dtype parameter for VectorizedActionNoise

* __init__ -> None; Capitalize and period in docstring when needed; fix dtype type hint; dtype in docstring

* fix dtype type hint

* Update version

* Clarify changelog [skip ci]

* empty commit to run ci

* Update docs/misc/changelog.rst

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
3 people committed Feb 7, 2023
1 parent 2e4a450 commit 489b1fd
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio)
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)

Deprecations:
Expand Down
44 changes: 25 additions & 19 deletions stable_baselines3/common/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Iterable, List, Optional

import numpy as np
from numpy.typing import DTypeLike


class ActionNoise(ABC):
Expand All @@ -15,7 +16,7 @@ def __init__(self) -> None:

def reset(self) -> None:
"""
call end of episode reset for the noise
Call end of episode reset for the noise
"""
pass

Expand All @@ -26,19 +27,21 @@ def __call__(self) -> np.ndarray:

class NormalActionNoise(ActionNoise):
"""
A Gaussian action noise
A Gaussian action noise.
:param mean: the mean value of the noise
:param sigma: the scale of the noise (std here)
:param mean: Mean value of the noise
:param sigma: Scale of the noise (std here)
:param dtype: Type of the output noise
"""

def __init__(self, mean: np.ndarray, sigma: np.ndarray):
def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None:
self._mu = mean
self._sigma = sigma
self._dtype = dtype
super().__init__()

def __call__(self) -> np.ndarray:
return np.random.normal(self._mu, self._sigma)
return np.random.normal(self._mu, self._sigma).astype(self._dtype)

def __repr__(self) -> str:
return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
Expand All @@ -50,11 +53,12 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
:param mean: the mean of the noise
:param sigma: the scale of the noise
:param theta: the rate of mean reversion
:param dt: the timestep for the noise
:param initial_noise: the initial value for the noise output, (if None: 0)
:param mean: Mean of the noise
:param sigma: Scale of the noise
:param theta: Rate of mean reversion
:param dt: Timestep for the noise
:param initial_noise: Initial value for the noise output, (if None: 0)
:param dtype: Type of the output noise
"""

def __init__(
Expand All @@ -64,11 +68,13 @@ def __init__(
theta: float = 0.15,
dt: float = 1e-2,
initial_noise: Optional[np.ndarray] = None,
):
dtype: DTypeLike = np.float32,
) -> None:
self._theta = theta
self._mu = mean
self._sigma = sigma
self._dt = dt
self._dtype = dtype
self.initial_noise = initial_noise
self.noise_prev = np.zeros_like(self._mu)
self.reset()
Expand All @@ -81,7 +87,7 @@ def __call__(self) -> np.ndarray:
+ self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
)
self.noise_prev = noise
return noise
return noise.astype(self._dtype)

def reset(self) -> None:
"""
Expand All @@ -97,11 +103,11 @@ class VectorizedActionNoise(ActionNoise):
"""
A Vectorized action noise for parallel environments.
:param base_noise: ActionNoise The noise generator to use
:param n_envs: The number of parallel environments
:param base_noise: Noise generator to use
:param n_envs: Number of parallel environments
"""

def __init__(self, base_noise: ActionNoise, n_envs: int):
def __init__(self, base_noise: ActionNoise, n_envs: int) -> None:
try:
self.n_envs = int(n_envs)
assert self.n_envs > 0
Expand All @@ -113,9 +119,9 @@ def __init__(self, base_noise: ActionNoise, n_envs: int):

def reset(self, indices: Optional[Iterable[int]] = None) -> None:
"""
Reset all the noise processes, or those listed in indices
Reset all the noise processes, or those listed in indices.
:param indices: Optional[Iterable[int]] The indices to reset. Default: None.
:param indices: The indices to reset. Default: None.
If the parameter is None, then all processes are reset to their initial position.
"""
if indices is None:
Expand All @@ -129,7 +135,7 @@ def __repr__(self) -> str:

def __call__(self) -> np.ndarray:
"""
Generate and stack the action noise from each noise object
Generate and stack the action noise from each noise object.
"""
noise = np.stack([noise() for noise in self.noises])
return noise
Expand Down
5 changes: 4 additions & 1 deletion tests/test_deterministic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
Expand All @@ -15,7 +16,9 @@ def test_deterministic_training_common(algo):
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v1"
if algo in [TD3, SAC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
kwargs.update(
{"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4}
)
else:
if algo == DQN:
env_id = "CartPole-v1"
Expand Down

0 comments on commit 489b1fd

Please sign in to comment.