In [None]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
from scipy.integrate import simpson
from scipy.optimize import least_squares
import pandas as pd

def SIR(y, t, beta, gamma):
    """
    :param y: initial condition
    :param t: time
    :param beta: para
    :param gamma: para
    :return: the SIR model ODEs
    """
    S, I, R = y
    dS = - beta * S * I
    dI = beta * S * I - gamma * I
    dR = gamma * I
    return [dS, dI, dR]

# initial condition
S0 = 9999
I0 = 1
R0 = 0
y = S0, I0, R0

# true parameters
beta = 0.00001
gamma = 0.01

# time
t = np.linspace(0, 500, 1000)

# simulate SIR
solution = scipy.integrate.odeint(SIR, [S0, I0, R0], t, args=(beta, gamma))
S, I, R = solution.T

def residual(paras, I_data, I_int, I_int2, int_double, I0):
    beta, gamma, alpha = paras
    I_hat = I0 + ((alpha + beta * I0) - gamma) * I_int - beta * I_int2 - beta * gamma * int_double
    return I_data - I_hat

# initial guess value of beta and gamma
beta0 = 0.01
gamma0 = 0.05
S0_initial = 100
alpha_initial = beta0 * S0_initial
x0 = [beta0, gamma0, alpha_initial]

def evaluate(method, num_of_iteration, bounds=None, plot=True):
    print(f"--- Method : {method} ---")
    print(f"bounds {bounds}")
    estimated_beta_list = []
    estimated_gamma_list = []
    estimated_S0_list = []

    plot_beta = []
    plot_gamma = []
    plot_S0 = []
    for _ in range(num_of_iteration):
        I_data = I + np.random.normal(0, 0.05 * I, size=I.shape)

        I_int = np.array([simpson(I_data[:i+1], t[:i+1]) for i in range(len(t))])
        I_int2 = np.array([simpson(I_data[:i+1]**2, t[:i+1]) for i in range(len(t))])
        int_double = 1/2 * (I_int ** 2)

        beta_list = []
        gamma_list = []
        S0_list = []
        def callback(x):
            beta_list.append(x[0])
            gamma_list.append(x[1])
            S0_list.append(x[2] / x[0])

        if method == 'lm':
            res = least_squares(residual, x0, args=(I_data, I_int, I_int2, int_double, I0),method=method)
        else:
            if bounds:
                res = least_squares(residual, x0, args=(I_data, I_int, I_int2, int_double, I0),method=method, callback=callback, bounds=bounds)
            else:
                res = least_squares(residual, x0, args=(I_data, I_int, I_int2, int_double, I0),method=method, callback=callback)

        estimated_beta_list.append(res.x[0])
        estimated_gamma_list.append(res.x[1])
        estimated_S0_list.append(res.x[2] / res.x[0])

        plot_beta.append(beta_list)
        plot_gamma.append(gamma_list)
        plot_S0.append(S0_list)

    estimated_beta = np.mean(estimated_beta_list)
    estimated_gamma = np.mean(estimated_gamma_list)
    estimated_S0 = np.mean(estimated_S0_list)
    print(f"Average estimated beta among {num_of_iteration}: {estimated_beta} --- real beta {beta}"
      f"\nAverage estimated gamma among {num_of_iteration}: {estimated_gamma} --- real gamma {gamma}"
      f"\nAverage estimated S0 among {num_of_iteration}: {estimated_S0} -- real S0 {S0}"
          f"\nBeta error {abs((estimated_beta - beta) / beta) * 100}"
          f"\nGamma error {abs((estimated_gamma - gamma) / gamma) * 100}"
          f"\nS0 error {abs((estimated_S0 - S0) / S0) * 100}")

    print("nfev:", res.nfev)
    print("njev:", res.njev)
    print("status:", res.status)
    print("message:", res.message)

    # plot
    if plot:
        if method != 'lm':
            # -- beta
            for i, beta_list in enumerate(plot_beta):
                plt.plot(beta_list, label=f"beta_{i}", marker='o')
            plt.axhline(y=beta, label="True beta", linestyle='-')
            plt.legend()
            plt.show()

            # -- gamma
            for i, gamma_list in enumerate(plot_gamma):
                plt.plot(gamma_list, label=f"gamma_{i}", marker='o')
            plt.axhline(y=gamma, label="True gamma", linestyle='-')
            plt.legend()
            plt.show()

            # -- S0
            for i, S0_list in enumerate(plot_S0):
                plt.plot(S0_list, label=f"S0_{i}", marker='o')
            plt.axhline(y=S0, label="True beta", linestyle='-')
            plt.legend()
            plt.show()


In [None]:
evaluate('trf', 20)

In [None]:
evaluate('dogbox', 20)

In [None]:
evaluate('lm', 10)

In [None]:
evaluate('trf', 20, bounds=([0, 0, 0], [np.inf, np.inf, np.inf]))

In [None]:
evaluate('dogbox', 20, bounds=([0, 0, 0], [np.inf, np.inf, np.inf]))

In [None]:
evaluate('lm', 10, bounds=([0, 0, 0], [np.inf, np.inf, np.inf]))

In [None]:
evaluate('trf', 50, plot=False)
evaluate('trf', 50, bounds=([0, 0, 0], [np.inf, np.inf, np.inf]), plot=False)
evaluate('dogbox', 50, plot=False)
evaluate('dogbox', 50, bounds=([0, 0, 0], [np.inf, np.inf, np.inf]), plot=False)
evaluate('lm', 50, plot=False)