Skip to content

Commit

Permalink
Ensure train/n_updates metric accounts for early stopping of training…
Browse files Browse the repository at this point in the history
… loop (#1311)

* Correct _n_updates when target_kl stops loop early

* Update changelog

* Simplify code

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
adamfrly and araffin committed Feb 6, 2023
1 parent d0c1a87 commit 411ff69
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)

Deprecations:
^^^^^^^^^^^^^
Expand Down
3 changes: 1 addition & 2 deletions stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def train(self) -> None:
clip_fractions = []

continue_training = True

# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
Expand Down Expand Up @@ -271,10 +270,10 @@ def train(self) -> None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

self._n_updates += 1
if not continue_training:
break

self._n_updates += self.n_epochs
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

# Logs
Expand Down

0 comments on commit 411ff69

Please sign in to comment.