Skip to content

Commit

Permalink
Merge pull request #216 from Huizerd/master
Browse files Browse the repository at this point in the history
Better scaling for reward plot + axis labels
  • Loading branch information
djsaunde committed Mar 20, 2019
2 parents b9318d4 + f439111 commit b3437de
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions bindsnet/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from typing import Callable, Optional

Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, network: Network, environment: Environment, encoding: Callabl
:param str output: String name of the layer from which to take output from.
:param float plot_length: Relative time length of the plotted record data. Relative to parameter time.
:param str plot_type: Type of plotting ('color' or 'line').
:param int reward_window: Moving average window for the reward plot.
:param int reward_delay: How many iterations to delay delivery of reward.
"""
self.network = network
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(self, network: Network, environment: Environment, encoding: Callabl
self.render_interval = kwargs.get('render_interval', None)
self.plot_length = kwargs.get('plot_length', 1.0)
self.plot_type = kwargs.get('plot_type', 'color')
self.reward_window = kwargs.get('reward_window', None)
self.reward_delay = kwargs.get('reward_delay', None)

self.dt = network.dt
Expand Down Expand Up @@ -223,19 +226,28 @@ def plot_obs(self) -> None:
def plot_reward(self) -> None:
# language=rst
"""
Plot the change of accumulated reward for each episodes
Plot the accumulated reward for each episode.
"""
# Compute moving average
if self.reward_window is not None:
# Ensure window size > 0 and < size of reward list
window = max(min(len(self.reward_list), self.reward_window), 0)

# Fastest implementation of moving average
reward_list_ = pd.Series(self.reward_list).rolling(window=window, min_periods=1).mean().values
else:
reward_list_ = self.reward_list[:]

if self.reward_im is None and self.reward_ax is None:
fig, self.reward_ax = plt.subplots()
self.reward_ax.set_title('Reward')
self.reward_plot, = self.reward_ax.plot(self.reward_list)
self.reward_im, self.reward_ax = plt.subplots()
self.reward_ax.set_title('Accumulated reward')
self.reward_ax.set_xlabel('Episode')
self.reward_ax.set_ylabel('Reward')
self.reward_plot, = self.reward_ax.plot(reward_list_)
else:
reward_array = np.array(self.reward_list)
y_min = reward_array.min()
y_max = reward_array.max()
self.reward_ax.set_xlim(left=0, right=self.episode)
self.reward_ax.set_ylim(bottom=y_min, top=y_max)
self.reward_plot.set_data(range(self.episode), self.reward_list)
self.reward_plot.set_data(range(self.episode), reward_list_)
self.reward_ax.relim()
self.reward_ax.autoscale_view()

def plot_data(self) -> None:
# language=rst
Expand Down

0 comments on commit b3437de

Please sign in to comment.