Skip to content

Commit

Permalink
Merge pull request #39 from hill-a/deepq-fixes
Browse files Browse the repository at this point in the history
DQN fixes
  • Loading branch information
araffin committed Oct 1, 2018
2 parents 288f458 + 4c68739 commit 4983566
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 112 deletions.
2 changes: 1 addition & 1 deletion docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Custom Policy Network
---------------------

Stable baselines provides default policy networks for images (CNNPolicies)
Stable baselines provides default policy networks (see :ref:`Policies <policies>` ) for images (CNNPolicies)
and other type of input features (MlpPolicies).
However, you can also easily define a custom architecture for the policy (or value) network:

Expand Down
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Pre Release 2.0.1.a0 (WIP)
**logging and bug fixes**

- added patch fix for equal function using `gym.spaces.MultiDiscrete` and `gym.spaces.MultiBinary`
- fixes for DQN action_probability
- re-added double DQN + refactored DQN policies **breaking changes**
- replaced `async` with `async_eigen_decomp` in ACKTR/KFAC for python 3.7 compat


Expand Down
3 changes: 1 addition & 2 deletions docs/modules/acer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ Example
import gym
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, \
CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import ACER
Expand Down
3 changes: 1 addition & 2 deletions docs/modules/acktr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ Example
import gym
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, \
CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import ACKTR
Expand Down
13 changes: 12 additions & 1 deletion docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ DDPG
The DDPG model does not support ``stable_baselines.common.policies`` because it uses q-value instead
of value estimation, as a result it must use its own policy models (see :ref:`ddpg_policies`).


.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
LnMlpPolicy
CnnPolicy
LnCnnPolicy

Notes
-----

Expand Down Expand Up @@ -47,7 +58,7 @@ Example
import gym
import numpy as np
from stable_baselines.ddpg.policies import MlpPolicy, CnnPolicy
from stable_baselines.ddpg.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise, AdaptiveParamNoiseSpec
from stable_baselines import DDPG
Expand Down
12 changes: 11 additions & 1 deletion docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ and its extensions (Double-DQN, Dueling-DQN, Prioritized Experience Replay).
The DQN model does not support ``stable_baselines.common.policies``,
as a result it must use its own policy models (see :ref:`deepq_policies`).

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
LnMlpPolicy
CnnPolicy
LnCnnPolicy

Notes
-----

Expand Down Expand Up @@ -46,7 +56,7 @@ Example
import gym
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.deepq.policies import MlpPolicy, CnnPolicy
from stable_baselines.deepq.policies import MlpPolicy
from stable_baselines import DQN
env = gym.make('CartPole-v1')
Expand Down
19 changes: 19 additions & 0 deletions docs/modules/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@
Policy Networks
===============

Stable-baselines provides a set of default policies, that can be used with most action spaces.
If you need more control on the policy architecture, You can also create a custom policy (see :ref:`custom_policy`).

.. note::

CnnPolicies are for images only. MlpPolicies are made for other type of features (e.g. robot joints)

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
MlpLstmPolicy
MlpLnLstmPolicy
CnnPolicy
CnnLstmPolicy
CnnLnLstmPolicy


Base Classes
------------
Expand Down
3 changes: 1 addition & 2 deletions docs/modules/ppo1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ Example
import gym
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, \
CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO1
Expand Down
3 changes: 1 addition & 2 deletions docs/modules/trpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ Example
import gym
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, \
CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import TRPO
Expand Down
34 changes: 16 additions & 18 deletions stable_baselines/deepq/build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def build_act(q_func, ob_space, ac_space, stochastic_ph, update_eps_ph, sess):

policy = q_func(sess, ob_space, ac_space, 1, 1, None)
obs_phs = (policy.obs_ph, policy.processed_x)
deterministic_actions = policy.proba_distribution.mode()
deterministic_actions = tf.argmax(policy.q_values, axis=1)

batch_size = tf.shape(policy.obs_ph)[0]
n_actions = ac_space.nvec if isinstance(ac_space, MultiDiscrete) else ac_space.n
Expand Down Expand Up @@ -235,8 +235,8 @@ def perturb_vars(original_scope, perturbed_scope):
adaptive_policy = q_func(sess, ob_space, ac_space, 1, 1, None, obs_phs=obs_phs)
perturb_for_adaption = perturb_vars(original_scope="model", perturbed_scope="adaptive_model/model")
kl_loss = tf.reduce_sum(
tf.nn.softmax(policy.value_fn) *
(tf.log(tf.nn.softmax(policy.value_fn)) - tf.log(tf.nn.softmax(adaptive_policy.value_fn))),
tf.nn.softmax(policy.q_values) *
(tf.log(tf.nn.softmax(policy.q_values)) - tf.log(tf.nn.softmax(adaptive_policy.q_values))),
axis=-1)
mean_kl = tf.reduce_mean(kl_loss)

