Skip to content

Commit

Permalink
Refactor and clarify doc for load_results (#583)
Browse files Browse the repository at this point in the history
* Refactor and clarify doc for load_results (#582)

* Updates from PR feedback

Added unit test for load_results and get_monitor_files
Add @jbulow in changelog

* Update tests/test_monitor.py

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update tests/test_monitor.py

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Updates from PR feedback

* Updates from PR feedback

* Updates from PR feedback

Convert path object to string to pass pytype's type check.
  • Loading branch information
jbulow authored and araffin committed Nov 27, 2019
1 parent 05c5717 commit b461adb
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 4 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Others:
- Add upper bound for Tensorflow version (<2.0.0).
- Refactored test to remove duplicated code
- Add pull request template
- Replaced redundant code in load_results (@jbulow)

Documentation:
^^^^^^^^^^^^^^
Expand All @@ -62,6 +63,7 @@ Documentation:
- Fix multiprocessing example (@rusu24edward)
- Fix `result_plotter` example
- Fix typo in algos.rst, "containes" to "contains" (@SyllogismRXS)
- Fix outdated source documentation for load_results

Release 2.8.0 (2019-09-29)
--------------------------
Expand Down Expand Up @@ -542,4 +544,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow
6 changes: 3 additions & 3 deletions stable_baselines/bench/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ def get_monitor_files(path):

def load_results(path):
"""
Load results from a given file
Load all Monitor logs from a given directory path matching ``*monitor.csv`` and ``*monitor.json``
:param path: (str) the path to the log file
:param path: (str) the directory path containing the log file(s)
:return: (Pandas DataFrame) the logged data
"""
# get both csv and (old) json files
monitor_files = (glob(os.path.join(path, "*monitor.json")) + glob(os.path.join(path, "*monitor.csv")))
monitor_files = (glob(os.path.join(path, "*monitor.json")) + get_monitor_files(path))
if not monitor_files:
raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, path))
data_frames = []
Expand Down
50 changes: 50 additions & 0 deletions tests/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import gym

from stable_baselines.bench import Monitor
from stable_baselines.bench.monitor import get_monitor_files, load_results


def test_monitor():
Expand Down Expand Up @@ -34,3 +35,52 @@ def test_monitor():
assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
file_handler.close()
os.remove(mon_file)

def test_monitor_load_results(tmp_path):
"""
test load_results on log files produced by the monitor wrapper
"""
tmp_path = str(tmp_path)
env1 = gym.make("CartPole-v1")
env1.seed(0)
monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
monitor_env1 = Monitor(env1, monitor_file1)

monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 1
assert monitor_file1 in monitor_files

monitor_env1.reset()
episode_count1 = 0
for _ in range(1000):
_, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample())
if done:
episode_count1 += 1
monitor_env1.reset()

results_size1 = len(load_results(os.path.join(tmp_path)).index)
assert results_size1 == episode_count1

env2 = gym.make("CartPole-v1")
env2.seed(0)
monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
monitor_env2 = Monitor(env2, monitor_file2)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 2
assert monitor_file1 in monitor_files
assert monitor_file2 in monitor_files

monitor_env2.reset()
episode_count2 = 0
for _ in range(1000):
_, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample())
if done:
episode_count2 += 1
monitor_env2.reset()

results_size2 = len(load_results(os.path.join(tmp_path)).index)

assert results_size2 == (results_size1 + episode_count2)

os.remove(monitor_file1)
os.remove(monitor_file2)

0 comments on commit b461adb

Please sign in to comment.