Skip to content

Commit

Permalink
Make VecNormalize pickleable (#525)
Browse files Browse the repository at this point in the history
* Make VecNormalize pickleable

* Docstrings and load/save methods

* Test serializing VecNormalize

* Bugfix in tests

* Fix lint errors

* VecNormalize: make venv mandatory

* Update example in documentation with new VecNormalize save routine
  • Loading branch information
AdamGleave authored and araffin committed Nov 2, 2019
1 parent 2de0bb6 commit a71f9db
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,10 @@ will compute a running average and standard deviation of input features (it can
model = PPO2(MlpPolicy, env)
model.learn(total_timesteps=2000)
# Don't forget to save the running average when saving the agent
# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = "/tmp/"
model.save(log_dir + "ppo_reacher")
env.save_running_average(log_dir)
env.save(os.path.join(log_dir, "vec_normalize.pkl"))
Custom Policy Network
Expand Down
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Add `n_cpu_tf_sess` to model constructor to choose the number of threads used by Tensorflow
- `VecNormalize` now supports being pickled and unpickled.

Bug Fixes:
^^^^^^^^^^
Expand All @@ -28,6 +29,7 @@ Deprecations:
^^^^^^^^^^^^^
- `nprocs` (ACKTR) and `num_procs` (ACER) are deprecated in favor of `n_cpu_tf_sess` which is now common
to all algorithms
- `VecNormalize`: `load_running_average` and `save_running_average` are deprecated in favour of using pickle.

Others:
^^^^^^^
Expand Down
73 changes: 72 additions & 1 deletion stable_baselines/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
import warnings

import numpy as np

Expand All @@ -9,7 +10,10 @@
class VecNormalize(VecEnvWrapper):
"""
A moving average, normalizing wrapper for vectorized environment.
has support for saving/loading moving average,
It is pickleable which will save moving averages and configuration parameters.
The wrapped environment `venv` is not saved, and must be restored manually with
`set_venv` after being unpickled.
:param venv: (VecEnv) the vectorized environment to wrap
:param training: (bool) Whether to update or not the moving average
Expand Down Expand Up @@ -37,6 +41,45 @@ def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
self.norm_reward = norm_reward
self.old_obs = np.array([])

def __getstate__(self):
"""
Gets state for pickling.
Excludes self.venv, as in general VecEnv's may not be pickleable."""
state = self.__dict__.copy()
# these attributes are not pickleable
del state['venv']
del state['class_attributes']
# these attributes depend on the above and so we would prefer not to pickle
del state['ret']
return state

def __setstate__(self, state):
"""
Restores pickled state.
User must call set_venv() after unpickling before using.
:param state: (dict)"""
self.__dict__.update(state)
assert 'venv' not in state
self.venv = None

def set_venv(self, venv):
"""
Sets the vector environment to wrap to venv.
Also sets attributes derived from this such as `num_env`.
:param venv: (VecEnv)
"""
if self.venv is not None:
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
VecEnvWrapper.__init__(self, venv)
if self.obs_rms.mean.shape != self.observation_space.shape:
raise ValueError("venv is incompatible with current statistics.")
self.ret = np.zeros(self.num_envs)

def step_wait(self):
"""
Apply sequence of actions to sequence of environments
Expand Down Expand Up @@ -88,18 +131,46 @@ def reset(self):
self.ret = np.zeros(self.num_envs)
return self._normalize_observation(obs)

@staticmethod
def load(load_path, venv):
"""
Loads a saved VecNormalize object.
:param load_path: the path to load from.
:param venv: the VecEnv to wrap.
:return: (VecNormalize)
"""
with open(load_path, "rb") as file_handler:
vec_normalize = pickle.load(file_handler)
vec_normalize.set_venv(venv)
return vec_normalize

def save(self, save_path):
with open(save_path, "wb") as file_handler:
pickle.dump(self, file_handler)

def save_running_average(self, path):
"""
:param path: (str) path to log dir
.. deprecated:: 2.9.0
This function will be removed in a future version
"""
warnings.warn("Usage of `save_running_average` is deprecated. Please "
"use `save` or pickle instead.", DeprecationWarning)
for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']):
with open("{}/{}.pkl".format(path, name), 'wb') as file_handler:
pickle.dump(rms, file_handler)

def load_running_average(self, path):
"""
:param path: (str) path to log dir
.. deprecated:: 2.9.0
This function will be removed in a future version
"""
warnings.warn("Usage of `load_running_average` is deprecated. Please "
"use `load` or pickle instead.", DeprecationWarning)
for name in ['obs_rms', 'ret_rms']:
with open("{}/{}.pkl".format(path, name), 'rb') as file_handler:
setattr(self, name, pickle.load(file_handler))
54 changes: 43 additions & 11 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
ENV_ID = 'Pendulum-v0'


def make_env():
return gym.make(ENV_ID)


def test_runningmeanstd():
"""Test RunningMeanStd object"""
for (x_1, x_2, x_3) in [
Expand All @@ -28,20 +32,48 @@ def test_runningmeanstd():
assert np.allclose(moments_1, moments_2)


def test_vec_env():
"""Test VecNormalize Object"""
def check_rms_equal(rmsa, rmsb):
assert np.all(rmsa.mean == rmsb.mean)
assert np.all(rmsa.var == rmsb.var)
assert np.all(rmsa.count == rmsb.count)


def check_vec_norm_equal(norma, normb):
assert norma.observation_space == normb.observation_space
assert norma.action_space == normb.action_space
assert norma.num_envs == normb.num_envs

def make_env():
return gym.make(ENV_ID)
check_rms_equal(norma.obs_rms, normb.obs_rms)
check_rms_equal(norma.ret_rms, normb.ret_rms)
assert norma.clip_obs == normb.clip_obs
assert norma.clip_reward == normb.clip_reward
assert norma.norm_obs == normb.norm_obs
assert norma.norm_reward == normb.norm_reward

env = DummyVecEnv([make_env])
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
_, done = env.reset(), [False]
obs = None
assert np.all(norma.ret == normb.ret)
assert norma.gamma == normb.gamma
assert norma.epsilon == normb.epsilon
assert norma.training == normb.training


def test_vec_env(tmpdir):
"""Test VecNormalize Object"""
clip_obs = 0.5
clip_reward = 5.0

orig_venv = DummyVecEnv([make_env])
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
_, done = norm_venv.reset(), [False]
while not done[0]:
actions = [env.action_space.sample()]
obs, _, done, _ = env.step(actions)
assert np.max(obs) <= 10
actions = [norm_venv.action_space.sample()]
obs, rew, done, _ = norm_venv.step(actions)
assert np.max(np.abs(obs)) <= clip_obs
assert np.max(np.abs(rew)) <= clip_reward

path = str(tmpdir.join("vec_normalize"))
norm_venv.save(path)
deserialized = VecNormalize.load(path, venv=orig_venv)
check_vec_norm_equal(norm_venv, deserialized)


def test_mpi_runningmeanstd():
Expand Down

0 comments on commit a71f9db

Please sign in to comment.