Skip to content

Commit

Permalink
Fix find trainable vars (#364)
Browse files Browse the repository at this point in the history
* Remove buggy `find_trainable_variables` and replace it with `tf_util.get_trainable_vars`

* Patch loading of old DDPG models

* Fix indentation
  • Loading branch information
araffin authored and hill-a committed Jun 11, 2019
1 parent 65ed396 commit 72dab6a
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 29 deletions.
4 changes: 4 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Pre-Release 2.6.0a0 (WIP)
- fixed ``num_timesteps`` (total_timesteps) variable in PPO2 that was wrongly computed.
- fixed a bug in DDPG/DQN/SAC, when there were the number of samples in the replay buffer was lesser than the batch size
(thanks to @dwiel for spotting the bug)
- **removed** ``a2c.utils.find_trainable_params`` please use ``common.tf_util.get_trainable_vars`` instead.
``find_trainable_params`` was returning all trainable variables, discarding the scope argument.
This bug was causing the model to save duplicated parameters (for DDPG and SAC)
but did not affect the performance.

**Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result,
when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error.
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from stable_baselines.common import explained_variance, tf_util, ActorCriticRLModel, SetVerbosity, TensorboardWriter
from stable_baselines.common.policies import ActorCriticPolicy, RecurrentActorCriticPolicy
from stable_baselines.common.runners import AbstractEnvRunner
from stable_baselines.a2c.utils import discount_with_dones, Scheduler, find_trainable_variables, mse, \
from stable_baselines.a2c.utils import discount_with_dones, Scheduler, mse, \
total_episode_reward_logger
from stable_baselines.ppo2.ppo2 import safe_mean

Expand Down Expand Up @@ -137,7 +137,7 @@ def setup_model(self):
tf.summary.scalar('value_function_loss', self.vf_loss)
tf.summary.scalar('loss', loss)

self.params = find_trainable_variables("model")
self.params = tf_util.get_trainable_vars("model")
grads = tf.gradients(loss, self.params)
if self.max_grad_norm is not None:
grads, _ = tf.clip_by_global_norm(grads, self.max_grad_norm)
Expand Down
12 changes: 0 additions & 12 deletions stable_baselines/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,18 +308,6 @@ def discount_with_dones(rewards, dones, gamma):
discounted.append(ret)
return discounted[::-1]


def find_trainable_variables(key):
"""
Returns the trainable variables within a given scope
:param key: (str) The variable scope
:return: ([TensorFlow Tensor]) the trainable variables
"""
with tf.variable_scope(key):
return tf.trainable_variables()


def make_path(path):
"""
For a given path, create the folders if they do not exist
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from gym.spaces import Discrete, Box

from stable_baselines import logger
from stable_baselines.a2c.utils import batch_to_seq, seq_to_batch, Scheduler, find_trainable_variables, EpisodeStats, \
from stable_baselines.a2c.utils import batch_to_seq, seq_to_batch, Scheduler, EpisodeStats, \
get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance, total_episode_reward_logger
from stable_baselines.acer.buffer import Buffer
from stable_baselines.common import ActorCriticRLModel, tf_util, SetVerbosity, TensorboardWriter
Expand Down Expand Up @@ -194,7 +194,7 @@ def setup_model(self):
step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
n_batch_step, reuse=False, **self.policy_kwargs)

self.params = find_trainable_variables("model")
self.params = tf_util.get_trainable_vars("model")

with tf.variable_scope("train_model", reuse=True,
custom_getter=tf_util.outer_scope_getter("train_model")):
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/acktr/acktr_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from stable_baselines import logger
from stable_baselines.common import explained_variance, ActorCriticRLModel, tf_util, SetVerbosity, TensorboardWriter
from stable_baselines.a2c.a2c import A2CRunner
from stable_baselines.a2c.utils import Scheduler, find_trainable_variables, calc_entropy, mse, \
from stable_baselines.a2c.utils import Scheduler, calc_entropy, mse, \
total_episode_reward_logger
from stable_baselines.acktr import kfac
from stable_baselines.common.policies import ActorCriticPolicy, RecurrentActorCriticPolicy
Expand Down Expand Up @@ -130,7 +130,7 @@ def setup_model(self):
self.model = step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs,
1, n_batch_step, reuse=False, **self.policy_kwargs)

self.params = params = find_trainable_variables("model")
self.params = params = tf_util.get_trainable_vars("model")

with tf.variable_scope("train_model", reuse=True,
custom_getter=tf_util.outer_scope_getter("train_model")):
Expand Down
24 changes: 20 additions & 4 deletions stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from stable_baselines.common.mpi_adam import MpiAdam
from stable_baselines.ddpg.policies import DDPGPolicy
from stable_baselines.common.mpi_running_mean_std import RunningMeanStd
from stable_baselines.a2c.utils import find_trainable_variables, total_episode_reward_logger
from stable_baselines.a2c.utils import total_episode_reward_logger
from stable_baselines.deepq.replay_buffer import ReplayBuffer


Expand Down Expand Up @@ -425,8 +425,10 @@ def setup_model(self):
tf.summary.scalar('actor_loss', self.actor_loss)
tf.summary.scalar('critic_loss', self.critic_loss)

self.params = find_trainable_variables("model")
self.target_params = find_trainable_variables("target")
self.params = tf_util.get_trainable_vars("model") \
+ tf_util.get_trainable_vars('noise/') + tf_util.get_trainable_vars('noise_adapt/')

self.target_params = tf_util.get_trainable_vars("target")
self.obs_rms_params = [var for var in tf.global_variables()
if "obs_rms" in var.name]
self.ret_rms_params = [var for var in tf.global_variables()
Expand Down Expand Up @@ -1106,7 +1108,21 @@ def load(cls, load_path, env=None, **kwargs):
model.__dict__.update(kwargs)
model.set_env(env)
model.setup_model()

# Patch for version < v2.6.0, duplicated keys where saved
if len(params) > len(model.get_parameter_list()):
n_params = len(model.params)
n_target_params = len(model.target_params)
n_normalisation_params = len(model.obs_rms_params) + len(model.ret_rms_params)
# Check that the issue is the one from
# https://github.com/hill-a/stable-baselines/issues/363
assert len(params) == 2 * (n_params + n_target_params) + n_normalisation_params,\
"The number of parameter saved differs from the number of parameters"\
" that should be loaded: {}!={}".format(len(params), len(model.get_parameter_list()))
# Remove duplicates
params_ = params[:n_params + n_target_params]
if n_normalisation_params > 0:
params_ += params[-n_normalisation_params:]
params = params_
model.load_parameters(params)

return model
4 changes: 2 additions & 2 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from stable_baselines.common.schedules import LinearSchedule
from stable_baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from stable_baselines.deepq.policies import DQNPolicy
from stable_baselines.a2c.utils import find_trainable_variables, total_episode_reward_logger
from stable_baselines.a2c.utils import total_episode_reward_logger


class DQN(OffPolicyRLModel):
Expand Down Expand Up @@ -134,7 +134,7 @@ def setup_model(self):
full_tensorboard_log=self.full_tensorboard_log
)
self.proba_step = self.step_model.proba_step
self.params = find_trainable_variables("deepq")
self.params = tf_util.get_trainable_vars("deepq")

# Initialize the parameters and copy them to the target network.
tf_util.initialize(self.sess)
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import tensorflow as tf

from stable_baselines.a2c.utils import find_trainable_variables, total_episode_reward_logger
from stable_baselines.a2c.utils import total_episode_reward_logger
from stable_baselines.common import tf_util, OffPolicyRLModel, SetVerbosity, TensorboardWriter
from stable_baselines.common.vec_env import VecEnv
from stable_baselines.deepq.replay_buffer import ReplayBuffer
Expand Down Expand Up @@ -311,8 +311,8 @@ def setup_model(self):
tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate_ph))

# Retrieve parameters that must be saved
self.params = find_trainable_variables("model")
self.target_params = find_trainable_variables("target/values_fn/vf")
self.params = get_vars("model")
self.target_params = get_vars("target/values_fn/vf")

# Initialize Variables and target network
with self.sess.as_default():
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/trpo_mpi/trpo_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from stable_baselines.common.mpi_adam import MpiAdam
from stable_baselines.common.cg import conjugate_gradient
from stable_baselines.common.policies import ActorCriticPolicy
from stable_baselines.a2c.utils import find_trainable_variables, total_episode_reward_logger
from stable_baselines.a2c.utils import total_episode_reward_logger
from stable_baselines.trpo_mpi.utils import traj_segment_generator, add_vtarg_and_adv, flatten_lists


Expand Down Expand Up @@ -250,7 +250,7 @@ def allmean(arr):
self.proba_step = self.policy_pi.proba_step
self.initial_state = self.policy_pi.initial_state

self.params = find_trainable_variables("model")
self.params = tf_util.get_trainable_vars("model") + tf_util.get_trainable_vars("oldpi")
if self.using_gail:
self.params.extend(self.reward_giver.get_trainable_variables())

Expand Down

0 comments on commit 72dab6a

Please sign in to comment.