# Introduction to Computational science - Assignment 2
Sander Broos, Nick van Santen

In [None]:
# Imports
from __future__ import annotations
from ipywidgets import FloatSlider, IntSlider, widgets
import matplotlib.pyplot as plt
import numpy as np
from typing import Callable, List

In [None]:
# Run cell to increase font sizes. Usefull when saving plots
SMALL_SIZE = 16
MEDIUM_SIZE = 20
BIGGER_SIZE = 24

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
class Event:

    def __init__(self, name: str, rate: Callable, event: dict):

        self.name = name
        self.event = event

        self.rate = rate

    def occur(self, groups, time):
        for group in self.event:
            groups[group].add_value(self.event[group], time)

class Simulator:

    def __init__(self, groups: List[Group], events: List[Event], max_time: float):

        self.groups = groups

        self.events = events

        self.time = 0
        self.max_time = max_time
        self.time_steps = [0]
    
    def update(self):

        total_rate = sum([event.rate() for event in self.events])

        if total_rate == 0:
            self.time = self.max_time
            return

        r1 = np.random.rand()
        delta_time = -1 / total_rate * np.log(r1)

        r2 = np.random.rand()
        P = r2 * total_rate

        event = self.determine_event(P)

        self.time += delta_time

        if self.time > self.max_time:
            return

        self.time_steps.append(self.time)
        
        # Apply event
        event.occur(self.groups, self.time)

        
    def determine_event(self, p: float):
        
        value = 0
        for event in self.events:

            value += event.rate()

            if value > p:
                return event
        
        print("ERROR: No event found")
        return None

    def run(self):

        while self.time < self.max_time:
            self.update()

        for group in self.groups.values():

            group.append_to_history(group.inhabitants, self.max_time)

    def plot_group_levels(self):

        for group in self.groups.values():

            plt.plot(group.time_steps, group.history, label=group.name, drawstyle="steps-post")

        plt.legend()
        plt.show()

class Group:

    def __init__(self, name: str, initial: int):
        
        self.name = name
        self.inhabitants = initial
        
        self.history = [initial]
        self.time_steps = [0]

    def add_value(self, value: int, time: float):

        self.inhabitants += value

        self.append_to_history(self.inhabitants, time)

    def append_to_history(self, value: int, time: float):
        
        self.history.append(value)
        self.time_steps.append(time)


In [None]:
class SIRSimulator(Simulator):

    def __init__(self, **kwargs):
        
        # default parameters
        self.parameters = {
            "init_susceptible": 90,
            "init_infected": 10,
            "init_recovered": 0,
            "inf_rate": 2,
            "r_rate": 1,
            "max_time":10,
        }

        # Parameters are updated with the key word arguments
        self.parameters.update(kwargs)
        
        groups = {
            "susceptible": Group("susceptible", self.parameters["init_susceptible"]),
            "infected": Group("infected", self.parameters["init_infected"]),
            "recovered": Group("recovered", self.parameters["init_recovered"]),
        }

        events = [
            Event("transmission", self.transmission_rate, {"infected": 1, "susceptible": -1}),
            Event("recovery", self.recovery_rate,  {"infected": -1, "recovered": 1}),
        ]

        super().__init__(groups, events, self.parameters["max_time"])
        

    def transmission_rate(self):
        return self.parameters["inf_rate"] * self.groups["infected"].inhabitants * self.groups["susceptible"].inhabitants

    def recovery_rate(self):
        return self.parameters["r_rate"] * self.groups["infected"].inhabitants


In [None]:
sus_slider = IntSlider(90, min=0, max=1000, step=10, description="Susceptible")
inf_slider = IntSlider(10, min=0, max=1000, step=10, description="Infected")
rec_slider = IntSlider(0, min=0, max=1000, step=10, description="Recovered")

inf_rate_slider = FloatSlider(1.6, min=0, max=10, step=0.01, description="Infection rate")
rec_rate_slider = FloatSlider(0.3, min=0, max=10, step=0.01, description="Recovery rate")

max_time_slider = FloatSlider(10, min=1, max=100, step=1, description="Max time")

In [None]:
def create_container_SIR():
    
    vbox1 = widgets.VBox([sus_slider, inf_slider, rec_slider])
    vbox2 = widgets.VBox([inf_rate_slider, rec_rate_slider])
    vbox3 = widgets.VBox([max_time_slider])

    return widgets.HBox([vbox1, vbox2, vbox3])

container_SIR = create_container_SIR()

In [None]:
%matplotlib widget

def plot_SIR(**kwargs):
    
    sim = SIRSimulator(**kwargs)
    sim.run()

    plt.clf()
    sim.plot_group_levels()
    plt.tight_layout()

args = {
    'init_susceptible': sus_slider,
    'init_infected': inf_slider,
    'init_recovered': rec_slider,
    'inf_rate': inf_rate_slider,
    'rec_rate': rec_rate_slider,
    'max_time': max_time_slider,
}

out = widgets.interactive_output(plot_SIR, args)

# Display both the container containing the slider and button widgets, and the plot output
display(container_SIR, out)