Skip to content

Commit

Permalink
Save parameters as a dictionary & load_parameters (#344)
Browse files Browse the repository at this point in the history
* * Add `get_parameters` function (returns all loadable/saveable tensorflow Variables)
* Add `load_parameters` function (loads model parameters from file/file-like/list of ndarrays)
* Update A2C, ACER, ACKTR, DDPG, DQN, PPOs, SAC and TRPO to use `get_parameters` to define
  parameters necessary for correctly loading/saving models.

* * Switch from using lists of parameters to dicts of
  variable name -> ndarray.
  * Includes support for loading from older .pkl files
    with a list of parameters

* Renamed `get_parameters` to `_get_parameter_list`
  * `get_parameters` dictionary of variable name -> ndarrays
  * `_get_parameter_list` returns list of tensorflow Variables
    that should be saved/loaded

* Updated changelog for  changes

* Clarified name of  function parameter

* Updated contributor's list

* Fix few PEP8 errors

* Update docs to reflect variable name

* Fix PEP8/style in test_load_parameters

* Requested small typo/doc changes and removed unused parameter in tests

* Add  parameter for  with tests

* Add tests for  from a file/file-like objects

* Use format-function and small line-length change

* Add warning about not updating trainer parameters upon 'load_parameters'

* Add an example of using load/get parameters with a simple ES example

* Use OrderedDict for get_params rather than normal dict

* Make _get_variable_list a public function

* Update load_parameters example with A2C+ES hybrid and only mutating specific parameters
  • Loading branch information
Miffyli authored and araffin committed Jun 2, 2019
1 parent e78a29d commit 204bc9a
Show file tree
Hide file tree
Showing 13 changed files with 376 additions and 46 deletions.
89 changes: 89 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ In the following example, we will train, save and load an A2C model on the Lunar
LunarLander requires the python package `box2d`.
You can install it using ``apt install swig`` and then ``pip install box2d box2d-kengz``

.. note::
``load`` function re-creates model from scratch on each call, which can be slow.
If you need to e.g. evaluate same model with multiple different sets of parameters, consider
using ``load_parameters`` instead.

.. code-block:: python
import gym
Expand Down Expand Up @@ -311,6 +316,90 @@ However, you can also easily define a custom architecture for the policy network
model.learn(total_timesteps=100000)
Accessing and modifying model parameters
----------------------------------------

You can access model's parameters via ``load_parameters`` and ``get_parameters`` functions, which
use dictionaries that map variable names to NumPy arrays.

These functions are useful when you need to e.g. evaluate large set of models with same network structure,
visualize different layers of the network or modify parameters manually.

You can access original Tensorflow Variables with function ``get_parameter_list``.

Following example demonstrates reading parameters, modifying some of them and loading them to model
by implementing `evolution strategy <http://blog.otoro.net/2017/10/29/visual-evolution-strategies/>`_
for solving ``CartPole-v1`` environment. The initial guess for parameters is obtained by running
A2C policy gradient updates on the model.

.. code-block:: python
import gym
import numpy as np
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import A2C
def mutate(params):
"""Mutate parameters by adding normal noise to them"""
return dict((name, param + np.random.normal(size=param.shape))
for name, param in params.items())
def evaluate(env, model):
"""Return mean fitness (sum of episodic rewards) for given model"""
episode_rewards = []
for _ in range(10):
reward_sum = 0
done = False
obs = env.reset()
while not done:
action, _states = model.predict(obs)
obs, reward, done, info = env.step(action)
reward_sum += reward
episode_rewards.append(reward_sum)
return np.mean(episode_rewards)
# Create env
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])
# Create policy with a small network
model = A2C(MlpPolicy, env, ent_coef=0.0, learning_rate=0.1,
policy_kwargs={'net_arch': [8, ]})
# Use traditional actor-critic policy gradient updates to
# find good initial parameters
model.learn(total_timesteps=5000)
# Get the parameters as the starting point for ES
mean_params = model.get_parameters()
# Include only variables with "/pi/" (policy) or "/shared" (shared layers)
# in their name: Only these ones affect the action.
mean_params = dict((key, value) for key, value in mean_params.items()
if ("/pi/" in key or "/shared" in key))
for iteration in range(10):
# Create population of candidates and evaluate them
population = []
for population_i in range(100):
candidate = mutate(mean_params)
# Load new policy parameters to agent.
# Tell function that it should only update parameters
# we give it (policy parameters)
model.load_parameters(candidate, exact_match=False)
fitness = evaluate(env, model)
population.append((candidate, fitness))
# Take top 10% and use average over their parameters as next mean parameter
top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:10]
mean_params = dict(
(name, np.stack([top_candidate[0][name] for top_candidate in top_candidates]).mean(0))
for name in mean_params.keys()
)
mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / 10.0
print("Iteration {:<3} Mean top fitness: {:.2f}".format(iteration, mean_fitness))
Recurrent Policies
------------------

