In [1]:
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, LinearInterpolation, Kvaerno5
from jax import grad, jit, vmap
import jax.numpy as np
import pandas as pd
import numpy as onp

import matplotlib.pyplot as plt

In [None]:
def f(t, x, params_infer, params):
    S, C, ε = x

    # parameters to infer
    β = params_infer["β"]
    θ = params_infer["θ"]

    # p2p
    γ  = params["γ"]
    σ  = params["σ"]
    δ  = params["δ"]
    α  = params["δ"]
    Nw = params["Nw"]

    # fom
    κ = params["κ"]

    λ1 = β * C / Nw
    λ2 = ε

    λ = λ1 + λ2

    sdot = (1-γ) * σ * Nw - λ * S - δ * S + α * C
    cdot = γ * σ * Nw     + λ * S - δ * C - α * C
    εdot = θ * C / Nw     - κ * ε

    return np.array([sdot, cdot, εdot])

def f0(params_infer, params):
    c0 = params["γ"] * params["Nw"]
    ε0 = params["θ"] * c0 / params["Nw"]

    return np.array([params["Nw"] - c0, c0, ε0])

def g(x, params_infer, params):
    """ Expected value of a binomial observational model.

    Args:
        x: state space vector, [S, C, ε]
        params: dictionary with 'fixed' parameters.

    Returns:
        Detected positives: _description_
    """
    S, C, ε = x

    N = S + C
    c = C / N

    ρ = params["ρ"]
    T = params["T"]

    return T * (ρ * c)


def numFIM(times, params, data, delta = 0.001):
    """ Numerical approximation of the sensitivity matrix.

            Fisher Information Matrix is computed as: Xi^T * Xi

    Args:
        times  (_type_): _description_
        params (_type_): _description_
        data   (_type_): _description_
        delta  (float, optional): _description_. Defaults to 0.001.

    Returns:
        _type_: _description_
    """

	listX    = []
	params_1 = np.array (params)
	params_2 = np.array (params)

	for i in range(len(params)):
		params_1[i] = params[i] * (1+delta)
		params_2[i] = params[i] * (1-delta)

		res_1 = ode(sir_ode.model, sir_ode.x0fcn(params_1,data), times, args=(params_1,))
		res_2 = ode(sir_ode.model, sir_ode.x0fcn(params_2,data), times, args=(params_2,))
		subX  = (sir_ode.yfcn(res_1, params_1) - sir_ode.yfcn(res_2, params_2)) / (2 * delta * params[i])

		listX.append(subX.tolist())
	X   = np.matrix(listX)
	FIM = np.dot(X, X.transpose())
	return FIM

