diff --git a/dworp/observer.py b/dworp/observer.py index 75765bc..61321c7 100644 --- a/dworp/observer.py +++ b/dworp/observer.py @@ -53,7 +53,10 @@ class ChainedObserver(Observer): *observers: Variable length arguments of Observer objects """ def __init__(self, *observers): - self.observers = observers + self.observers = list(observers) + + def append(self, observer): + self.observers.append(observer) def start(self, time, agents, env): for observer in self.observers: diff --git a/examples/segregation.py b/examples/segregation.py index 49e2dcc..0555929 100644 --- a/examples/segregation.py +++ b/examples/segregation.py @@ -2,10 +2,13 @@ import dworp import dworp.plot import logging -import matplotlib.colors -import matplotlib.pyplot as plt import numpy as np -import seaborn as sns +import sys +try: + import pygame +except ImportError: + # vis will be turned off + pass class Household(dworp.Agent): @@ -112,35 +115,43 @@ def test(self, time, agents, env): return SegObserver.get_happiness(agents) >= 100 -class HeatmapPlotObserver(dworp.Observer): - """Plot the segregration grid""" - def __init__(self, colors): - self.data = None - self.colors = colors - cmap = matplotlib.colors.ListedColormap([colors[0], "white", colors[1]]) - self.options = { - 'cmap': cmap, 'cbar': False, 'linewidths': 0.2, - 'xticklabels': False, 'yticklabels': False - } +class PyGameRenderer(dworp.Observer): + def __init__(self, size, zoom, fps): + self.zoom = zoom + self.fps = fps + self.width = size[0] + self.height = size[1] + + pygame.init() + pygame.display.set_caption("Segregation") + self.screen = pygame.display.set_mode((self.zoom * self.width, self.zoom * self.height)) + self.background = pygame.Surface((self.screen.get_size())) + self.background = self.background.convert() + self.clock = pygame.time.Clock() def start(self, time, agents, env): - self.data = np.zeros(env.grid.data.shape) - plt.ion() - self.plot(env.grid) + self.draw(agents) def step(self, time, agents, env): - self.plot(env.grid) + self.draw(agents) - def plot(self, grid): - for x in range(self.data.shape[0]): - for y in range(self.data.shape[1]): - if grid.data[x, y] is None: - self.data[x, y] = 0 - elif grid.data[x, y].color == self.colors[0]: - self.data[x, y] = -1 - else: - self.data[x, y] = 1 - sns.heatmap(self.data, **self.options) + def done(self, agents, env): + pygame.quit() + + def draw(self, agents): + side = self.zoom - 1 + self.background.fill((255, 255, 255)) + for agent in agents: + x = self.zoom * agent.x + y = self.zoom * agent.y + color = (255, 128, 0) if agent.color == "orange" else (0, 0, 255) + pygame.draw.rect(self.background, color, (x, y, side, side), 0) + self.screen.blit(self.background, (0, 0)) + pygame.display.flip() + self.clock.tick(self.fps) + for event in pygame.event.get(): + if event.type == pygame.QUIT: + quit() class SegregationParams: @@ -186,6 +197,9 @@ def __init__(self, params, observer): parser.add_argument("--similar", help="desired similarity (0-100)", default=30, type=int) parser.add_argument("--size", help="grid size formatted as XXXxYYY", default="50x50") parser.add_argument("--seed", help="seed of RNG", default=42, type=int) + parser.add_argument("--fps", help="frames per second", default="2", type=int) + parser.add_argument("--no-vis", dest='vis', action='store_false') + parser.set_defaults(vis=True) args = parser.parse_args() # prepare parameters of simulation @@ -195,14 +209,16 @@ def __init__(self, params, observer): similarity = args.similar / float(100) grid_size = [int(dim) for dim in args.size.split("x")] seed = args.seed + vis_flag = args.vis and 'pygame' in sys.modules + # vis does not support different colors colors = ["blue", "orange"] params = SegregationParams(density, similarity, grid_size, seed, colors) # create and run one realization of the simulation observer = dworp.ChainedObserver( SegObserver(), - HeatmapPlotObserver(colors), - dworp.plot.PlotPauseObserver(delay=1, start=True) ) + if vis_flag: + observer.append(PyGameRenderer(grid_size, 10, args.fps)) sim = SegregationSimulation(params, observer) sim.run() diff --git a/tests/test_observer.py b/tests/test_observer.py index 66b2831..372cc75 100644 --- a/tests/test_observer.py +++ b/tests/test_observer.py @@ -40,6 +40,17 @@ def test_stop_called_for_all_observers(self): self.assertEqual([mock.call.stop([], None)], obs1.mock_calls) self.assertEqual([mock.call.stop([], None)], obs2.mock_calls) + def test_append(self): + obs1 = mock.create_autospec(spec=Observer) + obs2 = mock.create_autospec(spec=Observer) + obs = ChainedObserver(obs1) + obs.append(obs2) + + obs.step(1, [], None) + + self.assertEqual([mock.call.step(1, [], None)], obs1.mock_calls) + self.assertEqual([mock.call.step(1, [], None)], obs2.mock_calls) + class KeyPauseObserverTest(unittest.TestCase): def setUp(self):