In [1]:
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, LinearInterpolation, Kvaerno5, PIDController
from jax import grad, jit, vmap
import jax.numpy as np
import pandas as pd
import numpy as onp
import sys
import os

import matplotlib.pyplot as plt

sys.path.insert(0, "../")
from global_config import config

results_dir  = config.get_property('results_dir')
paper_dir    = config.get_property('paper_dir')
data_dir     = config.get_property('data_dir')

figures_dir  = os.path.join(results_dir, 'figures')


In [2]:
def f(t, x, params_infer, params):
    S, C, ε = x

    # parameters to infer
    β = params_infer["β"]
    θ = params_infer["θ"]

    # p2p
    γ  = params["γ"]
    σ  = params["σ"]
    δ  = params["δ"]
    α  = params["α"]
    Nw = params["Nw"]

    # fom
    κ = params["κ"]

    λ1 = β * C / Nw
    λ2 = ε

    λ  = λ1 + λ2

    sdot = (1-γ) * σ * Nw - λ * S - δ * S + α * C
    cdot = γ * σ * Nw     + λ * S - δ * C - α * C
    εdot = θ * C / Nw     - κ * ε

    return np.array([sdot, cdot, εdot])

def f0(params_infer, params):
    c0 = params["γ"] * params["Nw"]
    ε0 = params_infer["θ"] * c0 / params["Nw"]

    return np.array([params["Nw"] - c0, c0, ε0])

def g(x, params_infer, params):
    """ Expected value of a binomial observational model.

    Args:
        x: state space vector, [S, C, ε]
        params: dictionary with 'fixed' parameters.

    Returns:
        Detected positives: ~ Binomial(Test, \rho * C/N), expected value is (Test * \rho * C/N)
    """
    S, C, _ = x

    c = C / (S + C)

    ρ = params["ρ"]
    T = params["T"]

    return T * (ρ * c)

def numFIM(params_infer, param_fixed, δ = {"β": 1e-3, "θ": 1e-4}):
    """ Numerical approximation of the sensitivity matrix.
            Fisher Information Matrix is computed as: Xi^T * Xi

    Args:
        times  (_type_): _description_
        params (_type_): _description_
        data   (_type_): _description_
        delta  (float, optional): _description_. Defaults to 0.001.

    Returns:
        _type_: _description_
    """

    listX    = []
    params_1 = params_infer.copy()
    params_2 = params_infer.copy()

    for k in ["β", "θ"]:
        params_1[k] = params_infer[k] * (1+δ[k])
        params_2[k] = params_infer[k] * (1-δ[k])

        num_years = 4
        t_max     = int((365) * num_years )  # simulate for 5 yrs but weekly
        t_save    = np.arange(0, t_max+1, 7) # list of days to simulate

        saveat = SaveAt(ts=t_save)

        model1  = lambda t, x, args: f(t, x, params_1, param_fixed)
        model2  = lambda t, x, args: f(t, x, params_2, param_fixed)

        term1   = ODETerm(model1)
        term2   = ODETerm(model2)

        solver        = Dopri5()
        stepsize_cont = PIDController(rtol=1e-8, atol=1e-8)

        y0_1 = f0(params_1, param_fixed)
        y0_2 = f0(params_2, param_fixed)

        sim1 = diffeqsolve(term1, solver,
                                        t0=0, t1=t_max,
                                        saveat=saveat, dt0=7, y0=y0_1,
                                        stepsize_controller = stepsize_cont)

        sim2 = diffeqsolve(term2, solver,
                                        t0=0, t1=t_max,
                                        saveat=saveat, dt0=7, y0=y0_2,
                                        stepsize_controller = stepsize_cont)

        # observe
        sim_data1 = g(sim1.ys.T, params_1, param_fixed)
        sim_data2 = g(sim2.ys.T, params_2, param_fixed)

        subX  = (sim_data1 - sim_data2) / (2 * δ[k] * params_infer[k])
        listX.append(subX.tolist())

    X   = onp.matrix(listX)
    FIM = onp.dot(X, X.transpose())
    return FIM


In [3]:
gammas  = [5/100, 25/100, 50/100, 75/100]
kappas  = [1/14, 1/(1*30), 1/(6*30), 1/(12*30)]
betas   = np.linspace(0, 0.1,  20)
tetas   = np.linspace(0, 1e-2, 20)

for idx_g, gamma in enumerate(gammas):
    for idx_k, kappa in enumerate(kappas):

        FIM_arr = onp.full((2, 2, len(betas), len(tetas)), np.nan)

        print(r"Running for $\kappa=${}".format( int(1/(kappa * 30))) )

        # fixed parameters
        params = {
                "γ":  gamma,
                "κ" : kappa,
                "σ":  1/3,
                "δ":  1/3,
                "α":  1/120,
                "Nw": 2000,
                "T" : 0.15 * 2000,
                "ρ" : 5/100
                }

        for idx_b, beta in enumerate(betas):
            for idx_t, teta in enumerate(tetas):

                params_infer = {
                        "β": beta,
                        "θ": teta}

                try:
                    FIM_arr[:, :, idx_b, idx_t] = numFIM(params_infer, params, δ = {"β": 1e-3, "θ": 1e-4})
                except:
                    print(r"Error with $\beta$ = {}, $\teta$ = {}".format(beta, teta))

        onp.savez_compressed(os.path.join(results_dir, f"{str(idx_g)}_gamma_{idx_k}_kappa_fim.npz"),
                                        FIM       = FIM_arr,
                                        gamma     = gamma,
                                        kappa     = kappa,
                                        params    = params)


