Skip to content

Commit

Permalink
Fixed traj_segment_generator ep_lens recording (#303)
Browse files Browse the repository at this point in the history
* Fixed traj_segment_generator ep_lens recording

* Updated changelog.rst with changes.

* current_ep_len reset to 0

* Update changelog.rst
  • Loading branch information
GerardMaggiolino authored and araffin committed May 3, 2019
1 parent 333c593 commit f238a4c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Pre-Release 2.5.1a0 (WIP)
- added option to not trim output of result plotter by number of timesteps (@Pastafarianist)
- clarified the public interface of ``BasePolicy`` and ``ActorCriticPolicy``. **Breaking change** when using custom policies: ``masks_ph`` is now called ``dones_ph``.
- support for custom stateful policies.
- fixed episode length recording in ``trpo_mpi.utils.traj_segment_generator`` (@GerardMaggiolino)


Release 2.5.0 (2019-03-28)
Expand Down Expand Up @@ -289,4 +290,4 @@ In random order...

Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino
16 changes: 6 additions & 10 deletions stable_baselines/trpo_mpi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):

cur_ep_ret = 0 # return in current episode
current_it_len = 0 # len of current iteration
current_ep_len = 0 # len of current episode
cur_ep_true_ret = 0
ep_true_rets = []
ep_rets = [] # returns of completed episodes in this segment
Expand All @@ -59,12 +60,6 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):
# before returning segment [0, T-1] so we get the correct
# terminal value
if step > 0 and step % horizon == 0:
# Fix to avoid "mean of empty slice" warning when there is only one episode
if len(ep_rets) == 0:
current_it_timesteps = current_it_len
else:
current_it_timesteps = sum(ep_lens) + current_it_len

yield {
"ob": observations,
"rew": rews,
Expand All @@ -77,15 +72,15 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):
"ep_rets": ep_rets,
"ep_lens": ep_lens,
"ep_true_rets": ep_true_rets,
"total_timestep": current_it_timesteps
"total_timestep": current_it_len
}
_, vpred, _, _ = policy.step(observation.reshape(-1, *observation.shape))
# Be careful!!! if you change the downstream algorithm to aggregate
# several of these batches, then be sure to do a deepcopy
ep_rets = []
ep_true_rets = []
ep_lens = []
# make sure current_it_timesteps increments correctly
# Reset current iteration length
current_it_len = 0
i = step % horizon
observations[i] = observation
Expand All @@ -111,13 +106,14 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):
cur_ep_ret += rew
cur_ep_true_ret += true_rew
current_it_len += 1
current_ep_len += 1
if done:
ep_rets.append(cur_ep_ret)
ep_true_rets.append(cur_ep_true_ret)
ep_lens.append(current_it_len)
ep_lens.append(current_ep_len)
cur_ep_ret = 0
cur_ep_true_ret = 0
current_it_len = 0
current_ep_len = 0
if not isinstance(env, VecEnv):
observation = env.reset()
step += 1
Expand Down

0 comments on commit f238a4c

Please sign in to comment.