Skip to content

Commit

Permalink
Fix wrong truncation in HER replay buffer (#1543)
Browse files Browse the repository at this point in the history
* fix episode start idx that leads to wrong episode length

* add episode length test

* Update changelog

* Reformat files

* Use replay_buffer.dones to test HER truncation warning

* truncate_last_trajectory: sample truncated episode and handle infinite horizon tasks

* make test_truncate_last_trajectory independent of learning

* Add timeout comment HER truncate_last_trajectory

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update version.txt

* Update version

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
lbergmann1 and qgallouedec committed Jun 7, 2023
1 parent 4fcda6b commit 32778dd
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 18 deletions.
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a11 (WIP)
Release 2.0.0a12 (WIP)
--------------------------

**Gymnasium support**
Expand Down Expand Up @@ -41,6 +41,7 @@ Bug Fixes:
- Fixed loading DQN changes ``target_update_interval`` (@tobirohrer)
- Fixed env checker to properly reset the env before calling ``step()`` when checking
for ``Inf`` and ``NaN`` (@lutogniew)
- Fixed HER ``truncate_last_trajectory()`` (@lbergmann1)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1350,4 +1351,4 @@ And all the contributors:
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew
@lutogniew @lbergmann1
43 changes: 29 additions & 14 deletions stable_baselines3/her/her_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,24 @@ def add(
# When episode ends, compute and store the episode length
for env_idx in range(self.n_envs):
if done[env_idx]:
episode_start = self._current_ep_start[env_idx]
episode_end = self.pos
if episode_end < episode_start:
# Occurs when the buffer becomes full, the storage resumes at the
# beginning of the buffer. This can happen in the middle of an episode.
episode_end += self.buffer_size
episode_indices = np.arange(episode_start, episode_end) % self.buffer_size
self.ep_length[episode_indices, env_idx] = episode_end - episode_start
# Update the current episode start
self._current_ep_start[env_idx] = self.pos
self._compute_episode_length(env_idx)

def _compute_episode_length(self, env_idx: int) -> None:
"""
Compute and store the episode length for environment with index env_idx
:param env_idx: index of the environment for which the episode length should be computed
"""
episode_start = self._current_ep_start[env_idx]
episode_end = self.pos
if episode_end < episode_start:
# Occurs when the buffer becomes full, the storage resumes at the
# beginning of the buffer. This can happen in the middle of an episode.
episode_end += self.buffer_size
episode_indices = np.arange(episode_start, episode_end) % self.buffer_size
self.ep_length[episode_indices, env_idx] = episode_end - episode_start
# Update the current episode start
self._current_ep_start[env_idx] = self.pos

def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
"""
Expand Down Expand Up @@ -375,12 +383,19 @@ def truncate_last_trajectory(self) -> None:
If not called, we assume that we continue the same trajectory (same episode).
"""
# If we are at the start of an episode, no need to truncate
if (self.ep_start[self.pos] != self.pos).any():
if (self._current_ep_start != self.pos).any():
warnings.warn(
"The last trajectory in the replay buffer will be truncated.\n"
"If you are in the same episode as when the replay buffer was saved,\n"
"you should use `truncate_last_trajectory=False` to avoid that issue."
)
self.ep_start[-1] = self.pos
# set done = True for current episodes
self.dones[self.pos - 1] = True
# only consider epsiodes that are not finished
for env_idx in np.where(self._current_ep_start != self.pos)[0]:
# set done = True for last episodes
self.dones[self.pos - 1, env_idx] = True
# make sure that last episodes can be sampled and
# update next episode start (self._current_ep_start)
self._compute_episode_length(env_idx)
# handle infinite horizon tasks
if self.handle_timeout_termination:
self.timeouts[self.pos - 1, env_idx] = True # not an actual timeout, but it allows bootstrapping
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a11
2.0.0a12
110 changes: 109 additions & 1 deletion tests/test_her.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def env_fn():

model.load_replay_buffer(path, truncate_last_traj=truncate_last_trajectory)

if truncate_last_trajectory:
if truncate_last_trajectory and (old_replay_buffer.dones[old_replay_buffer.pos - 1] == 0).any():
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
Expand Down Expand Up @@ -323,6 +323,114 @@ def env_fn():
model.learn(total_timesteps=100)


@pytest.mark.parametrize("n_envs", [1, 2])
@pytest.mark.parametrize("n_steps", [4, 5])
@pytest.mark.parametrize("handle_timeout_termination", [False, True])
def test_truncate_last_trajectory(n_envs, recwarn, n_steps, handle_timeout_termination):
"""
Test if 'truncate_last_trajectory' works correctly
"""
# remove gym warnings
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")

n_bits = 4

def env_fn():
return BitFlippingEnv(n_bits=n_bits, continuous=True)

venv = make_vec_env(env_fn, n_envs)

replay_buffer = HerReplayBuffer(
buffer_size=int(1e4),
observation_space=venv.observation_space,
action_space=venv.action_space,
env=venv,
n_envs=n_envs,
n_sampled_goal=2,
goal_selection_strategy="future",
)

observations = venv.reset()
for _ in range(n_steps):
actions = np.random.rand(n_envs, n_bits)
next_observations, rewards, dones, infos = venv.step(actions)
replay_buffer.add(observations, next_observations, actions, rewards, dones, infos)
observations = next_observations

old_replay_buffer = deepcopy(replay_buffer)
pos = replay_buffer.pos
if handle_timeout_termination:
env_idx_not_finished = np.where(replay_buffer._current_ep_start != pos)[0]

# Check that there is no warning
assert len(recwarn) == 0

replay_buffer.truncate_last_trajectory()

if (old_replay_buffer.dones[pos - 1] == 0).any():
# at least one episode in the replay buffer did not finish
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
# all episodes in the replay buffer are finished
assert len(recwarn) == 0

# next episode starts at current pos
assert (replay_buffer._current_ep_start == pos).all()
# done = True for last episodes
assert (replay_buffer.dones[pos - 1] == 1).all()
# for all episodes that are not finished before truncate_last_trajectory: timeouts should be 1
if handle_timeout_termination:
assert (replay_buffer.timeouts[pos - 1, env_idx_not_finished] == 1).all()
# episode length sould be != 0 -> episode can be sampled
assert (replay_buffer.ep_length[pos - 1] != 0).all()

# replay buffer should not have changed after truncate_last_trajectory (except dones[pos-1])
for key in ["observation", "desired_goal", "achieved_goal"]:
assert np.allclose(old_replay_buffer.observations[key], replay_buffer.observations[key])
assert np.allclose(old_replay_buffer.next_observations[key], replay_buffer.next_observations[key])
assert np.allclose(old_replay_buffer.actions, replay_buffer.actions)
assert np.allclose(old_replay_buffer.rewards, replay_buffer.rewards)
# we might change the last done of the last trajectory so we don't compare it
assert np.allclose(old_replay_buffer.dones[: pos - 1], replay_buffer.dones[: pos - 1])
assert np.allclose(old_replay_buffer.dones[pos:], replay_buffer.dones[pos:])

for _ in range(10):
actions = np.random.rand(n_envs, n_bits)
next_observations, rewards, dones, infos = venv.step(actions)
replay_buffer.add(observations, next_observations, actions, rewards, dones, infos)
observations = next_observations

# old oberservations must remain unchanged
for key in ["observation", "desired_goal", "achieved_goal"]:
assert np.allclose(old_replay_buffer.observations[key][:pos], replay_buffer.observations[key][:pos])
assert np.allclose(old_replay_buffer.next_observations[key][:pos], replay_buffer.next_observations[key][:pos])
assert np.allclose(old_replay_buffer.actions[:pos], replay_buffer.actions[:pos])
assert np.allclose(old_replay_buffer.rewards[:pos], replay_buffer.rewards[:pos])
assert np.allclose(old_replay_buffer.dones[: pos - 1], replay_buffer.dones[: pos - 1])

# new oberservations must differ from old observations
end_pos = replay_buffer.pos
for key in ["observation", "desired_goal", "achieved_goal"]:
assert not np.allclose(old_replay_buffer.observations[key][pos:end_pos], replay_buffer.observations[key][pos:end_pos])
assert not np.allclose(
old_replay_buffer.next_observations[key][pos:end_pos], replay_buffer.next_observations[key][pos:end_pos]
)
assert not np.allclose(old_replay_buffer.actions[pos:end_pos], replay_buffer.actions[pos:end_pos])
assert not np.allclose(old_replay_buffer.rewards[pos:end_pos], replay_buffer.rewards[pos:end_pos])
assert not np.allclose(old_replay_buffer.dones[pos - 1 : end_pos], replay_buffer.dones[pos - 1 : end_pos])

# all entries with index >= replay_buffer.pos must remain unchanged
for key in ["observation", "desired_goal", "achieved_goal"]:
assert np.allclose(old_replay_buffer.observations[key][end_pos:], replay_buffer.observations[key][end_pos:])
assert np.allclose(old_replay_buffer.next_observations[key][end_pos:], replay_buffer.next_observations[key][end_pos:])
assert np.allclose(old_replay_buffer.actions[end_pos:], replay_buffer.actions[end_pos:])
assert np.allclose(old_replay_buffer.rewards[end_pos:], replay_buffer.rewards[end_pos:])
assert np.allclose(old_replay_buffer.dones[end_pos:], replay_buffer.dones[end_pos:])


@pytest.mark.parametrize("n_bits", [10])
def test_performance_her(n_bits):
"""
Expand Down

0 comments on commit 32778dd

Please sign in to comment.