Skip to content

Commit

Permalink
Fix PPO2 tensorboard duplicate entry (#822)
Browse files Browse the repository at this point in the history
* Fix #81 on PPO2 logging system.

* Update changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
Enderdead and araffin committed Apr 28, 2020
1 parent d699d6f commit a57c80e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
3 changes: 3 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Bug Fixes:
- Added ``**kwarg`` pass through for ``reset`` method in ``atari_wrappers.FrameStack`` (@solliet)
- Fix consistency in ``setup_model()`` for SAC, ``target_entropy`` now uses ``self.action_space`` instead of ``self.env.action_space`` (@solliet)
- Fix reward threshold in ``test_identity.py``
- Partially fix tensorboard indexing for PPO2 (@enderdead)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -693,3 +694,5 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @tirafesi @caburu @johannes-dornheim @kvenkman @aakash94
@enderdead

12 changes: 6 additions & 6 deletions stable_baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ def _train_step(self, learning_rate, cliprange, obs, returns, masks, actions, va
td_map[self.clip_range_vf_ph] = cliprange_vf

if states is None:
update_fac = self.n_batch // self.nminibatches // self.noptepochs + 1
update_fac = max(self.n_batch // self.nminibatches // self.noptepochs, 1)
else:
update_fac = self.n_batch // self.nminibatches // self.noptepochs // self.n_steps + 1
update_fac = max(self.n_batch // self.nminibatches // self.noptepochs // self.n_steps, 1)

if writer is not None:
# run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...)
Expand Down Expand Up @@ -346,28 +346,28 @@ def learn(self, total_timesteps, callback=None, log_interval=1, tb_log_name="PPO
self.ep_info_buf.extend(ep_infos)
mb_loss_vals = []
if states is None: # nonrecurrent version
update_fac = self.n_batch // self.nminibatches // self.noptepochs + 1
update_fac = max(self.n_batch // self.nminibatches // self.noptepochs, 1)
inds = np.arange(self.n_batch)
for epoch_num in range(self.noptepochs):
np.random.shuffle(inds)
for start in range(0, self.n_batch, batch_size):
timestep = self.num_timesteps // update_fac + ((self.noptepochs * self.n_batch + epoch_num *
timestep = self.num_timesteps // update_fac + ((epoch_num *
self.n_batch + start) // batch_size)
end = start + batch_size
mbinds = inds[start:end]
slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
mb_loss_vals.append(self._train_step(lr_now, cliprange_now, *slices, writer=writer,
update=timestep, cliprange_vf=cliprange_vf_now))
else: # recurrent version
update_fac = self.n_batch // self.nminibatches // self.noptepochs // self.n_steps + 1
update_fac = max(self.n_batch // self.nminibatches // self.noptepochs // self.n_steps, 1)
assert self.n_envs % self.nminibatches == 0
env_indices = np.arange(self.n_envs)
flat_indices = np.arange(self.n_envs * self.n_steps).reshape(self.n_envs, self.n_steps)
envs_per_batch = batch_size // self.n_steps
for epoch_num in range(self.noptepochs):
np.random.shuffle(env_indices)
for start in range(0, self.n_envs, envs_per_batch):
timestep = self.num_timesteps // update_fac + ((self.noptepochs * self.n_envs + epoch_num *
timestep = self.num_timesteps // update_fac + ((epoch_num *
self.n_envs + start) // envs_per_batch)
end = start + envs_per_batch
mb_env_inds = env_indices[start:end]
Expand Down

0 comments on commit a57c80e

Please sign in to comment.