Skip to content

Commit

Permalink
Merge 75cac18 into 66c6412
Browse files Browse the repository at this point in the history
  • Loading branch information
cash committed May 10, 2018
2 parents 66c6412 + 75cac18 commit 2407ec4
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 4 deletions.
112 changes: 111 additions & 1 deletion dworp/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# Distributed under the terms of the Modified BSD License.

import matplotlib.pyplot as plt
from .observer import PauseObserver
import matplotlib.ticker
from .observer import Observer, PauseObserver

# Note: do not include this in __init__.py so that dworp does not have a hard requirement
# for matplotlib.
Expand All @@ -24,3 +25,112 @@ def __init__(self, delay, start=False, stop=False):

def pause(self):
plt.pause(self.delay)


class VariablePlotter(Observer): # pragma: no cover
"""Plot one or more variables from the Environment
Args:
var (string, list): Name or list of names of variable in Environment to plot
fmt (string, list): Optional matplotlib format string or list of strings (default is "b")
scrolling (int): Optional number of time steps in scroll or 0 for no scrolling
title(string): Optional figure title (default is name of variable)
xlabel(string): Optional x-axis label (default is "Time")
ylabel(string): Optional y-axis label (default is name of variable)
xlim(list): Optional starting x-axis limits as [xmin, xmax]
ylim(list): Optional starting y-axis limits as [ymin, ymax]
legend(bool, string): A string location for legend or False (default is False)
pause(float): Optional pause between updates (must be > 0)
"""
def __init__(self, var, fmt="b", scrolling=0, title=None, xlabel="Time", ylabel=None,
xlim=None, ylim=None, legend=False, pause=0.001):
self.var_names = [var] if isinstance(var, str) else var
self.fmt = self._prepare_format_option(fmt)
self.scrolling = scrolling
self.title = title if title else self._prepare_default_title()
self.xlabel = xlabel
self.ylabel = ylabel if ylabel else self._prepare_default_title()
self.xlim = xlim
self.ylim = ylim
self.legend = legend
self.pause = pause
self.fig = None
self.axes_margin = 0.01

self.time = []
self.data = {name: [] for name in self.var_names}

def _prepare_format_option(self, fmt):
# fmt must be either same length as var_names or length 1
format_ = [fmt] if isinstance(fmt, str) else fmt
assert(len(format_) == 1 or len(format_) == len(self.var_names))
if len(format_) != len(self.var_names):
format_ = format_ * len(self.var_names)
return dict(zip(self.var_names, format_))

def _prepare_default_title(self):
if len(self.var_names) == 0:
return self.var_names[0]
else:
return ' & '.join(self.var_names)

def start(self, now, agents, env):
self.prepare()
self.update(now, agents, env)

def step(self, now, agents, env):
self.update(now, agents, env)

def stop(self, now, agents, env):
plt.close(self.fig)

def plot(self, now, agents, env):
self.time.append(now)
for name in self.var_names:
self.data[name].append(getattr(env, name))
if self.scrolling:
plot_time = self.time[(-1 * self.scrolling):]
plot_data = {name: data[(-1 * self.scrolling):] for name, data in self.data.items()}
else:
plot_time = self.time
plot_data = self.data

for name, data in plot_data.items():
plt.plot(plot_time, data, self.fmt[name], label=name)
if self.legend:
plt.legend(loc=self.legend)
axes = self.fig.axes[0]
axes.set_xlabel(self.xlabel)
axes.set_ylabel(self.ylabel)
self.set_axes_limits(axes)

def set_axes_limits(self, axes):
if self.ylim:
ylim = axes.get_ylim()
margin = self.axes_margin * abs(ylim[1] - ylim[0])
ymin = min(self.ylim[0], min(min(self.data.values())) - margin)
ymax = max(self.ylim[1], max(max(self.data.values())) + margin)
axes.set_ylim([ymin, ymax])
if self.xlim:
xmin = min(self.xlim[0], min(self.time))
xmax = max(self.xlim[1], max(self.time))
axes.set_xlim([xmin, xmax])
axes.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))

def prepare(self):
# turn interactive mode on and create an empty figure
plt.ion()
self.fig = plt.figure()
self.fig.canvas.set_window_title(self.title)

def update(self, now, agents, env):
# clear figure, create new plot, and update figure
plt.clf()
self.plot(now, agents, env)
plt.draw()
# pause to give time for matplotlib to update figure
plt.pause(self.pause)
# if figure is closed, terminate
figures = plt.get_fignums()
if not figures:
quit()
36 changes: 33 additions & 3 deletions examples/birth_rates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import collections
import dworp
import dworp.plot
import logging
import math
import numpy as np
Expand All @@ -25,6 +27,24 @@ def step(self, now, env):
return [Person(self.color, self.fertility, self.rng) for x in range(num_children)]


class BirthEnvironment(dworp.Environment):
def __init__(self, data):
super().__init__(0)
self.data = collections.Counter()
self.red_count = data['red']
self.blue_count = data['blue']

def step(self, now, agents):
pass

def complete(self, now, agents):
self.data.clear()
for agent in agents:
self.data.update({agent.color: 1})
self.red_count = self.data['red']
self.blue_count = self.data['blue']


