Skip to content

Commit

Permalink
Merge 6c92a10 into 66c6412
Browse files Browse the repository at this point in the history
  • Loading branch information
cash committed May 10, 2018
2 parents 66c6412 + 6c92a10 commit 28be595
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 4 deletions.
106 changes: 105 additions & 1 deletion dworp/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Distributed under the terms of the Modified BSD License.

import matplotlib.pyplot as plt
from .observer import PauseObserver
from .observer import Observer, PauseObserver

# Note: do not include this in __init__.py so that dworp does not have a hard requirement
# for matplotlib.
Expand All @@ -24,3 +24,107 @@ def __init__(self, delay, start=False, stop=False):

def pause(self):
plt.pause(self.delay)


class VariablePlotter(Observer): # pragma: no cover
"""Plot one or more variables from the Environment
Args:
var (string, list): Name or list of names of variable in Environment to plot
fmt (string, list): Optional matplotlib format string or list of strings (default is "b")
scrolling (int): Optional number of time steps in scroll or 0 for no scrolling
title(string): Optional figure title (default is name of variable)
xlabel(string): Optional x-axis label (default is "Time")
ylabel(string): Optional y-axis label (default is name of variable)
xlim(list): Optional starting x-axis limits as [xmin, xmax]
ylim(list): Optional starting y-axis limits as [ymin, ymax]
pause(float): Optional pause between updates (must be > 0)
"""
def __init__(self, var, fmt="b", scrolling=0, title=None, xlabel="Time", ylabel=None,
xlim=None, ylim=None, pause=0.001):
self.var_names = [var] if isinstance(var, str) else var
self.fmt = self._prepare_format_option(fmt)
self.scrolling = scrolling
self.title = title if title else self._prepare_default_title()
self.xlabel = xlabel
self.ylabel = ylabel if ylabel else self._prepare_default_title()
self.xlim = xlim
self.ylim = ylim
self.pause = pause
self.fig = None
self.axes_margin = 0.01

self.time = []
self.data = {name: [] for name in self.var_names}

def _prepare_format_option(self, fmt):
# fmt must be either same length as var_names or length 1
format_ = [fmt] if isinstance(fmt, str) else fmt
assert(len(format_) == 1 or len(format_) == len(self.var_names))
if len(format_) != len(self.var_names):
format_ = format_ * len(self.var_names)
return dict(zip(self.var_names, format_))

def _prepare_default_title(self):
if len(self.var_names) == 0:
return self.var_names[0]
else:
return ' & '.join(self.var_names)

def start(self, now, agents, env):
self.prepare()
self.update(now, agents, env)

def step(self, now, agents, env):
self.update(now, agents, env)

def stop(self, now, agents, env):
plt.close(self.fig)

def plot(self, now, agents, env):
self.time.append(now)
for name in self.var_names:
self.data[name].append(getattr(env, name))
if self.scrolling:
plot_time = self.time[(-1 * self.scrolling):]
plot_data = {name: data[(-1 * self.scrolling):] for name, data in self.data.items()}
else:
plot_time = self.time
plot_data = self.data

for name, data in plot_data.items():
plt.plot(plot_time, data, self.fmt[name])
axes = self.fig.axes[0]
axes.set_xlabel(self.xlabel)
axes.set_ylabel(self.ylabel)
self.set_axes_limits(axes)

def set_axes_limits(self, axes):
if self.ylim:
ylim = axes.get_ylim()
margin = self.axes_margin * abs(ylim[1] - ylim[0])
ymin = min(self.ylim[0], min(min(self.data.values())) - margin)
ymax = max(self.ylim[1], max(max(self.data.values())) + margin)
axes.set_ylim([ymin, ymax])
if self.xlim:
xmin = min(self.xlim[0], min(self.time))
xmax = max(self.xlim[1], max(self.time))
axes.set_xlim([xmin, xmax])

def prepare(self):
# turn interactive mode on and create an empty figure
plt.ion()
self.fig = plt.figure()
self.fig.canvas.set_window_title(self.title)

def update(self, now, agents, env):
# clear figure, create new plot, and update figure
plt.clf()
self.plot(now, agents, env)
plt.draw()
# pause to give time for matplotlib to update figure
plt.pause(self.pause)
# if figure is closed, terminate
figures = plt.get_fignums()
if not figures:
quit()
71 changes: 68 additions & 3 deletions examples/birth_rates.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import argparse
import collections
import dworp
import logging
import math
import matplotlib.pyplot as plt
import numpy as np


Expand All @@ -25,6 +27,21 @@ def step(self, now, env):
return [Person(self.color, self.fertility, self.rng) for x in range(num_children)]


class BirthEnvironment(dworp.Environment):
def __init__(self, colors):
super().__init__(0)
self.colors = colors
self.data = collections.Counter()

def step(self, now, agents):
pass

def complete(self, now, agents):
self.data.clear()
for agent in agents:
self.data.update({agent.color: 1})


class BirthObserver(dworp.Observer):
"""Writes simulation state to stdout"""
def start(self, time, agents, env):
Expand All @@ -43,6 +60,49 @@ def get_counts(agents):
return num_red, num_blue


class MatplotlibObservor(dworp.Observer):
def __init__(self):
self.time = []
self.blue = []
self.red = []

def start(self, now, agents, env):
plt.ion()
plt.figure()
plt.gcf().canvas.set_window_title("Populations")

def step(self, now, agents, env):
self.plot(now, agents, env)

