Skip to content

Commit

Permalink
Bug fix when not enough samples in the replay buffer (#354)
Browse files Browse the repository at this point in the history
* Bug fix when not enough samples in the replay buffer

* Correct typo
  • Loading branch information
araffin authored and hill-a committed Jun 6, 2019
1 parent fefff48 commit 65ed396
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 4 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Pre-Release 2.6.0a0 (WIP)
- **important change** switched to using dictionaries rather than lists when storing parameters, with tensorflow Variable names being the keys. (@Miffyli)
- added specific hyperparameter for PPO2 to clip the value function (``cliprange_vf``)
- fixed ``num_timesteps`` (total_timesteps) variable in PPO2 that was wrongly computed.
- fixed a bug in DDPG/DQN/SAC, when there were the number of samples in the replay buffer was lesser than the batch size
(thanks to @dwiel for spotting the bug)

**Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result,
when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error.
Expand Down Expand Up @@ -342,4 +344,4 @@ In random order...
Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli
@Miffyli @dwiel
4 changes: 4 additions & 0 deletions stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,10 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
epoch_critic_losses = []
epoch_adaptive_distances = []
for t_train in range(self.nb_train_steps):
# Not enough samples in the replay buffer
if not self.replay_buffer.can_sample(self.batch_size):
break

# Adapt param noise, if necessary.
if len(self.replay_buffer) >= self.batch_size and \
t_train % self.param_noise_adaption_interval == 0:
Expand Down
8 changes: 6 additions & 2 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
episode_rewards.append(0.0)
reset = True

if self.num_timesteps > self.learning_starts and self.num_timesteps % self.train_freq == 0:
# Do not train if the warmup phase is not over
# or if there are not enough samples in the replay buffer
can_sample = self.replay_buffer.can_sample(self.batch_size)
if can_sample and self.num_timesteps > self.learning_starts \
and self.num_timesteps % self.train_freq == 0:
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
if self.prioritized_replay:
experience = self.replay_buffer.sample(self.batch_size,
Expand Down Expand Up @@ -261,7 +265,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
new_priorities = np.abs(td_errors) + self.prioritized_replay_eps
self.replay_buffer.update_priorities(batch_idxes, new_priorities)

if self.num_timesteps > self.learning_starts and \
if can_sample and self.num_timesteps > self.learning_starts and \
self.num_timesteps % self.target_network_update_freq == 0:
# Update target network periodically.
self.update_target(sess=self.sess)
Expand Down
10 changes: 10 additions & 0 deletions stable_baselines/deepq/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def buffer_size(self):
"""float: Max capacity of the buffer"""
return self._maxsize

def can_sample(self, n_samples):
"""
Check if n_samples samples can be sampled
from the buffer.
:param n_samples: (int)
:return: (bool)
"""
return len(self) >= n_samples

def is_full(self):
"""
Check whether the replay buffer is full or not.
Expand Down
11 changes: 11 additions & 0 deletions stable_baselines/her/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,20 @@ def add(self, obs_t, action, reward, obs_tp1, done):
def sample(self, *args, **kwargs):
return self.replay_buffer.sample(*args, **kwargs)

def can_sample(self, n_samples):
"""
Check if n_samples samples can be sampled
from the buffer.
:param n_samples: (int)
:return: (bool)
"""
return self.replay_buffer.can_sample(n_samples)

def __len__(self):
return len(self.replay_buffer)


def _sample_achieved_goal(self, episode_transitions, transition_idx):
"""
Sample an achieved goal according to the sampling strategy.
Expand Down
5 changes: 4 additions & 1 deletion stable_baselines/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,10 @@ def learn(self, total_timesteps, callback=None, seed=None,
mb_infos_vals = []
# Update policy, critics and target networks
for grad_step in range(self.gradient_steps):
if self.num_timesteps < self.batch_size or self.num_timesteps < self.learning_starts:
# Break if the warmup phase is not over
# or if there are not enough samples in the replay buffer
if not self.replay_buffer.can_sample(self.batch_size) \
or self.num_timesteps < self.learning_starts:
break
n_updates += 1
# Compute current learning_rate
Expand Down
22 changes: 22 additions & 0 deletions tests/test_her.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,28 @@ def test_her(model_class, goal_selection_strategy, discrete_obs_space):
model.learn(1000)


@pytest.mark.parametrize('model_class', [DDPG, SAC, DQN])
def test_long_episode(model_class):
"""
Check that the model does not break when the replay buffer is still empty
after the first rollout (because the episode is not over).
"""
# n_bits > nb_rollout_steps
n_bits = 10
env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC],
max_steps=n_bits)
kwargs = {}
if model_class == DDPG:
kwargs['nb_rollout_steps'] = 9 # < n_bits
elif model_class in [DQN, SAC]:
kwargs['batch_size'] = 8 # < n_bits
kwargs['learning_starts'] = 0

model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy='future',
verbose=0, **kwargs)
model.learn(200)


@pytest.mark.parametrize('goal_selection_strategy', [list(KEY_TO_GOAL_STRATEGY.keys())[0]])
@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG])
def test_model_manipulation(model_class, goal_selection_strategy):
Expand Down

0 comments on commit 65ed396

Please sign in to comment.