Skip to content

Commit

Permalink
Update VecNormalize normalization (#609)
Browse files Browse the repository at this point in the history
* VecNormalize: Add public normalize_{obs..,rew} methods

* Update changelog

* VecNormalize: get_original_{obs,rews}

* VecNormalize: Update rewards in reset()

Note that after the _update_rews() refactor, self.ret doesn't
update anymore if `not self.training`.

* update changelog

* renames

* changelog: fix indent

* changelog: nested list needs blank lines

* Add tests

* Address review, fix tests

* update tests

* More annotations

* Update stable_baselines/common/vec_env/vec_normalize.py

Co-Authored-By: Adam Gleave <adam@gleave.me>

* Address review comments

* Defensive copy
  • Loading branch information
shwang authored and AdamGleave committed Dec 18, 2019
1 parent ba51e25 commit 99dcdba
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 25 deletions.
9 changes: 8 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ New Features:
- Environments are automatically wrapped in a `DummyVecEnv` if needed when passing them to the model constructor
- Added `stable_baselines.common.make_vec_env` helper to simplify VecEnv creation
- Added `stable_baselines.common.evaluation.evaluate_policy` helper to simplify model evaluation
- `VecNormalize` now supports being pickled and unpickled.
- `VecNormalize` changes:

- Now supports being pickled and unpickled (@AdamGleave).
- New methods `.normalize_obs(obs)` and `normalize_reward(rews)` apply normalization
to arbitrary observation or rewards without updating statistics (@shwang)
- `.get_original_reward()` returns the unnormalized rewards from the most recent timestep
- `.reset()` now collects observation statistics (used to only apply normalization)

- Add parameter `exploration_initial_eps` to DQN. (@jdossgollin)
- Add type checking and PEP 561 compliance.
Note: most functions are still not annotated, this will be a gradual process.
Expand Down
70 changes: 46 additions & 24 deletions stable_baselines/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
self.training = training
self.norm_obs = norm_obs
self.norm_reward = norm_reward
self.old_obs = np.array([])
self.old_obs = None
self.old_rews = None

def __getstate__(self):
"""
Expand Down Expand Up @@ -88,48 +89,69 @@ def step_wait(self):
where 'news' is a boolean vector indicating whether each element is new.
"""
obs, rews, news, infos = self.venv.step_wait()
self.ret = self.ret * self.gamma + rews
self.old_obs = obs
obs = self._normalize_observation(obs)
if self.norm_reward:
if self.training:
self.ret_rms.update(self.ret)
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
self.old_rews = rews

if self.training:
self.obs_rms.update(obs)
obs = self.normalize_obs(obs)

if self.training:
self._update_reward(rews)
rews = self.normalize_reward(rews)

self.ret[news] = 0
return obs, rews, news, infos

def _normalize_observation(self, obs):
def _update_reward(self, reward: np.ndarray) -> None:
"""Update reward normalization statistics."""
self.ret = self.ret * self.gamma + reward
self.ret_rms.update(self.ret)

def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""
:param obs: (numpy tensor)
Normalize observations using this VecNormalize's observations statistics.
Calling this method does not update statistics.
"""
if self.norm_obs:
if self.training:
self.obs_rms.update(obs)
obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs,
obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon),
-self.clip_obs,
self.clip_obs)
return obs
else:
return obs
return obs

def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
"""
Normalize rewards using this VecNormalize's rewards statistics.
Calling this method does not update statistics.
"""
if self.norm_reward:
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon),
-self.clip_reward, self.clip_reward)
return reward

def get_original_obs(self):
def get_original_obs(self) -> np.ndarray:
"""
returns the unnormalized observation
Returns an unnormalized version of the observations from the most recent
step or reset.
"""
return self.old_obs.copy()

:return: (numpy float)
def get_original_reward(self) -> np.ndarray:
"""
Returns an unnormalized version of the rewards from the most recent step.
"""
return self.old_obs
return self.old_rews.copy()

def reset(self):
"""
Reset all environments
"""
obs = self.venv.reset()
if len(np.array(obs).shape) == 1: # for when num_cpu is 1
self.old_obs = [obs]
else:
self.old_obs = obs
self.old_obs = obs
self.ret = np.zeros(self.num_envs)
return self._normalize_observation(obs)
if self.training:
self._update_reward(self.ret)
return self.normalize_obs(obs)

@staticmethod
def load(load_path, venv):
Expand Down
47 changes: 47 additions & 0 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,53 @@ def test_vec_env(tmpdir):
check_vec_norm_equal(norm_venv, deserialized)


def _make_warmstart_cartpole():
"""Warm-start VecNormalize by stepping through CartPole"""
venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])
venv = VecNormalize(venv)
venv.reset()
venv.get_original_obs()

for _ in range(100):
actions = [venv.action_space.sample()]
venv.step(actions)
return venv


def test_get_original():
venv = _make_warmstart_cartpole()
for _ in range(3):
actions = [venv.action_space.sample()]
obs, rewards, _, _ = venv.step(actions)
obs = obs[0]
orig_obs = venv.get_original_obs()[0]
rewards = rewards[0]
orig_rewards = venv.get_original_reward()[0]

assert np.all(orig_rewards == 1)
assert orig_obs.shape == obs.shape
assert orig_rewards.dtype == rewards.dtype
assert not np.array_equal(orig_obs, obs)
assert not np.array_equal(orig_rewards, rewards)
np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)


def test_normalize_external():
venv = _make_warmstart_cartpole()

rewards = np.array([1, 1])
norm_rewards = venv.normalize_reward(rewards)
assert norm_rewards.shape == rewards.shape
# Episode return is almost always >= 1 in CartPole. So reward should shrink.
assert np.all(norm_rewards < 1)

# Don't have any guarantees on obs normalization, except shape, really.
obs = np.array([0, 0, 0, 0])
norm_obs = venv.normalize_obs(obs)
assert obs.shape == norm_obs.shape


def test_mpi_runningmeanstd():
"""Test RunningMeanStd object for MPI"""
return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2',
Expand Down

0 comments on commit 99dcdba

Please sign in to comment.