Skip to content

Commit

Permalink
Update DQN arguments: add double_q (#481)
Browse files Browse the repository at this point in the history
* Update DQN arguments: add double_q

* Fix CI build

* [ci skip] Update link

* Replace pdf link with arxiv link
  • Loading branch information
araffin committed Sep 21, 2019
1 parent 19ed2ca commit 4929e54
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ Breaking Changes:
wrapped in `if __name__ == '__main__'`. You can restore previous behavior
by explicitly setting `start_method = 'fork'`. See
`PR #428 <https://github.com/hill-a/stable-baselines/pull/428>`_.
- updated dependencies: tensorflow v1.8.0 is now required
- Updated dependencies: tensorflow v1.8.0 is now required
- Remove `checkpoint_path` and `checkpoint_freq` argument from `DQN` that were not used

New Features:
^^^^^^^^^^^^^
- **important change** Switch to using zip-archived JSON and Numpy `savez` for
storing models for better support across library/Python versions. (@Miffyli)
- Add `double_q` argument to `DQN` constructor

Bug Fixes:
^^^^^^^^^^
Expand Down
12 changes: 10 additions & 2 deletions docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,16 @@ and its extensions (Double-DQN, Dueling-DQN, Prioritized Experience Replay).
Notes
-----

- Original paper: https://arxiv.org/abs/1312.5602
- DQN paper: https://arxiv.org/abs/1312.5602
- Dueling DQN: https://arxiv.org/abs/1511.06581
- Double-Q Learning: https://arxiv.org/abs/1509.06461
- Prioritized Experience Replay: https://arxiv.org/abs/1511.05952

.. note::

By default, the DQN class has double q learning and dueling extensions enabled.
See `Issue #406 <https://github.com/hill-a/stable-baselines/issues/406>`_ for disabling dueling.
To disable double-q learning, you can change the default value in the constructor.


Can I use?
Expand Down Expand Up @@ -60,7 +69,6 @@ Example
from stable_baselines import DQN
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])
model = DQN(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
Expand Down
23 changes: 11 additions & 12 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

class DQN(OffPolicyRLModel):
"""
The DQN model class. DQN paper: https://arxiv.org/pdf/1312.5602.pdf
The DQN model class.
DQN paper: https://arxiv.org/abs/1312.5602
Dueling DQN: https://arxiv.org/abs/1511.06581
Double-Q Learning: https://arxiv.org/abs/1509.06461
Prioritized Experience Replay: https://arxiv.org/abs/1511.05952
:param policy: (DQNPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, LnMlpPolicy, ...)
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
Expand All @@ -27,11 +31,7 @@ class DQN(OffPolicyRLModel):
:param exploration_final_eps: (float) final value of random action probability
:param train_freq: (int) update the model every `train_freq` steps. set to None to disable printing
:param batch_size: (int) size of a batched sampled from replay buffer for training
:param checkpoint_freq: (int) how often to save the model. This is so that the best version is restored at the
end of the training. If you do not wish to restore the best version
at the end of the training set this variable to None.
:param checkpoint_path: (str) replacement path used if you need to log to somewhere else than a temporary
directory.
:param double_q: (bool) Whether to enable Double-Q learning or not.
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
:param target_network_update_freq: (int) update the target network every `target_network_update_freq` steps.
:param prioritized_replay: (bool) if True prioritized replay buffer will be used.
Expand All @@ -50,7 +50,7 @@ class DQN(OffPolicyRLModel):
"""

def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=50000, exploration_fraction=0.1,
exploration_final_eps=0.02, train_freq=1, batch_size=32, checkpoint_freq=10000, checkpoint_path=None,
exploration_final_eps=0.02, train_freq=1, batch_size=32, double_q=True,
learning_starts=1000, target_network_update_freq=500, prioritized_replay=False,
prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None,
prioritized_replay_eps=1e-6, param_noise=False, verbose=0, tensorboard_log=None,
Expand All @@ -60,15 +60,13 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000
super(DQN, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DQNPolicy,
requires_vec_env=False, policy_kwargs=policy_kwargs)

self.checkpoint_path = checkpoint_path
self.param_noise = param_noise
self.learning_starts = learning_starts
self.train_freq = train_freq
self.prioritized_replay = prioritized_replay
self.prioritized_replay_eps = prioritized_replay_eps
self.batch_size = batch_size
self.target_network_update_freq = target_network_update_freq
self.checkpoint_freq = checkpoint_freq
self.prioritized_replay_alpha = prioritized_replay_alpha
self.prioritized_replay_beta0 = prioritized_replay_beta0
self.prioritized_replay_beta_iters = prioritized_replay_beta_iters
Expand All @@ -79,6 +77,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000
self.gamma = gamma
self.tensorboard_log = tensorboard_log
self.full_tensorboard_log = full_tensorboard_log
self.double_q = double_q

self.graph = None
self.sess = None
Expand Down Expand Up @@ -131,7 +130,8 @@ def setup_model(self):
grad_norm_clipping=10,
param_noise=self.param_noise,
sess=self.sess,
full_tensorboard_log=self.full_tensorboard_log
full_tensorboard_log=self.full_tensorboard_log,
double_q=self.double_q
)
self.proba_step = self.step_model.proba_step
self.params = tf_util.get_trainable_vars("deepq")
Expand Down Expand Up @@ -334,15 +334,14 @@ def get_parameter_list(self):
def save(self, save_path, cloudpickle=False):
# params
data = {
"checkpoint_path": self.checkpoint_path,
"double_q": self.double_q,
"param_noise": self.param_noise,
"learning_starts": self.learning_starts,
"train_freq": self.train_freq,
"prioritized_replay": self.prioritized_replay,
"prioritized_replay_eps": self.prioritized_replay_eps,
"batch_size": self.batch_size,
"target_network_update_freq": self.target_network_update_freq,
"checkpoint_freq": self.checkpoint_freq,
"prioritized_replay_alpha": self.prioritized_replay_alpha,
"prioritized_replay_beta0": self.prioritized_replay_beta0,
"prioritized_replay_beta_iters": self.prioritized_replay_beta_iters,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_deepq():

model = DQN(env=env, policy=CnnPolicy, learning_rate=1e-4, buffer_size=10000, exploration_fraction=0.1,
exploration_final_eps=0.01, train_freq=4, learning_starts=10000, target_network_update_freq=1000,
gamma=0.99, prioritized_replay=True, prioritized_replay_alpha=0.6, checkpoint_freq=10000)
gamma=0.99, prioritized_replay=True, prioritized_replay_alpha=0.6)
model.learn(total_timesteps=NUM_TIMESTEPS)

env.close()
Expand Down

0 comments on commit 4929e54

Please sign in to comment.