Skip to content

Commit

Permalink
Add some missing tests, update VecNormalize and RolloutBuffer (#50)
Browse files Browse the repository at this point in the history
* Change saving/loading normalization parameters to use single pickle file

* Remove 'use_gae' from RolloutBuffer compute_returns function

* Add some missing tests for normalizer, nan-checker and PPO clip_value_fn argument

* Update changelog

* Fix typo

* Use proper pytest.raises for catching errors in tests

* Add comment on GAE and how to obtain non-GAE behaviour

* Remove save/load_running_average from VecNormalize in favor of load/save

* Update changelog

* Update docstring

* Add accidentally removed tests for VecNormalize

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
Miffyli and araffin committed Jun 10, 2020
1 parent 44f8218 commit b833207
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 82 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ Breaking Changes:
- ``check_env`` to ``check_for_correct_spaces`` (utils.py. Renamed to avoid confusion with environment checker tools)

- Moved static function ``_is_vectorized_observation`` from common/policies.py to common/utils.py under name ``is_vectorized_observation``.

- Removed ``{save,load}_running_average`` functions of ``VecNormalize`` in favor of ``load/save``.
- Removed ``use_gae`` parameter from ``RolloutBuffer.compute_returns_and_advantage``.

New Features:
^^^^^^^^^^^^^
Expand All @@ -50,6 +51,7 @@ Others:
- Added bit further comments on register/getting policies ("MlpPolicy", "CnnPolicy").
- Renamed ``progress`` (value from 1 in start of training to 0 in end) to ``progress_remaining``.
- Added ``policies.py`` files for A2C/PPO, which define MlpPolicy/CnnPolicy (renamed ActorCriticPolicies).
- Added some missing tests for ``VecNormalize``, ``VecCheckNan`` and ``PPO``.

Documentation:
^^^^^^^^^^^^^^
Expand Down
53 changes: 20 additions & 33 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,49 +240,36 @@ def reset(self) -> None:

def compute_returns_and_advantage(self,
last_value: th.Tensor,
dones: np.ndarray,
use_gae: bool = True) -> None:
dones: np.ndarray) -> None:
"""
Post-processing step: compute the returns (sum of discounted rewards)
and advantage (A(s) = R - V(S)).
and GAE advantage.
Adapted from Stable-Baselines PPO2.
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain vanilla advantage (A(s) = R - V(S))
where R is the discounted reward with value bootstrap,
set ``gae_lambda=1.0`` during initialization.
:param last_value: (th.Tensor)
:param dones: (np.ndarray)
:param use_gae: (bool) Whether to use Generalized Advantage Estimation
or normal advantage for advantage computation.
"""
# convert to numpy
last_value = last_value.clone().cpu().numpy().flatten()

if use_gae:
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_value = last_value
else:
next_non_terminal = 1.0 - self.dones[step + 1]
next_value = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
self.returns = self.advantages + self.values
else:
# Discounted return with value bootstrap
# Note: this is equivalent to GAE computation
# with gae_lambda = 1.0
last_return = 0.0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_value = last_value
last_return = self.rewards[step] + next_non_terminal * next_value
else:
next_non_terminal = 1.0 - self.dones[step + 1]
last_return = self.rewards[step] + self.gamma * last_return * next_non_terminal
self.returns[step] = last_return
self.advantages = self.returns - self.values
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_value = last_value
else:
next_non_terminal = 1.0 - self.dones[step + 1]
next_value = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
self.returns = self.advantages + self.values

