In [None]:
from jax import jit, vmap, random
import jax.numpy as np
import numpy as onp
import jax

# enable float 64
from jax.config import config
config.update('jax_enable_x64', True)

In [None]:
from utils_response import create_df_response
import pandas as pd
import datetime


In [None]:
# https://github.com/google/jax/issues/10144
def binomial_transition(xi, τ, dt=1):
    kb    = np.maximum(1.0 - np.exp(-τ*dt), 0)
    pop   = onp.random.binomial(onp.int64(xi), kb )
    return jax.device_put(pop)

def checkpopvars(x, pop):
    return np.clip(x, 0, pop)

def f(t, x, β, γ, N, dt=1):
    """
    Process model
    """
    S = x.at[0, :].get()
    I = x.at[1, :].get()
    R = x.at[2, :].get()
    C = x.at[3, :].get()


    s2i = binomial_transition(S, β * I / N, dt)
    i2r = binomial_transition(I, γ, dt)

    S  = S - s2i
    I  = I + s2i - i2r
    R  = R + i2r
    C  = s2i

    return np.array([S, I, R, C])

def g(t, x, θ):
    """
    Observational model.
    """
    return x.at[3, :].get()

def f0(pop=1e+6, inf_init=1/100, m =300):
    """
    Initial guess of the state space.
    """
    I0 = pop * 0.01
    S0 = pop - inf_init
    R0 = 0
    C0 = 0

    x0  = np.ones((4, m)) * np.expand_dims(np.array([S0, I0, R0, C0]), -1)

    return x0

In [None]:
N = 1000

β_truth = 0.9 # days
γ_truth = 1/7 # days


T    = 70
dt   = 1
ens  = 500
x0    = f0(N, 1/10, m=ens)
x_sim = x0

x_sim = np.full((4, T, ens), np.nan)
x_sim = x_sim.at[:, 0, :].set(x0)

for t in range(1, T):
    x     = f(t, x_sim.at[:, t-1, :].get(), β_truth, γ_truth, N)
    x_sim = x_sim.at[:, t, :].set(x)


C =  x_sim.at[3, :, :].get()

k = jax.random.PRNGKey(1)

obs_use   = C.at[:, onp.random.randint(ens)].get()
obs_use_n = np.maximum(obs_use + np.squeeze(jax.random.normal(k, shape=(T, 1))*0.5), 0)


In [None]:
observation_df          = pd.DataFrame(obs_use_n, columns=['y1'])
observation_df["oev1"]  = 1 +( 0.2 * observation_df["y1"])**2
observation_df["date"] = pd.date_range(start=datetime.datetime(2020, 1, 1), periods=T, freq='D')
observation_df

In [None]:
import jax.numpy as np
import numpy as onp
import jax

from tqdm import tqdm

from utils_probability import sample_uniform, sample_normal
from eakf import check_param_space, check_state_space, eakf

model_settings ={
    "m": 300,
    "p": 2,
    "k": 1,
    "n": 4,
    "param_name": ["β", "γ"],
    "dates": observation_df["date"].values
    }

if_settings = {
   "Nif"                : 100,
   "type_cooling"       : "geometric",
   "shrinkage_factor"   : 0.9,
   "assimilation_dates" : observation_df["date"].values
}

perturbation = None

In [None]:
from ifeakf import cooling

model_settings ={
    "m": 300,
    "p": 2,
    "k": 1,
    "n": 4,
    "param_name": ["β", "γ"],
    "dates": observation_df["date"].values
    }

if_settings = {
   "Nif"                : 100,
   "type_cooling"       : "geometric",
   "shrinkage_factor"   : 0.9,
   "assimilation_dates" : observation_df["date"].values
}

# By definition of the iterated filtering function the process model f and the observational model g input the time t, the state x and the parameters θ, 
# althought in practice the time t or the parameters θ are not used in the model observational model.
# I create a wrapper function to make those parameters in the functions.

f_if = lambda t, x, θ: f(t, x, θ.at[0,:].get(), θ.at[1, :].get(), N)
g_if = lambda t, x, θ: g(t, x, θ)

