In [81]:
%matplotlib notebook

import numpy as np
import matplotlib.pyplot as plt
from enum import Enum
from matplotlib import colors
from matplotlib.lines import Line2D
from matplotlib.widgets import Slider, Button
from IPython.display import display
import time

In [83]:
%%html
<style> /*Remove interactive buttons in matplotlib plots*/
.output_wrapper button.btn.btn-default,
.output_wrapper .ui-dialog-titlebar {
  display: none;
}
</style>

In [94]:
class Infection(Enum):
    travel_radius = 0
    encounters_per_day = 1
    transmission_rate = 2
    fatality_rate = 3

class Status(Enum):
    susceptible = 0
    infected_symptomless = 1
    infected_symptoms = 2
    recovered = 3
    deceased = 4

class EpidemicSimulation:
    class Person:
        def __init__(self, status=Status.susceptible, days_in_incubation=7, days_with_symptoms=8):
            self.status = status
            self.days_infected = 0
            self.days_in_incubation = days_in_incubation
            self.days_with_symptoms = days_with_symptoms
        
        @property
        def status(self):
            return self.__status
        
        @status.setter
        def status(self, status):
            for key, val in EpidemicSimulation.status_types.items():
                if status == val or status is key:
                    self.__status = val
                    return
            self.__status = EpidemicSimulation.status_types['susceptible']
        
        def update_status(self, infection_vals):
            if (self.status == EpidemicSimulation.status_types['deceased'] or
                self.status == EpidemicSimulation.status_types['recovered']):
                return
            elif self.status != EpidemicSimulation.status_types['susceptible']:
                self.days_infected += 1
                if np.random.random() > infection_vals['Fatality Rate']:
                    self.status = EpidemicSimulation.status_types['deceased']
                elif self.days_infected >= self.days_in_incubation + self.days_with_symptoms:
                    self.status = EpidemicSimulation.status_types['recovered']
                elif self.days_infected >= self.days_in_incubation:
                    self.status = EpidemicSimulation.status_types['infected (symptoms)']
            return
        