def add(self,
obs: np.ndarray,
Expand Down
30 changes: 10 additions & 20 deletions stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from stable_baselines3.common.vec_env.base_vec_env import VecEnvWrapper
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.running_mean_std import RunningMeanStd


Expand Down Expand Up @@ -160,35 +160,25 @@ def reset(self):
return self.normalize_obs(obs)

@staticmethod
def load(load_path, venv):
def load(load_path: str, venv: VecEnv) -> "VecNormalize":
"""
Loads a saved VecNormalize object.
:param load_path: the path to load from.
:param venv: the VecEnv to wrap.
:param load_path: (str) the path to load from.
:param venv: (VecEnv) the VecEnv to wrap.
:return: (VecNormalize)
"""
with open(load_path, "rb") as file_handler:
vec_normalize = pickle.load(file_handler)
vec_normalize.set_venv(venv)
return vec_normalize

def save(self, save_path):
with open(save_path, "wb") as file_handler:
pickle.dump(self, file_handler)

def save_running_average(self, path):
"""
:param path: (str) path to log dir
def save(self, save_path: str) -> None:
"""
for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']):
with open(f"{path}/{name}.pkl", 'wb') as file_handler:
pickle.dump(rms, file_handler)
Save current VecNormalize object with
all running statistics and settings (e.g. clip_obs)
def load_running_average(self, path):
:param save_path: (str) The path to save to
"""
:param path: (str) path to log dir
"""
for name in ['obs_rms', 'ret_rms']:
with open(f"{path}/{name}.pkl", 'rb') as file_handler:
setattr(self, name, pickle.load(file_handler))
with open(save_path, "wb") as file_handler:
pickle.dump(self, file_handler)
21 changes: 17 additions & 4 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,27 @@ def test_td3(action_noise):
model.learn(total_timesteps=1000, eval_freq=500)


@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
def test_onpolicy(model_class, env_id):
model = model_class('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
def test_a2c(env_id):
model = A2C('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=1000, eval_freq=500)


@pytest.mark.parametrize("ent_coef", ['auto', 0.01])
@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
@pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
def test_ppo(env_id, clip_range_vf):
if clip_range_vf is not None and clip_range_vf < 0:
# Should throw an error
with pytest.raises(AssertionError):
model = PPO('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True,
clip_range_vf=clip_range_vf)
else:
model = PPO('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True,
clip_range_vf=clip_range_vf)
model.learn(total_timesteps=1000, eval_freq=500)


@pytest.mark.parametrize("ent_coef", ['auto', 0.01, 'auto_0.01'])
def test_sac(ent_coef):
model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100, verbose=1, create_eval_env=True, ent_coef=ent_coef,
Expand Down
27 changes: 7 additions & 20 deletions tests/test_vec_check_nan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gym
from gym import spaces
import numpy as np
import pytest

from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan

Expand Down Expand Up @@ -40,32 +41,18 @@ def test_check_nan():

env.step([[0]])

try:
with pytest.raises(ValueError):
env.step([[float('NaN')]])
except ValueError:
pass
else:
assert False

try:
with pytest.raises(ValueError):
env.step([[float('inf')]])
except ValueError:
pass
else:
assert False

try:
with pytest.raises(ValueError):
env.step([[-1]])
except ValueError:
pass
else:
assert False

try:
with pytest.raises(ValueError):
env.step([[1]])
except ValueError:
pass
else:
assert False

env.step(np.array([[0, 1], [0, 1]]))

env.reset()
17 changes: 13 additions & 4 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import gym
import pytest
import numpy as np
Expand Down Expand Up @@ -150,17 +152,24 @@ def test_sync_vec_normalize():

env.reset()
# Initialize running mean
latest_reward = None
for _ in range(100):
env.step([env.action_space.sample()])
_, latest_reward, _, _ = env.step([env.action_space.sample()])

# Check that unnormalized reward is same as original reward
original_latest_reward = env.get_original_reward()
assert np.allclose(original_latest_reward, env.unnormalize_reward(latest_reward))

obs = env.reset()
original_obs = env.get_original_obs()
dummy_rewards = np.random.rand(10)
# Normalization must be different
original_obs = env.get_original_obs()
# Check that unnormalization works
assert np.allclose(original_obs, env.unnormalize_obs(obs))
# Normalization must be different (between different environments)
assert not np.allclose(obs, eval_env.normalize_obs(original_obs))

# Test syncing of parameters
sync_envs_normalization(env, eval_env)

# Now they must be synced
assert np.allclose(obs, eval_env.normalize_obs(original_obs))
assert np.allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))

0 comments on commit b833207

Please sign in to comment.