def plot(self, now, agents, env):
plt.clf()

self.time.append(now)
self.blue.append(env.data['blue'])
self.red.append(env.data['red'])

plt.plot(self.time[-20:], self.blue[-20:], 'b')
plt.plot(self.time[-20:], self.red[-20:], 'r')

plt.xlabel("Generations")
plt.ylabel("Population")

axes = plt.gca()
ylim = axes.get_ylim()
margin = 0.01 * abs(ylim[1] - ylim[0])
if len(self.time) < 20:
axes.set_xlim([0, 20])
ymin = min(300, min(self.red) - margin, min(self.blue) - margin)
ymax = max(700, max(self.red) + margin, max(self.blue) + margin)
axes.set_ylim([ymin, ymax])

plt.draw()
plt.pause(0.001)
figures = plt.get_fignums()
if not figures:
quit()


class BirthTerminator(dworp.Terminator):
"""Stop when only one people color is left"""
def test(self, now, agents, env):
Expand All @@ -63,7 +123,7 @@ class BirthSimulation(dworp.BasicSimulation):
def __init__(self, params, observer):
self.params = params
self.rng = np.random.RandomState(params.seed)
env = dworp.NullEnvironment()
env = BirthEnvironment(['red', 'blue'])
time = dworp.InfiniteTime()
scheduler = dworp.BasicScheduler()
terminator = BirthTerminator()
Expand All @@ -88,6 +148,7 @@ def run(self):
self.agents.extend(new_people)
# death
self.reap()
self.env.complete(current_time, self.agents)
self.observer.step(current_time, self.agents, self.env)
if self.terminator.test(current_time, self.agents, self.env):
break
Expand Down Expand Up @@ -121,8 +182,12 @@ def reap(self):
assert(1 <= args.capacity <= 4000)
assert(0 <= args.red <= 10)
assert(0 <= args.blue <= 10)
params = BirthParams(args.capacity, args.red, args. blue, args.seed)
params = BirthParams(args.capacity, args.red, args.blue, args.seed)

# create and run one realization of the simulation
sim = BirthSimulation(params, BirthObserver())
observer = dworp.ChainedObserver(
BirthObserver(),
MatplotlibObservor()
)
sim = BirthSimulation(params, observer)
sim.run()
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[nosetests]
attr=!interactive
94 changes: 94 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from dworp.plot import *
import unittest
import unittest.mock as mock
import matplotlib
import matplotlib.pyplot as plt
import random
import time


class PlotPauseObserverTest(unittest.TestCase):
Expand All @@ -23,3 +26,94 @@ def test_pause(self):
obs = PlotPauseObserver(1)
obs.step(0, [], None)
self.assertTrue(plt.pause.called)


class NonInteractiveVariablePlotterTest(unittest.TestCase):
# let developers filter the plotting tests
plot = True

def test_fmt_string_with_var_list(self):
plotter = VariablePlotter(['data1', 'data2'])
self.assertEqual({'data1': 'b', 'data2': 'b'}, plotter.fmt)

def test_fmt_list_with_var_list(self):
plotter = VariablePlotter(var=['data1', 'data2'], fmt=['r', 'g'])
self.assertEqual({'data1': 'r', 'data2': 'g'}, plotter.fmt)

def test_incorrect_fmt_list_len(self):
self.assertRaises(AssertionError, VariablePlotter, var=['data1'], fmt=['r', 'g'])


class InteractiveVariablePlotterTest(unittest.TestCase):
interactive = True

@classmethod
def setupClass(cls):
print("Running interactive plotting tests")
print("matplotlib.backend: {}".format(matplotlib.get_backend()))

def confirm(self):
result = input("Correct? ([y]/n): ") or "y"
self.assertEqual("y", result)

def test_1_default_arguments(self):
plotter = VariablePlotter('data', title="Basic plot test")
env = mock.Mock()
agents = []

env.data = 0
plotter.start(0, agents, env)
for x in range(1, 50):
env.data = random.random()
plotter.step(x, agents, env)
time.sleep(1)
plotter.stop(x, agents, env)

self.confirm()

def test_2_axes_limits(self):
plotter = VariablePlotter('data', title="Limit test and format test", xlim=[0, 20], ylim=[0, 5], fmt="r--")
env = mock.Mock()
agents = []

env.data = 0
plotter.start(0, agents, env)
for x in range(1, 50):
env.data = 10 * random.random()
plotter.step(x, agents, env)
time.sleep(1)
plotter.stop(x, agents, env)

self.confirm()

def test_3_scrolling(self):
plotter = VariablePlotter('data', title="Scrolling test", scrolling=20, fmt='g')
env = mock.Mock()
agents = []

env.data = 0
plotter.start(0, agents, env)
for x in range(1, 50):
env.data = 10 * random.random()
plotter.step(x, agents, env)
time.sleep(1)
plotter.stop(x, agents, env)

self.confirm()

def test_4_two_variables(self):
plotter = VariablePlotter(var=['data1', 'data2'], fmt=['b', 'r'], title="Two variable test")
env = mock.Mock()
agents = []

env.data1 = 0
env.data2 = 0
plotter.start(0, agents, env)
for x in range(1, 50):
env.data1 = 10 * random.random()
env.data2 = 5 * random.random()
plotter.step(x, agents, env)
time.sleep(1)
plotter.stop(x, agents, env)

self.confirm()

0 comments on commit 28be595

Please sign in to comment.