Expand Down
4 changes: 4 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Release 2.5.2a0 (WIP)
- The parameter ``filter_size`` of the function ``conv`` in A2C utils now supports passing a list/tuple of two integers (height and width), in order to have non-squared kernel matrix. (@yutingsz)
- fixed a bug where initial learning rate is logged instead of its placeholder in ``A2C.setup_model`` (@sc420)
- fixed a bug where number of timesteps is incorrectly updated and logged in ``A2C.learn`` and ``A2C._train_step`` (@sc420)
- added ``load_parameters`` and ``get_parameters`` for most learning algorithms.
With these methods, users are able to load and get parameters to/from existing model, without touching tensorflow. (@Miffyli)
- **important change** switched to using dictionaries rather than lists when storing parameters, with tensorflow Variable names being the keys. (@Miffyli)

Release 2.5.1 (2019-05-04)
--------------------------
Expand Down Expand Up @@ -303,3 +306,4 @@ In random order...
Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli
4 changes: 2 additions & 2 deletions stable_baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ def save(self, save_path):
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
params_to_save = self.get_parameters()

self._save_to_file(save_path, data=data, params=params)
self._save_to_file(save_path, data=data, params=params_to_save)


class A2CRunner(AbstractEnvRunner):
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,9 +566,9 @@ def save(self, save_path):
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
params_to_save = self.get_parameters()

self._save_to_file(save_path, data=data, params=params)
self._save_to_file(save_path, data=data, params=params_to_save)


class _Runner(AbstractEnvRunner):
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/acktr/acktr_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,6 @@ def save(self, save_path):
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
params_to_save = self.get_parameters()

self._save_to_file(save_path, data=data, params=params)
self._save_to_file(save_path, data=data, params=params_to_save)
117 changes: 113 additions & 4 deletions stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
import os
import glob
import warnings
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, pol
self.graph = None
self.sess = None
self.params = None
self._param_load_ops = None

if env is not None:
if isinstance(env, str):
Expand Down Expand Up @@ -153,6 +155,49 @@ def _setup_learn(self, seed):
if seed is not None:
set_global_seeds(seed)

@abstractmethod
def get_parameter_list(self):
"""
Get tensorflow Variables of model's parameters
This includes all variables necessary for continuing training (saving / loading).
:return: (list) List of tensorflow Variables
"""
pass

def get_parameters(self):
"""
Get current model parameters as dictionary of variable name -> ndarray.
:return: (OrderedDict) Dictionary of variable name -> ndarray of model's parameters.
"""
parameters = self.get_parameter_list()
parameter_values = self.sess.run(parameters)
return_dictionary = OrderedDict((param.name, value) for param, value in zip(parameters, parameter_values))
return return_dictionary

def _setup_load_operations(self):
"""
Create tensorflow operations for loading model parameters
"""
# Assume tensorflow graphs are static -> check
# that we only call this function once
if self._param_load_ops is not None:
raise RuntimeError("Parameter load operations have already been created")
# For each loadable parameter, create appropiate
# placeholder and an assign op, and store them to
# self.load_param_ops as dict of variable.name -> (placeholder, assign)
loadable_parameters = self.get_parameter_list()
# Use OrderedDict to store order for backwards compatibility with
# list-based params
self._param_load_ops = OrderedDict()
with self.graph.as_default():
for param in loadable_parameters:
placeholder = tf.placeholder(dtype=param.dtype, shape=param.shape)
# param.name is unique (tensorflow variables have unique names)
self._param_load_ops[param.name] = (placeholder, param.assign(placeholder))

@abstractmethod
def _get_pretrain_placeholders(self):
"""
Expand Down Expand Up @@ -312,6 +357,70 @@ def action_probability(self, observation, state=None, mask=None, actions=None):
"""
pass

def load_parameters(self, load_path_or_dict, exact_match=True):
"""
Load model parameters from a file or a dictionary
Dictionary keys should be tensorflow variable names, which can be obtained
with ``get_parameters`` function. If ``exact_match`` is True, dictionary
should contain keys for all model's parameters, otherwise RunTimeError
is raised. If False, only variables included in the dictionary will be updated.
This does not load agent's hyper-parameters.
.. warning::
This function does not update trainer/optimizer variables (e.g. momentum).
As such training after using this function may lead to less-than-optimal results.
:param load_path_or_dict: (str or file-like or dict) Save parameter location
or dict of parameters as variable.name -> ndarrays to be loaded.
:param exact_match: (bool) If True, expects load dictionary to contain keys for
all variables in the model. If False, loads parameters only for variables
mentioned in the dictionary. Defaults to True.
"""
# Make sure we have assign ops
if self._param_load_ops is None:
self._setup_load_operations()

