Skip to content

Commit

Permalink
Add parameter to DQN for initial probability of epsilon greedy explor…
Browse files Browse the repository at this point in the history
…ation (#559)

* add parameter to DQN for initial probability of random action to mirror existing parameter for final probability of random action

* Simplify changelog text
  • Loading branch information
jdossgollin authored and Miffyli committed Nov 19, 2019
1 parent a1ab7a1 commit be6ef31
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ New Features:
^^^^^^^^^^^^^
- Add `n_cpu_tf_sess` to model constructor to choose the number of threads used by Tensorflow
- `VecNormalize` now supports being pickled and unpickled.
- Add parameter `exploration_initial_eps` to DQN. (@jdossgollin)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -525,4 +526,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic
@MarvineGothic @jdossgollin
6 changes: 4 additions & 2 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class DQN(OffPolicyRLModel):
:param exploration_fraction: (float) fraction of entire training period over which the exploration rate is
annealed
:param exploration_final_eps: (float) final value of random action probability
:param exploration_initial_eps: (float) initial value of random action probability
:param train_freq: (int) update the model every `train_freq` steps. set to None to disable printing
:param batch_size: (int) size of a batched sampled from replay buffer for training
:param double_q: (bool) Whether to enable Double-Q learning or not.
Expand All @@ -54,7 +55,7 @@ class DQN(OffPolicyRLModel):
If None, the number of cpu of the current machine will be used.
"""
def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=50000, exploration_fraction=0.1,
exploration_final_eps=0.02, train_freq=1, batch_size=32, double_q=True,
exploration_final_eps=0.02, exploration_initial_eps=1.0, train_freq=1, batch_size=32, double_q=True,
learning_starts=1000, target_network_update_freq=500, prioritized_replay=False,
prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None,
prioritized_replay_eps=1e-6, param_noise=False,
Expand All @@ -76,6 +77,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000
self.prioritized_replay_beta0 = prioritized_replay_beta0
self.prioritized_replay_beta_iters = prioritized_replay_beta_iters
self.exploration_final_eps = exploration_final_eps
self.exploration_initial_eps = exploration_initial_eps
self.exploration_fraction = exploration_fraction
self.buffer_size = buffer_size
self.learning_rate = learning_rate
Expand Down Expand Up @@ -178,7 +180,7 @@ def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="D

# Create the schedule for exploration starting from 1.
self.exploration = LinearSchedule(schedule_timesteps=int(self.exploration_fraction * total_timesteps),
initial_p=1.0,
initial_p=self.exploration_initial_eps,
final_p=self.exploration_final_eps)

episode_rewards = [0.0]
Expand Down
28 changes: 27 additions & 1 deletion tests/test_schedules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from stable_baselines.common.schedules import ConstantSchedule, PiecewiseSchedule
from stable_baselines.common.schedules import ConstantSchedule, PiecewiseSchedule, LinearSchedule


def test_piecewise_schedule():
Expand Down Expand Up @@ -31,3 +31,29 @@ def test_constant_schedule():
constant_sched = ConstantSchedule(5)
for i in range(-100, 100):
assert np.isclose(constant_sched.value(i), 5)


def test_linear_schedule():
"""
test LinearSchedule
"""
linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.2, final_p=0.8)
assert np.isclose(linear_sched.value(50), 0.5)
assert np.isclose(linear_sched.value(0), 0.2)
assert np.isclose(linear_sched.value(100), 0.8)

linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.8, final_p=0.2)
assert np.isclose(linear_sched.value(50), 0.5)
assert np.isclose(linear_sched.value(0), 0.8)
assert np.isclose(linear_sched.value(100), 0.2)

linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=-0.6, final_p=0.2)
assert np.isclose(linear_sched.value(50), -0.2)
assert np.isclose(linear_sched.value(0), -0.6)
assert np.isclose(linear_sched.value(100), 0.2)

linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.2, final_p=-0.6)
assert np.isclose(linear_sched.value(50), -0.2)
assert np.isclose(linear_sched.value(0), 0.2)
assert np.isclose(linear_sched.value(100), -0.6)

0 comments on commit be6ef31

Please sign in to comment.