Skip to content

Commit

Permalink
Bug fix in TRPO (logger with zero gradient) (#496)
Browse files Browse the repository at this point in the history
* Update trpo_mpi.py

The bug is still there. I don't know guys how it works for you, but in my case if I get zero gradient then `mean_losses` are undefined and cannot be zipped in a logger for loop. So it always throws Exception and closes the program.

* Update changelog.rst

* Update changelog.rst

* Update changelog.rst
  • Loading branch information
MarvineGothic authored and araffin committed Oct 30, 2019
1 parent dae212b commit 11e744e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Bug Fixes:
^^^^^^^^^^
- Fix seeding, so it is now possible to have deterministic results on cpu
- Fix a bug in DDPG where `predict` method with `deterministic=False` would fail
- Fix a bug in TRPO: mean_losses was not initialized causing the logger to crash when there was no gradients (@MarvineGothic)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -518,4 +519,5 @@ In random order...
Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic
6 changes: 3 additions & 3 deletions stable_baselines/trpo_mpi/trpo_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def fisher_vector_product(vec):
# list of tuples
paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), self.vfadam.getflat().sum()))
assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

for (loss_name, loss_val) in zip(self.loss_names, mean_losses):
logger.record_tabular(loss_name, loss_val)

with self.timed("vf"):
for _ in range(self.vf_iters):
Expand All @@ -424,9 +427,6 @@ def fisher_vector_product(vec):
grad = self.allmean(self.compute_vflossandgrad(mbob, mbob, mbret, sess=self.sess))
self.vfadam.update(grad, self.vf_stepsize)

for (loss_name, loss_val) in zip(self.loss_names, mean_losses):
logger.record_tabular(loss_name, loss_val)

logger.record_tabular("explained_variance_tdlam_before",
explained_variance(vpredbefore, tdlamret))

Expand Down

0 comments on commit 11e744e

Please sign in to comment.