params = None
if isinstance(load_path_or_dict, dict):
# Assume `load_path_or_dict` is dict of variable.name -> ndarrays we want to load
params = load_path_or_dict
elif isinstance(load_path_or_dict, list):
warnings.warn("Loading model parameters from a list. This has been replaced " +
"with parameter dictionaries with variable names and parameters. " +
"If you are loading from a file, consider re-saving the file.",
DeprecationWarning)
# Assume `load_path_or_dict` is list of ndarrays.
# Create param dictionary assuming the parameters are in same order
# as `get_parameter_list` returns them.
params = dict()
for i, param_name in enumerate(self._param_load_ops.keys()):
params[param_name] = load_path_or_dict[i]
else:
# Assume a filepath or file-like.
# Use existing deserializer to load the parameters
_, params = BaseRLModel._load_from_file(load_path_or_dict)

feed_dict = {}
param_update_ops = []
# Keep track of not-updated variables
not_updated_variables = set(self._param_load_ops.keys())
for param_name, param_value in params.items():
placeholder, assign_op = self._param_load_ops[param_name]
feed_dict[placeholder] = param_value
# Create list of tf.assign operations for sess.run
param_update_ops.append(assign_op)
# Keep track which variables are updated
not_updated_variables.remove(param_name)

# Check that we updated all parameters if exact_match=True
if exact_match and len(not_updated_variables) > 0:
raise RuntimeError("Load dictionary did not contain all variables. " +
"Missing variables: {}".format(", ".join(not_updated_variables)))

self.sess.run(param_update_ops, feed_dict=feed_dict)

@abstractmethod
def save(self, save_path):
"""
Expand Down Expand Up @@ -541,6 +650,9 @@ def action_probability(self, observation, state=None, mask=None, actions=None):

return actions_proba

def get_parameter_list(self):
return self.params

@abstractmethod
def save(self, save_path):
pass
Expand All @@ -560,10 +672,7 @@ def load(cls, load_path, env=None, **kwargs):
model.set_env(env)
model.setup_model()

restores = []
for param, loaded_p in zip(model.params, params):
restores.append(param.assign(loaded_p))
model.sess.run(restores)
model.load_parameters(params)

return model

Expand Down
25 changes: 8 additions & 17 deletions stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,12 @@ def action_probability(self, observation, state=None, mask=None, actions=None):
warnings.warn("Warning: action probability is meaningless for DDPG. Returning None")
return None

def get_parameter_list(self):
return (self.params +
self.target_params +
self.obs_rms_params +
self.ret_rms_params)

def save(self, save_path):
data = {
"observation_space": self.observation_space,
Expand Down Expand Up @@ -1020,20 +1026,12 @@ def save(self, save_path):
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
target_params = self.sess.run(self.target_params)
norm_obs_params = self.sess.run(self.obs_rms_params)
norm_ret_params = self.sess.run(self.ret_rms_params)
params_to_save = self.get_parameters()

params_to_save = params \
+ target_params \
+ norm_obs_params \
+ norm_ret_params
self._save_to_file(save_path,
data=data,
params=params_to_save)


@classmethod
def load(cls, load_path, env=None, **kwargs):
data, params = cls._load_from_file(load_path)
Expand All @@ -1049,13 +1047,6 @@ def load(cls, load_path, env=None, **kwargs):
model.set_env(env)
model.setup_model()

restores = []
params_to_load = model.params \
+ model.target_params \
+ model.obs_rms_params \
+ model.ret_rms_params
for param, loaded_p in zip(params_to_load, params):
restores.append(param.assign(loaded_p))
model.sess.run(restores)
model.load_parameters(params)

return model
12 changes: 6 additions & 6 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ def action_probability(self, observation, state=None, mask=None, actions=None):

return actions_proba

def get_parameter_list(self):
return self.params

def save(self, save_path):
# params
data = {
Expand Down Expand Up @@ -338,9 +341,9 @@ def save(self, save_path):
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
params_to_save = self.get_parameters()

self._save_to_file(save_path, data=data, params=params)
self._save_to_file(save_path, data=data, params=params_to_save)

@classmethod
def load(cls, load_path, env=None, **kwargs):
Expand All @@ -357,9 +360,6 @@ def load(cls, load_path, env=None, **kwargs):
model.set_env(env)
model.setup_model()

restores = []
for param, loaded_p in zip(model.params, params):
restores.append(param.assign(loaded_p))
model.sess.run(restores)
model.load_parameters(params)

return model
4 changes: 2 additions & 2 deletions stable_baselines/ppo1/pposgd_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,6 @@ def save(self, save_path):
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
params_to_save = self.get_parameters()

self._save_to_file(save_path, data=data, params=params)
self._save_to_file(save_path, data=data, params=params_to_save)

0 comments on commit 204bc9a

Please sign in to comment.