#     infection_variables = {
#         'Travel Radius': 0,
#         'Encounters per day': 1,
#         'Transmission Rate': 2,
#         'Fatality Rate': 3
#     }
    
    color_types = ['lightgray', 'lightcoral', 'r', 'dimgray', 'black']
    
    def __init__(self, size, figsize=(7, 7)):
        assert size > 1, "Grid size should be larger than 1"
    
        self.size = size//2 * 2 + 1 #make an odd-shaped grid to have a center
        self.figsize = figsize
        self.population = self.init_population()
        
        self.fig, self.ax = self.setup_display()
        self.buttons = self.setup_btns()
        self.sliders = self.setup_sliders()
        
        
    @property
    def infection_vals(self):
        if not hasattr(self, 'sliders'):
            return {
                'Travel Radius': 5,
                'Encounters per day': 10,
                'Transmission Rate': 0.3,
                'Fatality Rate': 0.03
            }
        return {
            'Travel Radius': self.sliders[EpidemicSimulation.infection_variables['Travel Radius']] or 5,
            'Encounters per day': self.sliders[EpidemicSimulation.infection_variables['Encounters per day']] or 10,
            'Transmission Rate': self.sliders[EpidemicSimulation.infection_variables['Transmission Rate']] or 0.3,
            'Fatality Rate': self.sliders[EpidemicSimulation.infection_variables['Fatality Rate']] or 0.03
        }
        
    def init_population(self):
        population = np.array([[EpidemicSimulation.Person() for _ in range(self.size)] for _ in range(self.size)])
        center = self.size//2
        for p in population[center-1:center+2, center]:
            p.status = EpidemicSimulation.status_types['infected (symptomless)']
        for p in population[center, center-1:center+2]:
            p.status = EpidemicSimulation.status_types['infected (symptomless)']
        return population
    
    def update_population(self):
        for x in range(self.size):
            for y in range(self.size):
                return
        return
        
    def setup_display(self):
        color_map = colors.ListedColormap(EpidemicSimulation.color_types)
        bounds = np.arange(0, len(EpidemicSimulation.color_types) + 1)
        norm = colors.BoundaryNorm(bounds, color_map.N)

        fig, ax = plt.subplots(figsize=self.figsize)

        plt.subplots_adjust(left=0, bottom=0.275, right=0.6, top=1, wspace=2, hspace=0)
        ax.imshow([[p.status for p in row] for row in self.population], cmap=color_map, norm=norm)
        ax.grid(which='major', axis='both', linestyle='-', color='white', linewidth=1.2)
        ax.set_xticks(np.arange(-.5, self.size, 1));
        ax.set_yticks(np.arange(-.5, self.size, 1));
        
        custom_lines = [Line2D([0], [0], 
                               color=EpidemicSimulation.color_types[EpidemicSimulation.status_types[status]]) 
                        for status, color in EpidemicSimulation.status_types.items()]
        ax.legend(custom_lines, 
                  EpidemicSimulation.status_types.keys(), 
                  bbox_to_anchor=(1.6, 1),
                 title='EPIDEMIC SIMULATION LEGEND')

        for tic in ax.xaxis.get_major_ticks():
            tic.tick1line.set_visible(False)
            tic.label1.set_visible(False)

        for tic in ax.yaxis.get_major_ticks():
            tic.tick1line.set_visible(False)
            tic.label1.set_visible(False)
            
        return fig, ax

    def setup_btns(self):
        reset_ax = plt.axes([0, 0.025 + 0.25, 0.19, 0.04])
        play_ax = plt.axes([0.2, 0.025 + 0.25, 0.2, 0.04])
        step_ax = plt.axes([0.41, 0.025 + 0.25, 0.19, 0.04])
        
        reset_button = Button(reset_ax, 'Reset')
        play_button = Button(play_ax, 'Play')
        step_button = Button(step_ax, 'Step')
        
        reset_button.on_clicked(lambda e: self.reset(e))
        play_button.on_clicked(lambda e: self.play(e))
        step_button.on_clicked(lambda e: self.step(e))
        return [reset_button, play_button, step_button]
        
    def setup_sliders(self):
        travel_ax = plt.axes([0.2, 0.025 + 0.2, 0.4, 0.04])
        encounter_ax = plt.axes([0.2, 0.025 + 0.15, 0.4, 0.04])
        transmission_ax = plt.axes([0.2, 0.025 + 0.1, 0.4, 0.04])
        fatality_ax = plt.axes([0.2, 0.025 + 0.05, 0.4, 0.04])

        travel_slider = Slider(travel_ax, 
                               'Travel Radius',
                               valinit=5,
                               valmin=0, 
                               valmax=25,
                               valstep=1,
                               valfmt='%1d'
                              )
        encounter_slider = Slider(encounter_ax, 
                                  'Encounters per day',
                                  valinit=10,
                                  valmin=1,
                                  valmax=30,
                                  valstep=1,
                                  valfmt='%1d'
                                 )
        transmission_slider = Slider(transmission_ax, 
                                     'Transmission Rate',
                                     valinit=0.3,
                                     valmin=0, 
                                     valmax=1.0,
                                     valstep=0.01
                                    )
        fatality_slider = Slider(fatality_ax, 
                                 'Fatality Rate',
                                 valinit=0.03,
                                 valmin=0, 
                                 valmax=0.3,
                                 valstep=0.01
                                )
        return [travel_slider, encounter_slider, transmission_slider, fatality_slider]
    
    def redraw(self, times=3):
        for i in range(times):
            population = np.random.rand(self.size, self.size) * 5
            color_map = colors.ListedColormap(EpidemicSimulation.color_types)
            bounds = np.append(0, np.arange(1, len(EpidemicSimulation.color_types) + 1) + 0.1)
            norm = colors.BoundaryNorm(bounds, color_map.N)
            self.ax.imshow(population, cmap=color_map, norm=norm)
            self.fig.canvas.draw()
            if i != times - 1:
                time.sleep(1)
        return
        
    def play(self, event):
        self.redraw()
    
    def step(self, event):
        self.redraw(1)
    
    def reset(self, event):
        self.redraw()

In [95]:
epi = EpidemicSimulation(10, figsize=(9,9))

AttributeError: type object 'EpidemicSimulation' has no attribute 'status_types'