In [1]:
from jax import jit, vmap, random
import jax.numpy as np
import numpy as onp
import jax

# enable float 64
from jax.config import config
config.update('jax_enable_x64', True)

In [3]:
from utils_response import create_df_response
import pandas as pd
import datetime


In [4]:
# https://github.com/google/jax/issues/10144
def binomial_transition(xi, τ, dt=1):
    kb    = np.maximum(1.0 - np.exp(-τ*dt), 0)
    pop   = onp.random.binomial(onp.int64(xi), kb )
    return jax.device_put(pop)

def checkpopvars(x, pop):
    return np.clip(x, 0, pop)

def f(t, x, β, γ, N, dt=1):
    """
    Process model
    """
    S = x.at[0, :].get()
    I = x.at[1, :].get()
    R = x.at[2, :].get()
    C = x.at[3, :].get()


    s2i = binomial_transition(S, β * I / N, dt)
    i2r = binomial_transition(I, γ, dt)

    S  = S - s2i
    I  = I + s2i - i2r
    R  = R + i2r
    C  = s2i

    return np.array([S, I, R, C])

def g(t, x, θ):
    """
    Observational model.
    """
    return x.at[3, :].get()

def f0(pop=1e+6, inf_init=1/100, m =300):
    """
    Initial guess of the state space.
    """
    I0 = pop * 0.01
    S0 = pop - inf_init
    R0 = 0
    C0 = 0

    x0  = np.ones((4, m)) * np.expand_dims(np.array([S0, I0, R0, C0]), -1)

    return x0

In [5]:
N = 1000

β_truth = 0.9 # days
γ_truth = 1/7 # days


T    = 70
dt   = 1
ens  = 500
x0    = f0(N, 1/10, m=ens)
x_sim = x0

x_sim = np.full((4, T, ens), np.nan)
x_sim = x_sim.at[:, 0, :].set(x0)

for t in range(1, T):
    x     = f(t, x_sim.at[:, t-1, :].get(), β_truth, γ_truth, N)
    x_sim = x_sim.at[:, t, :].set(x)


C =  x_sim.at[3, :, :].get()

k = jax.random.PRNGKey(1)

obs_use   = C.at[:, onp.random.randint(ens)].get()
obs_use_n = np.maximum(obs_use + np.squeeze(jax.random.normal(k, shape=(T, 1))*0.5), 0)


In [6]:
observation_df          = pd.DataFrame(obs_use_n, columns=['y1'])
observation_df["oev1"]  = 1 +( 0.2 * observation_df["y1"])**2
observation_df["date"] = pd.date_range(start=datetime.datetime(2020, 1, 1), periods=T, freq='D')
observation_df

Unnamed: 0,y1,oev1,date
0,0.000000,1.000000,2020-01-01
1,13.273525,8.047459,2020-01-02
2,16.528501,11.927653,2020-01-03
3,32.555398,43.394156,2020-01-04
4,49.474932,98.910756,2020-01-05
...,...,...,...
65,0.000000,1.000000,2020-03-06
66,0.000000,1.000000,2020-03-07
67,0.000000,1.000000,2020-03-08
68,0.000000,1.000000,2020-03-09


In [8]:
import jax.numpy as np
import numpy as onp
import jax

from tqdm import tqdm

from utils_probability import sample_uniform, sample_normal
from eakf import check_param_space, check_state_space, eakf

model_settings ={
    "m": 300,
    "p": 2,
    "k": 1,
    "n": 4,
    "param_name": ["β", "γ"],
    "dates": observation_df["date"].values
    }

if_settings = {
   "Nif"                : 100,
   "type_cooling"       : "geometric",
   "shrinkage_factor"   : 0.9,
   "assimilation_dates" : observation_df["date"].values
}


In [9]:
from ifeakf import cooling


cooling_sequence   = cooling(if_settings["Nif"], type_cool=if_settings["type_cooling"], cooling_factor=if_settings["shrinkage_factor"])

k           = model_settings["k"] # Number of observations
p           = model_settings["p"] # Number of parameters (to be estimated)
n           = model_settings["n"] # Number of state variable
m           = model_settings["m"] # Number of stochastic trajectories / particles / ensembles

sim_dates   = model_settings["dates"]
assim_dates = if_settings["assimilation_dates"]