In [1]:
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, LinearInterpolation, Kvaerno5
from jax import grad, jit, vmap
import jax.numpy as np
import pandas as pd
import numpy as onp

import matplotlib.pyplot as plt

In [2]:
def fomite_foi_sc(t, x, params):
    S, C, ε = x

    # p2p
    γ  = params["γ"]
    β  = params["β"]
    σ  = params["σ"]
    δ  = params["δ"]
    α  = params["δ"]
    Nw = params["Nw"]

    # fom
    θ = params["θ"]
    κ = params["κ"]

    λ1 = β * C / Nw
    λ2 = ε

    λ = λ1 + λ2

    sdot = (1-γ) * σ * Nw - λ * S - δ * S + α * C
    idot = γ * σ     + λ * S - δ * C - α * C
    εdot = θ * C / Nw - κ * ε

    return np.array([sdot, idot, εdot])


In [3]:
params_sim_df = pd.DataFrame(columns=["γ", "β", "θ", "κ"])


gammas = [5/100, 15/100, 25/100, 50/100, 75/100, 80/100]
betas  = [0.01, 0.025, 0.05, 0.1]
tetas  = [1e-3, 1.5e-3, 1e-2]
kappas = [1/(1 * 30),  1/(3*30), 1/(6*30), 1/(12*30), 1/(24*30), 1/(36*30)]

idx_sim =1
for g in gammas:
    for κ in kappas:
        for b in betas:
            for θ in tetas:
                    params_sim_df = pd.concat([params_sim_df, pd.DataFrame.from_dict({"γ": g, "β": b, "θ": θ, "κ": κ, "sim_id": int(idx_sim)}, orient="index").T ], ignore_index=True)
                    idx_sim += 1


In [4]:
import diffrax

num_years = 4.5
t_max     = int((365 / 7 ) * num_years )      # simulate for 5 yrs but weekly
tsim      = np.arange(0, t_max+1) # list of days to simulate
num_sims  = len(params_sim_df.sim_id.unique())

solutions_arr = np.array([[],[],[],[],[],[],[],[]])
for idx_row, row in params_sim_df.iterrows():

        params = {
                "β":  row.β * 7,
                "γ":  row.γ,
                "σ":  1/3 * 7,
                "δ":  1/3 * 7,
                "α":  1/120 * 7,
                "Nw": 2000,
                "θ" : row.θ * 7,
                "κ" : row.κ * 7
                }

        saveat = SaveAt(ts=tsim)
        model  = lambda t, x, args: fomite_foi_sc(t, x, params)
        term   = ODETerm(model)
        solver = Dopri5()

        stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)

        c0 = int(row.γ * params["σ"] * params["Nw"])
        ε0 = 1e-4 # row.θ/row.κ *1/(params["α"] + params["δ"])
        y0 = np.array([1800,  c0, ε0 ])

        solution = diffeqsolve(term, solver,
                                        t0=0, t1=t_max,
                                        saveat=saveat, dt0=1, y0=y0,
                                        stepsize_controller = stepsize_controller)

        s, c, ε       = solution.ys.T
        y             = onp.concatenate([[solution.ts], [s], [c], [ε],
                                                [row.γ*np.ones(t_max+1)], [row.κ*np.ones(t_max+1)], [row.β*np.ones(t_max+1)], [row.θ*np.ones(t_max+1)] ], axis=0)
        solutions_arr = np.concatenate([solutions_arr, y], axis=1)
solution_df  = pd.DataFrame(columns=["time", "S", "C", "ε", "γ", "κ", "β", "θ"], data=solutions_arr.T)


In [None]:
import sys

sys.path.insert(0, "../")
from utils_local import plot_utils
import seaborn as sns


In [None]:
g       = gammas[0]
fig, ax = plt.subplots(len(tetas), len(kappas), figsize=(14.5, 9.2), sharex=True, sharey=False)

for idx_k, k in enumerate(kappas):
    ax[0, idx_k].set_title(r"$1/\kappa=${} month".format(int(1/k / 30)))


    for idx_teta, teta in enumerate(tetas):

        sol_plt_df  = solution_df.query(f"γ == {g} and κ == {k} and θ == {teta}")
        sns.lineplot(ax=ax[idx_teta, idx_k], x="time", y="S", data=sol_plt_df, hue="β", palette="Reds")

        #ax_c        = ax[idx_teta, idx_k].twinx()
        #sns.lineplot(ax=ax_c, x="time", y="C", data=sol_plt_df, hue="β", palette="Blues")

for axi in ax.flatten():
    axi.spines['right'].set_visible(False)
    axi.spines['top'].set_visible(False)
    axi.legend().remove()
