In [1]:
%matplotlib widget

from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
import scipy

In [None]:
# Set up SIR problem
X_NAMES = ['S', 'I', 'R']
X_0 = np.array([900, 100, 0])
N = 1000
BETA_TRUE = 0.02
R_TRUE = 0.6
DELTA_TRUE = 0.15
NOISE_STD = 50


def f_p1(t, x, beta, r, delta):
    S, I, R = x
    return np.array([
        delta*N - delta * S - beta*I*S,
        beta*I*S - (r+delta)*I,
        r*I - delta*R])


def f_p2(t, x, gamma, kappa, r, delta):
    S, I, R = x
    return np.array([
        delta*N - delta * S - gamma*kappa*I*S,
        gamma*kappa*I*S - (r+delta)*I,
        r*I - delta*R])


@dataclass
class Solution:
    t: np.ndarray
    y: np.ndarray


@dataclass
class Observation:
    t: np.ndarray
    y: np.ndarray


T0 = 0
Tend = 6
Tsteps = 61
Tsteps_multiplier = 100  # increase the fidelity of the true solution
Teval = np.linspace(T0, Tend, Tsteps*Tsteps_multiplier)


def simulate_p1(params: np.ndarray) -> Solution:
    # problem 1: parameters k1, k2
    """Generate a solution for the states field for the given control params."""
    def f(t, x):
        return f_p2(t, x, params[0], params[1], params[2])
    sol = scipy.integrate.solve_ivp(
        fun=f, t_span=(T0, Tend), y0=X_0, method='RK45', t_eval=Teval)

    return Solution(sol.t, sol.y)  # y = 3(dims) x Tsteps(6100)


def simulate_p1(params: np.ndarray) -> Solution:
    # problem 1: parameters k1, k2
    """Generate a solution for the states field for the given control params."""
    def f(t, x):
        return f_p2(t, x, params[0], params[1], params[2], params[3])
    sol = scipy.integrate.solve_ivp(
        fun=f, t_span=(T0, Tend), y0=X_0, method='RK45', t_eval=Teval)

    return Solution(sol.t, sol.y)  # y = 3(dims) x Tsteps(6100)


def sensor(solution: Solution) -> Observation:
    """Sensor model that extracts data from the solution."""
    # pick out every Tsteps_multiplier-th solution point
    idx_obs = np.arange(0, solution.t.shape[0], Tsteps_multiplier)
    return Observation(solution.t[idx_obs], solution.y[:3, idx_obs])


def generate_data(true_beta=BETA_TRUE, true_r=R_TRUE, true_delta=DELTA_TRUE, noise_std=NOISE_STD):
    true_sim = simulate_p1([true_beta, true_r, true_delta])
    sensor_vals = sensor(true_sim)
    sensor_vals.y += noise_std * np.random.randn(sensor_vals.y.shape[0], sensor_vals.y.shape[1])
    return true_sim, sensor_vals