Skip to content

Commit

Permalink
Fixed path splitting in _get_latest_run_id() on Windows machines (#318)
Browse files Browse the repository at this point in the history
* Fixed path splitting in _get_latest_run_id() on Windows machines

* Returned to previous split method, replaced split delimiter with os.sep in _get_latest_run_id function. Wrote test for saving tensorboard data twice with the same logname

* Fixed tests for saving tensorboard twice with same logname

* Updated tensorboard tests

* Updated tensorboard tests. Added name and fix to changelog

* Update test_tensorboard.py

* Update test_tensorboard.py
  • Loading branch information
PatrickWalter214 authored and araffin committed May 11, 2019
1 parent bea2eed commit a13e12c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Release 2.5.2a0 (WIP)

- Bugfix for ``VecEnvWrapper.__getattr__`` which enables access to class attributes inherited from parent classes.
- Removed ``get_available_gpus`` function which hadn't been used anywhere (@Pastafarianist)
- Fixed path splitting in ``TensorboardWriter._get_latest_run_id()`` on Windows machines (@PatrickWalter214)


Release 2.5.1 (2019-05-04)
Expand Down Expand Up @@ -299,4 +300,4 @@ In random order...

Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214
4 changes: 2 additions & 2 deletions stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,8 @@ def _get_latest_run_id(self):
:return: (int) latest run number
"""
max_run_id = 0
for path in glob.glob(self.tensorboard_log_path + "/{}_[0-9]*".format(self.tb_log_name)):
file_name = path.split("/")[-1]
for path in glob.glob("{}/{}_[0-9]*".format(self.tensorboard_log_path, self.tb_log_name)):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if self.tb_log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
max_run_id = int(ext)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,25 @@

@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_tensorboard(model_name):
logname = model_name.upper()
algo, env_id = MODEL_DICT[model_name]
model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=TENSORBOARD_DIR)
model.learn(N_STEPS)
model.learn(N_STEPS, reset_num_timesteps=False)

assert os.path.isdir(TENSORBOARD_DIR + logname + "_1")
assert not os.path.isdir(TENSORBOARD_DIR + logname + "_2")

@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_multiple_runs(model_name):
logname = "tb_multiple_runs_" + model_name
algo, env_id = MODEL_DICT[model_name]
model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=TENSORBOARD_DIR)
model.learn(N_STEPS, tb_log_name=logname)
model.learn(N_STEPS, tb_log_name=logname)

assert os.path.isdir(TENSORBOARD_DIR + logname + "_1")
# Check that the log dir name increments correctly
assert os.path.isdir(TENSORBOARD_DIR + logname + "_2")


0 comments on commit a13e12c

Please sign in to comment.