Skip to content

Commit

Permalink
Update log_interval ppo2 (#73)
Browse files Browse the repository at this point in the history
Previous implementation does not comply with the objective. 
Additionally it should be more consistent with other algorithms (A2C) is still a bit different.
  • Loading branch information
huvar authored and araffin committed Nov 4, 2018
1 parent 7c95b74 commit 6776f53
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stable_baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _train_step(self, learning_rate, cliprange, obs, returns, masks, actions, va

return policy_loss, value_loss, policy_entropy, approxkl, clipfrac

def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="PPO2"):
def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_log_name="PPO2"):
with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer:
self._setup_learn(seed)

Expand Down Expand Up @@ -319,7 +319,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
if callback is not None:
callback(locals(), globals())

if self.verbose >= 1 and ((update + 1) % log_interval//100 == 0 or update == 0):
if self.verbose >= 1 and ((update + 1) % log_interval == 0 or update == 0):
explained_var = explained_variance(values, returns)
logger.logkv("serial_timesteps", (update + 1) * self.n_steps)
logger.logkv("nupdates", (update + 1))
Expand Down

0 comments on commit 6776f53

Please sign in to comment.