Skip to content

Commit

Permalink
Implemented Vectorized Action Noise (#34)
Browse files Browse the repository at this point in the history
* Implemented Vectorized Action Noise

Vectorized Action Noise allows for multiple instances of
ActionNoiseProcesses to run in parallel. This makes it easier to
run TD3/SAC/DDPG with VecEnv.

* fixed linting issues

* make test function name consistent

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* sanity checks and more detailed test

* Update stable_baselines3/common/noise.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Added assertion error message in noises setter

* Corrected tests to reflect change to AssertionError from ValueError

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
m-rph and araffin committed May 27, 2020
1 parent 9b42b97 commit 78e8d40
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ New Features:
- Added ``cmd_util`` and ``atari_wrappers``
- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc)
- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc)
- Added ``VectorizedActionNoise`` for continuous vectorized environments (@PartiallyTyped)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -229,4 +230,4 @@ And all the contributors:
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur
83 changes: 81 additions & 2 deletions stable_baselines3/common/noise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
from typing import Optional, List, Iterable
from abc import ABC, abstractmethod
import copy

import numpy as np

Expand Down Expand Up @@ -45,7 +46,7 @@ def __repr__(self) -> str:

class OrnsteinUhlenbeckActionNoise(ActionNoise):
"""
A Ornstein Uhlenbeck action noise, this is designed to aproximate brownian motion with friction.
An Ornstein Uhlenbeck action noise, this is designed to aproximate brownian motion with friction.
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
Expand Down Expand Up @@ -84,3 +85,81 @@ def reset(self) -> None:

def __repr__(self) -> str:
return f'OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})'


class VectorizedActionNoise(ActionNoise):
"""
A Vectorized action noise for parallel environments.
:param base_noise: ActionNoise The noise generator to use
:param n_envs: (int) The number of parallel environments
"""

def __init__(self, base_noise: ActionNoise, n_envs: int):
try:
self.n_envs = int(n_envs)
assert self.n_envs > 0
except (TypeError, AssertionError):
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")

self.base_noise = base_noise
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]

def reset(self, indices: Optional[Iterable[int]] = None) -> None:
"""
Reset all the noise processes, or those listed in indices
:param indices: Optional[Iterable[int]] The indices to reset. Default: None.
If the parameter is None, then all processes are reset to their initial position.
"""
if indices is None:
indices = range(len(self.noises))

for index in indices:
self.noises[index].reset()

def __repr__(self) -> str:
return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})"

def __call__(self) -> np.ndarray:
"""
Generate and stack the action noise from each noise object
"""
noise = np.stack([noise() for noise in self.noises])
return noise

@property
def base_noise(self) -> ActionNoise:
return self._base_noise

@base_noise.setter
def base_noise(self, base_noise: ActionNoise):
if base_noise is None:
raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
if not isinstance(base_noise, ActionNoise):
raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
self._base_noise = base_noise

@property
def noises(self) -> List[ActionNoise]:
return self._noises

@noises.setter
def noises(self, noises: List[ActionNoise]) -> None:
noises = list(noises) # raises TypeError if not iterable
assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."

different_types = [
i for i, noise in enumerate(noises)
if not isinstance(noise, type(self.base_noise))
]

if len(different_types):
raise ValueError(
f"Noise instances at indices {different_types} don't match the type of base_noise",
type(self.base_noise)
)

self._noises = noises
for noise in noises:
noise.reset()
35 changes: 35 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.cmd_util import make_vec_env, make_atari_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.noise import (
VectorizedActionNoise, OrnsteinUhlenbeckActionNoise, ActionNoise)


@pytest.mark.parametrize("env_id", ['CartPole-v1', lambda: gym.make('CartPole-v1')])
Expand Down Expand Up @@ -107,3 +109,36 @@ def dummy_callback(locals_, _globals):

episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
assert len(episode_rewards) == n_eval_episodes


def test_vec_noise():
num_envs = 4
num_actions = 10
mu = np.zeros(num_actions)
sigma = np.ones(num_actions) * 0.4
base: ActionNoise = OrnsteinUhlenbeckActionNoise(mu, sigma)
with pytest.raises(ValueError):
vec = VectorizedActionNoise(base, -1)
with pytest.raises(ValueError):
vec = VectorizedActionNoise(base, None)
with pytest.raises(ValueError):
vec = VectorizedActionNoise(base, "whatever")

vec = VectorizedActionNoise(base, num_envs)
assert vec.n_envs == num_envs
assert vec().shape == (num_envs, num_actions)
assert not (vec() == base()).all()
with pytest.raises(ValueError):
vec = VectorizedActionNoise(None, num_envs)
with pytest.raises(TypeError):
vec = VectorizedActionNoise(12, num_envs)
with pytest.raises(AssertionError):
vec.noises = []
with pytest.raises(TypeError):
vec.noises = None
with pytest.raises(ValueError):
vec.noises = [None] * vec.n_envs
with pytest.raises(AssertionError):
vec.noises = [base] * (num_envs - 1)
assert all(isinstance(noise, type(base)) for noise in vec.noises)
assert len(vec.noises) == num_envs

0 comments on commit 78e8d40

Please sign in to comment.