Running for $\kappa=$0


In [None]:
eig_vals_arr = onp.full((2, len(gammas), len(kappas), len(betas), len(tetas)), np.nan)
det_arr      = onp.full((len(gammas), len(kappas), len(betas), len(tetas)), np.nan)

for idx_g, gamma in enumerate(gammas):
    for idx_k, kappa in enumerate(kappas):
        for idx_b, beta in enumerate(betas):
            for idx_t, teta in enumerate(tetas):

                try:
                    det_arr[idx_g, idx_k, idx_b, idx_t]         = onp.linalg.det(FIM_arr[:, :, idx_g, idx_k, idx_b, idx_t])
                    eig_vals_arr[:, idx_g, idx_k, idx_b, idx_t] = onp.linalg.eigvals(FIM_arr[:, :, idx_g, idx_k, idx_b, idx_t])
                except:
                    det_arr[idx_g, idx_k, idx_b, idx_t]         = np.nan
                    eig_vals_arr[:, idx_g, idx_k, idx_b, idx_t] = np.nan

In [None]:
import sys
sys.path.insert(0, '../')

from utils_local import plot_utils
import seaborn as sns

palette     = sns.color_palette("flare", as_cmap=True)

idx_g = 0

fig, axes = plt.subplots(1, len(kappas), figsize=(15.2, 5.5), sharey=True, sharex=True)

for idx_k in range(len(kappas)):

    ax_hm      = axes[idx_k]
    det_fim_df = pd.DataFrame(det_arr[idx_g, idx_k, :, :], index=betas, columns=tetas)

    sns.heatmap(ax = ax_hm, data=onp.log10(det_fim_df.iloc[::-1]), cmap=palette,  cbar=True)

    xticks         = det_fim_df.columns
    keptxticksidx  = np.linspace(0,len(xticks), 6)
    xtickslabels   = list(xticks[ np.maximum(keptxticksidx.astype(int)-1,0) ])
    xtickslabels   = ['{:.3f}'.format(l) for l in xtickslabels]

    yticks         = det_fim_df.iloc[::-1].index
    keptyticksidx  = np.linspace(0,len(yticks), 6)
    ytickslabels   = list(yticks[ np.maximum(keptyticksidx.astype(int)-1,0) ])
    ytickslabels   = ['{:.2f}'.format(l) for l in ytickslabels]

    ax_hm.set_xticks(keptxticksidx)
    ax_hm.set_yticks(keptyticksidx)

    ax_hm.set_xticklabels(xtickslabels, rotation=0)
    ax_hm.set_yticklabels(ytickslabels)

    ax_hm.tick_params( which='both', axis='x', labelrotation=90)
    ax_hm.tick_params( which='both', axis='both')

fig.supxlabel(r"Environmental transmission rate, $\theta$", y=-0.09)
fig.supylabel(r"Nosocomial transmission rate, $\beta$", x=-0.00001)


In [None]:

fig, ax = plt.subplots(1, len(kappas), figsize=(12.2, 5.5), sharex=True)

for idx_k in range(len(kappas)):
    det = det_arr[idx_g, idx_k, :, :]
    for idx_t in range(len(tetas)):
        ax[idx_k].plot(betas, np.log10(det[:, idx_t]), color=sns.dark_palette("#b285bc",  len(tetas))[idx_t])

fig.supxlabel(r"Nosocomial transmission rate, $\beta$")

ax[0].set_ylabel(r"$\log_{10}(det($FIM$))$")
plt.tight_layout()


In [None]:
fig, axes = plt.subplots(2, len(kappas), figsize=(14.2, 7.5), sharey=True, sharex=True)

for idx_k in range(len(kappas)):

    λ1_fim_df = pd.DataFrame(eig_vals_arr[0, idx_g, idx_k, :, :], index=betas, columns=tetas)
    λ2_fim_df = pd.DataFrame(eig_vals_arr[1, idx_g, idx_k, :, :], index=betas, columns=tetas)

    sns.heatmap(ax = axes[0, idx_k], data=λ1_fim_df.iloc[::-1], cmap=palette,  cbar=True)
    sns.heatmap(ax = axes[1, idx_k], data=onp.log10(λ2_fim_df.iloc[::-1]), cmap=palette,  cbar=True)

for ax in axes.flatten():

    xticks         = det_fim_df.columns
    keptxticksidx  = np.linspace(0,len(xticks),6)
    xtickslabels   = list(xticks[ np.maximum(keptxticksidx.astype(int)-1,0) ])
    xtickslabels   = ['{:.3f}'.format(l) for l in xtickslabels]

    yticks         = det_fim_df.iloc[::-1].index
    keptyticksidx  = np.linspace(0,len(yticks),6)
    ytickslabels   = list(yticks[ np.maximum(keptyticksidx.astype(int)-1,0) ])
    ytickslabels   = ['{:.2f}'.format(l) for l in ytickslabels]

    ax.set_xticks(keptxticksidx)
    ax.set_yticks(keptyticksidx)

    ax.set_xticklabels(xtickslabels, rotation=0)
    ax.set_yticklabels(ytickslabels)

    ax.tick_params( which='both', axis='x', labelrotation=90)
    ax.tick_params( which='both', axis='both')

fig.supxlabel(r"Environmental transmission rate, $\theta$", y=-0.05)
fig.supylabel(r"Nosocomial transmission rate, $\beta$")

plt.tight_layout()