# Introduction to Computational science - Assignment 2
Sander Broos, Nick van Santen

In [74]:
# Imports
from __future__ import annotations
from ipywidgets import *

import matplotlib.pyplot as plt
import numpy as np

from typing import Callable, List

In [75]:
# Run cell to increase font sizes. Usefull when saving plots
SMALL_SIZE = 16
MEDIUM_SIZE = 20
BIGGER_SIZE = 24

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [76]:
class Event:

    def __init__(self, name: str, rate: Callable, event: dict):

        self.name = name
        self.event = event

        self.rate = rate

    def occur(self, groups, time):
        
        for group, value_to_add in self.event.items():
            groups[group].add_value(value_to_add, time)

class Simulator:

    def __init__(self, groups: List[Group], events: List[Event], max_time: float):

        self.groups = groups

        self.events = events

        self.time = 0
        self.max_time = max_time
        self.time_steps = [0]
    
    def update(self):

        total_rate = sum([event.rate(self.groups) for event in self.events])

        if total_rate == 0:
            self.time = self.max_time
            return

        r1 = np.random.rand()
        delta_time = -1 / total_rate * np.log(r1)
        self.time += delta_time

        if self.time > self.max_time:
            return

        self.time_steps.append(self.time)
        
        r2 = np.random.rand()
        P = r2 * total_rate
        event = self.determine_event(P)

        # Apply event
        event.occur(self.groups, self.time)

        
    def determine_event(self, p: float):
        
        value = 0
        for event in self.events:

            value += event.rate(self.groups)

            if value > p:
                return event
        
        print("ERROR: No event found")
        return None

    def run(self):

        while self.time < self.max_time:
            self.update()

        for group in self.groups.values():

            group.append_to_history(group.number, self.max_time)

    def plot_group_levels(self):

        for group in self.groups.values():

            plt.plot(group.time_steps, group.history, label=group.name, drawstyle="steps-post")

        plt.legend()
        plt.show()
        plt.close()

class Group:

    def __init__(self, name: str, initial: int):
        
        self.name = name
        self.number = initial
        
        self.history = [initial]
        self.time_steps = [0]

    def add_value(self, value: int, time: float):

        self.number += value

        self.append_to_history(self.number, time)

    def append_to_history(self, value: int, time: float):
        
        self.history.append(value)
        self.time_steps.append(time)


In [77]:
def sir_gillespie(beta=3, gamma=1, s_init=1000, i_init=5, r_init=0, max_time=10):
    groups = {
        "susceptible": Group("susceptible", s_init),
        "infected": Group("infected", i_init),
        "recovered": Group("recovered", r_init),
    }

    events = [
        Event(name="transmission", 
              rate=lambda groups: beta * groups["infected"].number * groups["susceptible"].number / sum([group.number for group in groups.values()]), 
              event={"infected": 1, "susceptible": -1}),
        Event(name="recovery", 
              rate=lambda groups: gamma * groups["infected"].number, 
              event={"infected": -1, "recovered": 1}),
    ]

    sim = Simulator(groups, events, max_time)
    sim.run()
    sim.plot_group_levels()

interactive(sir_gillespie, beta=(0, 5.0), gamma=(0, 5.0), s_init=(0, 2000), i_init=(0, 1000), r_init=(0, 1000), max_time=(0, 100))

interactive(children=(FloatSlider(value=3.0, description='beta', max=5.0), FloatSlider(value=1.0, description=…