Expand All @@ -259,8 +259,8 @@ def update_scale():
lambda: param_noise_threshold))

# Put everything together.
perturbed_deterministic_actions = tf.argmax(perturbable_policy.value_fn, axis=1)
deterministic_actions = tf.argmax(policy.value_fn, axis=1)
perturbed_deterministic_actions = tf.argmax(perturbable_policy.q_values, axis=1)
deterministic_actions = tf.argmax(policy.q_values, axis=1)
batch_size = tf.shape(policy.obs_ph)[0]
n_actions = ac_space.nvec if isinstance(ac_space, MultiDiscrete) else ac_space.n
random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=n_actions, dtype=tf.int64)
Expand Down Expand Up @@ -349,7 +349,7 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=
optimize the error in Bellman's equation. See the top of the file for details.
update_target: (function) copy the parameters from optimized Q function to the target Q function.
See the top of the file for details.
debug: ({str: function}) a bunch of functions to print debug data like q_values.
step_model: (DQNPolicy) Policy for evaluation
"""
n_actions = ac_space.nvec if isinstance(ac_space, MultiDiscrete) else ac_space.n
with tf.variable_scope("input", reuse=reuse):
Expand All @@ -364,23 +364,23 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=
act_f, obs_phs = build_act(q_func, ob_space, ac_space, stochastic_ph, update_eps_ph, sess)

# q network evaluation
with tf.variable_scope("eval_q_func", reuse=True, custom_getter=tf_util.outer_scope_getter("eval_q_func")):
eval_policy = q_func(sess, ob_space, ac_space, 1, 1, None, reuse=True, obs_phs=obs_phs)
with tf.variable_scope("step_model", reuse=True, custom_getter=tf_util.outer_scope_getter("step_model")):
step_model = q_func(sess, ob_space, ac_space, 1, 1, None, reuse=True, obs_phs=obs_phs)
q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name + "/model")
# target q network evalution
# target q network evaluation

with tf.variable_scope("target_q_func", reuse=False):
target_policy = q_func(sess, ob_space, ac_space, 1, 1, None, reuse=False)
target_q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
scope=tf.get_variable_scope().name + "/target_q_func")

# compute estimate of best possible value starting from state at t + 1
double_value_fn = None
double_q_values = None
double_obs_ph = target_policy.obs_ph
if double_q:
with tf.variable_scope("double_q", reuse=True, custom_getter=tf_util.outer_scope_getter("double_q")):
double_policy = q_func(sess, ob_space, ac_space, 1, 1, None, reuse=True)
double_value_fn = double_policy.value_fn
double_q_values = double_policy.q_values
double_obs_ph = double_policy.obs_ph

with tf.variable_scope("loss", reuse=reuse):
Expand All @@ -391,14 +391,14 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=
importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight")

# q scores for actions which we know were selected in the given state.
q_t_selected = tf.reduce_sum(eval_policy.value_fn * tf.one_hot(act_t_ph, n_actions), 1)
q_t_selected = tf.reduce_sum(step_model.q_values * tf.one_hot(act_t_ph, n_actions), axis=1)

# compute estimate of best possible value starting from state at t + 1
if double_q:
q_tp1_best_using_online_net = tf.argmax(double_value_fn, 1)
q_tp1_best = tf.reduce_sum(target_policy.value_fn * tf.one_hot(q_tp1_best_using_online_net, n_actions), 1)
q_tp1_best_using_online_net = tf.argmax(double_q_values, axis=1)
q_tp1_best = tf.reduce_sum(target_policy.q_values * tf.one_hot(q_tp1_best_using_online_net, n_actions), axis=1)
else:
q_tp1_best = tf.reduce_max(target_policy.value_fn, 1)
q_tp1_best = tf.reduce_max(target_policy.q_values, axis=1)
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best

# compute RHS of bellman equation
Expand Down Expand Up @@ -457,6 +457,4 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=
)
update_target = tf_util.function([], [], updates=[update_target_expr])

q_values = tf_util.function([obs_phs[0]], eval_policy.value_fn)

return act_f, train, update_target, {'q_values': q_values}
return act_f, train, update_target, step_model
36 changes: 23 additions & 13 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import tensorflow as tf
import numpy as np
import gym
Expand Down Expand Up @@ -77,8 +79,10 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000
self.graph = None
self.sess = None
self._train_step = None
self.step_model = None
self.update_target = None
self.act = None
self.proba_step = None
self.replay_buffer = None
self.beta_schedule = None
self.exploration = None
Expand All @@ -91,10 +95,16 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000

def setup_model(self):
with SetVerbosity(self.verbose):

assert not isinstance(self.action_space, gym.spaces.Box), \
"Error: DQN cannot output a gym.spaces.Box action space."
assert issubclass(self.policy, DQNPolicy), "Error: the input policy for the DQN model must be " \

# If the policy is wrap in functool.partial (e.g. to disable dueling)
# unwrap it to check the class type
if isinstance(self.policy, partial):
test_policy = self.policy.func
else:
test_policy = self.policy
assert issubclass(test_policy, DQNPolicy), "Error: the input policy for the DQN model must be " \
"an instance of DQNPolicy."

self.graph = tf.Graph()
Expand All @@ -103,7 +113,7 @@ def setup_model(self):

optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)

self.act, self._train_step, self.update_target, _ = deepq.build_train(
self.act, self._train_step, self.update_target, self.step_model = deepq.build_train(
q_func=self.policy,
ob_space=self.observation_space,
ac_space=self.action_space,
Expand All @@ -113,7 +123,7 @@ def setup_model(self):
param_noise=self.param_noise,
sess=self.sess
)

self.proba_step = self.step_model.proba_step
self.params = find_trainable_variables("deepq")

# Initialize the parameters and copy them to the target network.
Expand Down Expand Up @@ -239,13 +249,13 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_

return self

def predict(self, observation, state=None, mask=None, deterministic=False):
def predict(self, observation, state=None, mask=None, deterministic=True):
observation = np.array(observation)
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)

observation = observation.reshape((-1,) + self.observation_space.shape)
with self.sess.as_default():
actions = self.act(observation, stochastic=not deterministic)
actions, _, _ = self.step_model.step(observation, deterministic=deterministic)

if not vectorized_env:
actions = actions[0]
Expand All @@ -257,14 +267,14 @@ def action_probability(self, observation, state=None, mask=None):
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)

observation = observation.reshape((-1,) + self.observation_space.shape)
actions_proba = self.proba_step(observation, state, mask)

if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions_proba = actions_proba[0]

# Get the tensor just before the softmax function in the TensorFlow graph,
# then execute the graph from the input observation to this tensor.
tensor = self.graph.get_tensor_by_name('deepq/q_func/fully_connected_2/BiasAdd:0')
if vectorized_env:
return self._softmax(self.sess.run(tensor, feed_dict={'deepq/observation:0': observation}))
else:
return self._softmax(self.sess.run(tensor, feed_dict={'deepq/observation:0': observation}))[0]
return actions_proba

def save(self, save_path):
# params
Expand Down
7 changes: 6 additions & 1 deletion stable_baselines/deepq/experiments/enjoy_mountaincar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

import gym
import numpy as np

from stable_baselines.deepq import DQN

Expand All @@ -20,7 +21,11 @@ def main(args):
while not done:
if not args.no_render:
env.render()
action, _ = model.predict(obs)
# Epsilon-greedy
if np.random.random() < 0.02:
action = env.action_space.sample()
else:
action, _ = model.predict(obs, deterministic=True)
obs, rew, done, _ = env.step(action)
episode_rew += rew
print("Episode reward", episode_rew)
Expand Down
6 changes: 4 additions & 2 deletions stable_baselines/deepq/experiments/run_atari.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from functools import partial

from stable_baselines import bench, logger
from stable_baselines.common import set_global_seeds
Expand All @@ -14,8 +15,8 @@ def main():
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--prioritized', type=int, default=1)
parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
parser.add_argument('--dueling', type=int, default=1)
parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
parser.add_argument('--checkpoint-freq', type=int, default=10000)
parser.add_argument('--checkpoint-path', type=str, default=None)
Expand All @@ -26,10 +27,11 @@ def main():
env = make_atari(args.env)
env = bench.Monitor(env, logger.get_dir())
env = wrap_atari_dqn(env)
policy = partial(CnnPolicy, dueling=args.dueling == 1)

model = DQN(
env=env,
policy=CnnPolicy,
policy=policy,
learning_rate=1e-4,
buffer_size=10000,
exploration_fraction=0.1,
Expand Down

0 comments on commit 4983566

Please sign in to comment.