In [1]:
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, LinearInterpolation, Kvaerno5, PIDController
from jax import grad, jit, vmap
import jax.numpy as np
import pandas as pd
import numpy as onp

import matplotlib.pyplot as plt

In [2]:
def f(t, x, params_infer, params):
    S, C, ε = x

    # parameters to infer
    β = params_infer["β"]
    θ = params_infer["θ"]

    # p2p
    γ  = params["γ"]
    σ  = params["σ"]
    δ  = params["δ"]
    α  = params["δ"]
    Nw = params["Nw"]

    # fom
    κ = params["κ"]

    λ1 = β * C / Nw
    λ2 = ε

    λ = λ1 + λ2

    sdot = (1-γ) * σ * Nw - λ * S - δ * S + α * C
    cdot = γ * σ * Nw     + λ * S - δ * C - α * C
    εdot = θ * C / Nw     - κ * ε

    return np.array([sdot, cdot, εdot])

def f0(params_infer, params):
    c0 = params["γ"] * params["Nw"]
    ε0 = params_infer["θ"] * c0 / params["Nw"]

    return np.array([params["Nw"] - c0, c0, ε0])

def g(x, params_infer, params):
    """ Expected value of a binomial observational model.

    Args:
        x: state space vector, [S, C, ε]
        params: dictionary with 'fixed' parameters.

    Returns:
        Detected positives: ~ Binomial(Test, \rho * C/N), whose expected value is (Test * \rho * C/N)
    """
    S, C, _ = x

    c = C / (S + C)

    ρ = params["ρ"]
    T = params["T"]

    return T * (ρ * c)

def numFIM(params_infer, param_fixed, δ = {"β": 1e-3, "θ": 1e-4}):
    """ Numerical approximation of the sensitivity matrix.

            Fisher Information Matrix is computed as: Xi^T * Xi

    Args:
        times  (_type_): _description_
        params (_type_): _description_
        data   (_type_): _description_
        delta  (float, optional): _description_. Defaults to 0.001.

    Returns:
        _type_: _description_
    """

    listX    = []
    params_1 = params_infer.copy()
    params_2 = params_infer.copy()

    for k in ["β", "θ"]:
        params_1[k] = params_infer[k] * (1+δ[k])
        params_2[k] = params_infer[k] * (1-δ[k])

        num_years = 4
        t_max     = int((365) * num_years )  # simulate for 5 yrs but weekly
        t_save    = np.arange(0, t_max+1, 7) # list of days to simulate

        saveat = SaveAt(ts=t_save)

        model1  = lambda t, x, args: f(t, x, params_1, param_fixed)
        model2  = lambda t, x, args: f(t, x, params_2, param_fixed)

        term1   = ODETerm(model1)
        term2   = ODETerm(model2)

        solver        = Dopri5()
        stepsize_cont = PIDController(rtol=1e-8, atol=1e-8)

        y0_1 = f0(params_1, param_fixed)
        y0_2 = f0(params_2, param_fixed)

        sim1 = diffeqsolve(term1, solver,
                                        t0=0, t1=t_max,
                                        saveat=saveat, dt0=7, y0=y0_1,
                                        stepsize_controller = stepsize_cont)

        sim2 = diffeqsolve(term2, solver,
                                        t0=0, t1=t_max,
                                        saveat=saveat, dt0=7, y0=y0_2,
                                        stepsize_controller = stepsize_cont)

        # observe
        sim_data1 = g(sim1.ys.T, params_1, param_fixed)
        sim_data2 = g(sim2.ys.T, params_2, param_fixed)

        subX  = (sim_data1 - sim_data2) / (2 * δ[k] * params_infer[k])
        listX.append(subX.tolist())

    X   = onp.matrix(listX)
    FIM = onp.dot(X, X.transpose())
    return FIM


In [7]:
gammas  = [5/100]
kappas  = [1/(1 * 30), 1/(6*30), 1/(12*30), 1/(24*30)]
betas   = np.linspace(0.01, 0.1, 10)
tetas   = np.linspace(1e-3, 1e-2, 10)

FIM_arr = onp.full((2, 2, len(gammas), len(kappas), len(betas), len(tetas)), np.nan)

for idx_g, gamma in enumerate(gammas):
    for idx_k, kappa in enumerate(kappas):

        print(r"Running for $\kappa=${}".format( int(1/(kappa * 30))) )

        # fixed parameters
        params = {
                "γ":  gamma,
                "σ":  1/3 ,
                "δ":  1/3 ,
                "α":  1/120,
                "Nw": 2000,
                "κ" : kappa,
                "T" : 0.15 * 2000,
                "ρ" : 0.1
                }

        for idx_b, beta in enumerate(betas):
            for idx_t, teta in enumerate(tetas):
                params_infer = {
                        "β": beta,
                        "θ": teta
                        }
                try:
                    FIM_arr[:, :, idx_g, idx_k, idx_b, idx_t] = numFIM(params_infer, params, δ = {"β": 1e-3, "θ": 1e-4})
                except:
                    print(r"Error with $\beta$ = {}, $\teta$ = {}".format(beta, teta))

Running for $\kappa=$1
