Skip to content

Commit

Permalink
Improve results plotter (#293)
Browse files Browse the repository at this point in the history
* Do not attempt to plot curves shorter than window

Otherwise numpy crashes with ValueError.

* Allow not trimming timesteps in results_plotter

* Update changelog for #293

* Update changelog.rst
  • Loading branch information
Pastafarianist authored and araffin committed Apr 28, 2019
1 parent 0693426 commit 0eac3f5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Pre-Release 2.5.1a0 (WIP)
``set_attr`` now returns ``None`` rather than a list of ``None``. (@kantneel)
- ``GAIL``: ``gail.dataset.ExpertDataset` supports loading from memory rather than file, and
``gail.dataset.record_expert`` supports returning in-memory rather than saving to file.
- fixed bug where result plotter would crash on very short runs (@Pastafarianist)
- added option to not trim output of result plotter by number of timesteps (@Pastafarianist)


Release 2.5.0 (2019-03-28)
Expand Down Expand Up @@ -282,4 +284,4 @@ In random order...

Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel
@XMaster96 @kantneel @Pastafarianist
12 changes: 8 additions & 4 deletions stable_baselines/results_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ def plot_curves(xy_list, xaxis, title):
for (i, (x, y)) in enumerate(xy_list):
color = COLORS[i]
plt.scatter(x, y, s=2)
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) # So returns average of last EPISODE_WINDOW episodes
plt.plot(x, y_mean, color=color)
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
if x.shape[0] >= EPISODES_WINDOW:
# Compute and plot rolling mean with window of size EPISODE_WINDOW
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
plt.plot(x, y_mean, color=color)
plt.xlim(minx, maxx)
plt.title(title)
plt.xlabel(xaxis)
Expand All @@ -98,7 +101,7 @@ def plot_results(dirs, num_timesteps, xaxis, task_name):
plot the results
:param dirs: ([str]) the save location of the results to plot
:param num_timesteps: (int) only plot the points below this value
:param num_timesteps: (int or None) only plot the points below this value
:param xaxis: (str) the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:param task_name: (str) the title of the task to plot
Expand All @@ -107,7 +110,8 @@ def plot_results(dirs, num_timesteps, xaxis, task_name):
tslist = []
for folder in dirs:
timesteps = load_results(folder)
timesteps = timesteps[timesteps.l.cumsum() <= num_timesteps]
if num_timesteps is not None:
timesteps = timesteps[timesteps.l.cumsum() <= num_timesteps]
tslist.append(timesteps)
xy_list = [ts2xy(timesteps_item, xaxis) for timesteps_item in tslist]
plot_curves(xy_list, xaxis, task_name)
Expand Down

0 comments on commit 0eac3f5

Please sign in to comment.