βmin = 0.2
βmax = 1.5

γmin = 1/20
γmax = 1/4

state_space_range = np.array([[0, N], [0, N], [0, N], [0, N]])
parameters_range = np.array([[βmin, βmax], [γmin, γmax]])

observations_df = observation_df.set_index("date")

In [None]:
cooling_sequence   = cooling(if_settings["Nif"], type_cool=if_settings["type_cooling"], cooling_factor=if_settings["shrinkage_factor"])

k           = model_settings["k"] # Number of observations
p           = model_settings["p"] # Number of parameters (to be estimated)
n           = model_settings["n"] # Number of state variable
m           = model_settings["m"] # Number of stochastic trajectories / particles / ensembles

sim_dates   = model_settings["dates"]
assim_dates = if_settings["assimilation_dates"]

param_range = parameters_range.copy()
std_param   = param_range[:,1] - param_range[:,0]
SIG         = std_param ** 2 / 4; #  Initial covariance of parameters

perturbation = std_param ** 2 / 4


assimilation_times = len(observations_df)

θpost = np.full((p, m, if_settings["Nif"], assimilation_times), np.nan)
θmean = np.full((p, if_settings["Nif"]+1), np.nan)

key     = jax.random.PRNGKey(0)
keys_if = jax.random.split(key, if_settings["Nif"])

In [9]:
from utils_probability import sample_uniform, sample_normal, truncated_normal
from eakf import check_param_space, check_state_space, eakf
from ifeakf import random_walk_perturbation

# n =0
for n in tqdm(range(if_settings["Nif"])):

    if n==0:
        θ     = sample_uniform(keys_if[n], param_range[:,0], param_range[:,1], p, m)
        x     = f0()
        θmean = θmean.at[:, n].set(np.mean(θ, -1))

    else:
        pmean     = θmean.at[:,n].get()
        pvar      = SIG * (if_settings["shrinkage_factor"]**n)**2
        θ         = truncated_normal(keys_if[n], pmean, pvar,  param_range.at[:,0].get(), param_range.at[:,1].get(), p, m)
        x         = f0()

    t_assim = 0
    ycum    = np.zeros((k, m))

    for t, date in enumerate(sim_dates):
        x    = f_if(t, x, θ)
        y    = g_if(t, x, θ)
        ycum += y

        if date == assim_dates[t_assim]:
            date_infer =  assim_dates[t_assim]

            σp = perturbation*cooling_sequence.at[n].get()
            θ  = random_walk_perturbation(jax.random.split(keys_if[n])[0], θ, σp, p, m)

            # Measured observations
            z     = observations_df.loc[date_infer][[f"y{i+1}" for i in range(k)]].values
            oev   = observations_df.loc[date_infer][[f"oev{i+1}" for i in range(k)]].values

            # Update state space
            x, y = eakf(x, ycum, z, oev)
            x    = check_state_space(x, state_space_range)

            # Update parameter space
            θ, y = eakf(θ, ycum, z, oev)
            θ    = check_param_space(keys_if[n], θ, param_range)

            θpost = θpost.at[:, :, n, t_assim].set(θ)

            ycum     = np.zeros((k, m))
            t_assim  += 1

    θtime = θpost.at[:, :, n, :].get()
    θmean = θmean.at[:,n+1].set(θtime.mean(-1).mean(-1)) # average posterior over all assimilation times and them over all IF iterations


KeyboardInterrupt: 

In [None]:

def sample_normal(key, θ_min, θ_max, μ, cov, p, m=300):
    """
    Generate a truncated normal distribution
    """
    return truncated_multivariate_normal(key, μ, cov, shape=(m, p), lower=θ_min, upper=θ_max).T

In [None]:
def truncated_normal(key, mean, sd, lower, upper, p, m, dtype=np.float64):
    """
    Generate a truncated normal distribution
    """
    return jax.random.truncated_normal(key, shape=(m, p), lower= lower, upper=upper, dtype=np.float64) * (pvar)**(1/2) + pmean

truncated_normal(keys_if[n], pmean, pvar,  param_range.at[:,0].get(), param_range.at[:,1].get(), p, m)