# Introduction to Computational science - Assignment 2
Sander Broos, Nick van Santen

In [None]:
# Imports
from __future__ import annotations
from ipywidgets import *

import matplotlib.pyplot as plt
import numpy as np
import bisect
import math
from scipy.fft import fft, fftfreq

from typing import Callable, Dict, List

In [None]:
# Run cell to increase font sizes. Usefull when saving plots
SMALL_SIZE = 16
MEDIUM_SIZE = 20
BIGGER_SIZE = 24

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# Gillespie’s Direct Algorithm

In [None]:
class Event:

    def __init__(self, name: str, rate: Callable, event: dict):

        self.name = name
        self.event = event

        self.rate = rate

    def occur(self, groups, time):
        
        for group, value_to_add in self.event.items():
            groups[group].add_value(value_to_add, time)

class Simulator:

    def __init__(self, groups: Dict[str, Group], events: List[Event], max_time: float):

        self.groups = groups

        self.events = events

        self.time = 0
        self.max_time = max_time
        self.time_steps = [0]
    
    def update(self):

        total_rate = sum([event.rate(self.groups) for event in self.events])

        if total_rate == 0:
            self.time = self.max_time
            return

        r1 = np.random.rand()
        delta_time = -1 / total_rate * np.log(r1)
        self.time += delta_time

        if self.time > self.max_time:
            return

        self.time_steps.append(self.time)
        
        r2 = np.random.rand()
        P = r2 * total_rate
        event = self.determine_event(P)

        # Apply event
        event.occur(self.groups, self.time)

    def determine_event(self, p: float):
        
        value = 0
        for event in self.events:

            value += event.rate(self.groups)

            if value > p:
                return event
        
        print("ERROR: No event found")
        return None
    
    def total_n(self):
        return sum([group.number for group in self.groups.values()])
    
    def run(self):

        while self.time < self.max_time:
            self.update()

        self.finalise_results()

    def reset(self):
        
        for group in self.groups.values():
            group.reset()

        self.time = 0
        self.time_steps = [0]
         
    def finalise_results(self):

        for group in self.groups.values():

            group.append_to_history(group.number, self.max_time)

    def plot_group_levels(self):

        for group in self.groups.values():

            plt.plot(group.time_steps, group.history, label=group.name, drawstyle="steps-post")

        plt.title("Evolution of the system")
        plt.xlabel("Time")
        plt.ylabel("Cases")
        plt.grid(alpha=0.3)
        plt.legend()
        plt.show()

    def plot_FFT(self, group_name: str, start_time = 20):
        """
        Plots the fast fourier transformation of a given group name.
        """

        if group_name not in self.groups:
            raise Exception(f"Error: Invalid group name. {group_name} is not one of the current groups. Choises are {[group.name for group in self.groups.values()]}")

        group = self.groups[group_name]
        
        if start_time >= group.time_steps[-1]:
            print("ERROR: start time of FFT is less than the maximum simulation time.")
            return
            
        start_index = bisect.bisect_left(group.time_steps, start_time)
        
        # https://docs.scipy.org/doc/scipy/reference/tutorial/fft.html
        
        y = group.history[start_index:]
        N = len(y)

        # Sample spacing, assumes that each event has the same time interval. 
        T = (group.time_steps[-1] - group.time_steps[start_index]) / len(group.time_steps[start_index:])

        xf = fftfreq(N, T)[:N//2]        
        yf = fft(np.array(y) - np.mean(y))
        yf = 2.0/N * np.abs(yf[:N//2])
        
        # Get the max amplitude and the frequency at the max amplitude
        max_amp = max(yf)
        max_freq = xf[np.where(yf==max_amp)][0]

        plt.text(0.95,0.95, f"Max amplitude: {max_amp:.3e}\n Max frequency: {max_freq:.3e}",transform=plt.gca().transAxes, ha="right", va="top")

        plt.plot(xf, yf)

        plt.title("Fast Fourier Transform")
        plt.xlabel("Frequency")
        plt.ylabel("Amplitude")

        plt.grid(alpha=0.3)

    def print_group_levels(self):

        for group in self.groups.values():
            print(f"{group.name}: {group.number}")

class Group:

    def __init__(self, name: str, initial: int):
        
        self.initial = initial
        self.name = name
        self.history = [initial]
        self.time_steps = [0]

    @property
    def number(self):
        return self.history[-1]

    def add_value(self, value: int, time: float):

        self.append_to_history(self.number + value, time)

    def append_to_history(self, value: int, time: float):
        
        self.history.append(value)
        self.time_steps.append(time)

    def reset(self):
        
        self.__init__(self.name, self.initial)
    
    def __str__(self):

        return f"{self.name}: {self.history}"


In [None]:
class SIRSimulator(Simulator):

    def __init__(self, beta=3.0, gamma=1.0, delta=0.0, epsilon=0.0, mu=0.0, s_init=1000, i_init=5, r_init=0, max_time=10.0, **kwargs):

        groups = {
            "susceptible": Group("susceptible", s_init),
            "infected": Group("infected", i_init),
            "recovered": Group("recovered", r_init),
        }
        
        events = [
            Event(name="birth", 
                rate=lambda groups: mu * self.total_n(), 
                event={"susceptible": 1}),
            Event(name="transmission", 
                rate=lambda groups: beta * groups["susceptible"].number * groups["infected"].number / self.total_n(), 
                event={"susceptible": -1, "infected": 1}),
            Event(name="recovery", 
                rate=lambda groups: gamma * groups["infected"].number, 
                event={"infected": -1, "recovered": 1}),
            Event(name="death_s", 
                rate=lambda groups: mu * groups["susceptible"].number, 
                event={"susceptible": -1}),
            Event(name="death_i", 
                rate=lambda groups: mu * groups["infected"].number, 
                event={"infected": -1}),
            Event(name="death_r", 
                rate=lambda groups: mu * groups["recovered"].number, 
                event={"recovered": -1}),
            Event(name="import_move_in", 
                rate=lambda groups: delta * self.total_n()**0.5, 
                event={"infected": 1}),
            Event(name="import_pass_through", 
                rate=lambda groups: epsilon * groups["susceptible"].number * ((beta / (gamma + mu)) - 1) / self.total_n()**0.5, 
                event={"susceptible": -1, "infected": 1}),
        ]

        super().__init__(groups, events, max_time)

In [None]:
class SIRDeterminisic:

    def __init__(self, susceptible, infected, recovered, infection_rate, recovery_rate, time_step, max_time, natural_death_rate, infection_death_probability):

        # Make sure the values are normalized
        N = susceptible + infected + recovered

        self.susceptible = susceptible / N
        self.infected = infected / N
        self.recovered = recovered / N
        
        self.infection_rate = infection_rate
        self.recovery_rate = recovery_rate

        self.natural_death_rate = natural_death_rate
        self.infection_death_probability = infection_death_probability
        
        self.time_step = time_step
        self.max_time = max_time
        self.time = 0

        # Lists which keep track of the number of SIR at each timestep. Seperate timestep list which allows the simulator to
        # change the timesteps during the simulation
        self.time_steps = [self.time]
        self.susceptible_at_timestep = [self.susceptible]
        self.infected_at_timestep = [self.infected]
        self.recovered_at_timestep = [self.recovered]

    def step(self):
        
        # Calculate the changes in all the categories first, before updating them. Otherwise they would impact eachother
        susceptible_change = self.calc_susceptible_change()
        infected_change = self.calc_infected_change()
        recovered_change = self.calc_recovered_change()

        self.susceptible += susceptible_change
        self.infected += infected_change
        self.recovered += recovered_change
        
        self.time += self.time_step

        # Add SIR levels to lists...
        self.time_steps.append(self.time)
        self.susceptible_at_timestep.append(self.susceptible)
        self.infected_at_timestep.append(self.infected)
        self.recovered_at_timestep.append(self.recovered)

    def calc_susceptible_change(self):
        return (-self.infection_rate * self.susceptible * self.infected - self.natural_death_rate * self.susceptible + self.natural_death_rate * (self.susceptible + self.infected + self.recovered)) * self.time_step

    def calc_infected_change(self):
        
        # Catch divide by zero error
        if self.infection_death_probability == 1:
            return -self.infected

        return (self.infection_rate * self.susceptible * self.infected - self.recovery_rate * self.infected - self.natural_death_rate * self.infected - (self.infection_death_probability / (1 - self.infection_death_probability)) * (self.recovery_rate + self.natural_death_rate)* self.infected) * self.time_step 

    def calc_recovered_change(self):
        return (self.recovery_rate * self.infected - self.natural_death_rate * self.recovered) * self.time_step

    def run(self, t=None):

        if t is None:
            t = self.max_time
            
        while self.time < t:
            self.step()

    def plot_SIR_levels(self):
        
        plt.plot(self.time_steps, self.susceptible_at_timestep, label="Susceptible")
        plt.plot(self.time_steps, self.infected_at_timestep, label="Infected")
        plt.plot(self.time_steps, self.recovered_at_timestep, label="Recovered")

        plt.title("SIR levels over time")
        plt.xlabel("Time")
        plt.ylabel("Cases")
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
            
    def print_SIR_levels(self):

        print(f"Susceptible: {self.susceptible}")
        print(f"Infected: {self.infected}")
        print(f"Recovered: {self.recovered}")
 

In [None]:

def sir_gillespie(beta=3.0, gamma=1.0, delta=0.0, epsilon=0.0, mu=0.0, s_init=1000, i_init=5, r_init=0, max_time=10.0, show_plot=True):

    sim = SIRSimulator(**locals())
    sim.run()
    
    if show_plot:
        plt.clf()
        sim.plot_group_levels()

    return sim


In [None]:
%matplotlib widget

interactive(sir_gillespie, beta=(0, 5.0), gamma=(0, 5.0), delta=(0, 5.0), epsilon=(0, 5.0, 0.01), mu=(0, 1.0, 0.01), s_init=(0, 2000), i_init=(0, 1000), r_init=(0, 1000), max_time=(0, 1000), show_plot=True)

## Variance

In [None]:
%matplotlib inline
from statistics import variance

beta_values = np.linspace(0, 3, 50, endpoint=False)
variances = []

# Show the variance in the final value of R for different values of beta
for beta in beta_values:
    
    end_infected = []
    print(round(beta, 8), end='--')

    for _ in range(200):
    
        sim = sir_gillespie(beta=beta, gamma=1.0, delta=0.0, epsilon=0.0, mu=0.0, s_init=1000, i_init=25, r_init=0, max_time=100.0, show_plot=False)
        end_infected.append(sim.groups["recovered"].number)

    variances.append(variance(end_infected))

plt.title(r"Variance dependance on $\beta$")
plt.ylabel("Variance of infected")
plt.xlabel(r"$\beta$")
plt.plot(beta_values, variances)
plt.show()

## Negative covariances

In [None]:
def find_elem(elem, sorted_list):

    i = bisect.bisect_left(sorted_list, elem)

    if i != len(sorted_list) and sorted_list[i] == elem:
        return i

    return -1

def filter_common_times(groups, group1, group2, from_time=0):

    data1 = []
    data2 = []
    t_steps = groups[group1].time_steps

    # get the index of the time stamp closest to "from_time" x the maximum time
    slice_time = np.argmin(np.abs(np.array(t_steps)-(t_steps[-1]*from_time)))
    for index1 in range(len(t_steps))[slice_time:]:
        
        time1 = groups[group1].time_steps[index1]

        if find_elem(time1, groups[group2].time_steps) != -1:

            index2 = groups[group2].time_steps.index(time1)

            data1.append(groups[group1].history[index1])
            data2.append(groups[group2].history[index2])

    return data1, data2

In [None]:
%matplotlib inline
sim = sir_gillespie(mu=0.2, max_time=100, show_plot=True)

s_data, i_data = filter_common_times(sim.groups, "susceptible", "infected", from_time=0.1)

print(f"Covariance: {np.cov(s_data, i_data)[0][1]}")

plt.title("Stochastic phase space of susceptible and infected")
plt.xlabel("Susceptible")
plt.ylabel("Infected")
plt.plot(s_data, i_data)
plt.show()

In [None]:
N_values = np.linspace(100, 10000, 20, endpoint=False)
cov_averages = []

for N in N_values:
    print(N)

    covs = []
    i = 0

    while i < 20:

        sim = sir_gillespie(beta=15, gamma=1.216, mu=0.015, epsilon=1.06, max_time=200, s_init=N, i_init=0, show_plot=False)
        s_data, i_data = filter_common_times(sim.groups, "susceptible", "infected", from_time=0.1)

        if sim.groups["infected"].history[-1] != 0:
            covs.append(np.cov(s_data, i_data)[0][1])
            i += 1
    
    cov_averages.append(np.mean(covs))

plt.title("Covariance dependence on population size")
plt.xlabel("Population size")
plt.ylabel("Covariance")
plt.plot(N_values, cov_averages)
plt.show()

## Transients

In [None]:
%matplotlib widget
def compare_determinisic_with_stochastic(beta=1.4, gamma=0.6, delta=0.0, epsilon=0.0, mu=0.06, s_init=1000, i_init=30, r_init=0, max_time=100.0, show_plot=True):
    """
    Compares the deterministic model with the stochastic model. Both use the same parameters.
    We are only looking at the infected, otherwise the plot would become too hard to read. 
    """

    sim_stochastic = SIRSimulator(**locals())
    sim_deterministic = SIRDeterminisic(s_init, i_init, r_init, beta, gamma, 0.01, max_time, mu, 0)

    sim_stochastic.run()
    sim_deterministic.run()
    plt.clf()
    
    # Plot stochastic run
    for group in sim_stochastic.groups.values():

        if group.name == "infected":
            plt.plot(group.time_steps, group.history, label=f"{group.name} stochastic", drawstyle="steps-post")

    N = s_init + r_init + i_init

    # Plot deterministic run
    plt.plot(sim_deterministic.time_steps, np.array(sim_deterministic.infected_at_timestep) * N, label="Infected deterministic")

    plt.title("Deterministic vs stochastic model")
    plt.xlabel("Time")
    plt.ylabel("Cases")
    plt.legend()
    plt.tight_layout()

interactive(compare_determinisic_with_stochastic, beta=(0, 5.0), gamma=(0, 5.0), delta=(0, 5.0), epsilon=(0, 5.0, 0.01), mu=(0, 1.0, 0.01), s_init=(0, 2000, 50), i_init=(0, 1000, 50), r_init=(0, 1000, 50), max_time=(0, 1000, 10), show_plot=True)

## Stochastic resonance

In [None]:
%matplotlib widget

def plot_fourier(beta=1.4, gamma=0.6, delta=0.0, epsilon=0.0, mu=0.06, s_init=1000, i_init=30, r_init=0, max_time=100.0, show_plot=True):
    sim = SIRSimulator(**locals())
    sim.run()

    plt.clf()
    plt.subplot(121)
    sim.plot_group_levels()

    plt.subplot(122)
    sim.plot_FFT("infected", start_time=50)

interactive(plot_fourier, beta=(0, 5.0), gamma=(0, 5.0), delta=(0, 5.0), epsilon=(0, 5.0, 0.01), mu=(0, 1.0, 0.01), s_init=(0, 2000, 50), i_init=(0, 1000, 50), r_init=(0, 1000, 50), max_time=(0, 1000, 10), show_plot=True)

## Extinction

In [None]:
%matplotlib widget

def calc_critical_size(gamma, r0, mu = 0.3, max_time=100, treshold=0.8, precision=1, nruns=10, size_guess=100):
    """
    Calculates the critical size of a system. At the critical size the system goes extinct
    <treshold> percent of the time. It uses some version of binary search to find the critical
    size. At first it calculates if the initial size results in extinction, then it decreases 
    or increases the size exponentially until the opposite result of the initial system is found.
    (the crossover) Now we know at which region the critical size should be. We then start reducing 
    our step sizes until we are at our critical size.
    """

    # r0 = (beta / (gamma + mu) ==> beta = (r0 * (gamma + mu))
    beta = (r0 * (gamma + mu))

    infection_percentage = 0.1

    N = size_guess
    dN = N
    
    # Get the result of the initial system
    s_init = int((1 - infection_percentage) * N)
    i_init = int(N - s_init)
    sim = SIRSimulator(beta=beta, gamma=gamma, mu=mu, s_init=s_init, i_init=i_init, max_time=max_time)
    initial_extinct = is_extinct_rate_above_treshold(sim, treshold)

    crossover_found = False

    while dN > precision:
        
        print(f"\rCalc crit size for gamma={gamma:.3e} and r0={r0:.3e} N: {N}, dN {dN} ", end="")

        # Upper- and lower bounds of the sizes we're chechking. This is done since the results either dont
        # make sense (populations smaller than 10???) or because the runs would take too long
        if N > 10000 or N < 10:
            
            print("ERROR: Critical size is out of range")
            break

        s_init = int((1 - infection_percentage) * N)
        i_init = int(N - s_init)     
        sim = SIRSimulator(beta=beta, gamma=gamma, mu=mu, s_init=s_init, i_init=i_init, max_time=max_time)
        extinct = is_extinct_rate_above_treshold(sim, treshold, nruns=nruns)
        
        if not crossover_found:

            crossover_found = initial_extinct != extinct

        # Update the step size
        if crossover_found or not crossover_found and not extinct:
            dN /= 2
        else:
            dN *= 2

        # Update the size
        if extinct:
            N += dN

        if not extinct:
            N -= dN
    
    print()

    return N

def goes_extinct(sim):
    """
    Checks if a single run of a SIR simulator results in a extinction.
    Returns True if an extinction is reached, else False.
    """

    while sim.time < sim.max_time:
        sim.update()

        if sim.groups["infected"].number == 0.0:
            sim.finalise_results()
            return True
    
    return False

def is_extinct_rate_above_treshold(sim, treshold, nruns=10):
    """
    Checks if a simulator results in a extinction for more than <treshold> percent of the time.

    <treshold> has a range of 0 to 1
    """

    # Calculate the number of extinctions or non extinctions needed to reach the treshold
    # This is used to pre-emptively obtain the results. For example if we have 100 runs
    # and a treshold of 50% and the first 51 runs result in an extinction, then we dont
    # need to run the other 49 runs, since we will always be above the treshold.
    n_extinctions_treshold = math.ceil(nruns * treshold)
    n_non_extinctions_treshold = nruns - n_extinctions_treshold
    
    n_extinctions = 0
    n_non_extinctions = 0

    for _ in range(nruns):
        
        sim.reset()

        extinct = goes_extinct(sim)

        if extinct:
            n_extinctions += 1
        else:
            n_non_extinctions += 1

        if n_extinctions > n_extinctions_treshold:
            return True

        if n_non_extinctions > n_non_extinctions_treshold:
            return False

    return True

def obtain_extinction_data(gammas, r0s):
    
    data = []
    for r0 in r0s:

        results = []
        for gamma in gammas:
            
            size = calc_critical_size(gamma, r0, max_time=100, nruns=200, precision=1)
            results.append(size)
        
        data.append(results)

    return data

gammas = np.logspace(-1, -3, 10)
r0s = np.logspace(0.01, 1, 10)
data = obtain_extinction_data(gammas, r0s)

In [None]:
import matplotlib.ticker as mticker

# https://stackoverflow.com/a/67774238
def log_tick_formatter(val, pos=None):
    return f"{10**val:.1e}"          # e-Notation

def plot_extinction(gammas, r0s, sizes):

    X, Y = np.meshgrid(gammas, r0s)

    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.plot_surface(np.log10(X), np.log10(Y), np.log10(np.array(sizes)),  rstride=1, cstride=1, cmap='viridis', edgecolor='none')

    # Log scale for x, y, and z-axis
    ax.xaxis.set_major_formatter(mticker.FuncFormatter(log_tick_formatter))
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(log_tick_formatter))
    ax.zaxis.set_major_formatter(mticker.FuncFormatter(log_tick_formatter))

    # Limit the number of ticks on the axis, otherwise the axis would become too cluttered
    plt.locator_params(nbins=4)

    fig.suptitle("Critical size")
    ax.set_xlabel("gamma")
    ax.set_ylabel("R0")
    ax.set_zlabel("size")
    plt.tight_layout()
    
plot_extinction(gammas, r0s, data)

# Spatial models

In [None]:
class SIRSpatialModel:

    def __init__(self, X: np.ndarray, Y: np.ndarray, Z: np.ndarray, rho: np.ndarray, beta: np.ndarray, nu: np.ndarray, mu: np.ndarray, gamma: np.ndarray, max_time: float, dt: float):
        
        # Dimension checking of arrays
        shape = X.shape
        for nparray in [X, Y, Z, beta, nu, mu, gamma]:

            if nparray.shape != shape:
                raise Exception("Error: The dimensions of the lists don't match.")

        if rho.shape != (shape[0], shape[0]):
            raise Exception("Error: rho has wrong dimensions")

        self.ngroups = shape[0]

        self.X = X
        self.Y = Y
        self.Z = Z

        self.group_sizes = self.calc_group_sizes()
        self.rho = rho
        self.beta = beta
        self.nu = nu
        self.mu = mu
        self.gamma = gamma

        self.time = 0
        self.max_time = max_time
        self.dt = dt

        # These lists keep track of the evolution of the system
        self.X_history = []
        self.Y_history = []
        self.Z_history = []
        self.time_steps = []

        self.update_lists()

    def update_lists(self):

        self.X_history.append(np.copy(self.X))
        self.Y_history.append(np.copy(self.Y))
        self.Z_history.append(np.copy(self.Z))
        self.time_steps.append(self.time)

    def run(self):

        while self.time < self.max_time:
            self.update()

    def update(self):

        lamda = self.calc_lamda()

        xis = self.calc_xis()

        birth_term = self.nu * self.group_sizes
        transmission_term = lamda * self.X
        recovery_term = self.gamma * self.Y

        death_term_X = self.mu * self.X
        death_term_Y = self.mu * self.Y
        death_term_Z = self.mu * self.Z

        dX = (birth_term + np.sqrt(birth_term) * xis[0]) - (transmission_term + np.sqrt(transmission_term) * xis[1]) - (death_term_X + np.sqrt(death_term_X) * xis[2])
        dY = (transmission_term + np.sqrt(transmission_term) * xis[1]) - (recovery_term + np.sqrt(recovery_term) * xis[3]) - (death_term_Y + np.sqrt(death_term_Y) * xis[4])
        dZ = (recovery_term + np.sqrt(recovery_term) * xis[3]) - (death_term_Y + np.sqrt(death_term_Y) * xis[5])

        self.X += dX * self.dt
        self.Y += dY * self.dt
        self.Z += dZ * self.dt

        # Due to the noise the X, Y, and Z groups might become negative. This makes sure that they remain positive
        self.X = np.maximum(self.X, 0)
        self.Y = np.maximum(self.Y, 0)
        self.Z = np.maximum(self.Z, 0)

        self.group_sizes = self.calc_group_sizes()
        
        self.time += self.dt

        self.update_lists()

    def calc_group_sizes(self):

        return self.X + self.Y + self.Z

    def calc_lamda(self):
        
        lamdas = []
        for i in range(self.ngroups):

            lamda = self.beta[i] * sum([self.rho[i][j] * self.Y[j] / self.group_sizes[i] for j in range(self.ngroups)])
            lamdas.append(lamda)

        return np.array(lamdas)

    def calc_xis(self):

        return np.random.normal(size=(6, self.ngroups)) / np.sqrt(self.dt)

    def plot_group_levels(self):
        
        line_styles = ["-", "--", "-.", ":"]
        for i in range(self.ngroups):
            
            # Only label the first group, otherwise the legend would become too clutterd
            if i == 0:
                label_sus = "susceptible"
                label_inf = "infected"
                label_rec = "recovered"
            else:
                label_sus = ""
                label_inf = ""
                label_rec = ""

            plt.plot(self.time_steps, [row[i] for row in self.X_history], label=label_sus, ls=line_styles[i % self.ngroups], c="#1f77b4")
            plt.plot(self.time_steps, [row[i] for row in self.Y_history], label=label_inf, ls=line_styles[i % self.ngroups], c="#ff7f0e")
            plt.plot(self.time_steps, [row[i] for row in self.Z_history], label=label_rec, ls=line_styles[i % self.ngroups], c="#2ca02c")

        plt.title("Evolution of system")
        plt.xlabel("Time")
        plt.ylabel("Cases")
        plt.legend()
        plt.tight_layout()
        plt.show()

    def print_group_levels(self):

        print(self.X_history)
        print(self.Y_history)
        print(self.Z_history)
        
class UniformSIRSpatialModel(SIRSpatialModel):

    def __init__(self, ngroups: int, x: float, y: float, z: float, rhoii: float, rhoij: float, beta=10., nu=0., mu=0., gamma=0.5, max_time=10., dt=0.01):

        rho = np.full((ngroups, ngroups), rhoij)
        np.fill_diagonal(rho, rhoii)

        super().__init__(
            np.full(ngroups, x),
            np.full(ngroups, y),
            np.full(ngroups, z),
            rho,
            np.full(ngroups, beta),
            np.full(ngroups, nu),
            np.full(ngroups, mu),
            np.full(ngroups, gamma),
            max_time,
            dt
        )

In [None]:
%matplotlib widget

sim = UniformSIRSpatialModel(4, 10000., 1., 0., 1, 0.01, max_time=20)
sim.run()
sim.plot_group_levels()

## NDlib

In [None]:
import networkx as nx
import ndlib.models.epidemics as ep
import ndlib.models.ModelConfig as mc
from bokeh.io import output_notebook, show
from ndlib.viz.mpl.DiffusionTrend import DiffusionTrend
from ndlib.viz.mpl.DiffusionPrevalence import DiffusionPrevalence
# from ndlib.viz.mpl.MultiPlot import MultiPlot

In [None]:
def configure_ndlib_model(network, beta, gamma, fraction_infected, number_iterations, infect_highest_degrees=False):

    # Model Selection
    model = ep.SIRModel(network)

    # Model Configuration
    config = mc.Configuration()
    config.add_model_parameter('beta', beta)
    config.add_model_parameter('gamma', gamma)

    if infect_highest_degrees:

        number_infected = round(network.number_of_nodes()*fraction_infected)
        highest_degrees = sorted(network.degree, key=lambda x: x[1], reverse=True)[:number_infected]    
        nodes_to_infect = [item[0] for item in highest_degrees]
        
        config.add_model_initial_configuration("Infected", nodes_to_infect)
    else:
        # otherwise, infect randomly
        config.add_model_parameter("fraction_infected", fraction_infected)

    model.set_initial_status(config)

    # Simulation
    iterations = model.iteration_bunch(number_iterations)
    trends = model.build_trends(iterations)

    return trends

def plot_trends_prevalence(trends):
    S = trends[0]["trends"]["node_count"][0]
    I = trends[0]["trends"]["node_count"][1]
    R = trends[0]["trends"]["node_count"][2]

    plt.plot(S, label="Susceptible population")
    plt.plot(I, label="Infected population")
    plt.plot(R, label="Recovered population")
    plt.xlabel("iterations")
    plt.ylabel("number in population")
    plt.legend()
    plt.show()

In [None]:
%matplotlib widget

# Network Definition
# g = nx.barabasi_albert_graph(1000, 53)
g = nx.watts_strogatz_graph(1000, 100, 1)
# g = nx.erdos_renyi_graph(1000, 0.1)

# print(nx.average_shortest_path_length(g))
print(nx.info(g))

trends = configure_ndlib_model(g, 0.001, 0.01, 0.01, 200, infect_highest_degrees=True)
plot_trends_prevalence(trends)