class BirthObserver(dworp.Observer):
"""Writes simulation state to stdout"""
def start(self, time, agents, env):
Expand Down Expand Up @@ -63,7 +83,6 @@ class BirthSimulation(dworp.BasicSimulation):
def __init__(self, params, observer):
self.params = params
self.rng = np.random.RandomState(params.seed)
env = dworp.NullEnvironment()
time = dworp.InfiniteTime()
scheduler = dworp.BasicScheduler()
terminator = BirthTerminator()
Expand All @@ -74,6 +93,8 @@ def __init__(self, params, observer):
blue = [Person('blue', params.blue_fertility, self.rng) for x in range(num_people)]
people = red + blue

env = BirthEnvironment({'red': len(red), 'blue': len(blue)})

super().__init__(people, env, time, scheduler, observer, terminator)

def run(self):
Expand All @@ -88,6 +109,7 @@ def run(self):
self.agents.extend(new_people)
# death
self.reap()
self.env.complete(current_time, self.agents)
self.observer.step(current_time, self.agents, self.env)
if self.terminator.test(current_time, self.agents, self.env):
break
Expand Down Expand Up @@ -121,8 +143,16 @@ def reap(self):
assert(1 <= args.capacity <= 4000)
assert(0 <= args.red <= 10)
assert(0 <= args.blue <= 10)
params = BirthParams(args.capacity, args.red, args. blue, args.seed)
params = BirthParams(args.capacity, args.red, args.blue, args.seed)

# create and run one realization of the simulation
sim = BirthSimulation(params, BirthObserver())
observer = dworp.ChainedObserver(
BirthObserver(),
dworp.PauseAtEndObserver(3),
dworp.plot.VariablePlotter(['red_count', 'blue_count'], ['r', 'b'],
title="Birth and Death Sim",
ylim=[300, 700], xlim=[0, 20],
xlabel='Generations', ylabel='Population')
)
sim = BirthSimulation(params, observer)
sim.run()
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[nosetests]
attr=!interactive
95 changes: 95 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from dworp.plot import *
import unittest
import unittest.mock as mock
import matplotlib
import matplotlib.pyplot as plt
import random
import time


class PlotPauseObserverTest(unittest.TestCase):
Expand All @@ -23,3 +26,95 @@ def test_pause(self):
obs = PlotPauseObserver(1)
obs.step(0, [], None)
self.assertTrue(plt.pause.called)


class NonInteractiveVariablePlotterTest(unittest.TestCase):
# let developers filter the plotting tests
plot = True

def test_fmt_string_with_var_list(self):
plotter = VariablePlotter(['data1', 'data2'])
self.assertEqual({'data1': 'b', 'data2': 'b'}, plotter.fmt)

def test_fmt_list_with_var_list(self):
plotter = VariablePlotter(var=['data1', 'data2'], fmt=['r', 'g'])
self.assertEqual({'data1': 'r', 'data2': 'g'}, plotter.fmt)

def test_incorrect_fmt_list_len(self):
self.assertRaises(AssertionError, VariablePlotter, var=['data1'], fmt=['r', 'g'])


class InteractiveVariablePlotterTest(unittest.TestCase):
interactive = True

@classmethod
def setupClass(cls):
print("Running interactive plotting tests")
print("matplotlib.backend: {}".format(matplotlib.get_backend()))

def confirm(self):
result = input("Correct? ([y]/n): ") or "y"
self.assertEqual("y", result)

def test_1_default_arguments(self):
plotter = VariablePlotter('data', title="Basic plot test")
env = mock.Mock()
agents = []

env.data = 0
plotter.start(0, agents, env)
for x in range(1, 50):
env.data = random.random()
plotter.step(x, agents, env)
time.sleep(1)
plotter.stop(x, agents, env)

self.confirm()

def test_2_axes_limits(self):
plotter = VariablePlotter('data', title="Limit test and format test", xlim=[0, 20], ylim=[0, 10], fmt="r--")
env = mock.Mock()
agents = []

env.data = 0
plotter.start(0, agents, env)
for x in range(1, 50):
env.data = 10 * random.random()
plotter.step(x, agents, env)
time.sleep(1)
plotter.stop(x, agents, env)

self.confirm()

def test_3_scrolling(self):
plotter = VariablePlotter('data', title="Scrolling test", scrolling=20, fmt='g')
env = mock.Mock()
agents = []

env.data = 0
plotter.start(0, agents, env)
for x in range(1, 50):
env.data = 10 * random.random()
plotter.step(x, agents, env)
time.sleep(1)
plotter.stop(x, agents, env)

self.confirm()

def test_4_two_variables(self):
plotter = VariablePlotter(var=['data1', 'data2'], fmt=['b', 'r'],
title="Two variable test", legend="upper right")
env = mock.Mock()
agents = []

env.data1 = 0
env.data2 = 0
plotter.start(0, agents, env)
for x in range(1, 50):
env.data1 = 10 * random.random()
env.data2 = 5 * random.random()
plotter.step(x, agents, env)
time.sleep(1)
plotter.stop(x, agents, env)

self.confirm()

0 comments on commit 2407ec4

Please sign in to comment.