# Epidemic Simulator
#### EPS 109 Final Project by Raymond Chau

This is a python epidemic simulator that is inspired by Kevin Simler's Outbreak which is written in javascript.

In [1]:
# %matplotlib notebook
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from enum import Enum
from matplotlib import colors
from matplotlib.animation import FFMpegWriter
from matplotlib.lines import Line2D
from matplotlib.widgets import Slider, Button
from datetime import datetime

In [2]:
%%html
<style> /*Remove interactive buttons in matplotlib plots*/
.output_wrapper button.btn.btn-default,
.output_wrapper .ui-dialog-titlebar {
  display: none;
}
</style>

In [3]:
class Infection(Enum):
    travel_radius = 0
    encounters_per_day = 1
    transmission_rate = 2
    fatality_rate = 3

class _Colors(Enum):
    #matplotlib list of named colors
    whitesmoke = 0
    lightpink = 1
    r = 2
    gray = 3
    black = 4

class Status(Enum):
    susceptible = 0
    infected_symptomless = 1
    infected_symptoms = 2
    recovered = 3
    deceased = 4
    colors = _Colors

class EpidemicSimulation:
    class Person:
        def __init__(self, status=Status.susceptible, days_in_incubation=7, days_with_symptoms=8):
            self.status = status
            self.days_infected = 0
            self.days_in_incubation = days_in_incubation
            self.days_with_symptoms = days_with_symptoms
        
        @property
        def status(self):
            return self.__status
        
        @status.setter
        def status(self, status):
            if isinstance(status, float):
                status = int(status)
            for s in Status:
                if s.value == status:
                    self.__status = s
                    return
            self.__status = status
        
        def update_status(self, infection_vals):
            if (self.status == Status.deceased or
                self.status == Status.recovered):
                return
            elif self.status != Status.susceptible:
                self.days_infected += 1
                if np.random.random() <= infection_vals[Infection.fatality_rate]:
                    self.status = Status.deceased
                elif self.days_infected >= self.days_in_incubation + self.days_with_symptoms:
                    self.status = Status.recovered
                elif self.days_infected >= self.days_in_incubation:
                    self.status = Status.infected_symptoms
            return
    
    def __init__(self, size=100, figsize=(12, 12)):
        assert size > 1 and size <= 150, "Grid size should be larger than 1 and less than 151"
    
        self.size = size//2 * 2 + 1 #make an odd-shaped grid to have a center
        self.figsize = figsize
        self.bins = self.size
        self.factor = self.bins/(2*np.pi)
        self.reset_simulation()
        
        self.fig, self.ax = self.setup_display()
        self.buttons = self.setup_btns()
        self.sliders = self.setup_sliders()
        
    def reset_simulation(self):
        self.infected_population = [] #tuple-positions of infected
        self.population = self.init_population()
        self.maxdist_bins = {b: 0 for b in range(self.bins + 1)} #int(angle*(36/(2*np.pi))) for bin
        self.status_counts = {s: 0 for s in Status}

    @property
    def get_stats(self):
        return self.status_counts
    
    @property
    def get_bins(self):
        return self.maxdist_bins
        
    @property
    def population_values(self):
        return [[p.status.value for p in row] for row in self.population]

    @property
    def infection_vals(self):
        if not hasattr(self, 'sliders'):
            return {
                Infection.travel_radius: 5,
                Infection.encounters_per_day: 10,
                Infection.transmission_rate: 0.3,
                Infection.fatality_rate: 0.03
            }
        return {
            Infection.travel_radius: int(self.sliders[Infection.travel_radius.value].val),
            Infection.encounters_per_day: int(self.sliders[Infection.encounters_per_day.value].val),
            Infection.transmission_rate: self.sliders[Infection.transmission_rate.value].val,
            Infection.fatality_rate: np.power(self.sliders[Infection.fatality_rate.value].val, 2)
        }
        
    def init_population(self):
        population = [[EpidemicSimulation.Person() for _ in range(self.size)] for _ in range(self.size)]
        population = np.array(population)
        center = self.size//2
        for x in range(center-1, center+2):
            if x == center:
                for y in range(center-1, center+2):
                    p = population[center][y]
                    p.status = Status.infected_symptomless
                    self.infected_population.append((center, y))
            else:
                p = population[x][center]
                p.status = Status.infected_symptomless
                self.infected_population.append((x, center))
        return population
    
    def update_population(self):
        remove_infected = set()
        infected_num = len(self.infected_population)
        for idx in range(infected_num):
            x, y = self.infected_population[idx]
            p = self.population[x][y]
            p.update_status(self.infection_vals)
            if p.status == Status.deceased or p.status == Status.recovered:
                remove_infected.add(idx)
                self.status_counts[p.status] += 1

            center = self.size//2
            dist = abs(np.linalg.norm(np.array((x-center, y-center))))
            curr_bin = int((np.arctan2(y-center, x-center) + np.pi)*self.factor)
            max_dist = self.maxdist_bins[curr_bin]
            if dist <= (1.2e-4*infected_num)*max_dist:
                continue
            elif dist > max_dist:
                self.maxdist_bins[curr_bin] = dist
            
            for encounter in range(self.infection_vals[Infection.encounters_per_day]):
                if np.random.random() > self.infection_vals[Infection.transmission_rate]:
                    continue
                #Infect
                rand_ang = np.random.random() * 2*np.pi
                rand_dist = np.random.random() * self.infection_vals[Infection.travel_radius]
                infect = (int(x + np.cos(rand_ang)*rand_dist), int(y + np.sin(rand_ang)*rand_dist))
                if (infect[0] >= 0 and infect[0] < self.size) and (infect[1] >= 0 and infect[1] < self.size):
                    if self.population[infect[0]][infect[1]].status is Status.susceptible:
                        new_infected = self.population[infect[0]][infect[1]]
                        new_infected.status = Status.infected_symptomless
                        self.infected_population.append(infect)
                        
        self.infected_population = [self.infected_population[i] for i in range(len(self.infected_population)) if i not in remove_infected]
        
    def setup_display(self):
        color_map = colors.ListedColormap([c.name for c in Status.colors.value])
        bounds = np.arange(0, len(Status.colors.value) + 1)
        norm = colors.BoundaryNorm(bounds, color_map.N)

        fig, ax = plt.subplots(figsize=self.figsize)

        plt.subplots_adjust(left=0, bottom=0.275, right=0.6, top=1, wspace=2, hspace=0)
        ax.imshow(self.population_values, cmap=color_map, norm=norm)
        ax.grid(which='major', axis='both', linestyle='-', color='white', linewidth=2 - 0.015*self.size)
        ax.set_xticks(np.arange(-.5, self.size, 1))
        ax.set_yticks(np.arange(-.5, self.size, 1))
        
        custom_lines = [Line2D([0], [0], color=c.name) for c in Status.colors.value]
        ax.legend(custom_lines, 
                  ['Susceptible', 'Infected (symptomless)', 'Infected (symptoms)', 'Recovered', 'Deceased'], 
                  bbox_to_anchor=(1.4, 1), title='EPIDEMIC SIMULATION LEGEND')

        for tic in ax.xaxis.get_major_ticks():
            tic.tick1line.set_visible(False)
            tic.label1.set_visible(False)

        for tic in ax.yaxis.get_major_ticks():
            tic.tick1line.set_visible(False)
            tic.label1.set_visible(False)
            
        return fig, ax

    def setup_btns(self):
        reset_ax = plt.axes([0, 0.025 + 0.25, 0.19, 0.04])
        play_ax = plt.axes([0.2, 0.025 + 0.25, 0.2, 0.04])
        step_ax = plt.axes([0.41, 0.025 + 0.25, 0.19, 0.04])
        
        reset_button = Button(reset_ax, 'Reset')
        play_button = Button(play_ax, 'Play')
        step_button = Button(step_ax, 'Step')
        
        reset_button.on_clicked(lambda e: self.reset(e))
        play_button.on_clicked(lambda e: self.play(e))
        step_button.on_clicked(lambda e: self.step(e))
        return [reset_button, play_button, step_button]
        
    def setup_sliders(self):
        travel_ax = plt.axes([0.2, 0.025 + 0.2, 0.4, 0.04])
        encounter_ax = plt.axes([0.2, 0.025 + 0.15, 0.4, 0.04])
        transmission_ax = plt.axes([0.2, 0.025 + 0.1, 0.4, 0.04])
        fatality_ax = plt.axes([0.2, 0.025 + 0.05, 0.4, 0.04])

        travel_slider = Slider(travel_ax, 
                               'Travel Radius',
                               valinit=5,
                               valmin=0, 
                               valmax=25,
                               valstep=1,
                               valfmt='%1d'
                              )
        encounter_slider = Slider(encounter_ax, 
                                  'Encounters per day',
                                  valinit=10,
                                  valmin=1,
                                  valmax=30,
                                  valstep=1,
                                  valfmt='%1d'
                                 )
        transmission_slider = Slider(transmission_ax, 
                                     'Transmission Rate',
                                     valinit=0.3,
                                     valmin=0, 
                                     valmax=1.0,
                                     valstep=0.01
                                    )
        fatality_slider = Slider(fatality_ax, 
                                 'Fatality Rate',
                                 valinit=0.03,
                                 valmin=0, 
                                 valmax=0.3,
                                 valstep=0.01
                                )
        return [travel_slider, encounter_slider, transmission_slider, fatality_slider]

    def display_population(self):
        color_map = colors.ListedColormap([c.name for c in Status.colors.value])
        bounds = np.arange(0, len(Status.colors.value) + 1)
        norm = colors.BoundaryNorm(bounds, color_map.N)
        self.ax.imshow(self.population_values, cmap=color_map, norm=norm)
        self.fig.canvas.draw()
    
    def animate_population(self, iterations=999999):
        curr_iter = 0
        while curr_iter < iterations and len(self.infected_population) > 0:
            curr_iter += 1
            self.update_population()
            self.display_population()

    def record(self):
        title = 'EpidemicSimulation'
        metadata = dict(title=title, artist='Raymond Chau')
        writer = FFMpegWriter(fps=6, metadata=metadata)
        curr_time = datetime.now().strftime("%Y%m%d-%H%M%S")
        with writer.saving(self.fig, f"{title}{curr_time}.mp4", dpi=200):
            writer.grab_frame()
            while len(self.infected_population) > 0:
                self.update_population()
                self.display_population()
                plt.pause(0.1)
                writer.grab_frame()
        
    def play(self, event):
        self.animate_population()
    
    def step(self, event):
        self.update_population()
        self.display_population()

    def reset(self, event):
        self.reset_simulation()
        self.display_population()

In [8]:
sim = EpidemicSimulation(100, figsize=(12,12))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [5]:
# sim.record()
# sim.get_stats #only recovered and deceased counts are implemented