In [1]:
import numpy as np
import numba as nb
import matplotlib.pyplot as plt
import imageio

from tqdm import tqdm
from enum import IntEnum
from typing import Tuple

# Susceptible–Infected–Removed model

In [2]:
class State(IntEnum):
    SUSPECTIBLE = 0
    INFECTED = 1
    REMOVED = 2

    
class Population:
    
    def __init__(self, n: int, size_map: int, neighborhood_radius: int, 
                 propability_of_death: float, propability_of_contation: float):
        self.n = n
        self.size_map = size_map
        self.neighborhood_radius = neighborhood_radius
        self.propability_of_death = propability_of_death
        self.propability_of_contation = propability_of_contation
        
        self.population = np.zeros((self.n), dtype=[('x', 'f8'), ('y', 'f8'), ('state', 'i4')])
        self.population['x'] = np.random.random(self.n)*self.size_map
        self.population['y'] = np.random.random(self.n)*self.size_map
        self.population = np.sort(self.population, kind='mergesort', order='x')
        
    def gaussian_move(self):
        self.population['x'] += np.random.randn(self.n)
        self.population['y'] += np.random.randn(self.n) 
        
        fmin = np.vectorize(lambda x: min(x, float(self.size_map)))
        fmax = np.vectorize(lambda x: max(x, 0.))
        
        self.population['x'] = fmin(self.population['x'])
        self.population['x'] = fmax(self.population['x'])
        self.population['y'] = fmin(self.population['y'])
        self.population['y'] = fmax(self.population['y'])
        
        self.population = np.sort(self.population, kind='mergesort', order='x')
        
    def seed_infected_people(self, n_infected):
        indexes = np.random.choice(self.n, n_infected)
        self.population['state'][indexes] = State.INFECTED
       
    def remove(self):
        for man in self.population:
            if man['state'] == State.INFECTED and np.random.rand() < self.propability_of_death:
                man['state'] = State.REMOVED
    
    def infect(self):
        @nb.njit
        def numba_loop(table, infected_state, radius, propability_of_contation):
            def norm(x1, y1, x2, y2):
                return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
            
            for index, man in enumerate(table):
                if man['state'] == infected_state:

                    i = index + 1
                    while(i < len(table) and 
                          table['x'][i] < man['x'] + radius):
                        if(norm(table['x'][i], table['y'][i], man['x'], man['y']) < radius 
                           and np.random.rand() < propability_of_contation):
                            table[i]['state'] = infected_state
                        i += 1

                    i = index - 1
                    while(i >= 0 and 
                          table['x'][i] > man['x'] - radius):
                        if(norm(table['x'][i], table['y'][i], man['x'], man['y']) < radius
                           and np.random.rand() < propability_of_contation):
                            table['state'][i] = infected_state
                        i -= 1
                        
        numba_loop(self.population, State.INFECTED, self.neighborhood_radius, self.propability_of_contation)
                     

# Making animation of pandemic

In [4]:
test = Population(3_000, 100, 0.5, 0.10, 0.6)
test.seed_infected_people(1)

with imageio.get_writer('pandemic.gif', mode='I') as gif_maker:
    for _ in tqdm(range(100)):
        test.gaussian_move()
        test.infect()
        test.remove()

        suspectible = test.population[test.population['state'] == State.SUSPECTIBLE]
        infected = test.population[test.population['state'] == State.INFECTED]
        removed = test.population[test.population['state'] == State.REMOVED]

        plt.figure(figsize=(16, 16), dpi=80)
        plt.scatter(suspectible['x'], suspectible['y'], c='b')
        plt.scatter(infected['x'], infected['y'], c='r')
        plt.scatter(removed['x'], removed['y'], c='black')
        plt.axis('off')
        plt.savefig('pandemic_step.png', bbox_inches='tight')
        plt.close()
        
        image = imageio.imread('pandemic_step.png')
        gif_maker.append_data(image)

        if len(infected) == 0:
            break


 11%|█         | 11/100 [00:14<01:54,  1.28s/it]