<a href="https://colab.research.google.com/github/DKH707/B-ODE-DM/blob/main/covid19_Cook_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1009472
- https://arxiv.org/abs/2302.09125
- https://github.com/stefanradev93/BayesFlow
- https://github.com/bayesflow-org/JANA-Paper
- https://github.com/stefanradev93/AIAgainstCorona/tree/main
- https://github.com/CSSEGISandData/COVID-19/tree/master/csse_covid_19_data/csse_covid_19_time_series

In [None]:
! pip install git+https://github.com/stefanradev93/BayesFlow
! pip install -U ipython-autotime numpy pandas tensorflow wbgapi
get_ipython().kernel.do_shutdown(True)

In [None]:
%reload_ext autotime
import os, warnings, datetime, pathlib, shutil, google.colab, dataclasses, pickle, wbgapi
import numpy as np, pandas as pd, tensorflow as tf
import matplotlib.pyplot as plt, seaborn as sns
from functools import partial
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, LSTM
from bayesflow.networks import InvertibleNetwork, SequentialNetwork
from bayesflow.coupling_networks import CouplingLayer
from bayesflow.simulation import GenerativeModel, Prior, Simulator
from bayesflow.amortizers import AmortizedLikelihood, AmortizedPosterior, AmortizedPosteriorLikelihood
from bayesflow.trainers import Trainer
from bayesflow import default_settings
from bayesflow.helper_functions import build_meta_dict
import bayesflow.diagnostics as diag
from bayesflow.computational_utilities import maximum_mean_discrepancy
warnings.filterwarnings("ignore", message="Could not infer format, so each element will be parsed individually")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
plt.rcParams.update({"text.usetex": False, "font.family": "serif", "text.latex.preamble": r"\usepackage{{amsmath}}"})
mnt = '/content/drive'
google.colab.drive.mount(mnt)
RNG = np.random.default_rng(42)
EPS = 1e-6

class MultiConvLayer(tf.keras.Model):
    """Implements an inception-inspired conv layer using different kernel sizes"""

    def __init__(self, n_filters=32, strides=1):
        super(MultiConvLayer, self).__init__()

        self.convs = [
            tf.keras.layers.Conv1D(
                n_filters // 2,
                kernel_size=f,
                strides=strides,
                padding="causal",
                activation="relu",
                kernel_initializer="glorot_uniform",
            )
            for f in range(2, 8)
        ]
        self.dim_red = tf.keras.layers.Conv1D(
            n_filters, 1, 1, activation="relu", kernel_initializer="glorot_uniform"
        )

    def call(self, x):
        """x is a timeseries of dimensions B timestamps, n_features"""

        out = tf.concat([conv(x) for conv in self.convs], axis=-1)
        out = self.dim_red(out)
        return out


class MultiConvNet(tf.keras.Model):
    """Implements an inception-inspired conv layer using different kernel sizes"""

    def __init__(self, n_layers=3, n_filters=64, strides=1):
        super(MultiConvNet, self).__init__()

        self.net = tf.keras.Sequential(
            [MultiConvLayer(n_filters, strides) for _ in range(n_layers)]
        )

        self.lstm = LSTM(n_filters)

    def call(self, x, **args):
        """x is a timeseries of dimensions B timestamps, n_features"""

        out = self.net(x)
        out = self.lstm(out)
        return out


class SummaryNet(tf.keras.Model):
    def __init__(self, n_summary):
        super(SummaryNet, self).__init__()
        self.net_I = MultiConvNet(n_filters=n_summary // 3)
        self.net_R = MultiConvNet(n_filters=n_summary // 3)
        self.net_D = MultiConvNet(n_filters=n_summary // 3)

    @tf.function
    def call(self, x, **args):
        """x is a timeseries of dimensions B timestamps, n_features"""

        x = tf.split(x, 3, axis=-1)
        x_i = self.net_I(x[0])
        x_r = self.net_R(x[1])
        x_d = self.net_D(x[2])
        return tf.concat([x_i, x_r, x_d], axis=-1)

class MemoryNetwork(tf.keras.Model):
    def __init__(self, meta):
        super(MemoryNetwork, self).__init__()

        self.gru = GRU(meta["n_hidden"], return_sequences=True, return_state=True)
        self.h = meta["n_hidden"]
        self.n_params = meta["n_params"]

    @tf.function
    def call(self, target, condition):
        """Performs a forward pass through the network.

        Params:
        -------
        target    : tf.Tesnor of shape (batch_size, time_stes, dim)
            The time-dependent signal to process.
        condition : tf.Tensor of shape (batch_size, cond_dim)
            The conditional (static) variables, e.g., parameters.
        """
        shift_target = target[:, :-1, :]
        init = tf.zeros((target.shape[0], 1, target.shape[2]))
        inp_teacher = tf.concat([init, shift_target], axis=1)
        inp_teacher_c = tf.concat([inp_teacher, condition], axis=-1)
        out, _ = self.gru(inp_teacher_c)
        return out

    def step_loop(self, target, condition, state):
        out, new_state = self.gru(
            tf.concat([target, condition], axis=-1), initial_state=state
        )
        return out, new_state

class InvertibleNetworkWithMemory(tf.keras.Model):
    """Implements a chain of conditional invertible blocks for Bayesian parameter inference."""

    def __init__(
        self,
        num_params,
        num_coupling_layers=4,
        coupling_settings=None,
        coupling_design="affine",
        permutation="fixed",
        use_act_norm=True,
        act_norm_init=None,
        use_soft_flow=False,
        soft_flow_bounds=(1e-3, 5e-2),
    ):
        """Initializes a custom invertible network with recurrent memory."""

        super().__init__()

        # Create settings dict for coupling layer
        settings = dict(
            latent_dim=num_params,
            coupling_settings=coupling_settings,
            coupling_design=coupling_design,
            permutation=permutation,
            use_act_norm=use_act_norm,
            act_norm_init=act_norm_init,
        )

        # Create sequence of coupling layers and store reference to dimensionality
        self.coupling_layers = [
            CouplingLayer(**settings) for _ in range(num_coupling_layers)
        ]

        # Store attributes
        self.soft_flow = use_soft_flow
        self.soft_low = soft_flow_bounds[0]
        self.soft_high = soft_flow_bounds[1]
        self.use_act_norm = use_act_norm
        self.latent_dim = num_params
        self.dynamic_summary_net = MemoryNetwork({"n_hidden": 256, "n_params": 3})
        self.latent_dim = num_params

    def call(self, targets, condition, inverse=False):
        """Performs one pass through an invertible chain (either inverse or forward).

        Parameters
        ----------
        targets   : tf.Tensor
            The estimation quantities of interest, shape (batch_size, ...)
        condition : tf.Tensor
            The conditional data x, shape (batch_size, summary_dim)
        inverse   : bool, default: False
            Flag indicating whether to run the chain forward or backwards

        Returns
        -------
        (z, log_det_J)  :  tuple(tf.Tensor, tf.Tensor)
            If inverse=False: The transformed input and the corresponding Jacobian of the transformation,
            v shape: (batch_size, ...), log_det_J shape: (batch_size, ...)

        target          :  tf.Tensor
            If inverse=True: The transformed out, shape (batch_size, ...)

        Important
        ---------
        If ``inverse=False``, the return is ``(z, log_det_J)``.\n
        If ``inverse=True``, the return is ``target``.
        """

        if inverse:
            return self.inverse(targets, condition)
        return self.forward(targets, condition)

    @tf.function
    def forward(self, targets, condition, **kwargs):
        """Performs a forward pass though the chain."""

        # Add memory condition
        memory = self.dynamic_summary_net(targets, condition)
        condition = tf.concat([memory, condition], axis=-1)

        z = targets
        log_det_Js = []
        for layer in self.coupling_layers:
            z, log_det_J = layer(z, condition, **kwargs)
            log_det_Js.append(log_det_J)
        # Sum Jacobian determinants for all layers (coupling blocks) to obtain total Jacobian.
        log_det_J = tf.add_n(log_det_Js)
        return z, log_det_J

    @tf.function
    def inverse(self, z, condition, **kwargs):
        """Performs a reverse pass through the chain."""

        target = z
        T = z.shape[1]
        gru_inp = tf.zeros((z.shape[0], 1, z.shape[-1]))
        state = tf.zeros((z.shape[0], self.dynamic_summary_net.h))
        outs = []
        for t in range(T):
            # One step condition
            memory, state = self.dynamic_summary_net.step_loop(
                gru_inp, condition[:, t : t + 1, :], state
            )
            condition_t = tf.concat([memory, condition[:, t : t + 1, :]], axis=-1)
            target_t = target[:, t : t + 1, :]
            for layer in reversed(self.coupling_layers):
                target_t = layer(target_t, condition_t, inverse=True, **kwargs)
            outs.append(target_t)
            gru_inp = target_t
        return tf.concat(outs, axis=1)


@dataclasses.dataclass
class COVID():
    # country: str = 'Germany'
    name: str = 'covid_000'
    n_steps: int = 100
    n_calibrate: int = 5000
    refresh: bool = False

    def read_or_create(self, file, fun=None, refresh=False):
        """Read or create pickle for simulation results"""
        try:
            if refresh:
                file.unlink(missing_ok=True)
            with open(file, "rb") as f:
                sims = pickle.load(f)
            print(f'{file} successfully read')
        except Exception as e:
            print(f'Running sims: {e}')
            with open(file, "wb") as f:
                sims = fun()
                pickle.dump(sims, f)
        return sims


    def check_params(self, p):
        for key in ['E_0','sim_diff','t_1','t_2','t_3','t_4','t_5','delta_1','delta_2','delta_3','delta_4','lag_I','lag_R','lag_D']:
            p[key] = int(round(p[key]))
        p['E_0'] = max(p['E_0'], 1)
        if all([
            all(val > EPS for key, val in p.items() if key[:3] != 'phi'),
            p['alpha'] < 1 - EPS,
            p['delta'] < 1 - EPS,
            p['sim_diff'] > max(p['lag_I'],p['lag_R'],p['lag_D']),
            *[p[f't_{i}'] + p[f'delta_{i}'] <= p[f't_{i+1}'] for i in range(1,5)],
        ]):
            return p

    def prior_fun(self):
        alpha_f = (0.7**2) * ((1 - 0.7) / (0.17**2) - (1 - 0.7))
        beta_f = alpha_f * (1 / 0.7 - 1)
        while True:
            p = self.check_params({
                'N'       :86e6,
                'E_0'     :RNG.gamma(shape=2, scale=30),
                'alpha'   :RNG.uniform(low=0.005, high=0.99),
                'beta'    :RNG.lognormal(mean=np.log(0.25), sigma=0.3),
                'gamma'   :RNG.lognormal(mean=np.log(1/6.5), sigma=0.5),
                'delta'   :RNG.uniform(low=0.01, high=0.3),
                'epsilon' :RNG.uniform(low=1/14, high=1/3),
                'eta'     :RNG.lognormal(mean=np.log(1/3.2), sigma=0.5),
                'lambda'  :RNG.lognormal(mean=np.log(1.2), sigma=0.5),
                'mu'      :RNG.lognormal(mean=np.log(1/8), sigma=0.2),
                'theta'   :RNG.uniform(low=1/14, high=1/3),
                'sim_diff':16,
                't_1'     :RNG.normal(loc=8, scale=3),
                't_2'     :RNG.normal(loc=15, scale=3),
                't_3'     :RNG.normal(loc=22, scale=3),
                't_4'     :RNG.normal(loc=66, scale=3),
                't_5'     :self.n_steps,
                'delta_1' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_2' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_3' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_4' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'lambda_0':RNG.lognormal(mean=np.log(1.20), sigma=0.5),
                'lambda_1':RNG.lognormal(mean=np.log(0.60), sigma=0.5),
                'lambda_2':RNG.lognormal(mean=np.log(0.30), sigma=0.5),
                'lambda_3':RNG.lognormal(mean=np.log(0.10), sigma=0.5),
                # 'lambda_4':RNG.lognormal(mean=np.log(0.10), sigma=0.5),
                'lambda_4':RNG.lognormal(mean=np.log(0.15), sigma=0.5),
                'A_I'     :RNG.beta(a=alpha_f, b=beta_f),
                'A_R'     :RNG.beta(a=alpha_f, b=beta_f),
                'A_D'     :RNG.beta(a=alpha_f, b=beta_f),
                'phi_I'   :RNG.vonmises(mu=0, kappa=0.01),
                'phi_R'   :RNG.vonmises(mu=0, kappa=0.01),
                'phi_D'   :RNG.vonmises(mu=0, kappa=0.01),
                'lag_I'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'lag_R'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'lag_D'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'sigma_I' :RNG.gamma(shape=1, scale=5),
                'sigma_R' :RNG.gamma(shape=1, scale=5),
                'sigma_D' :RNG.gamma(shape=1, scale=5),
            })
            if p:
                return p

    def calc_lambda_array(self, p):
        """Computes the array of time-varying contact rates/transimission probabilities."""
        # Array of initial lambdas
        lambd0_arr = np.array([p['lambda_0']] * (p['t_1'] + p['sim_diff'] - 1))

        # Compute lambd1 array
        if p['delta_1'] == 1:
            lambd1_arr = np.array([p['lambda_1']] * (p['t_2'] - p['t_1']))
        else:
            lambd1_arr = np.linspace(p['lambda_0'], p['lambda_1'], p['delta_1'])
            lambd1_arr = np.append(lambd1_arr, [p['lambda_1']] * (p['t_2'] - p['t_1'] - p['delta_1']))

        # Compute lambd2 array
        if p['delta_2'] == 1:
            lambd2_arr = np.array([p['lambda_2']] * (p['t_3'] - p['t_2']))
        else:
            lambd2_arr = np.linspace(p['lambda_1'], p['lambda_2'], p['delta_2'])
            lambd2_arr = np.append(lambd2_arr, [p['lambda_2']] * (p['t_3'] - p['t_2'] - p['delta_2']))

        # Compute lambd3 array
        if p['delta_3'] == 1:
            lambd3_arr = np.array([p['lambda_3']] * (p['t_4'] - p['t_3']))
        else:
            lambd3_arr = np.linspace(p['lambda_3'], p['lambda_4'], p['delta_3'])
            lambd3_arr = np.append(lambd3_arr, [p['lambda_3']] * (p['t_4'] - p['t_3'] - p['delta_3']))

        # Compute lambd4 array
        if p['delta_4'] == 1:
            lambd4_arr = np.array([p['lambda_4']] * (p['t_5'] - p['t_4']))
        else:
            lambd4_arr = np.linspace(p['lambda_3'], p['lambda_4'], p['delta_4'])
            lambd4_arr = np.append(lambd4_arr, [p['lambda_4']] * (p['t_5'] - p['t_4'] - p['delta_4']))

        return np.r_[lambd0_arr, lambd1_arr, lambd2_arr, lambd3_arr, lambd4_arr]


    def prep_params(self, params):
        try:
            p = self.const_params | params  # if params was passed as dict
        except:
            p = self.const_params | dict(zip(self.param_names, params))  # if prior_values was passed as list
        return self.check_params(p)


    def sir(self, prior_draw):
        p = self.prep_params(prior_draw)
        assert p
        sim_lag = p['sim_diff'] - 1
        lambd_arr = self.calc_lambda_array(p)

        # Initial conditions
        S, E, C, I, R, D = [p['N'] - p['E_0']], [p['E_0']], [0], [0], [0], [0]

        # Containers
        I_news = []
        R_news = []
        D_news = []

        # Reported new cases
        I_data = np.zeros(p['t_5'])
        R_data = np.zeros(p['t_5'])
        D_data = np.zeros(p['t_5'])
        fs_I = np.zeros(p['t_5'])
        fs_R = np.zeros(p['t_5'])
        fs_D = np.zeros(p['t_5'])

        # Simulate T-1 tiemsteps
        for t in range(p['t_5'] + sim_lag):
            # Calculate new exposed cases
            E_new = lambd_arr[t] * ((C[t] + p['beta'] * I[t]) / p['N']) * S[t]

            # Remove exposed from susceptible
            S_t = S[t] - E_new

            # Calculate current exposed by adding new exposed and
            # subtracting the exposed becoming carriers.
            E_t = E[t] + E_new - p['gamma'] * E[t]

            # Calculate current carriers by adding the new exposed and subtracting
            # those who will develop symptoms and become detected and those who
            # will go through the disease asymptomatically.
            C_t = C[t] + p['gamma'] * E[t] - (1 - p['alpha']) * p['eta'] * C[t] - p['alpha'] * p['theta'] * C[t]

            # Calculate current infected by adding the symptomatic carriers and
            # subtracting the dead and recovered. The newly infected are just the
            # carriers who get detected.
            I_t = I[t] + (1 - p['alpha']) * p['eta'] * C[t] - (1 - p['delta']) * p['mu'] * I[t] - p['delta'] * p['epsilon'] * I[t]
            I_new = (1 - p['alpha']) * p['eta'] * C[t]

            # Calculate current recovered by adding the symptomatic and asymptomatic
            # recovered. The newly recovered are only the detected recovered
            R_t = R[t] + p['alpha'] * p['theta'] * C[t] + (1 - p['delta']) * p['mu'] * I[t]
            R_new = (1 - p['delta']) * p['mu'] * I[t]

            # Calculate the current dead
            D_t = D[t] + p['delta'] * p['epsilon'] * I[t]
            D_new = p['delta'] * p['epsilon'] * I[t]

            # Ensure some numerical onstraints
            S_t = np.clip(S_t, 0, p['N'])
            E_t = np.clip(E_t, 0, p['N'])
            C_t = np.clip(C_t, 0, p['N'])
            I_t = np.clip(I_t, 0, p['N'])
            R_t = np.clip(R_t, 0, p['N'])
            D_t = np.clip(D_t, 0, p['N'])

            # Keep track of process over time
            S.append(S_t)
            E.append(E_t)
            C.append(C_t)
            I.append(I_t)
            R.append(R_t)
            D.append(D_t)
            I_news.append(I_new)
            R_news.append(R_new)
            D_news.append(D_new)

            # From here, start adding new cases with delay D
            # Note, we assume the same delay
            if t >= sim_lag:
                # Compute lags and add to data arrays
                fs_I[t - sim_lag] = (1 - p['A_I']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * p['phi_I']))
                )
                fs_R[t - sim_lag] = (1 - p['A_R']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * p['phi_R']))
                )
                fs_D[t - sim_lag] = (1 - p['A_D']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * p['phi_D']))
                )
                I_data[t - sim_lag] = I_news[t - p['lag_I']]
                R_data[t - sim_lag] = R_news[t - p['lag_R']]
                D_data[t - sim_lag] = D_news[t - p['lag_D']]

        # Compute weekly modulation
        I_data = (1 - fs_I) * I_data
        R_data = (1 - fs_R) * R_data
        D_data = (1 - fs_D) * D_data

        # Add noise
        I_data = I_data + RNG.standard_t(4) * np.sqrt(I_data) * p['sigma_I']
        R_data = R_data + RNG.standard_t(4) * np.sqrt(R_data) * p['sigma_R']
        D_data = D_data + RNG.standard_t(4) * np.sqrt(D_data) * p['sigma_D']
        n = I_data.shape[0]
        return (
            pd.DataFrame({'S':S[-n:],'E':E[-n:],'C':C[-n:],'I':I[-n:],'R':R[-n:],'D':D[-n:],'dI_obs':I_data,'dR_obs':R_data,'dD_obs':D_data})
            .clip(0, p['N']).rename_axis('t'))

    def calibrate(self):
        c = self.generator(self.n_calibrate)
        d = {'prior':c['prior_draws'], 'data':c['sim_data']}
        funs = [np.mean, np.std, np.min, np.max]
        c |= {f'{key}_{f.__name__}': f(val, axis=0) for f in funs for key, val in d.items()}
        c |= {f'obs_{f.__name__}': c[f'data_{f.__name__}'][...,-self.n_obs:] for f in funs}
        return c




    def __post_init__(self):
        # self.iso = wbgapi.economy.coder(self.country)
        # assert self.iso, f'Unrecognized country {self.country}'
        # self.tot_pop = next(wbgapi.data.fetch('SP.POP.TOTL', economy=self.iso, time=2020))['value']
        # self.load_data()

        self.root_path = pathlib.Path(mnt + f'/MyDrive/bayesian disease modeling/{self.name}')
        if self.refresh:
            # delete root_path and everything in it
            shutil.rmtree(self.root_path, ignore_errors=True)
        # self.cntry_path = self.root_path / self.country
        self.model_path = self.root_path / f'model/'
        # self.cntry_path.mkdir(exist_ok=True, parents=True)
        self.model_path.mkdir(exist_ok=True, parents=True)
        self.calibration_file = self.model_path / 'calibration.pkl'
        self.diagnostic_file  = self.model_path / 'diagnostic.pkl'
        # self.predictive_file  = self.cntry_path / 'predictive.pkl'


        # self.file = {key: self.model_path / f'{key}.pkl' for key in ['calibration','diagnostic','predictive','error']}
        # self.file = {key: self.model_path / f'{key}.pkl' for key in ['calibration','diagnostic','predictive','error']}
        # self.iso = wbgapi.economy.coder(self.country)
        # assert self.iso, f'Unrecognized country {self.country}'
        # self.tot_pop = next(wbgapi.data.fetch('SP.POP.TOTL', economy=self.iso, time=2020))['value']
        # self.path = self.root_path / self.iso
        # self.path.mkdir(exist_ok=True)
        # self.load_data()

        prior_samples = [self.prior_fun() for _ in range(100)]
        self.const_params = pd.DataFrame(prior_samples).agg(['mean','std']).T.query('std < 1e-5')['mean'].to_dict()
        self.param_names = [key for key in prior_samples[0].keys() if key not in self.const_params]
        self.param_latex = [f'${key}$' if key in ['N','E_0','sim_diff','t_1','t_2','t_3','t_4','A_I','A_R','A_D','lag_I','lag_R','lag_D'] else f'$\{key}$' for key in self.param_names]
        self.n_params = len(self.param_names)
        self.obs_classes = self.sir(prior_samples[0]).filter(like='obs').columns.tolist()
        self.n_obs = len(self.obs_classes)

        self.prior = Prior(prior_fun=lambda: [val for key, val in self.prior_fun().items() if key in self.param_names], param_names=self.param_names)
        self.simulator = Simulator(simulator_fun=self.sir)
        self.generator = GenerativeModel(self.prior, self.simulator)
        self.calibration = self.read_or_create(file=self.calibration_file, fun=self.calibrate)

        # summary_net = SequentialNetwork()
        # inference_net = InvertibleNetwork(num_params=len(self.param_names), num_coupling_layers=3)
        coupling_settings = {
            "dense_args": dict(units=128, activation="swish", kernel_regularizer=None),
            "num_dense": 2,
            "dropout": False,
        }
        summary_net = SummaryNet(n_summary=192)
        inference_net = InvertibleNetwork(
            num_params=len(self.param_names),
            num_coupling_layers=6,
            coupling_settings=coupling_settings,
        )
        self.model = Trainer(
            generative_model = self.generator,
            configurator = self.pre_amortizer,
            amortizer = AmortizedPosterior(summary_net=summary_net, inference_net=inference_net),# summary_loss_fun="MMD"),
            checkpoint_path = self.model_path, max_to_keep = 3, #memory = True, memory is broken for now
            skip_checks = True,
        )

    def pre_amortizer(self, dct):
        if 'sim_data' in dct:
            data  = np.array(dct['sim_data'])[...,-self.n_obs:]
            prior = np.array(dct['prior_draws'])
        elif 'real_data' in dct:
            data  = np.array(dct['real_data'])
            prior = np.full([1,self.n_params], np.nan)
        else:
            raise Exception
        dct['summary_conditions'] = np.float32((data - self.calibration['obs_mean']) / self.calibration['obs_std'])
        dct['parameters'] = np.float32((prior - self.calibration['prior_mean']) / self.calibration['prior_std'])
        return dct

    def post_amortizer(self, dct):
        for _ in range(3-dct['parameters_out'].ndim):
            dct['parameters_out'] = dct['parameters_out'][np.newaxis]
        dct['posterior_draws'] = dct['parameters_out'] * self.calibration['prior_std'] + self.calibration['prior_mean']
        return dct

    def draw_samples(self, n_posteriors, n_priors=None, real_data=None, ensemble=True):
        if real_data is not None:
            samples = {'real_data':[real_data]}
        elif n_priors is not None:
            samples = self.model.generative_model(n_priors)
        else:
            raise Exception('Must specify n_priors or real_data')
        samples = self.model.configurator(samples)
        samples['parameters_out'] = self.model.amortizer.sample(samples, n_posteriors)
        samples = self.post_amortizer(samples)
        if ensemble:
            samples['valid_draws'] = [[p for p in d if self.prep_params(p) is not None] for d in samples['posterior_draws']]
            samples['posterior'] = pd.concat([pd.concat([pd.DataFrame([p], columns=self.param_names).assign(prior_idx=i, posterior_idx=j).set_index(['prior_idx','posterior_idx']) for j, p in enumerate(D)]) for i, D in enumerate(samples['valid_draws'])])
            samples['ensemble'] = pd.concat([self.sir(p).reset_index().assign(prior_idx=i, posterior_idx=j).set_index(['prior_idx','posterior_idx','t']) for (i,j),p in samples['posterior'].iterrows()])
            if real_data is not None:
                R = real_data.reset_index().rename_axis('t')
                samples['ensemble'] = samples['ensemble'].join(R).set_index('date',append=True)
        return samples

    def get_diagnostics(self, n_posteriors=50, refresh=False):
        self.diagnostic = self.read_or_create(
            file=self.diagnostic_file, refresh=refresh,
            fun=lambda: self.draw_samples(n_posteriors=n_posteriors, n_priors=21*n_posteriors, ensemble=False))
        self.plot('loss')
        self.plot('ecdf', self.diagnostic)
        self.plot('hist', self.diagnostic)
        self.plot('recovery', self.diagnostic)


    def plot(self, kind='loss', samples=None):
        if kind == 'loss':
            fig = diag.plot_losses(**self.model.loss_history.get_plottable())
        else:
            opts = {'param_names':self.param_latex, 'post_samples':samples['parameters_out'], 'prior_samples':samples['parameters']}
            if kind == 'ecdf':
                fig = diag.plot_sbc_ecdf(**opts)
            elif kind == 'hist':
                fig = diag.plot_sbc_histograms(**opts)
            elif kind == 'recovery':
                fig = diag.plot_recovery(**opts)
            else:
                raise Exception(f'Unrecognized kind "{kind}"')
        fig.savefig(self.model_path / f'{kind}.png')
        plt.show()


    # def load_data(self, country, start=39):
    #     """Download and prepare data from Johns Hopkins"""
    #     url = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_'
    #     fetch = lambda cl: pd.read_csv(url+cl+'_global.csv', sep=",").drop(columns=['Province/State','Lat','Long']).groupby('Country/Region').sum().loc[country]
    #     real_data = (pd.DataFrame({'dI_real':fetch('confirmed'), 'dR_real':fetch('recovered'), 'dD_real':fetch('deaths')})
    #         .assign(date = lambda x: pd.to_datetime(x.index)).set_index('date')
    #         .diff().dropna().clip(0, self.tot_pop).astype(int).iloc[start:start+self.n_steps])
    #     return real_data

    def fetch_data(self, cntry, start=39):
        iso = wbgapi.economy.coder(cntry)
        assert iso, f'Unrecognized country {cntry}'
        tot_pop = next(wbgapi.data.fetch('SP.POP.TOTL', economy=iso, time=2020))['value']

        url = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_'
        fetch = lambda cl: pd.read_csv(url+cl+'_global.csv', sep=",").drop(columns=['Province/State','Lat','Long']).groupby('Country/Region').sum().loc[cntry]
        real_data = (
            pd.DataFrame({'dI_real':fetch('confirmed'), 'dR_real':fetch('recovered'), 'dD_real':fetch('deaths')})
            .assign(date = lambda x: pd.to_datetime(x.index)).set_index('date')
            .diff().dropna().clip(0, tot_pop).astype(int).iloc[start:start+self.n_steps])
        return dict(cntry=cntry, iso=iso, tot_pop=tot_pop, real_data=real_data)


    def get_predictive(self, cntry, n_posteriors=500, refresh=False):
        cntry_path = self.root_path / cntry
        cntry_path.mkdir(exist_ok=True, parents=True)
        predictive_file = cntry_path / 'predictive.pkl'
        dct = self.fetch_data(cntry)
        predictive = self.read_or_create(
            file=predictive_file, refresh=refresh,
            fun=lambda: self.draw_samples(n_posteriors=n_posteriors, real_data=dct['real_data'], ensemble=True))

        fig, ax = plt.subplots(self.n_obs, 1, sharex=True, figsize=(20,10))
        fig.suptitle(cntry)
        E = predictive['ensemble'].groupby('date')
        for i, obs in enumerate(self.obs_classes):
            real = obs.replace('obs','real')
            ax[i].plot(E[real].median(), 'k.', label='Real')
            for ci, clr in {90:'green', 50:'red', 10:'blue'}.items():
                x = (100 - ci) / 200
                lb = E[obs].quantile(x)
                ub = E[obs].quantile(1-x)
                ax[i].fill_between(x=lb.index, y1=lb, y2=ub, color=clr, alpha=0.3, label=f'{ci}% Prediction Interval')
            ax[i].legend()
            ax[i].set_title(obs[:2])
        fig.savefig(cntry_path / f'predictive.png')
        plt.show()

        ode_params = [
            'alpha','beta','gamma','delta','epsilon','eta','mu','theta',
            'lambda_1','lambda_2','lambda_3','lambda_4','t_1','t_2','t_3','t_4',]
        tx = [f'${key}$' if key in ['N','E_0','sim_diff','t_1','t_2','t_3','t_4','A_I','A_R','A_D','lag_I','lag_R','lag_D'] else f'$\{key}$' for key in ode_params]
        pst = predictive['posterior'].assign(kind='posterior')
        pir = pd.DataFrame(self.prior(pst.shape[0])['prior_draws'], columns=self.param_names).assign(kind='prior')
        Q = pd.concat([pir,pst]).set_index('kind')[ode_params]
        Q.columns = tx
        M = Q.melt(ignore_index=False).reset_index()
        sns.set_palette("Set2")
        fig = sns.FacetGrid(M, hue='kind', col='variable', col_wrap=4, sharex=False, sharey=False)
        fig.map(sns.histplot, 'value', kde=True, element='step', alpha=0.7)
        fig.add_legend()
        fig.savefig(cntry_path / f'distributions.png')
        plt.show()
        return predictive


self = COVID(
    name='radev_model_03',
    # refresh=True,
    # n_calibrate = 50,
)
# h = self.model.train_online(epochs=1000, iterations_per_epoch=32*10, batch_size=32, validation_sims=500)

In [None]:
cntry = 'Germany'
pred = self.get_predictive(cntry=cntry)#, refresh=True)


In [None]:
p = Prior(prior_fun=lambda: [val for key, val in self.prior_fun().items() if key in ode_params])
p(10)
pred['posterior'][ode_params]

In [None]:
ode_params = [
    'alpha','beta','gamma','delta','epsilon','eta','mu','theta',
    'lambda_0','lambda_1','lambda_2','lambda_3','lambda_4','t_1','t_2','t_3','t_4',]
f = diag.plot_posterior_2d(
    posterior_draws=pred['posterior'][ode_params],
    prior=Prior(prior_fun=lambda: [val for key, val in self.prior_fun().items() if key in ode_params]),
    param_names=ode_params,
    )

In [None]:
sns.displot(pred['posterior'])

In [None]:
ode_params = [
    'alpha','beta','gamma','delta','epsilon','eta','mu','theta',
    'lambda_1','lambda_2','lambda_3','lambda_4','t_1','t_2','t_3','t_4',

    ]
tx = [f'${key}$' if key in ['N','E_0','sim_diff','t_1','t_2','t_3','t_4','A_I','A_R','A_D','lag_I','lag_R','lag_D'] else f'$\{key}$' for key in ode_params]
pst = pred['posterior'].assign(kind='posterior')
pir = pd.DataFrame(self.prior(post.shape[0])['prior_draws'], columns=self.param_names).assign(kind='prior')
Q = pd.concat([pir,pst]).set_index('kind')[ode_params]
Q.columns = tx
M = Q.melt(ignore_index=False).reset_index()
sns.set_palette("Set2")
g = sns.FacetGrid(M, hue='kind', col='variable', col_wrap=4, sharex=False, sharey=False)
g.map(sns.histplot, 'value', kde=True, element='step', alpha=0.7)
g.add_legend()


In [None]:
pred['posterior'].shape

In [None]:
diag.plot_posterior_2d?

In [None]:
df = pd.read_csv('https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv', sep=",")
for cntry in df['Country/Region'].sample(frac=1).unique():
# for cntry in ['Germany','US','Israel','Italy','Sweden']:
    try:
        self.get_predictive(cntry=cntry)
    except Exception as e:
        print(e)

In [None]:
self.get_diagnostics(refresh=True)

In [None]:
# pd.read_csv(f
# 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_{cl}_global.csv', sep=",")
url = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_'
cl = 'deaths'
df = pd.read_csv(url+cl+'_global.csv', sep=",").drop(columns=['Province/State','Lat','Long']).groupby('Country/Region').sum().loc[self.country]
df#.query('`Country/Region` == @self.country')


In [None]:
country = 'US'
iso = wbgapi.economy.coder(country)
print(iso)
tot_pop = next(wbgapi.data.fetch('SP.POP.TOTL', economy=self.iso, time=2020))['value']
print(tot_pop)

In [None]:
# self.get_diagnostics(refresh=True)

In [None]:
self.get_predictive(refresh=True)

In [None]:
E = self.predictive['ensemble']
# E['dD_obs'].describe()
# E.join(self.real_data.reset_index())
# self.real_data.reset_index().rename_axis('t')
# E.join(self.real_data.reset_index().rename_axis('t')).set_index('date',append=True)

In [None]:
plt.figure?

In [None]:
fig, ax = plt.subplots(self.n_obs, 1, sharex=True, figsize=(20,10))
fig.suptitle(self.country)
E = self.predictive['ensemble'].groupby('date')
for i, obs in enumerate(self.obs_classes):
    real = obs.replace('obs','real')
    ax[i].plot(E[real].median(), 'k.', label='Real')
    for ci, clr in {90:'green', 50:'red', 10:'blue'}.items():
        x = (100 - ci) / 200
        lb = E[obs].quantile(x)
        ub = E[obs].quantile(1-x)
        ax[i].fill_between(x=lb.index, y1=lb, y2=ub, color=clr, alpha=0.3, label=f'{ci}% Prediction Interval')
    ax[i].legend()
    ax[i].set_title(obs[:2])
fig.savefig(self.cntry_path / f'predictive.png')
        # rl = obs.replace('obs','real')
        # pr


In [None]:
plt.plot?

In [None]:
ax[0].fill_between?

In [None]:
# for obs in self.obs_classes:
#     e = E.groupby('t')[obs]
#     for ci in [75, 95]:
#         x = (100-ci) / 200
#         lb = e.quantile(x)
#         ub = e.quantile(1-x)
#         plt.plot(lb)



    # .agg(median='median', mean='mean', a=quantile)
#     e = E.groupby('t')[obs].quantile([0.25,0.75])
#     e = E.groupby('t')[obs].describe(percentiles=[0.5,0.9])

    # q50 = e.median()
    # e

    # plt.plot(Q[obs],label=obs)
    # rl = obs.replace('obs','real')
    # plt.plot(Q[rl],label=rl)
    # plt.legend()
    plt.show()

# Q.plot()

In [None]:
samples['posterior_draws'].shape
samples['parameters_out'].shape


In [None]:
P = pd.concat([pd.concat([pd.DataFrame([p], columns=self.param_names).assign(prior_idx=prior_idx, posterior_idx=posterior_idx).set_index(['prior_idx','posterior_idx']) for posterior_idx, p in enumerate(P) if self.prep_params(p) is not None]) for prior_idx, P in enumerate(samples['posterior_draws'])])
P

In [None]:
samples['valid_draws'] = [[p for p in d if self.prep_params(p) is not None] for d in samples['posterior_draws']]
samples['posterior'] = pd.concat([pd.concat([pd.DataFrame([p], columns=self.param_names).assign(prior_idx=i, posterior_idx=j).set_index(['prior_idx','posterior_idx']) for j, p in enumerate(d)]) for i, d in enumerate(samples['valid_draws'])])
samples['ensemble'] = pd.concat([self.sir(p).reset_index().assign(prior_idx=i, posterior_idx=j).set_index(['prior_idx','posterior_idx','t']) for (i,j),p in samples['posterior'].iterrows()])
E

In [None]:
P = pd.concat([pd.concat([pd.DataFrame([p], columns=self.param_names).assign(prior_idx=prior_idx, posterior_idx=posterior_idx).set_index(['prior_idx','posterior_idx']) for posterior_idx, p in enumerate(P)]) for prior_idx, P in enumerate(samples['posterior_draws'])])
# self.check_params(P.iloc[0].values)
# dict(zip(self.param_names,P.iloc[0].values))
# self.prep_params(P.iloc[0])
# P.to_dict('records')[0]

g = lambda p: self.prep_params(p) is not None
# P['valid'] = P.apply(g, axis=1)
# P.query('valid')
P[P.apply(g, axis=1)]

In [None]:
self.prior(1)['prior_draws'].shape

In [None]:
n_posteriors = 10
samples = self.pre_amortizer({'real_data':self.load_data()})
samples['parameters_out'] = self.model.amortizer.sample(samples, n_posteriors)

In [None]:
samples = self.draw_samples(n_priors=5, n_posteriors=7)

In [None]:
self.plot('recovery',samples)

In [None]:
self.plot('loss')

In [None]:
self.

In [None]:
samples = self.model.configurator(self.model.generative_model(300))
samples['parameters_out'] = self.model.amortizer.sample(samples, 500)
samples['posterior_draws'] = samples['parameters_out'] * self.calibration['prior_std'] + self.calibration['prior_mean']
# samples = self.post_amortizer(samples)
samples.keys()

In [None]:
# self.calibration['prior_draws'].shape, self.calibration['sim_data'].shape
# f = diag.plot_sbc_ecdf(samples['posterior_draws'], samples["parameters"], param_names=self.param_names)
f = diag.plot_recovery(post_samples=samples['parameters_out'], prior_samples=samples['parameters'], param_names=self.param_names)

#   (samples['posterior_draws'], samples["parameters"]

In [None]:
def generate_samples(self, n_priors=2, n_sims=None, n_samples=1):
    n_sims = n_sims if n_sims else 20*n_priors
    samples = self.generate_data(n_priors=n_priors, n_sims=n_sims)
    samples = self.pre_amortizer(samples)
    samples['parameters_out'] = self.model.amortizer.sample(samples, n_samples=n_samples)
    samples = self.post_amortizer(samples)
    return samples
    # return self.post_amortizer(dct)

In [None]:
w = self.generator(5)
'sim_data' in w
# hasattr(w,'sim_data')
# w.keys()


In [None]:
f = np.mean
dir(f)
f.__name__

In [None]:
c = self.calibration
c.keys()
c['obs_mean'].shape

In [None]:
a = RNG.uniform(size=[2,3])
fun = np.mean
fun(a,axis=1)
np.apply_over_axes('sum', a, 0)

In [None]:
self.generator(10)['prior_draws'].shape, self.generator(10)['sim_data'].shape

In [None]:
def calibrate():
    c = self.generator(n_priors=self.n_calibrate, n_sims=1)
    d = {'prior':c['prior_samples'], 'data':pd.concat(c['data']).groupby('t')}
    funs = ['mean','std','min','max']
    c |= {f'{key}_{fun}': np.float32(val.agg(fun)) for fun in funs for key, val in d.items()}
    c |= {f'obs_{fun}'  : c[f'data_{fun}'][...,-self.n_obs:] for fun in funs}
    # c |= {f'prior_{fun}': c[f'prior_{fun}'].values for fun in funs}

    # c |= {f'obs_{fun}_tf'  : self.to_tf(c[f'data_{fun}' ], 3, True ) for fun in funs}
    # c |= {f'prior_{fun}_tf': self.to_tf(c[f'prior_{fun}'], 2, False) for fun in funs}
    return c


In [None]:
self.obs_classes

In [None]:
prior_samples = [self.prior_fun() for _ in range(100)]
self.sir(prior_samples[0])

In [None]:
# type(self.generator(n_priors=2, n_sims=3)['data'][0])
# np.float32(self.generator(n_priors=2, n_sims=3)['data'])
samples = self.pre_amortizer(self.generator(n_priors=2, n_sims=3))
samples = self.pre_amortizer(self.generator(n_priors=2, n_sims=3))
{key:type(val) for key,val in samples.items()}
{key:val.shape for key,val in samples.items() if key != 'data'}

In [None]:
self.prior(10000)

In [None]:
self.prior(10000)

In [None]:
%reload_ext autotime
import os, datetime, pathlib, shutil, google.colab, dataclasses, pickle, wbgapi
import matplotlib.pyplot as plt, seaborn as sns
import numpy as np, pandas as pd, tensorflow as tf
from functools import partial
from scipy import stats
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, LSTM
from bayesflow.networks import InvertibleNetwork, SequentialNetwork
from bayesflow.coupling_networks import CouplingLayer
from bayesflow.simulation import GenerativeModel, Prior, Simulator
from bayesflow.amortizers import AmortizedLikelihood, AmortizedPosterior, AmortizedPosteriorLikelihood
from bayesflow.trainers import Trainer
from bayesflow import default_settings
from bayesflow.helper_functions import build_meta_dict
import bayesflow.diagnostics as diag
from bayesflow.computational_utilities import maximum_mean_discrepancy
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
plt.rcParams.update({"text.usetex": False, "font.family": "serif", "text.latex.preamble": r"\usepackage{{amsmath}}"})
mnt = '/content/drive'
google.colab.drive.mount(mnt)
RNG = np.random.default_rng(42)
# SCALE = 1000
EPS = 1e-6

class MultiConvLayer(tf.keras.Model):
    """Implements an inception-inspired conv layer using different kernel sizes"""

    def __init__(self, n_filters=32, strides=1):
        super(MultiConvLayer, self).__init__()

        self.convs = [
            tf.keras.layers.Conv1D(
                n_filters // 2,
                kernel_size=f,
                strides=strides,
                padding="causal",
                activation="relu",
                kernel_initializer="glorot_uniform",
            )
            for f in range(2, 8)
        ]
        self.dim_red = tf.keras.layers.Conv1D(
            n_filters, 1, 1, activation="relu", kernel_initializer="glorot_uniform"
        )

    def call(self, x):
        """x is a timeseries of dimensions B timestamps, n_features"""

        out = tf.concat([conv(x) for conv in self.convs], axis=-1)
        out = self.dim_red(out)
        return out


class MultiConvNet(tf.keras.Model):
    """Implements an inception-inspired conv layer using different kernel sizes"""

    def __init__(self, n_layers=3, n_filters=64, strides=1):
        super(MultiConvNet, self).__init__()

        self.net = tf.keras.Sequential(
            [MultiConvLayer(n_filters, strides) for _ in range(n_layers)]
        )

        self.lstm = LSTM(n_filters)

    def call(self, x, **args):
        """x is a timeseries of dimensions B timestamps, n_features"""

        out = self.net(x)
        out = self.lstm(out)
        return out


class SummaryNet(tf.keras.Model):
    def __init__(self, n_summary):
        super(SummaryNet, self).__init__()
        self.net_I = MultiConvNet(n_filters=n_summary // 3)
        self.net_R = MultiConvNet(n_filters=n_summary // 3)
        self.net_D = MultiConvNet(n_filters=n_summary // 3)

    @tf.function
    def call(self, x, **args):
        """x is a timeseries of dimensions B timestamps, n_features"""

        x = tf.split(x, 3, axis=-1)
        x_i = self.net_I(x[0])
        x_r = self.net_R(x[1])
        x_d = self.net_D(x[2])
        return tf.concat([x_i, x_r, x_d], axis=-1)

class MemoryNetwork(tf.keras.Model):
    def __init__(self, meta):
        super(MemoryNetwork, self).__init__()

        self.gru = GRU(meta["n_hidden"], return_sequences=True, return_state=True)
        self.h = meta["n_hidden"]
        self.n_params = meta["n_params"]

    @tf.function
    def call(self, target, condition):
        """Performs a forward pass through the network.

        Params:
        -------
        target    : tf.Tesnor of shape (batch_size, time_stes, dim)
            The time-dependent signal to process.
        condition : tf.Tensor of shape (batch_size, cond_dim)
            The conditional (static) variables, e.g., parameters.
        """
        shift_target = target[:, :-1, :]
        init = tf.zeros((target.shape[0], 1, target.shape[2]))
        inp_teacher = tf.concat([init, shift_target], axis=1)
        inp_teacher_c = tf.concat([inp_teacher, condition], axis=-1)
        out, _ = self.gru(inp_teacher_c)
        return out

    def step_loop(self, target, condition, state):
        out, new_state = self.gru(
            tf.concat([target, condition], axis=-1), initial_state=state
        )
        return out, new_state

class InvertibleNetworkWithMemory(tf.keras.Model):
    """Implements a chain of conditional invertible blocks for Bayesian parameter inference."""

    def __init__(
        self,
        num_params,
        num_coupling_layers=4,
        coupling_settings=None,
        coupling_design="affine",
        permutation="fixed",
        use_act_norm=True,
        act_norm_init=None,
        use_soft_flow=False,
        soft_flow_bounds=(1e-3, 5e-2),
    ):
        """Initializes a custom invertible network with recurrent memory."""

        super().__init__()

        # Create settings dict for coupling layer
        settings = dict(
            latent_dim=num_params,
            coupling_settings=coupling_settings,
            coupling_design=coupling_design,
            permutation=permutation,
            use_act_norm=use_act_norm,
            act_norm_init=act_norm_init,
        )

        # Create sequence of coupling layers and store reference to dimensionality
        self.coupling_layers = [
            CouplingLayer(**settings) for _ in range(num_coupling_layers)
        ]

        # Store attributes
        self.soft_flow = use_soft_flow
        self.soft_low = soft_flow_bounds[0]
        self.soft_high = soft_flow_bounds[1]
        self.use_act_norm = use_act_norm
        self.latent_dim = num_params
        self.dynamic_summary_net = MemoryNetwork({"n_hidden": 256, "n_params": 3})
        self.latent_dim = num_params

    def call(self, targets, condition, inverse=False):
        """Performs one pass through an invertible chain (either inverse or forward).

        Parameters
        ----------
        targets   : tf.Tensor
            The estimation quantities of interest, shape (batch_size, ...)
        condition : tf.Tensor
            The conditional data x, shape (batch_size, summary_dim)
        inverse   : bool, default: False
            Flag indicating whether to run the chain forward or backwards

        Returns
        -------
        (z, log_det_J)  :  tuple(tf.Tensor, tf.Tensor)
            If inverse=False: The transformed input and the corresponding Jacobian of the transformation,
            v shape: (batch_size, ...), log_det_J shape: (batch_size, ...)

        target          :  tf.Tensor
            If inverse=True: The transformed out, shape (batch_size, ...)

        Important
        ---------
        If ``inverse=False``, the return is ``(z, log_det_J)``.\n
        If ``inverse=True``, the return is ``target``.
        """

        if inverse:
            return self.inverse(targets, condition)
        return self.forward(targets, condition)

    @tf.function
    def forward(self, targets, condition, **kwargs):
        """Performs a forward pass though the chain."""

        # Add memory condition
        memory = self.dynamic_summary_net(targets, condition)
        condition = tf.concat([memory, condition], axis=-1)

        z = targets
        log_det_Js = []
        for layer in self.coupling_layers:
            z, log_det_J = layer(z, condition, **kwargs)
            log_det_Js.append(log_det_J)
        # Sum Jacobian determinants for all layers (coupling blocks) to obtain total Jacobian.
        log_det_J = tf.add_n(log_det_Js)
        return z, log_det_J

    @tf.function
    def inverse(self, z, condition, **kwargs):
        """Performs a reverse pass through the chain."""

        target = z
        T = z.shape[1]
        gru_inp = tf.zeros((z.shape[0], 1, z.shape[-1]))
        state = tf.zeros((z.shape[0], self.dynamic_summary_net.h))
        outs = []
        for t in range(T):
            # One step condition
            memory, state = self.dynamic_summary_net.step_loop(
                gru_inp, condition[:, t : t + 1, :], state
            )
            condition_t = tf.concat([memory, condition[:, t : t + 1, :]], axis=-1)
            target_t = target[:, t : t + 1, :]
            for layer in reversed(self.coupling_layers):
                target_t = layer(target_t, condition_t, inverse=True, **kwargs)
            outs.append(target_t)
            gru_inp = target_t
        return tf.concat(outs, axis=1)


@dataclasses.dataclass
class COVID():
    country: str = 'Germany'
    name: str = 'covid_000'
    n_steps: int = 100
    n_calibrate: int = 5000
    refresh: bool = False

    def load_data(self, start=39):
        """Download and prepare data from Johns Hopkins"""
        def fetch(cl):
            return (
                pd.read_csv(f'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_{cl}_global.csv', sep=",")
                .drop(columns=['Province/State','Lat','Long'])
                .groupby('Country/Region').sum()
                .loc[self.country]
            )
        data = (
            pd.DataFrame(
            {'dI_real':fetch('confirmed'), 'dR_real':fetch('recovered'), 'dD_real':fetch('deaths')})
            .assign(t = lambda x: pd.to_datetime(x.index)).set_index('t')
            .diff().dropna().clip(0, self.tot_pop).astype(int).iloc[start:start+self.n_steps]
            )
        return data

    def read_or_create(self, file, fun=None, refresh=False):
        """Read or create pickle for simulation results"""
        try:
            assert not refresh
            with open(file, "rb") as f:
                sims = pickle.load(f)
            print(f'{file} successfully read')
        except Exception as e:
            print(f'Running sims due to error: {e}')
            with open(file, "wb") as f:
                sims = fun()
                pickle.dump(sims, f)
        return sims


    def check_params(self, p):
        for key in ['E_0','sim_diff','t_1','t_2','t_3','t_4','t_5','delta_1','delta_2','delta_3','delta_4','lag_I','lag_R','lag_D']:
            p[key] = int(round(p[key]))
        p['E_0'] = max(p['E_0'], 1)
        if all([
            all(val > EPS for key, val in p.items() if key[:3] != 'phi'),
            p['alpha'] < 1 - EPS,
            p['delta'] < 1 - EPS,
            p['sim_diff'] > max(p['lag_I'],p['lag_R'],p['lag_D']),
            *[p[f't_{i}'] + p[f'delta_{i}'] <= p[f't_{i+1}'] for i in range(1,5)],
        ]):
            return p

    def prior_fun(self, batch_size=1):
        alpha_f = (0.7**2) * ((1 - 0.7) / (0.17**2) - (1 - 0.7))
        beta_f = alpha_f * (1 / 0.7 - 1)
        L = []
        while len(L) < batch_size:
            p = self.check_params({
                'N'       :self.tot_pop,
                'E_0'     :RNG.gamma(shape=2, scale=30),
                'alpha'   :RNG.uniform(low=0.005, high=0.99),
                'beta'    :RNG.lognormal(mean=np.log(0.25), sigma=0.3),
                'gamma'   :RNG.lognormal(mean=np.log(1/6.5), sigma=0.5),
                'delta'   :RNG.uniform(low=0.01, high=0.3),
                'epsilon' :RNG.uniform(low=1/14, high=1/3),
                'eta'     :RNG.lognormal(mean=np.log(1/3.2), sigma=0.5),
                'lambda'  :RNG.lognormal(mean=np.log(1.2), sigma=0.5),
                'mu'      :RNG.lognormal(mean=np.log(1/8), sigma=0.2),
                'theta'   :RNG.uniform(low=1/14, high=1/3),
                'sim_diff':16,
                't_1'     :RNG.normal(loc=8, scale=3),
                't_2'     :RNG.normal(loc=15, scale=3),
                't_3'     :RNG.normal(loc=22, scale=3),
                't_4'     :RNG.normal(loc=66, scale=3),
                't_5'     :self.n_steps,
                'delta_1' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_2' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_3' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_4' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'lambda_0':RNG.lognormal(mean=np.log(1.20), sigma=0.5),
                'lambda_1':RNG.lognormal(mean=np.log(0.60), sigma=0.5),
                'lambda_2':RNG.lognormal(mean=np.log(0.30), sigma=0.5),
                'lambda_3':RNG.lognormal(mean=np.log(0.10), sigma=0.5),
                # 'lambda_4':RNG.lognormal(mean=np.log(0.10), sigma=0.5),
                'lambda_4':RNG.lognormal(mean=np.log(0.15), sigma=0.5),
                'f_I'     :RNG.beta(a=alpha_f, b=beta_f),
                'f_R'     :RNG.beta(a=alpha_f, b=beta_f),
                'f_D'     :RNG.beta(a=alpha_f, b=beta_f),
                'phi_I'   :RNG.vonmises(mu=0, kappa=0.01),
                'phi_R'   :RNG.vonmises(mu=0, kappa=0.01),
                'phi_D'   :RNG.vonmises(mu=0, kappa=0.01),
                'lag_I'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'lag_R'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'lag_D'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'sigma_I' :RNG.gamma(shape=1, scale=5),
                'sigma_R' :RNG.gamma(shape=1, scale=5),
                'sigma_D' :RNG.gamma(shape=1, scale=5),
            })
            if p:
                L.append(p)
        return pd.DataFrame(L).rename_axis('prior_idx')
        # return L)

    def calc_lambda_array(self, p):
        """Computes the array of time-varying contact rates/transimission probabilities."""
        # Array of initial lambdas
        lambd0_arr = np.array([p['lambda_0']] * (p['t_1'] + p['sim_diff'] - 1))

        # Compute lambd1 array
        if p['delta_1'] == 1:
            lambd1_arr = np.array([p['lambda_1']] * (p['t_2'] - p['t_1']))
        else:
            lambd1_arr = np.linspace(p['lambda_0'], p['lambda_1'], p['delta_1'])
            lambd1_arr = np.append(lambd1_arr, [p['lambda_1']] * (p['t_2'] - p['t_1'] - p['delta_1']))

        # Compute lambd2 array
        if p['delta_2'] == 1:
            lambd2_arr = np.array([p['lambda_2']] * (p['t_3'] - p['t_2']))
        else:
            lambd2_arr = np.linspace(p['lambda_1'], p['lambda_2'], p['delta_2'])
            lambd2_arr = np.append(lambd2_arr, [p['lambda_2']] * (p['t_3'] - p['t_2'] - p['delta_2']))

        # Compute lambd3 array
        if p['delta_3'] == 1:
            lambd3_arr = np.array([p['lambda_3']] * (p['t_4'] - p['t_3']))
        else:
            lambd3_arr = np.linspace(p['lambda_3'], p['lambda_4'], p['delta_3'])
            lambd3_arr = np.append(lambd3_arr, [p['lambda_3']] * (p['t_4'] - p['t_3'] - p['delta_3']))

        # Compute lambd4 array
        if p['delta_4'] == 1:
            lambd4_arr = np.array([p['lambda_4']] * (p['t_5'] - p['t_4']))
        else:
            lambd4_arr = np.linspace(p['lambda_3'], p['lambda_4'], p['delta_4'])
            lambd4_arr = np.append(lambd4_arr, [p['lambda_4']] * (p['t_5'] - p['t_4'] - p['delta_4']))

        return np.r_[lambd0_arr, lambd1_arr, lambd2_arr, lambd3_arr, lambd4_arr]


    def sir(self, prior_draw):
        try:
            p = self.const_params | prior_draw  # if prior_values was passed as dict
        except:
            p = self.const_params | dict(zip(self.param_names, prior_draw))  # if prior_values was passed as list
        assert self.check_params(p)
        sim_lag = p['sim_diff'] - 1
        lambd_arr = self.calc_lambda_array(p)

        # Initial conditions
        S, E, C, I, R, D = [self.tot_pop - p['E_0']], [p['E_0']], [0], [0], [0], [0]

        # Containers
        I_news = []
        R_news = []
        D_news = []

        # Reported new cases
        I_data = np.zeros(p['t_5'])
        R_data = np.zeros(p['t_5'])
        D_data = np.zeros(p['t_5'])
        fs_I = np.zeros(p['t_5'])
        fs_R = np.zeros(p['t_5'])
        fs_D = np.zeros(p['t_5'])

        # Simulate T-1 tiemsteps
        for t in range(p['t_5'] + sim_lag):
            # Calculate new exposed cases
            E_new = lambd_arr[t] * ((C[t] + p['beta'] * I[t]) / self.tot_pop) * S[t]

            # Remove exposed from susceptible
            S_t = S[t] - E_new

            # Calculate current exposed by adding new exposed and
            # subtracting the exposed becoming carriers.
            E_t = E[t] + E_new - p['gamma'] * E[t]

            # Calculate current carriers by adding the new exposed and subtracting
            # those who will develop symptoms and become detected and those who
            # will go through the disease asymptomatically.
            C_t = C[t] + p['gamma'] * E[t] - (1 - p['alpha']) * p['eta'] * C[t] - p['alpha'] * p['theta'] * C[t]

            # Calculate current infected by adding the symptomatic carriers and
            # subtracting the dead and recovered. The newly infected are just the
            # carriers who get detected.
            I_t = I[t] + (1 - p['alpha']) * p['eta'] * C[t] - (1 - p['delta']) * p['mu'] * I[t] - p['delta'] * p['epsilon'] * I[t]
            I_new = (1 - p['alpha']) * p['eta'] * C[t]

            # Calculate current recovered by adding the symptomatic and asymptomatic
            # recovered. The newly recovered are only the detected recovered
            R_t = R[t] + p['alpha'] * p['theta'] * C[t] + (1 - p['delta']) * p['mu'] * I[t]
            R_new = (1 - p['delta']) * p['mu'] * I[t]

            # Calculate the current dead
            D_t = D[t] + p['delta'] * p['epsilon'] * I[t]
            D_new = p['delta'] * p['epsilon'] * I[t]

            # Ensure some numerical onstraints
            S_t = np.clip(S_t, 0, self.tot_pop)
            E_t = np.clip(E_t, 0, self.tot_pop)
            C_t = np.clip(C_t, 0, self.tot_pop)
            I_t = np.clip(I_t, 0, self.tot_pop)
            R_t = np.clip(R_t, 0, self.tot_pop)
            D_t = np.clip(D_t, 0, self.tot_pop)

            # Keep track of process over time
            S.append(S_t)
            E.append(E_t)
            C.append(C_t)
            I.append(I_t)
            R.append(R_t)
            D.append(D_t)
            I_news.append(I_new)
            R_news.append(R_new)
            D_news.append(D_new)

            # From here, start adding new cases with delay D
            # Note, we assume the same delay
            if t >= sim_lag:
                # Compute lags and add to data arrays
                fs_I[t - sim_lag] = (1 - p['f_I']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * p['phi_I']))
                )
                fs_R[t - sim_lag] = (1 - p['f_R']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * p['phi_R']))
                )
                fs_D[t - sim_lag] = (1 - p['f_D']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * p['phi_D']))
                )
                I_data[t - sim_lag] = I_news[t - p['lag_I']]
                R_data[t - sim_lag] = R_news[t - p['lag_R']]
                D_data[t - sim_lag] = D_news[t - p['lag_D']]

        # Compute weekly modulation
        I_data = (1 - fs_I) * I_data
        R_data = (1 - fs_R) * R_data
        D_data = (1 - fs_D) * D_data

        # Add noise
        I_data = I_data + RNG.standard_t(4) * np.sqrt(I_data) * p['sigma_I']
        R_data = R_data + RNG.standard_t(4) * np.sqrt(R_data) * p['sigma_R']
        D_data = D_data + RNG.standard_t(4) * np.sqrt(D_data) * p['sigma_D']
        n = I_data.shape[0]
        return (
            pd.DataFrame({'S':S[-n:],'E':E[-n:],'C':C[-n:],'I':I[-n:],'R':R[-n:],'D':D[-n:],'dI_obs':I_data,'dR_obs':R_data,'dD_obs':D_data})
            .clip(0, self.tot_pop).rename_axis('t'))


    # def generate_data(self, n_priors=1, n_sims=1):
    #     prior = self.prior(n_priors).reset_index()
    #     prior = pd.concat([prior.assign(sim_idx=j).set_index(['prior_idx','sim_idx']) for j in range(n_sims)])
    #     data = [self.sir(p).reset_index().assign(prior_idx=i, sim_idx=j).set_index(['prior_idx','sim_idx','t']) for (i,j), p in prior.iterrows()]
    #     return dict(prior=prior, data=data)
    #     # return self.prep_data(prior, data)

    # def get_real_data(self):
    #     prior = pd.DataFrame([np.full(len(self.param_names), np.nan)], columns=self.param_names)
    #     data = [self.load_data().reset_index().assign(sim_idx=0, prior_idx=0).set_index(['sim_idx','prior_idx','t'])]
    #     return dict(prior=prior, data=data)
    #     # return self.prep_data(prior, data)

    # # def prep_data(self, prior, data):
    # #     return dict(prior=prior, prior_draws=self.to_tf(prior, 2, False), data=pd.concat(data), obs=self.to_tf(data, 3, True))

    # def pre_amortizer(self, dct):
    #     # dct['prior_samples'] = np.float32(dct['prior'])
    #     # dct['ensemble'] = pd.concat(dct['data'])
    #     # dct['data'] = np.float32(dct['data'])

    #     dct['summary_conditions'] = (np.float32(dct['data'])[...,-self.n_obs:] - self.calibration['obs_mean']) / self.calibration['obs_std']
    #     dct['parameters'] = (np.float32(dct['prior']) - self.calibration['prior_mean']) / self.calibration['prior_std']

    #     # dct['summary_conditions'] = (self.to_tf(dct['data'], 3, True) - self.calibration['obs_mean_tf']) / self.calibration['obs_std_tf']
    #     # dct['parameters'] = (self.to_tf(dct['prior'], 2, False) - self.calibration['prior_mean_tf']) / self.calibration['prior_std_tf']


    #     return dct

    # def post_amortizer(self, dct):
    #     dct['posterior_samples'] = np.float32(dct['parameters_out']) * self.calibration['prior_std'] + self.calibration['prior_mean']
    #     return dct

    # def generate_samples(self, n_priors=2, n_sims=None, n_samples=1):
    #     n_sims = n_sims if n_sims else 20*n_priors
    #     samples = self.generate_data(n_priors=n_priors, n_sims=n_sims)
    #     samples = self.pre_amortizer(samples)
    #     samples['parameters_out'] = self.model.amortizer.sample(samples, n_samples=n_samples)
    #     samples = self.post_amortizer(samples)
    #     return samples
    #     # return self.post_amortizer(dct)


    def __post_init__(self):
        self.root_path = pathlib.Path(mnt + f'/MyDrive/bayesian disease modeling/{self.name}')
        if self.refresh:
            # delete root_path and everything in it
            shutil.rmtree(self.root_path, ignore_errors=True)
        self.model_path = self.root_path / f'model/'
        self.model_path.mkdir(exist_ok=True, parents=True)
        self.file = {key: self.model_path / f'{key}.pkl' for key in ['calibration','diagnostic','predictive','error']}

        self.iso = wbgapi.economy.coder(self.country)
        assert self.iso, f'Unrecognized country {self.country}'
        self.tot_pop = next(wbgapi.data.fetch('SP.POP.TOTL', economy=self.iso, time=2020))['value']
        self.path = self.root_path / self.iso
        self.path.mkdir(exist_ok=True)

        prior_samples = self.prior_fun(10)
        self.const_params = prior_samples.agg(['mean','std']).T.query('std < 1e-5')['mean'].to_dict()
        self.param_names = [key for key in prior_samples.columns if key not in self.const_params]
        self.param_latex = [f'${key}$' if key in ['E_0','lag_I','lag_R','lag_D'] else f'$\{key}$' for key in self.const_params]
        self.n_params = len(self.param_names)
        self.obs_classes = self.sir(self.prior().iloc[0]).filter(like='obs').columns.tolist()

        # self.prior = lambda batch_size=1: self.prior_fun(batch_size)[self.param_names]
        self.prior = Prior(batch_prior_fun=lambda batch_size=1: self.prior_fun(batch_size)[self.param_names], param_names=self.param_names)
        self.simulator = Simulator(simulator_fun=self.sir)
        self.generator = GenerativeModel(self.prior, self.simulator)

    #     self.n_obs = len(self.obs_classes)
    #     self.calibration = self.read_or_create(file=self.file['calibration'], fun=self.calibrate)

    #     coupling_settings = {
    #         "dense_args": dict(units=128, activation="swish", kernel_regularizer=None),
    #         "num_dense": 2,
    #         "dropout": False,
    #     }
    #     summary_net = SummaryNet(n_summary=192)
    #     # summary_net = SequentialNetwork()
    #     inference_net = InvertibleNetwork(
    #         num_params=len(self.param_names),
    #         num_coupling_layers=6,
    #         coupling_settings=coupling_settings,
    #     )
    #     self.model = Trainer(
    #         # generative_model = lambda n: self.generate_data(n_prior=n),
    #         generative_model = partial(self.generate_data, n_sims=1),
    #         # configurator = self.configurator,
    #         configurator = self.pre_amortizer,
    #         amortizer = AmortizedPosterior(summary_net=summary_net, inference_net=inference_net),# summary_loss_fun="MMD"),
    #         checkpoint_path = self.model_path, max_to_keep = 3, #memory = True, memory is broken for now
    #         # skip_checks = True,
    #     )

    # def plot_diagnostics(self, n_samples=100, refresh=False):
    #     self.diagnostic = self.read_or_create(
    #         file=self.file['diagnostic'], refresh=refresh,
    #         fun=lambda: self.sample(n_samples=n_samples, n_sims=20*n_samples, ensemble=False))
    #     self.plot('loss')
    #     self.plot('ecdf', self.diagnostic)
    #     self.plot('hist', self.diagnostic)
    #     self.plot('recovery', self.diagnostic)


    # def plot(self, kind='loss', samples=None):
    #     if kind == 'loss':
    #         fig = diag.plot_losses(**self.model.loss_history.get_plottable())
    #     else:
    #         opts = {'param_names':self.param_names, 'post_samples':samples['posterior_samples'], 'prior_samples':samples['prior_samples']}
    #         if kind == 'ecdf':
    #             fig = diag.plot_sbc_ecdf(**opts)
    #         elif kind == 'hist':
    #             fig = diag.plot_sbc_histograms(**opts)
    #         elif kind == 'recovery':
    #             fig = diag.plot_recovery(**opts)
    #         else:
    #             raise Exception(f'Unrecognized kind "{kind}"')
    #     fig.savefig(self.path / f'{kind}.png')
    #     plt.show()

    # def calibrate(self):
    #     c = self.generate_data(n_priors=self.n_calibrate, n_sims=1)
    #     d = {'prior':c['prior'], 'data':pd.concat(c['data']).groupby('t')}
    #     funs = ['mean','std','min','max']
    #     c |= {f'{key}_{fun}': np.float32(val.agg(fun)) for fun in funs for key, val in d.items()}
    #     c |= {f'obs_{fun}'  : c[f'data_{fun}'][...,-self.n_obs:] for fun in funs}
    #     # c |= {f'prior_{fun}': c[f'prior_{fun}'].values for fun in funs}

    #     # c |= {f'obs_{fun}_tf'  : self.to_tf(c[f'data_{fun}' ], 3, True ) for fun in funs}
    #     # c |= {f'prior_{fun}_tf': self.to_tf(c[f'prior_{fun}'], 2, False) for fun in funs}
    #     return c

self = COVID(
    name='radev_model_03',
    refresh=True,
    n_calibrate = 50,
)
# h = self.model.train_online(epochs=1000, iterations_per_epoch=32*1, batch_size=32, validation_sims=500)

In [None]:
prior_samples = self.prior_fun(100)

prior_samples.iloc[0]
self.sir(prior_samples.iloc[0])
# self.sir(self.prior(1))

In [None]:
sam = self.generator(10)
sam.keys()

In [None]:
samples = self.model.generative_model(2)
#  = partial(self.generate_data, n_sims=1),
#             # configurator = self.configurator,
#             configurator = self.pre_amortizer,

In [None]:
samples = self.model.configurator(samples)

In [None]:
# samples.keys()
self.model.amortizer.sample(input_dict=samples, n_samples=3)

In [None]:
{key: val.shape for key, val in samples.items()}

In [None]:
samples = self.model.configurator(self.generate_data(n_priors=300, n_sims=1))
{key: val.shape for key, val in samples.items()}

In [None]:
samples['posterior_samples'] = self.model.amortizer.sample(samples, n_samples=100)

In [None]:
{key: val.shape for key, val in samples.items()}

In [None]:
samples = self.generate_samples(n_priors=300, n_sims=2, n_samples=500)

In [None]:
samples['prior_samples'].shape, samples['posterior_samples'].shape

In [None]:
# self.plot('loss')
self.plot('recovery', samples)

In [None]:
opts = {'param_names':self.param_names, 'post_samples':samples['posterior_draws'], 'prior_samples':samples['prior_samples'].to_numpy()}
fig = diag.plot_recovery(**opts)

In [None]:
{key:type(val) for key, val in samples.items()}

samples['posterior_draws'].shape, samples['prior'].shape

In [None]:
samples = self.generate_sample(n_prior=2)

In [None]:
# self.plot('loss')
self.plot('recovery', samples)
# # samples.keys()
# opts = {'param_names':self.param_names, 'post_samples':samples['posterior_draws'], 'prior_samples':samples['prior_draws']}
# fig = diag.plot_sbc_ecdf(**opts)

In [None]:
n_prior = 2
n_sims = 3
prior = self.prior(n_prior).reset_index()
prior = pd.concat([prior.assign(sim_idx=j).set_index(['prior_idx','sim_idx']) for j in range(n_sims)])
data = [self.sir(p).reset_index().assign(prior_idx=i, sim_idx=j).set_index(['prior_idx','sim_idx','t']) for (i,j), p in prior.iterrows()]
# # # data = [self.sir(p).assign(sim_idx=i, prior_idx=j).set_index(['prior_idx','sim_idx'], append=True) for (i,j), p in prior.iterrows()]
# # # data = [self.sir(p) for j, p in prior.iterrows()]
# len(data)
# data[1]

In [None]:
# tf.convert_to_tensor()
# tf.convert_to_tensor(self.calibration['prior_mean'])
c = self.generate_data(n_prior=50)
c['prior'].agg('mean')

In [None]:
self.calibration['prior_draws'].shape

In [None]:
self.plot(kind='loss')

In [None]:
# np.full((
pd.DataFrame(np.full((1,len(self.param_names)), np.nan), columns=self.param_names)
pd.DataFrame([np.full(len(self.param_names), np.nan)], columns=self.param_names)

In [None]:
dct['parameters_out'].shape
dct['obs'].shape
# # dct['summary_conditions'].shape
dct['parameters_out'].shape

# dct['prior'].shape
dct['posterior_draws'] = dct['parameters_out'] * self.calibration['prior_std'] + self.calibration['prior_mean']
dct['posterior_draws']
dct['prior']

In [None]:
# dct['prior_draws'].shape
# dct['parameters'].shape
# d=2
dct = self.generate_data(n_prior_draws=d, n_sims=20*d)
dct = self.model.configurator(dct)
self.model.amortizer.sample(dct, 3)
# self.model(dct, 3)

In [None]:
self.sample(dct=dct, n_samples=3)

In [None]:
# self.calibration.keys()
# self.n_obs
np.repeat(dct['prior_draws'], 2, axis=0).shape

In [None]:
self.generate_obs(n_sims=2, n_prior_draws=5)['obs'].shape
# dct = self.get_real_data()
# dct['parameters'] = (
# dct['obs'].shape
# self.n_obs
# dct['prior_draws'] - pad(self.calibration['prior_mean'],2)
# ) / self.calibration['prior_std'
# dct = self.pre_amortizer(dct)
# dct['parameters'].shape, dct['obs'].shape
# P['data_df']
# P['data']#.shape
# P['prior_draws'].shape


In [None]:
[1]*2

In [None]:
P = self.prior(3)
w = [p for i,p in P.iterrows() for j in range(2)]
w = [p for i,p in P.iterrows() for j in range(2)]
len(w)

In [None]:
P = self.generator(n_sims=2, n_prior_draws=3)['prior_draws']
len(P)
def g(X):
    # Y = pd.DataFrame()
    # Y['a'] = X['alpha']# + X['beta']
    # Y['b'] = X['alpha']# + X['beta']
    # display(Y)
    # return Y
    return pd.Series({'a':[17,17],'n':[19,21]})#X.assign(new=17)
# ?P[['new','newner']] =
P.apply(g, axis=1)
# P

In [None]:
[k for k, p in self.enumerate(generator(n_sims=2, n_prior_draws=3)['prior_draws'])]

In [None]:
p = pad(self.prior(4),3)
# a = [p for _ in range(3)]
# a = [self.prior(4).values]*3
# display(a[0])
# display(a[1])
np.repeat(p,3, axis=0).shape

In [None]:
A = self.model.configurator(self.model.generative_model(5))
A['summary_conditions'].shape
A['data'][0][1].shape

# A = self.pre_amortizer(self.model_generative_model(n_prior_draws=3, n_sims=2))
# A['summary_conditions'].shape, A['parameters'].shape

In [None]:
self.calibration['sim_std']

In [None]:
# len(self.generator(n_prior_draws=10)['data'][0])
# display(self.generator(n_prior_draws=10)['data'][0][0])
A = self.generator(n_sims=2, n_prior_draws=3)
dfs = dict()
dfs['prior_draws'] = A['prior_draws'][0]
dfs['data'] = pd.concat([pd.concat([C.reset_index().assign(sim_idx=j, prior_idx=i).set_index(['sim_idx','prior_idx','t']) for i, C in enumerate(B)]) for j, B in enumerate(A['data'])])
dfs['data']

# D
# display(pd.concat(D[0]))

In [None]:
dct = self.generator(n_sims=2, n_prior_draws=1)
# dct = self.get_real_data()
dct = self.pre_amortizer(dct)
dct['summary_conditions'].shape
dct['parameters'].shape
self.model.amortizer(dct, n_samples=7)

In [None]:
n_prior_draws=2
n_sims = 3
p = self.prior(n_prior_draws)
prior_draws = [p for j in range(n_sims)]
len(prior_draws[0])
prior_draws[0].shape
# [[self.sir(p) for i,p in prior_draws.iterrows()] for j in range(n_sims)]
[[self.sir(p) for i,p in P.iterrows()] for P in prior_draws]

In [None]:
n_sims = 3
[[self.sir(p) for i,p in prior_draws.iterrows()] for j in range(n_sims)]

In [None]:
self.generator()['prior_draws'][0]

In [None]:
dct = self.get_real_data()
# pad(dct['data'], 4)[...,-self.n_obs:].shape
# pad(self.calibration['sim_mean'], 4)[...,-self.n_obs:].shape
# # self.n_obs

In [None]:
# self.prior_fun()
# self.sir(self.prior_fun())
# len(self.generator(n_prior_draws=3,n_sims=2)['data'])#[0][0]

In [None]:
# pad(dct['data'], 3)[...,-self.n_obs]
dct = self.generator(n_sims=2, n_prior_draws=3)
# dct = dict(data =
# self.real_data = self.load_data()
# dict(prior_draws=np.full([1, self.n_params], np.nan), data=[self.real_data])#, obs_data=obs_data)

n_prior_draws = 10
prior_draws = self.prior(n_prior_draws)
# data = [self.sir(p) for p in prior_draws]
dct = dict(prior_draws=prior_draws, data=[self.sir(p) for i,p in prior_draws.iterrows()])

x   = pad(dct['data'], 3)[...,-self.n_obs:]
mu  = pad(self.calibration['sim_mean'], 3)[...,-self.n_obs:]
std = pad(self.calibration['sim_std' ], 3)[...,-self.n_obs:]
(x - mu)/ std

In [None]:
self.pre_amortizer(self.generator(n_sims=2, n_prior_draws=3))
self.pre_amortizer(self.generator(n_sims=2, n_prior_draws=3))

In [None]:
self.calibration['sim_std'].shape

In [None]:
self.calibration['prior_mean']

In [None]:
c = self.generator()['data'][0].columns.tolist()
c
# c.
# c.filter(like='obs').shape[1]

In [None]:
prior_draws = self.prior(3)
[x for i, x in prior_draws.iterrows()]

In [None]:
# self.prior(5)
n_prior_draws=3
n_sims = 2
# self.prior_draws = [[param.assign(prior_idx=i, sim_idx=j).set_index(['prior_idx','sim_idx']) for j in range(n_sims)] for i, param in self.prior(n_prior_draws).iterrows()]
# self.prior_draws = pd.concat([[param for j in range(n_sims)] for i, param in self.prior(n_prior_draws).iterrows()])

# [prior_idx, p in self.prior(n_prior_draws).to_dict('index').items()
# a = [[(sim_idx, prior_idx, p) for prior_idx, p in self.prior(n_prior_draws).iterrows()] for sim_idx in range(n_sims)]
a = [[(sim_idx, prior_idx, p) for sim_idx in range(n_sims)] for prior_idx, p in self.prior(n_prior_draws).iterrows()]
a = [[p.to_frame().assign(sim_idx=sim_idx, prior_idx=prior_idx) for sim_idx in range(n_sims)] for prior_idx, p in self.prior(n_prior_draws).iterrows()]
a = [[p for sim_idx in range(n_sims)] for prior_idx, p in self.prior(n_prior_draws).iterrows()]
a[0]
# pd.concat(a[0])

# prior_draws = [p.assign(sim_idx=sim_idx).set_index(['sim_idx','prior_idx']) for sim_idx in range(n_sims)]
# # prior = pd.concat(prior_draws)
# # prior_draws = np.array(prior_draws)

# a = [[p for prior_idx, p in enumerate(sim)] for sim_idx, sim in enumerate(prior_draws)]
# a[0]
# data = [[self.sir(p).assign(sim_idx=sim_idx).set_index(['sim_idx','prior_idx']) for prior_idx, p in enumerate(sim)] for sim_idx, sim in enumerate(prior_draws)]
# ensemble = pd.concat([pd.concat(x) for x in data])
# data = np.array(data)
# prior_draws.shape, data.shape
# ensemble

# [self.sir(p).reset_index().assign(sim_idx=i,prior_idx=j).set_index(['sim_idx','prior_idx','t']) for (i,j),p in parameters.iterrows()]

# prior_draws = pd.concat([params.assign(sim_idx=i).set_index('sim_idx',append=True) for i in range(n_sims)])
# prior_draws = [params.assign(sim_idx=i).set_index('sim_idx',append=True) for i in range(n_sims)]
# np.array(prior_draws).shape
# self.prior_draws = [[param.assign( for j in range(n_sims)] for i, param in self.prior(n_prior_draws)]

# len(self.prior_draws[0])
# self.prior_draws

In [None]:
# v = self.model.generative_model(5)
v = self.generator(n_sims=2, n_prior_draws=3)
# c = self.configurator(v)
# n_posterior_draws=7
# s = self.model.amortizer.sample(c, n_samples=n_posterior_draws)
# # c.keys()
# s.shape
v['ensemble']
# v['parameters']

In [None]:
pd.concat([pd.DataFrame([params], columns=self.param_names).assign( for param_idx, sim in enumerate(s) for j, draw in enumerate(sim)])

In [None]:
self.model.loss_history.

In [None]:
self.model.train_online??

In [None]:
self.calibration['prior_max']

In [None]:
self.model.loss_history.get_plottable()

In [None]:
# fig = diag.plot_losses(train_losses=self.model.loss_history.get_plottable()['train_losses'])
fig = diag.plot_losses(**self.model.loss_history.get_plottable())

In [None]:
c = self.configurator(self.generator(n_prior_draws=5))
c.keys()

In [None]:
def g(x):
    if x%2==0:
        return x

[g(x) for x in range(10)]

In [None]:
forward_dict = {'data':self.real_data}
c = self.configurator(forward_dict)
print(c.keys(), c['parameters'].shape, c['summary_conditions'].shape)
forward_dict = self.generator(n_sims=1, n_prior_draws=1)
c = self.configurator(forward_dict)
print(c.keys(), c['parameters'].shape, c['summary_conditions'].shape)


In [None]:
self.real_data#.describe()
# R

In [None]:
self.calibration['sim_mean']

In [None]:
g = self.generator(n_sims=2, n_prior_draws=3)
g['data'].shape
g['ensemble']
g['parameters']
# E = pd.concat(D)
# E.groupby(['sim_idx','param_idx'])
# D[0]
# len(D)
# pd.concat()

In [None]:
self.model

In [None]:
self.model.loss_history.get_plottable()

In [None]:
g = lambda n: self.generator(n_prior_draws=n)
w = self.configurator(g(10))
w['summary_conditions'].shape, w['parameters'].shape

In [None]:
sel

In [None]:
# def to_df(pd.concat([fun(val).assign(sim=i, sample=j).set_index(['sim', 'sample'], append=append) for i, sim in enumerate(arr) for j, val in enumerate(sim)])

c = self.generator(n_prior_draws=3, n_sims=2)

c['parameters']
c['data'][1]
# self.configurator(c)
# d = (pad(c['data'],4)[...,-self.n_obs:] - pad(self.calibration['sim_mean'],3)[...,-self.n_obs:]) / pad(self.calibration['sim_std'],3)[...,-self.n_obs:]
# (pad(c['parameters'],3) - pad(self.calibration['prior_mean'],1)) / pad(self.calibration['prior_std'],1)
# self.calibration['prior_mean'].shape
# d.shape
# (pad(forward_dict['data'],4)[...,-self.n_obs] - self.calibration['sim_mean'].iloc[:,-self.n_obs]) / self.calibration['sim_std'].iloc[:,-self.n_obs]

In [None]:
w = self.prior(5)
A = pd.concat([w]*2)
for prior_idx, p in A.iterrows():
    print(prior_idx)

In [None]:
m = self.calibration['sim_mean'].columns.str.contains('obs')
S[...,m].shape

In [None]:
self.calibration['sim_mean']

In [None]:
# self.calibration['sim_mean']
params = self.prior(2)
# p
c = self.generator(params=params, n_sims=2)
# c = self.generator(n_prior_draws=10, n_sims=2)
# # c['sim_data'][0]
c['sim_data'][0][0].filter(like='obs')
# [pd.concat(x).assign(prior_idx=i) for i, x in enumerate(c['sim_data'])][0]

#  .to_dict('records'))]
# [x for i, x in c['prior_draws'].to_dict('index').items()]
# [pd.concat(x).assign(prior_idx=i) for i, x in c['sim_data'].to_dict('index').items()]
# pd.concat([pd.concat(x) for x in c['sim_data'].to_dict('index')])
# S = pd.concat([[pd.concat(y) for y in x] for x in c['sim_data']])
# S = pd.concat([[pd.concat(y) for y in x] for x in c['sim_data']])

In [None]:
calib = self.generator(n_prior_draws=10)
P = calib['prior_draws']
S = calib['sim_data']
calib['prior_mean'] = np.mean(P, axis=0)
calib['prior_std'] = np.std(P, axis=0)
calib['prior_min'] = np.min(P, axis=0)
calib['prior_max'] = np.max(P, axis=0)
calib['sim_mean'] = np.mean(S, axis=(0,1))
calib['sim_std'] = np.std(S, axis=(0,1))
calib['sim_min'] = np.min(S, axis=(0,1))
calib['sim_max'] = np.max(S, axis=(0,1))


# d = dict(posterior=g['prior_draws'], sim=g['sim_data'])
# {key:{stat:df
# P =

# # len(g['sim_data'])
# # np.array(g['sim_data']).mean(axis=(0,1)).shape
# S = g['sim_data']
# cols = S[0][0].columns
# pd.DataFrame(np.mean(S, axis=(0,1)), columns=cols)



In [None]:
# draws = self.prior(100)
self.prior(100)['prior_draws'].shape, self.prior_fun(100).shape
# self.simulator(self.prior(10))

In [None]:
self.calibration['const']

In [None]:
pd.DataFrame([self.prior_fun() for _ in range(100)])

In [None]:
%reload_ext autotime
import os, datetime, pathlib, shutil, google.colab, dataclasses, pickle, wbgapi
import matplotlib.pyplot as plt, seaborn as sns
import numpy as np, pandas as pd, tensorflow as tf
from functools import partial
from scipy import stats
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, LSTM
from bayesflow.networks import InvertibleNetwork
from bayesflow.coupling_networks import CouplingLayer
from bayesflow.simulation import GenerativeModel, Prior, Simulator
from bayesflow.amortizers import AmortizedLikelihood, AmortizedPosterior, AmortizedPosteriorLikelihood
from bayesflow.trainers import Trainer
from bayesflow import default_settings
from bayesflow.helper_functions import build_meta_dict
from bayesflow.diagnostics import plot_sbc_ecdf, plot_sbc_histograms, plot_losses
from bayesflow.computational_utilities import maximum_mean_discrepancy
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
plt.rcParams.update({"text.usetex": False, "font.family": "serif", "text.latex.preamble": r"\usepackage{{amsmath}}"})
mnt = '/content/drive'
google.colab.drive.mount(mnt)
RNG = np.random.default_rng(42)
SCALE = 1000
EPS = 1e-6

@dataclasses.dataclass
class COVID():
    country: str = 'Germany'
    name: str = 'covid_000'
    n_steps: int = 100
    refresh: bool = False

    def __post_init__(self):
        self.root_path = pathlib.Path(mnt + f'/MyDrive/bayesian disease modeling/{self.name}')
        if self.refresh:
            # delete root_path and everything in it
            shutil.rmtree(self.root_path, ignore_errors=True)
        self.model_path = self.root_path / f'model/'
        self.model_path.mkdir(exist_ok=True, parents=True)
        self.file = {key: self.model_path / f'{key}.pkl' for key in ['calibration','diagnostic','predictive','error']}

        self.iso = wbgapi.economy.coder(self.country)
        assert self.iso, f'Unrecognized country {self.country}'
        self.tot_pop = next(wbgapi.data.fetch('SP.POP.TOTL', economy=self.iso, time=2020))['value']

        self.calibration = self.read_or_create(self.file['calibration'], fun=self.calibrate, refresh=False)
        self.param_names = [key for key in self.calibration['draws'].columns if key not in self.calibration['const']]
        self.param_latex = [f'${key}$' if key in ['E_0','lag_I','lag_R','lag_D'] else f'$\{key}$' for key in self.param_names]
        self.prior = lambda batch_size: self.prior_fun(batch_size)[self.param_names]

    def calibrate(self):
        draws = self.prior_fun(10000)
        stats = draws.agg(['mean','std']).T
        mask = stats['std'] > EPS
        prior_means, prior_stds = stats[mask].T.values
        const = stats.loc[~mask,'mean'].to_dict()
        return dict(draws=draws, stats=stats, prior_means=prior_means, prior_stds=prior_means, const=const)

    def read_or_create(self, file, fun=None, refresh=False):
        """Read or create pickle for simulation results"""
        try:
            assert not refresh
            with open(file, "rb") as f:
                sims = pickle.load(f)
            print(f'{file} successfully read')
        except Exception as e:
            print(f'Running sims due to error: {e}')
            with open(file, "wb") as f:
                sims = fun()
                pickle.dump(sims, f)
        return sims

    def check_params(self, p):
        for key in ['E_0','t_1','t_2','t_3','t_4','t_5','delta_1','delta_2','delta_3','delta_4','lag_I','lag_R','lag_D']:
            p[key] = int(round(p[key]))
        p['E_0'] = max(p['E_0'], 1)
        valid = all(val > 0 for key, val in p.items() if key[:3] != 'phi')
        for i in range(1,4):
            valid &= p[f't_{i}'] + p[f'delta_{i}'] <= p[f't_{i+1}']
        return p if valid else valid

    def prior_fun(self, batch_size=1):
        alpha_f = (0.7**2) * ((1 - 0.7) / (0.17**2) - (1 - 0.7))
        beta_f = alpha_f * (1 / 0.7 - 1)
        L = []
        while len(L) < batch_size:
            p = self.check_params({
                'N'       :self.tot_pop,
                'E_0'     :RNG.gamma(shape=2, scale=30),
                'alpha'   :RNG.uniform(low=0.005, high=0.99),
                'beta'    :RNG.lognormal(mean=np.log(0.25), sigma=0.3),
                'gamma'   :RNG.lognormal(mean=np.log(1/6.5), sigma=0.5),
                'delta'   :RNG.uniform(low=0.01, high=0.3),
                'epsilon' :RNG.uniform(low=1/14, high=1/3),
                'eta'     :RNG.lognormal(mean=np.log(1/3.2), sigma=0.5),
                'lambda'  :RNG.lognormal(mean=np.log(1.2), sigma=0.5),
                'mu'      :RNG.lognormal(mean=np.log(1/8), sigma=0.2),
                'theta'   :RNG.uniform(low=1/14, high=1/3),
                't_1'     :RNG.normal(loc=8, scale=3),
                't_2'     :RNG.normal(loc=15, scale=3),
                't_3'     :RNG.normal(loc=22, scale=3),
                't_4'     :RNG.normal(loc=66, scale=3),
                't_5'     :self.n_steps,
                'delta_1' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_2' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_3' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'delta_4' :RNG.lognormal(mean=np.log(3), sigma=0.3),
                'lambda_0':RNG.lognormal(mean=np.log(1.20), sigma=0.5),
                'lambda_1':RNG.lognormal(mean=np.log(0.60), sigma=0.5),
                'lambda_2':RNG.lognormal(mean=np.log(0.30), sigma=0.5),
                'lambda_3':RNG.lognormal(mean=np.log(0.10), sigma=0.5),
                # 'lambda_4':RNG.lognormal(mean=np.log(0.10), sigma=0.5),
                'lambda_4':RNG.lognormal(mean=np.log(0.15), sigma=0.5),
                'f_I'     :RNG.beta(a=alpha_f, b=beta_f),
                'f_R'     :RNG.beta(a=alpha_f, b=beta_f),
                'f_D'     :RNG.beta(a=alpha_f, b=beta_f),
                'phi_I'   :RNG.vonmises(mu=0, kappa=0.01),
                'phi_R'   :RNG.vonmises(mu=0, kappa=0.01),
                'phi_D'   :RNG.vonmises(mu=0, kappa=0.01),
                'lag_I'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'lag_R'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'lag_D'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
                'sigma_I' :RNG.gamma(shape=1, scale=5),
                'sigma_R' :RNG.gamma(shape=1, scale=5),
                'sigma_D' :RNG.gamma(shape=1, scale=5),
            })
            if p:
                L.append(p)
        return pd.DataFrame(L)
        # return L)

    def calc_lambda_array(self, p):
        """Computes the array of time-varying contact rates/transimission probabilities."""
        assert self.check_params(p)
        # Array of initial lambdas
        lambd0_arr = np.array([p['lambda_0']] * (p['t_1'] + p['sim_lag']))

        # Compute lambd1 array
        if p['delta_1'] == 1:
            lambd1_arr = np.array([p['lambda_1']] * (p['t_2'] - p['t_1']))
        else:
            lambd1_arr = np.linspace(p['lambda_0'], p['lambda_1'], p['delta_1'])
            lambd1_arr = np.append(lambd1_arr, [p['lambda_1']] * (p['t_2'] - p['t_1'] - p['delta_1']))

        # Compute lambd2 array
        if p['delta_2'] == 1:
            lambd2_arr = np.array([p['lambda_2']] * (p['t_3'] - p['t_2']))
        else:
            lambd2_arr = np.linspace(p['lambda_1'], p['lambda_2'], p['delta_2'])
            lambd2_arr = np.append(lambd2_arr, [p['lambda_2']] * (p['t_3'] - p['t_2'] - p['delta_2']))

        # Compute lambd3 array
        if p['delta_3'] == 1:
            lambd3_arr = np.array([p['lambda_3']] * (p['t_4'] - p['t_3']))
        else:
            lambd3_arr = np.linspace(p['lambda_3'], p['lambda_4'], p['delta_3'])
            lambd3_arr = np.append(lambd3_arr, [p['lambda_3']] * (p['t_4'] - p['t_3'] - p['delta_3']))

        # Compute lambd4 array
        if p['delta_4'] == 1:
            lambd4_arr = np.array([p['lambda_4']] * (p['t_5'] - p['t_4']))
        else:
            lambd4_arr = np.linspace(p['lambda_3'], p['lambda_4'], p['delta_4'])
            lambd4_arr = np.append(lambd4_arr, [p['lambda_4']] * (p['t_5'] - p['t_4'] - p['delta_4']))

        return np.r_[lambd0_arr, lambd1_arr, lambd2_arr, lambd3_arr, lambd4_arr]


    def sir(self, prior_draw, sim_diff=16, observation_model=True):
        for k,v in self.calibration['const'].items():
            prior_draw[k] = v

        # assert self.check_params(p)
        assert sim_diff > max(p['lag_I'],p['lag_R'],p['lag_D'])
        p['sim_lag'] = sim_diff - 1
        lambd_arr = self.calc_lambda_array(p)

        # Initial conditions
        S, E, C, I, R, D = [self.tot_pop - p['E_0']], [p['E_0']], [0], [0], [0], [0]

        # Containers
        I_news = []
        R_news = []
        D_news = []

        # Reported new cases
        I_data = np.zeros(p['t_5'])
        R_data = np.zeros(p['t_5'])
        D_data = np.zeros(p['t_5'])
        fs_I = np.zeros(p['t_5'])
        fs_R = np.zeros(p['t_5'])
        fs_D = np.zeros(p['t_5'])

        # Simulate T-1 tiemsteps
        for t in range(p['t_5'] + p['sim_lag']):
            # Calculate new exposed cases
            E_new = lambd_arr[t] * ((C[t] + p['beta'] * I[t]) / self.tot_pop) * S[t]

            # Remove exposed from susceptible
            S_t = S[t] - E_new

            # Calculate current exposed by adding new exposed and
            # subtracting the exposed becoming carriers.
            E_t = E[t] + E_new - p['gamma'] * E[t]

            # Calculate current carriers by adding the new exposed and subtracting
            # those who will develop symptoms and become detected and those who
            # will go through the disease asymptomatically.
            C_t = C[t] + p['gamma'] * E[t] - (1 - p['alpha']) * p['eta'] * C[t] - p['alpha'] * p['theta'] * C[t]

            # Calculate current infected by adding the symptomatic carriers and
            # subtracting the dead and recovered. The newly infected are just the
            # carriers who get detected.
            I_t = I[t] + (1 - p['alpha']) * p['eta'] * C[t] - (1 - p['delta']) * p['mu'] * I[t] - p['delta'] * p['epsilon'] * I[t]
            I_new = (1 - p['alpha']) * p['eta'] * C[t]

            # Calculate current recovered by adding the symptomatic and asymptomatic
            # recovered. The newly recovered are only the detected recovered
            R_t = R[t] + p['alpha'] * p['theta'] * C[t] + (1 - p['delta']) * p['mu'] * I[t]
            R_new = (1 - p['delta']) * p['mu'] * I[t]

            # Calculate the current dead
            D_t = D[t] + p['delta'] * p['epsilon'] * I[t]
            D_new = p['delta'] * p['epsilon'] * I[t]

            # Ensure some numerical onstraints
            S_t = np.clip(S_t, 0, self.tot_pop)
            E_t = np.clip(E_t, 0, self.tot_pop)
            C_t = np.clip(C_t, 0, self.tot_pop)
            I_t = np.clip(I_t, 0, self.tot_pop)
            R_t = np.clip(R_t, 0, self.tot_pop)
            D_t = np.clip(D_t, 0, self.tot_pop)

            # Keep track of process over time
            S.append(S_t)
            E.append(E_t)
            C.append(C_t)
            I.append(I_t)
            R.append(R_t)
            D.append(D_t)
            I_news.append(I_new)
            R_news.append(R_new)
            D_news.append(D_new)

            # From here, start adding new cases with delay D
            # Note, we assume the same delay
            if t >= p['sim_lag']:
                # Compute lags and add to data arrays
                fs_I[t - p['sim_lag']] = (1 - p['f_I']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - p['sim_lag']) - 0.5 * p['phi_I']))
                )
                fs_R[t - p['sim_lag']] = (1 - p['f_R']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - p['sim_lag']) - 0.5 * p['phi_R']))
                )
                fs_D[t - p['sim_lag']] = (1 - p['f_D']) * (
                    1 - np.abs(np.sin((np.pi / 7) * (t - p['sim_lag']) - 0.5 * p['phi_D']))
                )
                I_data[t - p['sim_lag']] = I_news[t - p['lag_I']]
                R_data[t - p['sim_lag']] = R_news[t - p['lag_R']]
                D_data[t - p['sim_lag']] = D_news[t - p['lag_D']]

        # Compute weekly modulation
        I_data = (1 - fs_I) * I_data
        R_data = (1 - fs_R) * R_data
        D_data = (1 - fs_D) * D_data

        # Add noise
        # I_data = stats.t(df=4, loc=I_data, scale=np.sqrt(I_data) * scale_I).rvs()
        # R_data = stats.t(df=4, loc=R_data, scale=np.sqrt(R_data) * scale_R).rvs()
        # D_data = stats.t(df=4, loc=D_data, scale=np.sqrt(D_data) * scale_D).rvs()
        I_data = I_data + RNG.standard_t(4) * np.sqrt(I_data) * p['sigma_I']
        R_data = R_data + RNG.standard_t(4) * np.sqrt(R_data) * p['sigma_R']
        D_data = D_data + RNG.standard_t(4) * np.sqrt(D_data) * p['sigma_D']
        n = I_data.shape[0]
        Y = pd.DataFrame({'S':S[-n:],'E':E[-n:],'C':C[-n:],'I':I[-n:],'R':R[-n:],'D':D[-n:],'dI_obs':I_data,'dR_obs':R_data,'dD_obs':D_data}).clip(0, self.tot_pop) / SCALE
        if observation_model:
            Y = Y.filter(like='obs')
        return Y


    def generator(self, n_sims=1, n_priors=1, prior_draws=None):

        prior_draws = self.prior(batch_size)
        # self.sir(prior_draws)
        return prior_draws

        # return [v for k,v in param.items()]
    #     return params


self = COVID(
    # refresh=True,
)
prior_draws = self.prior(3)
for k,v in self.calibration['const'].items():
    prior_draws[k] = v

prior_draws.to_dict('records')
x = [self.sir(p, observation_model=False) for p in prior_draws.to_dict('records')]
# x = [self.sir(p, observation_model=False) for p in prior_draws.to_dict('index').values()]
# x = [p for i,p in prior_draws.to_dict('index')]
x[0][0]

# prior_draws.to_dict('index')
# self.check_params(prior_draws.iloc[0])
# self.sir(self.generator(5).iloc[0])
# self.prior(5)
# self.calibration['const']
# for params in self.prior(5).to
# calibration_draws = pd.DataFrame([self.prior_fun() for k in range(100)])

In [None]:
self.calibration['const']

In [None]:
df = self.prior(5)
# [v for k,v in df.to_dict(orient='index').items()]
[x for i,x in df.iterrows()]
# f = lambda x: [x.values, x.values]
# df.apply(f, axis=1)

In [None]:
self.const_params

In [None]:
def prior_fun():
    alpha_f = (0.7**2) * ((1 - 0.7) / (0.17**2) - (1 - 0.7))
    beta_f = alpha_f * (1 / 0.7 - 1)
    while True:
        p = self.check_params({
            'N'       :self.tot_pop,
            'E_0'     :RNG.gamma(shape=2, scale=30),
            'alpha'   :RNG.uniform(low=0.005, high=0.99),
            'beta'    :RNG.lognormal(mean=np.log(0.25), sigma=0.3),
            'gamma'   :RNG.lognormal(mean=np.log(1/6.5), sigma=0.5),
            'delta'   :RNG.uniform(low=0.01, high=0.3),
            'epsilon' :RNG.uniform(low=1/14, high=1/3),
            'eta'     :RNG.lognormal(mean=np.log(1/3.2), sigma=0.5),
            'lambda'  :RNG.lognormal(mean=np.log(1.2), sigma=0.5),
            'mu'      :RNG.lognormal(mean=np.log(1/8), sigma=0.2),
            'theta'   :RNG.uniform(low=1/14, high=1/3),
            't_1'     :RNG.normal(loc=8, scale=3),
            't_2'     :RNG.normal(loc=15, scale=3),
            't_3'     :RNG.normal(loc=22, scale=3),
            't_4'     :RNG.normal(loc=66, scale=3),
            't_5'     :self.n_steps,
            'delta_1' :RNG.lognormal(mean=np.log(3), sigma=0.3),
            'delta_2' :RNG.lognormal(mean=np.log(3), sigma=0.3),
            'delta_3' :RNG.lognormal(mean=np.log(3), sigma=0.3),
            'delta_4' :RNG.lognormal(mean=np.log(3), sigma=0.3),
            'lambda_0':RNG.lognormal(mean=np.log(1.20), sigma=0.5),
            'lambda_1':RNG.lognormal(mean=np.log(0.60), sigma=0.5),
            'lambda_2':RNG.lognormal(mean=np.log(0.30), sigma=0.5),
            'lambda_3':RNG.lognormal(mean=np.log(0.10), sigma=0.5),
            # 'lambda_4':RNG.lognormal(mean=np.log(0.10), sigma=0.5),
            'lambda_4':RNG.lognormal(mean=np.log(0.15), sigma=0.5),
            'f_I'     :RNG.beta(a=alpha_f, b=beta_f),
            'f_R'     :RNG.beta(a=alpha_f, b=beta_f),
            'f_D'     :RNG.beta(a=alpha_f, b=beta_f),
            'phi_I'   :RNG.vonmises(mu=0, kappa=0.01),
            'phi_R'   :RNG.vonmises(mu=0, kappa=0.01),
            'phi_D'   :RNG.vonmises(mu=0, kappa=0.01),
            'lag_I'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
            'lag_R'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
            'lag_D'   :RNG.lognormal(mean=np.log(8), sigma=0.2),
            'sigma_I' :RNG.gamma(shape=1, scale=5),
            'sigma_R' :RNG.gamma(shape=1, scale=5),
            'sigma_D' :RNG.gamma(shape=1, scale=5),
        })
        if p:
            return p

In [None]:
self.const_params

In [None]:
EPS = 1e-4
stat = calibration_draws.agg(['mean','std']).T
mask = stat['std'] > EPS
c = stat[~mask].index.tolist()
stat[mask].T.to_dict(orient='index')
mean, stds = stat[mask].T.values
mean

# mean = stat.loc[mask,'mean'].to_dict()
# mean
# mean = stat

In [None]:
# ode_params = ['alpha','beta','gamma','delta','epsilon','eta','lambda','mu','theta']
# self.lower_bound = np.array([EPS for key in self.param_names])
# self.upper_bound = np.array([1-EPS if key in ['alpha','delta','psi_I','psi_R','psi_D'] else np.inf for key in self.param_names])
# self.classes = self.sir(prior_draws[0]).columns.tolist()

In [None]:
def calc_lambda_array(p):
    """Computes the array of time-varying contact rates/transimission probabilities."""
    assert check_params(p)
    # Array of initial lambdas
    lambd0_arr = np.array([p['lambda_0']] * (p['t_1'] + p['sim_lag']))

    # Compute lambd1 array
    if p['delta_1'] == 1:
        lambd1_arr = np.array([p['lambda_1']] * (p['t_2'] - p['t_1']))
    else:
        lambd1_arr = np.linspace(p['lambda_0'], p['lambda_1'], p['delta_1'])
        lambd1_arr = np.append(lambd1_arr, [p['lambda_1']] * (p['t_2'] - p['t_1'] - p['delta_1']))

    # Compute lambd2 array
    if p['delta_2'] == 1:
        lambd2_arr = np.array([p['lambda_2']] * (p['t_3'] - p['t_2']))
    else:
        lambd2_arr = np.linspace(p['lambda_1'], p['lambda_2'], p['delta_2'])
        lambd2_arr = np.append(lambd2_arr, [p['lambda_2']] * (p['t_3'] - p['t_2'] - p['delta_2']))

    # Compute lambd3 array
    if p['delta_3'] == 1:
        lambd3_arr = np.array([p['lambda_3']] * (p['t_4'] - p['t_3']))
    else:
        lambd3_arr = np.linspace(p['lambda_3'], p['lambda_4'], p['delta_3'])
        lambd3_arr = np.append(lambd3_arr, [p['lambda_3']] * (p['t_4'] - p['t_3'] - p['delta_3']))

    # Compute lambd4 array
    if p['delta_4'] == 1:
        lambd4_arr = np.array([p['lambda_4']] * (p['t_5'] - p['t_4']))
    else:
        lambd4_arr = np.linspace(p['lambda_3'], p['lambda_4'], p['delta_4'])
        lambd4_arr = np.append(lambd4_arr, [p['lambda_4']] * (p['t_5'] - p['t_4'] - p['delta_4']))

    return np.r_[lambd0_arr, lambd1_arr, lambd2_arr, lambd3_arr, lambd4_arr]


def non_stationary_SEICR(prior_draw, N, T, sim_diff=16, observation_model=True):
    try:
        p = const_params | prior_draw  # if prior_values was passed as dict
    except:
        p = const_params | dict(zip(param_names, prior_draw))  # if prior_values was passed as list
    assert check_params(p)
    assert sim_diff > max(p['lag_I'],p['lag_R'],p['lag_D'])
    p['sim_lag'] = sim_diff - 1
    lambd_arr = calc_lambda_array(p)

    # Initial conditions
    S, E, C, I, R, D = [N - p['E_0']], [p['E_0']], [0], [0], [0], [0]

    # Containers
    I_news = []
    R_news = []
    D_news = []

    # Reported new cases
    I_data = np.zeros(p['t_5'])
    R_data = np.zeros(p['t_5'])
    D_data = np.zeros(p['t_5'])
    fs_I = np.zeros(p['t_5'])
    fs_R = np.zeros(p['t_5'])
    fs_D = np.zeros(p['t_5'])

    # Simulate T-1 tiemsteps
    for t in range(p['t_5'] + p['sim_lag']):
        # Calculate new exposed cases
        E_new = lambd_arr[t] * ((C[t] + p['beta'] * I[t]) / N) * S[t]

        # Remove exposed from susceptible
        S_t = S[t] - E_new

        # Calculate current exposed by adding new exposed and
        # subtracting the exposed becoming carriers.
        E_t = E[t] + E_new - p['gamma'] * E[t]

        # Calculate current carriers by adding the new exposed and subtracting
        # those who will develop symptoms and become detected and those who
        # will go through the disease asymptomatically.
        C_t = C[t] + p['gamma'] * E[t] - (1 - p['alpha']) * p['eta'] * C[t] - p['alpha'] * p['theta'] * C[t]

        # Calculate current infected by adding the symptomatic carriers and
        # subtracting the dead and recovered. The newly infected are just the
        # carriers who get detected.
        I_t = I[t] + (1 - p['alpha']) * p['eta'] * C[t] - (1 - p['delta']) * p['mu'] * I[t] - p['delta'] * p['epsilon'] * I[t]
        I_new = (1 - p['alpha']) * p['eta'] * C[t]

        # Calculate current recovered by adding the symptomatic and asymptomatic
        # recovered. The newly recovered are only the detected recovered
        R_t = R[t] + p['alpha'] * p['theta'] * C[t] + (1 - p['delta']) * p['mu'] * I[t]
        R_new = (1 - p['delta']) * p['mu'] * I[t]

        # Calculate the current dead
        D_t = D[t] + p['delta'] * p['epsilon'] * I[t]
        D_new = p['delta'] * p['epsilon'] * I[t]

        # Ensure some numerical onstraints
        S_t = np.clip(S_t, 0, N)
        E_t = np.clip(E_t, 0, N)
        C_t = np.clip(C_t, 0, N)
        I_t = np.clip(I_t, 0, N)
        R_t = np.clip(R_t, 0, N)
        D_t = np.clip(D_t, 0, N)

        # Keep track of process over time
        S.append(S_t)
        E.append(E_t)
        C.append(C_t)
        I.append(I_t)
        R.append(R_t)
        D.append(D_t)
        I_news.append(I_new)
        R_news.append(R_new)
        D_news.append(D_new)

        # From here, start adding new cases with delay D
        # Note, we assume the same delay
        if t >= p['sim_lag']:
            # Compute lags and add to data arrays
            fs_I[t - p['sim_lag']] = (1 - p['f_I']) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - p['sim_lag']) - 0.5 * p['phi_I']))
            )
            fs_R[t - p['sim_lag']] = (1 - p['f_R']) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - p['sim_lag']) - 0.5 * p['phi_R']))
            )
            fs_D[t - p['sim_lag']] = (1 - p['f_D']) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - p['sim_lag']) - 0.5 * p['phi_D']))
            )
            I_data[t - p['sim_lag']] = I_news[t - p['lag_I']]
            R_data[t - p['sim_lag']] = R_news[t - p['lag_R']]
            D_data[t - p['sim_lag']] = D_news[t - p['lag_D']]

    # Compute weekly modulation
    I_data = (1 - fs_I) * I_data
    R_data = (1 - fs_R) * R_data
    D_data = (1 - fs_D) * D_data

    # Add noise
    # I_data = stats.t(df=4, loc=I_data, scale=np.sqrt(I_data) * scale_I).rvs()
    # R_data = stats.t(df=4, loc=R_data, scale=np.sqrt(R_data) * scale_R).rvs()
    # D_data = stats.t(df=4, loc=D_data, scale=np.sqrt(D_data) * scale_D).rvs()
    I_data = I_data + RNG.standard_t(4) * np.sqrt(I_data) * p['sigma_I']
    R_data = R_data + RNG.standard_t(4) * np.sqrt(R_data) * p['sigma_R']
    D_data = D_data + RNG.standard_t(4) * np.sqrt(D_data) * p['sigma_D']

    Y = pd.DataFrame({'S':S,'E':E,'C':C,'I':I,'R':R,'D':D,'dI_obs':I_data,'dR_obs':R_data,'dD_obs':D_data}).clip(0, N) / SCALE
    if observation_model:
        Y = Y.filter(like='obs')
    return Y
non_stationary_SEICR(prior_fun(), real_data['N'], real_data['T'])
# prior_fun()

In [None]:
def load_data():
    """Download and prepare data from Johns Hopkins"""
    N = 83e6  / SCALE
    T = 100
    start = 39
    def fetch(cl):
        return (
            pd.read_csv(f'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_{cl}_global.csv', sep=",")
            .drop(columns=['Province/State','Lat','Long'])
            .groupby('Country/Region').sum()
            .loc['Germany']
        )
    data = (
        pd.DataFrame(
        {'dI_real':fetch('confirmed'), 'dR_real':fetch('recovered'), 'dD_real':fetch('deaths')})
        .iloc[start:start+T].assign(date = lambda x: pd.to_datetime(x.index)).set_index('date')
        .diff().dropna().div(SCALE).clip(0, N)
        )
    return dict(x=data, T=T, N=N, mean=data.mean(), std=data.std())
real_data = load_data()

In [None]:
# non_stationary_SEICR(


In [None]:


def prior_sir():
    """
    Implements batch sampling from a stationary prior over the parameters
    of the non-stationary SIR model.
    """

    t1 = np.random.normal(loc=8, scale=3)
    t2 = np.random.normal(loc=15, scale=1)
    t3 = np.random.normal(loc=22, scale=1)
    t4 = np.random.normal(loc=66, scale=1)
    delta_t1 = np.random.lognormal(mean=np.log(3), sigma=0.3)
    delta_t2 = np.random.lognormal(mean=np.log(3), sigma=0.3)
    delta_t3 = np.random.lognormal(mean=np.log(3), sigma=0.3)
    delta_t4 = np.random.lognormal(mean=np.log(3), sigma=0.3)
    lambd0 = np.random.lognormal(mean=np.log(1.2), sigma=0.5)
    lambd1 = np.random.lognormal(mean=np.log(0.6), sigma=0.5)
    lambd2 = np.random.lognormal(mean=np.log(0.3), sigma=0.5)
    lambd3 = np.random.lognormal(mean=np.log(0.1), sigma=0.5)
    lambd4 = np.random.lognormal(mean=np.log(0.1), sigma=0.5)
    mu = np.random.lognormal(mean=np.log(1 / 8), sigma=0.2)
    f_i = np.random.beta(a=alpha_f, b=beta_f)
    phi_i = stats.vonmises(kappa=0.01).rvs()
    f_r = np.random.beta(a=alpha_f, b=beta_f)
    phi_r = stats.vonmises(kappa=0.01).rvs()
    f_d = np.random.beta(a=alpha_f, b=beta_f)
    phi_d = stats.vonmises(kappa=0.01).rvs()
    D_i = np.random.lognormal(mean=np.log(8), sigma=0.2)
    D_r = np.random.lognormal(mean=np.log(8), sigma=0.2)
    D_d = np.random.lognormal(mean=np.log(8), sigma=0.2)
    E0 = np.random.gamma(shape=2, scale=30)
    scale_I = np.random.gamma(shape=1, scale=5)
    scale_R = np.random.gamma(shape=1, scale=5)
    scale_D = np.random.gamma(shape=1, scale=5)
    return [
        t1,
        t2,
        t3,
        t4,
        delta_t1,
        delta_t2,
        delta_t3,
        delta_t4,
        lambd0,
        lambd1,
        lambd2,
        lambd3,
        lambd4,
        mu,
        f_i,
        phi_i,
        f_r,
        phi_r,
        f_d,
        phi_d,
        D_i,
        D_r,
        D_d,
        E0,
        scale_I,
        scale_R,
        scale_D,
    ]


def prior_secir():
    """
    Implements batch sampling from a stationary prior over the parameters
    of the non-stationary SIR model.
    """

    alpha = np.random.uniform(low=0.005, high=0.9)
    beta = np.random.lognormal(mean=np.log(0.25), sigma=0.3)
    gamma = np.random.lognormal(mean=np.log(1 / 6.5), sigma=0.5)
    eta = np.random.lognormal(mean=np.log(1 / 3.2), sigma=0.3)
    theta = np.random.uniform(low=1 / 14, high=1 / 3)
    delta = np.random.uniform(low=0.01, high=0.3)
    d = np.random.uniform(low=1 / 14, high=1 / 3)
    return [alpha, beta, gamma, eta, theta, delta, d]


def calc_lambda_array(
    sim_lag,
    lambd0,
    lambd1,
    lambd2,
    lambd3,
    lambd4,
    t1,
    t2,
    t3,
    t4,
    delta_t1,
    delta_t2,
    delta_t3,
    delta_t4,
    T,
):
    """Computes the array of time-varying contact rates/transimission probabilities."""

    # Array of initial lambdas
    lambd0_arr = np.array([lambd0] * (t1 + sim_lag))

    # Compute lambd1 array
    if delta_t1 == 1:
        lambd1_arr = np.array([lambd1] * (t2 - t1))
    else:
        lambd1_arr = np.linspace(lambd0, lambd1, delta_t1)
        lambd1_arr = np.append(lambd1_arr, [lambd1] * (t2 - t1 - delta_t1))

    # Compute lambd2 array
    if delta_t2 == 1:
        lambd2_arr = np.array([lambd2] * (t3 - t2))
    else:
        lambd2_arr = np.linspace(lambd1, lambd2, delta_t2)
        lambd2_arr = np.append(lambd2_arr, [lambd2] * (t3 - t2 - delta_t2))

    # Compute lambd3 array
    if delta_t3 == 1:
        lambd3_arr = np.array([lambd3] * (t4 - t3))
    else:
        lambd3_arr = np.linspace(lambd3, lambd4, delta_t3)
        lambd3_arr = np.append(lambd3_arr, [lambd3] * (t4 - t3 - delta_t3))

    # Compute lambd4 array
    if delta_t4 == 1:
        lambd4_arr = np.array([lambd4] * (T - t4))
    else:
        lambd4_arr = np.linspace(lambd3, lambd4, delta_t4)
        lambd4_arr = np.append(lambd4_arr, [lambd4] * (T - t4 - delta_t4))

    return np.r_[lambd0_arr, lambd1_arr, lambd2_arr, lambd3_arr, lambd4_arr]


def non_stationary_SEICR(
    params_sir, params_secir, N, T, sim_diff=16, observation_model=True
):
    """
    Performs a forward simulation from the stationary SIR model.
    """

    # Extract parameters
    (
        t1,
        t2,
        t3,
        t4,
        delta_t1,
        delta_t2,
        delta_t3,
        delta_t4,
        lambd0,
        lambd1,
        lambd2,
        lambd3,
        lambd4,
        mu,
        f_i,
        phi_i,
        f_r,
        phi_r,
        f_d,
        phi_d,
        delay_i,
        delay_r,
        delay_d,
        E0,
        scale_I,
        scale_R,
        scale_D,
    ) = params_sir
    alpha, beta, gamma, eta, theta, delta, d = params_secir

    # Round integer parameters
    t1, t2, t3, t4 = int(round(t1)), int(round(t2)), int(round(t3)), int(round(t4))
    delta_t1, delta_t2, delta_t3, delta_t4 = (
        int(round(delta_t1)),
        int(round(delta_t2)),
        int(round(delta_t3)),
        int(round(delta_t4)),
    )
    E0 = max(1, np.round(E0))
    delay_i = int(round(delay_i))
    delay_r = int(round(delay_r))
    delay_d = int(round(delay_d))

    # Impose constraints
    assert sim_diff > delay_i
    assert sim_diff > delay_r
    assert sim_diff > delay_d
    assert t1 > 0 and t2 > 0 and t3 > 0 and t4 > 0
    assert t1 < t2 < t3 < t4
    assert delta_t1 > 0 and delta_t2 > 0 and delta_t3 > 0 and delta_t4 > 0
    assert (
        t2 - t1 >= delta_t1
        and t3 - t2 >= delta_t2
        and t4 - t3 >= delta_t3
        and T - t4 >= delta_t4
    )

    # Calculate lambda arrays
    # Lambda0 is the initial contact rate which will be consecutively
    # reduced via the government measures
    sim_lag = sim_diff - 1
    lambd_arr = calc_lambda_array(
        sim_lag,
        lambd0,
        lambd1,
        lambd2,
        lambd3,
        lambd4,
        t1,
        t2,
        t3,
        t4,
        delta_t1,
        delta_t2,
        delta_t3,
        delta_t4,
        T,
    )

    # Initial conditions
    S, E, C, I, R, D = [N - E0], [E0], [0], [0], [0], [0]

    # Containers
    I_news = []
    R_news = []
    D_news = []

    # Reported new cases
    I_data = np.zeros(T)
    R_data = np.zeros(T)
    D_data = np.zeros(T)
    fs_i = np.zeros(T)
    fs_r = np.zeros(T)
    fs_d = np.zeros(T)

    # Simulate T-1 tiemsteps
    for t in range(T + sim_lag):
        # Calculate new exposed cases
        E_new = lambd_arr[t] * ((C[t] + beta * I[t]) / N) * S[t]

        # Remove exposed from susceptible
        S_t = S[t] - E_new

        # Calculate current exposed by adding new exposed and
        # subtracting the exposed becoming carriers.
        E_t = E[t] + E_new - gamma * E[t]

        # Calculate current carriers by adding the new exposed and subtracting
        # those who will develop symptoms and become detected and those who
        # will go through the disease asymptomatically.
        C_t = C[t] + gamma * E[t] - (1 - alpha) * eta * C[t] - alpha * theta * C[t]

        # Calculate current infected by adding the symptomatic carriers and
        # subtracting the dead and recovered. The newly infected are just the
        # carriers who get detected.
        I_t = (
            I[t] + (1 - alpha) * eta * C[t] - (1 - delta) * mu * I[t] - delta * d * I[t]
        )
        I_new = (1 - alpha) * eta * C[t]

        # Calculate current recovered by adding the symptomatic and asymptomatic
        # recovered. The newly recovered are only the detected recovered
        R_t = R[t] + alpha * theta * C[t] + (1 - delta) * mu * I[t]
        R_new = (1 - delta) * mu * I[t]

        # Calculate the current dead
        D_t = D[t] + delta * d * I[t]
        D_new = delta * d * I[t]

        # Ensure some numerical onstraints
        S_t = np.clip(S_t, 0, N)
        E_t = np.clip(E_t, 0, N)
        C_t = np.clip(C_t, 0, N)
        I_t = np.clip(I_t, 0, N)
        R_t = np.clip(R_t, 0, N)
        D_t = np.clip(D_t, 0, N)

        # Keep track of process over time
        S.append(S_t)
        E.append(E_t)
        C.append(C_t)
        I.append(I_t)
        R.append(R_t)
        D.append(D_t)
        I_news.append(I_new)
        R_news.append(R_new)
        D_news.append(D_new)

        # From here, start adding new cases with delay D
        # Note, we assume the same delay
        if t >= sim_lag:
            # Compute lags and add to data arrays
            fs_i[t - sim_lag] = (1 - f_i) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * phi_i))
            )
            fs_r[t - sim_lag] = (1 - f_r) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * phi_r))
            )
            fs_d[t - sim_lag] = (1 - f_d) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * phi_d))
            )
            I_data[t - sim_lag] = I_news[t - delay_i]
            R_data[t - sim_lag] = R_news[t - delay_r]
            D_data[t - sim_lag] = D_news[t - delay_d]

    # Compute weekly modulation
    I_data = (1 - fs_i) * I_data
    R_data = (1 - fs_r) * R_data
    D_data = (1 - fs_d) * D_data

    # Add noise
    I_data = stats.t(df=4, loc=I_data, scale=np.sqrt(I_data) * scale_I).rvs()
    R_data = stats.t(df=4, loc=R_data, scale=np.sqrt(R_data) * scale_R).rvs()
    D_data = stats.t(df=4, loc=D_data, scale=np.sqrt(D_data) * scale_D).rvs()

    Y = pd.DataFrame({'S':S,'E':E,'C':C,'I':I,'R':R,'D':D,'dI_obs':I_data,'dR_obs':R_data,'dD_obs':D_data}).clip(0, N) / SCALE
    if observation_model:
        Y = Y.filter(like='obs')
    return Y

def simulate(params=None, n_params=1, n_sim, params, N, T, sim_diff=21, observation_model=True):


def simulate(*, N, T, sim_diff=21, observation_model=True, n_sim=1, n_params=1, params=None):
    if params is None:



    x = []
    theta1, theta2 = params[:-7], params[-7:]
    return np.array([non_stationary_SEICR(theta1, theta2, N=N, T=T, sim_diff=sim_diff, observation_model=observation_model) for _ in range(n_sim)])
    # for _ in range(n_sim):
    #     x_i = non_stationary_SEICR(
    #         theta1,
    #         theta2,
    #         N=N,
    #         T=T,
    #         sim_diff=sim_diff,
    #         observation_model=observation_model,
    #     )
    #     x.append(x_i)
    # return np.clip(np.array(x), 0, np.inf)

def data_generator(batch_size, T=None, N=None, sim_diff=21, observation_model=True):
    """
    Runs the forward model 'batch_size' times by first sampling fromt the prior
    theta ~ p(theta) and running x ~ p(x|theta).
    ----------

    Arguments:
    batch_size : int -- the number of samples to draw from the prior
    ----------

    Output:
    forward_dict : dict
        The expected outputs for a BayesFlow pipeline
    """

    if seed is not None:
        np.random.seed(seed)

    # Generate data
    # x is a np.ndarray of shape (batch_size, n_obs, x_dim)
    x = []
    theta = []
    for i in range(batch_size):
        # Reject meaningless simulaitons
        x_i = None
        while x_i is None:
            try:
                theta1 = prior_sir()
                theta2 = prior_secir()
                x_i = non_stationary_SEICR(theta1, theta2, N, T, sim_diff=sim_diff)
            except:
                pass
        # Simulate SECIR
        x.append(x_i)
        theta.append(theta1 + theta2)

    # Clip negative and normalize
    x = np.clip(np.array(x), 0.0, np.inf) / scale
    theta = np.array(theta)

    forward_dict = {"prior_draws": theta, "sim_data": x}
    return forward_dict

class MultiConvLayer(tf.keras.Model):
    """Implements an inception-inspired conv layer using different kernel sizes"""

    def __init__(self, n_filters=32, strides=1):
        super(MultiConvLayer, self).__init__()

        self.convs = [
            tf.keras.layers.Conv1D(
                n_filters // 2,
                kernel_size=f,
                strides=strides,
                padding="causal",
                activation="relu",
                kernel_initializer="glorot_uniform",
            )
            for f in range(2, 8)
        ]
        self.dim_red = tf.keras.layers.Conv1D(
            n_filters, 1, 1, activation="relu", kernel_initializer="glorot_uniform"
        )

    def call(self, x):
        """x is a timeseries of dimensions B timestamps, n_features"""

        out = tf.concat([conv(x) for conv in self.convs], axis=-1)
        out = self.dim_red(out)
        return out


class MultiConvNet(tf.keras.Model):
    """Implements an inception-inspired conv layer using different kernel sizes"""

    def __init__(self, n_layers=3, n_filters=64, strides=1):
        super(MultiConvNet, self).__init__()

        self.net = tf.keras.Sequential(
            [MultiConvLayer(n_filters, strides) for _ in range(n_layers)]
        )

        self.lstm = LSTM(n_filters)

    def call(self, x, **args):
        """x is a timeseries of dimensions B timestamps, n_features"""

        out = self.net(x)
        out = self.lstm(out)
        return out


class SummaryNet(tf.keras.Model):
    def __init__(self, n_summary):
        super(SummaryNet, self).__init__()
        self.net_I = MultiConvNet(n_filters=n_summary // 3)
        self.net_R = MultiConvNet(n_filters=n_summary // 3)
        self.net_D = MultiConvNet(n_filters=n_summary // 3)

    @tf.function
    def call(self, x, **args):
        """x is a timeseries of dimensions B timestamps, n_features"""

        x = tf.split(x, 3, axis=-1)
        x_i = self.net_I(x[0])
        x_r = self.net_R(x[1])
        x_d = self.net_D(x[2])
        return tf.concat([x_i, x_r, x_d], axis=-1)

class MemoryNetwork(tf.keras.Model):
    def __init__(self, meta):
        super(MemoryNetwork, self).__init__()

        self.gru = GRU(meta["n_hidden"], return_sequences=True, return_state=True)
        self.h = meta["n_hidden"]
        self.n_params = meta["n_params"]

    @tf.function
    def call(self, target, condition):
        """Performs a forward pass through the network.

        Params:
        -------
        target    : tf.Tesnor of shape (batch_size, time_stes, dim)
            The time-dependent signal to process.
        condition : tf.Tensor of shape (batch_size, cond_dim)
            The conditional (static) variables, e.g., parameters.
        """
        shift_target = target[:, :-1, :]
        init = tf.zeros((target.shape[0], 1, target.shape[2]))
        inp_teacher = tf.concat([init, shift_target], axis=1)
        inp_teacher_c = tf.concat([inp_teacher, condition], axis=-1)
        out, _ = self.gru(inp_teacher_c)
        return out

    def step_loop(self, target, condition, state):
        out, new_state = self.gru(
            tf.concat([target, condition], axis=-1), initial_state=state
        )
        return out, new_state

class InvertibleNetworkWithMemory(tf.keras.Model):
    """Implements a chain of conditional invertible blocks for Bayesian parameter inference."""

    def __init__(
        self,
        num_params,
        num_coupling_layers=4,
        coupling_settings=None,
        coupling_design="affine",
        permutation="fixed",
        use_act_norm=True,
        act_norm_init=None,
        use_soft_flow=False,
        soft_flow_bounds=(1e-3, 5e-2),
    ):
        """Initializes a custom invertible network with recurrent memory."""

        super().__init__()

        # Create settings dict for coupling layer
        settings = dict(
            latent_dim=num_params,
            coupling_settings=coupling_settings,
            coupling_design=coupling_design,
            permutation=permutation,
            use_act_norm=use_act_norm,
            act_norm_init=act_norm_init,
        )

        # Create sequence of coupling layers and store reference to dimensionality
        self.coupling_layers = [
            CouplingLayer(**settings) for _ in range(num_coupling_layers)
        ]

        # Store attributes
        self.soft_flow = use_soft_flow
        self.soft_low = soft_flow_bounds[0]
        self.soft_high = soft_flow_bounds[1]
        self.use_act_norm = use_act_norm
        self.latent_dim = num_params
        self.dynamic_summary_net = MemoryNetwork({"n_hidden": 256, "n_params": 3})
        self.latent_dim = num_params

    def call(self, targets, condition, inverse=False):
        """Performs one pass through an invertible chain (either inverse or forward).

        Parameters
        ----------
        targets   : tf.Tensor
            The estimation quantities of interest, shape (batch_size, ...)
        condition : tf.Tensor
            The conditional data x, shape (batch_size, summary_dim)
        inverse   : bool, default: False
            Flag indicating whether to run the chain forward or backwards

        Returns
        -------
        (z, log_det_J)  :  tuple(tf.Tensor, tf.Tensor)
            If inverse=False: The transformed input and the corresponding Jacobian of the transformation,
            v shape: (batch_size, ...), log_det_J shape: (batch_size, ...)

        target          :  tf.Tensor
            If inverse=True: The transformed out, shape (batch_size, ...)

        Important
        ---------
        If ``inverse=False``, the return is ``(z, log_det_J)``.\n
        If ``inverse=True``, the return is ``target``.
        """

        if inverse:
            return self.inverse(targets, condition)
        return self.forward(targets, condition)

    @tf.function
    def forward(self, targets, condition, **kwargs):
        """Performs a forward pass though the chain."""

        # Add memory condition
        memory = self.dynamic_summary_net(targets, condition)
        condition = tf.concat([memory, condition], axis=-1)

        z = targets
        log_det_Js = []
        for layer in self.coupling_layers:
            z, log_det_J = layer(z, condition, **kwargs)
            log_det_Js.append(log_det_J)
        # Sum Jacobian determinants for all layers (coupling blocks) to obtain total Jacobian.
        log_det_J = tf.add_n(log_det_Js)
        return z, log_det_J

    @tf.function
    def inverse(self, z, condition, **kwargs):
        """Performs a reverse pass through the chain."""

        target = z
        T = z.shape[1]
        gru_inp = tf.zeros((z.shape[0], 1, z.shape[-1]))
        state = tf.zeros((z.shape[0], self.dynamic_summary_net.h))
        outs = []
        for t in range(T):
            # One step condition
            memory, state = self.dynamic_summary_net.step_loop(
                gru_inp, condition[:, t : t + 1, :], state
            )
            condition_t = tf.concat([memory, condition[:, t : t + 1, :]], axis=-1)
            target_t = target[:, t : t + 1, :]
            for layer in reversed(self.coupling_layers):
                target_t = layer(target_t, condition_t, inverse=True, **kwargs)
            outs.append(target_t)
            gru_inp = target_t
        return tf.concat(outs, axis=1)

# def configurator(forward_dict):
#     """Customized preprocessing for the Covid simulator."""

#     out = {"posterior_inputs": {}, "likelihood_inputs": {}}

#     # Extract data
#     x = forward_dict["sim_data"].astype(np.float32)
#     x_means = np.mean(x, axis=1, keepdims=True)
#     x_std = np.std(x, axis=1, keepdims=True)
#     x = (x - x_means) / x_std
#     log_mu = np.log2(1 + x_means[:, 0, :])
#     log_std = np.log2(1 + x_std[:, 0, :])

#     # Extract params
#     p = forward_dict["prior_draws"].astype(np.float32)
#     p = (p - theta_mu) / theta_std

#     # Repeat condition
#     cond = np.concatenate([p, log_mu, log_std], axis=-1)
#     cond = np.stack([cond] * x.shape[1], axis=1)

#     # Likelihood inputs
#     out["likelihood_inputs"]["observables"] = x.astype(np.float32)
#     out["likelihood_inputs"]["conditions"] = np.concatenate([cond], axis=-1).astype(
#         np.float32
#     )

#     # Posterior inputs
#     out["posterior_inputs"]["parameters"] = p
#     out["posterior_inputs"]["summary_conditions"] = out["likelihood_inputs"][
#         "observables"
#     ]
#     out["posterior_inputs"]["direct_conditions"] = np.concatenate(
#         [log_mu, log_std], axis=-1
#     )

#     return out


def configurator(forward_dict):
    """Customized preprocessing for the Covid simulator."""

    out = {"posterior_inputs": {}, "likelihood_inputs": {}}

    # Extract data
    x = forward_dict["sim_data"].astype(np.float32)
    x_means = np.mean(x, axis=1, keepdims=True)
    x_std = np.std(x, axis=1, keepdims=True)
    x = (x - x_means) / x_std
    log_mu = np.log2(1 + x_means[:, 0, :])
    log_std = np.log2(1 + x_std[:, 0, :])

    # Extract params
    p = forward_dict["prior_draws"].astype(np.float32)
    p = (p - theta_mu) / theta_std

    # Repeat condition
    cond = np.concatenate([p, log_mu, log_std], axis=-1)
    cond = np.stack([cond] * x.shape[1], axis=1)

    # Likelihood inputs
    forward_dict['likelihood_inputs'] = {
        'observables': x.astype(np.float32),
        'conditions' : np.float32(np.concatenate([cond], axis=-1)),
        }

    forward_dict['posterior_inputs'] = {
        'parameters': p,
        'summary_conditions': forward_dict['likelihood_inputs']['observables'],
        'direct_conditions': np.concatenate([log_mu, log_std], axis=-1,)
        }
    return forward_dict

# param_names = [
#     r"$t_1$",
#     r"$t_2$",
#     r"$t_3$",
#     r"$t_4$",
#     r"$\Delta t_1$",
#     r"$\Delta t_2$",
#     r"$\Delta t_3$",
#     r"$\Delta t_4$",
#     r"$\lambda_0$",
#     r"$\lambda_1$",
#     r"$\lambda_2$",
#     r"$\lambda_3$",
#     r"$\lambda_4$",
#     r"$\mu$",
#     r"$A_I$",
#     r"$\phi_I$",
#     r"$A_R$",
#     r"$\phi_R$",
#     r"$A_D$",
#     r"$\phi_D$",
#     r"$L_I$",
#     r"$L_R$",
#     r"$L_D$",
#     r"$E_0$",
#     r"$\sigma_I$",
#     r"$\sigma_R$",
#     r"$\sigma_D$",
#     r"$\alpha$",
#     r"$\beta$",
#     r"$\gamma$",
#     r"$\eta$",
#     r"$\theta$",
#     r"$\delta$",
#     r"$d$",
# ]

name = 'radev_model'
root_path = pathlib.Path(mnt + f'/MyDrive/bayesian disease modeling/{name}')
model_path = root_path / f'model/'
model_path.mkdir(exist_ok=True, parents=True)
real_data = load_data()

In [None]:
theta1_s = np.array([prior_sir() for _ in range(5000)])
theta2_s = np.array([prior_secir() for _ in range(5000)])
theta1_mu = np.mean(theta1_s, axis=0, keepdims=True)
theta2_mu = np.mean(theta2_s, axis=0, keepdims=True)
theta1_std = np.std(theta1_s, axis=0, keepdims=True)
theta2_std = np.std(theta2_s, axis=0, keepdims=True)
theta_mu = np.c_[theta1_mu, theta2_mu]
theta_std = np.c_[theta1_std, theta2_std]

In [None]:
generative_model = partial(data_generator, N=real_data["N"], T=real_data["T"])
coupling_settings = {
    "dense_args": dict(units=128, activation="swish", kernel_regularizer=None),
    "num_dense": 2,
    "dropout": False,
}
likelihood_net = InvertibleNetworkWithMemory(
    num_params=3, num_coupling_layers=8, coupling_settings=coupling_settings
)
posterior_net = InvertibleNetwork(
    num_params=len(param_names),
    num_coupling_layers=6,
    coupling_settings=coupling_settings,
)
summary_net = SummaryNet(n_summary=192)
amortized_posterior = AmortizedPosterior(
    posterior_net, summary_net, summary_loss_fun="MMD"
)
amortized_likelihood = AmortizedLikelihood(likelihood_net)
joint_amortizer = AmortizedPosteriorLikelihood(
    amortized_posterior, amortized_likelihood
)
model = Trainer(
    amortizer=joint_amortizer,
    generative_model=generative_model,
    configurator=configurator,
    checkpoint_path=model_path,
    memory=False,
    max_to_keep=1,
)
# # Uncomment for training
h = model.train_online(epochs=100, iterations_per_epoch=32*31, batch_size=32, validation_sims=150)

In [None]:
# # Uncomment for training
h = model.train_online(epochs=100, iterations_per_epoch=32*31, batch_size=32, validation_sims=150)

# Validation

## Loss Trajectories

In [None]:
# Use loaded history, since reference 'h' will only exist after training
h = model.loss_history.get_plottable()
f = plot_losses(h["train_losses"], h["val_losses"])
f.savefig(root_path / "loss_history.pdf", dpi=300);

In [None]:
n_sims = 3
n_samples = 2
samples = model.generative_model(n_sims)
samples = model.configurator(samples)
samples['parameters_out'] = model.amortizer.sample(samples, n_samples=n_samples)
samples.keys()

# Paper Plots

In [None]:
import matplotlib.ticker as ticker


def publication_plot(to_plot, sim_out, real_data):
    """Helper function to generate pretty simulation vs. re-simulation plots."""

    colors = ["#000080", "#008000", "#800000"]
    titles = ["Infected", "Recovered", "Dead"]
    f, axarr = plt.subplots(2, 3, figsize=(15, 9))
    time = np.arange(1, real_data["T"] + 1)
    sur_med = np.median(to_plot, axis=0)
    sur_q_95 = np.quantile(to_plot, axis=0, q=(0.025, 0.975))
    sur_q_50 = np.quantile(to_plot, axis=0, q=(0.25, 0.75))
    sim_med = np.median(sim_out, axis=0)
    sim_q_95 = np.quantile(sim_out, axis=0, q=(0.025, 0.975))
    sim_q_50 = np.quantile(sim_out, axis=0, q=(0.25, 0.75))

    for i, ax in enumerate(axarr.flat[:3]):
        # Surrogate outputs
        ax.plot(
            time, sur_med[:, i], color=colors[i], lw=4, linestyle="dotted", alpha=0.9
        )
        ax.fill_between(
            time, sur_q_50[0, :, i], sur_q_50[1, :, i], color=colors[i], alpha=0.5
        )
        ax.fill_between(
            time, sur_q_95[0, :, i], sur_q_95[1, :, i], color=colors[i], alpha=0.3
        )

        # General plot settings
        sns.despine(ax=ax)
        ax.grid(alpha=0.25)
        ax.set_ylabel(f"# {titles[i]}", fontsize=25)
        ax.tick_params(axis="both", which="major", labelsize=17)
        ax.yaxis.set_major_formatter(
            ticker.FuncFormatter(lambda x, pos: "{:,.0f}".format(x) + "k")
        )

    for i, ax in enumerate(axarr.flat[3:]):
        # Surrogate outputs
        ax.plot(
            time,
            sim_med[:, i],
            color=colors[i],
            lw=4,
            linestyle="dotted",
            alpha=0.9,
            label="Median",
        )
        ax.fill_between(
            time,
            sim_q_50[0, :, i],
            sim_q_50[1, :, i],
            color=colors[i],
            alpha=0.5,
            label="50%-CI",
        )
        ax.fill_between(
            time,
            sim_q_95[0, :, i],
            sim_q_95[1, :, i],
            color=colors[i],
            alpha=0.3,
            label="95%-CI",
        )

        # General plot settings
        sns.despine(ax=ax)
        ax.grid(alpha=0.25)
        ax.set_xlabel("Time (days)", fontsize=25)
        ax.set_ylabel(f"# {titles[i]}", fontsize=25)
        ax.tick_params(axis="both", which="major", labelsize=17)
        ax.yaxis.set_major_formatter(
            ticker.FuncFormatter(lambda x, pos: "{:,.0f}".format(x) + "k")
        )
        # ax.legend(fontsize=15)

    axarr[0, 0].text(
        -0.42,
        0.5,
        "Surrogate",
        horizontalalignment="left",
        verticalalignment="center",
        rotation=90,
        fontsize=30,
        transform=axarr[0, 0].transAxes,
    )

    axarr[1, 0].text(
        -0.42,
        0.5,
        "Simulator",
        horizontalalignment="left",
        verticalalignment="center",
        rotation=90,
        fontsize=30,
        transform=axarr[1, 0].transAxes,
    )
    f.tight_layout()
    return f

### Teaser Figure

In [None]:
# Set number of simulated trajectories
n_trajectories = 1000

# Simulate data and format for likelihood network

# Generate from prior for new figure (Uncomment)
out = generative_model(1)
pars = np.array([out['prior_draws'][0].astype(np.float32)] * n_trajectories)

# Or load for paper figure (Comment out)
# Note, that the parameters are already repeated n_trajectories = 1000 times over axis = 0
# pars = np.load("assets/parameters.npy")

# Generate simulations from the true simulator given parameters
sim_out = (
    simulate_given_params(n_trajectories, pars[0], real_data["N"], real_data["T"])
    / 1000
)
net_in = configurator({"sim_data": sim_out, "prior_draws": pars})
means_out = np.mean(sim_out, axis=1)
stds_out = np.std(sim_out, axis=1)

# Generate surrogate simulations given parameters
net_out = amortized_likelihood.sample(net_in["likelihood_inputs"], real_data["T"])

In [None]:
# Create summary representations
rep_sim = summary_net(net_in["posterior_inputs"]["summary_conditions"])
rep_sur = summary_net(net_out)

# Compute individual MMDs
total_mmd = maximum_mean_discrepancy(rep_sim, rep_sur)
mmds_all = [
    maximum_mean_discrepancy(rep_sim, rep_sur[i : (i + 1), :]).numpy()
    for i in range(n_trajectories)
]

# Remove 1% with a highest MMD - posterior net criticizes likelihood net
idx_good = np.argsort(mmds_all)[:990]

to_plot = (
    net_out[idx_good] * stds_out[idx_good, np.newaxis, :]
    + means_out[idx_good, np.newaxis, :]
)

In [None]:
f = publication_plot(to_plot, sim_out, real_data)
# f.savefig('figures/covid_teaser_nolegend.pdf', dpi=300)

### Appendix Plots

In [None]:
n_sims = 10
n_trajectories = 1000

for sim in range(n_sims):
    # Sample from simulator
    out = generative_model(1)
    pars = np.array([out["prior_draws"][0].astype(np.float32)] * n_trajectories)

    # Generate simulations from the true simulator given parameters
    sim_out = (
        simulate_given_params(n_trajectories, pars[0], real_data["N"], real_data["T"])
        / 1000
    )
    net_in = configurator({"sim_data": sim_out, "prior_draws": pars})
    means_out = np.mean(sim_out, axis=1)
    stds_out = np.std(sim_out, axis=1)

    # Generate surrogate simulations given parameters
    net_out = amortized_likelihood.sample(net_in["likelihood_inputs"], real_data["T"])

    # Create summary representations and compute MMDs
    rep_sim = summary_net(net_in["posterior_inputs"]["summary_conditions"])
    rep_sur = summary_net(net_out)
    total_mmd = maximum_mean_discrepancy(rep_sim, rep_sur)
    mmds_all = [
        maximum_mean_discrepancy(rep_sim, rep_sur[i : (i + 1), :]).numpy()
        for i in range(n_trajectories)
    ]

    # Remove 1% with a highest MMD - posterior net criticizes likelihood net
    idx_good = np.argsort(mmds_all)[:990]
    to_plot = (
        net_out[idx_good] * stds_out[idx_good, np.newaxis, :]
        + means_out[idx_good, np.newaxis, :]
    )

    f = publication_plot(to_plot, sim_out, real_data)
    f.savefig(f"figures/surrogate_{sim}.pdf", dpi=300, bbox_inches="tight")

# Calibration

In [None]:
n_test_cal = 1000
n_posterior_samples = 100
gen_out = generative_model(n_test_cal)

In [None]:
# Configure simulator output
conf = configurator(gen_out)

# Obtain surrogate time series given prior draws
means_out = np.mean(gen_out["sim_data"], axis=1)
stds_out = np.std(gen_out["sim_data"], axis=1)
x_sim_s_u = joint_amortizer.sample_data(
    conf["likelihood_inputs"], n_samples=real_data["T"]
)
x_sim_s = x_sim_s_u * stds_out[:, np.newaxis, :] + means_out[:, np.newaxis, :]

# Configure surrogate outputs
conf_s = configurator({"sim_data": x_sim_s, "prior_draws": gen_out["prior_draws"]})

# Sample from approx. posteriors given surrogate simulator outputs
post_samples_s = joint_amortizer.sample_parameters(
    conf_s["posterior_inputs"], n_samples=n_posterior_samples
)

# Sample from approx. posteriors given true simulator outputs
post_samples_t = joint_amortizer.sample_parameters(
    conf["posterior_inputs"], n_samples=n_posterior_samples
)

# Extract prior samples
prior_samples = conf["posterior_inputs"]["parameters"]

## Joint Calibration

In [None]:
from assets.custom_plots import plot_sbc_ecdf_appendix

### Posterior

In [None]:
f = plot_sbc_ecdf_appendix(
    post_samples_t,
    prior_samples,
    param_names=param_names,
    difference=True,
    rank_ecdf_color="#000080",
    label_fontsize=24,
    legend_fontsize=24,
    title_fontsize=40,
)
f.savefig("figures/sbc_post_ecdf.pdf", dpi=300)

### Joint

In [None]:
f = plot_sbc_ecdf_appendix(
    post_samples_s,
    prior_samples,
    param_names=param_names,
    difference=True,
    rank_ecdf_color="#800000",
    label_fontsize=24,
    legend_fontsize=24,
    title_fontsize=40,
)
f.savefig("figures/sbc_joint_ecdf.pdf", dpi=300)

In [None]:
import os, datetime

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from functools import partial
from scipy import stats
import pickle

import tensorflow as tf
# Comment out, if you want tensorflow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, LSTM
from bayesflow.networks import InvertibleNetwork
from bayesflow.coupling_networks import CouplingLayer
from bayesflow.amortizers import (
    AmortizedLikelihood,
    AmortizedPosterior,
    AmortizedPosteriorLikelihood,
)
from bayesflow.trainers import Trainer
from bayesflow import default_settings
from bayesflow.helper_functions import build_meta_dict
from bayesflow.diagnostics import plot_sbc_ecdf, plot_sbc_histograms, plot_losses
from bayesflow.computational_utilities import maximum_mean_discrepancy

def load_data():
    confirmed_cases_url = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv"
    recovered_cases_url = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_recovered_global.csv"
    dead_cases_url = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv"

    confirmed_cases = pd.read_csv(confirmed_cases_url, sep=",")
    recovered_cases = pd.read_csv(recovered_cases_url, sep=",")
    dead_cases = pd.read_csv(dead_cases_url, sep=",")

    date_data_begin = datetime.date(2020, 3, 1)
    date_data_end = datetime.date(2020, 5, 21)

    format_date = lambda date_py: "{}/{}/{}".format(
        date_py.month, date_py.day, str(date_py.year)[2:4]
    )
    date_formatted_begin = format_date(date_data_begin)
    date_formatted_end = format_date(date_data_end)

    cases_obs = np.array(
        confirmed_cases.loc[
            confirmed_cases["Country/Region"] == "Germany",
            date_formatted_begin:date_formatted_end,
        ]
    )[0]
    recovered_obs = np.array(
        recovered_cases.loc[
            recovered_cases["Country/Region"] == "Germany",
            date_formatted_begin:date_formatted_end,
        ]
    )[0]

    dead_obs = np.array(
        dead_cases.loc[
            dead_cases["Country/Region"] == "Germany",
            date_formatted_begin:date_formatted_end,
        ]
    )[0]

    data_germany = np.stack([cases_obs, recovered_obs, dead_obs]).T
    data_germany = np.diff(data_germany, axis=0)
    T_germany = data_germany.shape[0]
    N_germany = 83e6
    mean_g = np.mean(data_germany, axis=0)
    std_g = np.std(data_germany, axis=0)
    out = dict(x=data_germany, T=T_germany, N=N_germany, Mean=mean_g, Std=std_g)
    return out

plt.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "serif",
        "text.latex.preamble": r"\usepackage{{amsmath}}",
    }
)

alpha_f = (0.7**2) * ((1 - 0.7) / (0.17**2) - (1 - 0.7))
beta_f = alpha_f * (1 / 0.7 - 1)


def prior_sir():
    """
    Implements batch sampling from a stationary prior over the parameters
    of the non-stationary SIR model.
    """

    t1 = np.random.normal(loc=8, scale=3)
    t2 = np.random.normal(loc=15, scale=1)
    t3 = np.random.normal(loc=22, scale=1)
    t4 = np.random.normal(loc=66, scale=1)
    delta_t1 = np.random.lognormal(mean=np.log(3), sigma=0.3)
    delta_t2 = np.random.lognormal(mean=np.log(3), sigma=0.3)
    delta_t3 = np.random.lognormal(mean=np.log(3), sigma=0.3)
    delta_t4 = np.random.lognormal(mean=np.log(3), sigma=0.3)
    lambd0 = np.random.lognormal(mean=np.log(1.2), sigma=0.5)
    lambd1 = np.random.lognormal(mean=np.log(0.6), sigma=0.5)
    lambd2 = np.random.lognormal(mean=np.log(0.3), sigma=0.5)
    lambd3 = np.random.lognormal(mean=np.log(0.1), sigma=0.5)
    lambd4 = np.random.lognormal(mean=np.log(0.1), sigma=0.5)
    mu = np.random.lognormal(mean=np.log(1 / 8), sigma=0.2)
    f_i = np.random.beta(a=alpha_f, b=beta_f)
    phi_i = stats.vonmises(kappa=0.01).rvs()
    f_r = np.random.beta(a=alpha_f, b=beta_f)
    phi_r = stats.vonmises(kappa=0.01).rvs()
    f_d = np.random.beta(a=alpha_f, b=beta_f)
    phi_d = stats.vonmises(kappa=0.01).rvs()
    D_i = np.random.lognormal(mean=np.log(8), sigma=0.2)
    D_r = np.random.lognormal(mean=np.log(8), sigma=0.2)
    D_d = np.random.lognormal(mean=np.log(8), sigma=0.2)
    E0 = np.random.gamma(shape=2, scale=30)
    scale_I = np.random.gamma(shape=1, scale=5)
    scale_R = np.random.gamma(shape=1, scale=5)
    scale_D = np.random.gamma(shape=1, scale=5)
    return [
        t1,
        t2,
        t3,
        t4,
        delta_t1,
        delta_t2,
        delta_t3,
        delta_t4,
        lambd0,
        lambd1,
        lambd2,
        lambd3,
        lambd4,
        mu,
        f_i,
        phi_i,
        f_r,
        phi_r,
        f_d,
        phi_d,
        D_i,
        D_r,
        D_d,
        E0,
        scale_I,
        scale_R,
        scale_D,
    ]


def prior_secir():
    """
    Implements batch sampling from a stationary prior over the parameters
    of the non-stationary SIR model.
    """

    alpha = np.random.uniform(low=0.005, high=0.9)
    beta = np.random.lognormal(mean=np.log(0.25), sigma=0.3)
    gamma = np.random.lognormal(mean=np.log(1 / 6.5), sigma=0.5)
    eta = np.random.lognormal(mean=np.log(1 / 3.2), sigma=0.3)
    theta = np.random.uniform(low=1 / 14, high=1 / 3)
    delta = np.random.uniform(low=0.01, high=0.3)
    d = np.random.uniform(low=1 / 14, high=1 / 3)
    return [alpha, beta, gamma, eta, theta, delta, d]


def calc_lambda_array(
    sim_lag,
    lambd0,
    lambd1,
    lambd2,
    lambd3,
    lambd4,
    t1,
    t2,
    t3,
    t4,
    delta_t1,
    delta_t2,
    delta_t3,
    delta_t4,
    T,
):
    """Computes the array of time-varying contact rates/transimission probabilities."""

    # Array of initial lambdas
    lambd0_arr = np.array([lambd0] * (t1 + sim_lag))

    # Compute lambd1 array
    if delta_t1 == 1:
        lambd1_arr = np.array([lambd1] * (t2 - t1))
    else:
        lambd1_arr = np.linspace(lambd0, lambd1, delta_t1)
        lambd1_arr = np.append(lambd1_arr, [lambd1] * (t2 - t1 - delta_t1))

    # Compute lambd2 array
    if delta_t2 == 1:
        lambd2_arr = np.array([lambd2] * (t3 - t2))
    else:
        lambd2_arr = np.linspace(lambd1, lambd2, delta_t2)
        lambd2_arr = np.append(lambd2_arr, [lambd2] * (t3 - t2 - delta_t2))

    # Compute lambd3 array
    if delta_t3 == 1:
        lambd3_arr = np.array([lambd3] * (t4 - t3))
    else:
        lambd3_arr = np.linspace(lambd3, lambd4, delta_t3)
        lambd3_arr = np.append(lambd3_arr, [lambd3] * (t4 - t3 - delta_t3))

    # Compute lambd4 array
    if delta_t4 == 1:
        lambd4_arr = np.array([lambd4] * (T - t4))
    else:
        lambd4_arr = np.linspace(lambd3, lambd4, delta_t4)
        lambd4_arr = np.append(lambd4_arr, [lambd4] * (T - t4 - delta_t4))

    return np.r_[lambd0_arr, lambd1_arr, lambd2_arr, lambd3_arr, lambd4_arr]


def non_stationary_SEICR(
    params_sir, params_secir, N, T, sim_diff=16, observation_model=True
):
    """
    Performs a forward simulation from the stationary SIR model.
    """

    # Extract parameters
    (
        t1,
        t2,
        t3,
        t4,
        delta_t1,
        delta_t2,
        delta_t3,
        delta_t4,
        lambd0,
        lambd1,
        lambd2,
        lambd3,
        lambd4,
        mu,
        f_i,
        phi_i,
        f_r,
        phi_r,
        f_d,
        phi_d,
        delay_i,
        delay_r,
        delay_d,
        E0,
        scale_I,
        scale_R,
        scale_D,
    ) = params_sir
    alpha, beta, gamma, eta, theta, delta, d = params_secir

    # Round integer parameters
    t1, t2, t3, t4 = int(round(t1)), int(round(t2)), int(round(t3)), int(round(t4))
    delta_t1, delta_t2, delta_t3, delta_t4 = (
        int(round(delta_t1)),
        int(round(delta_t2)),
        int(round(delta_t3)),
        int(round(delta_t4)),
    )
    E0 = max(1, np.round(E0))
    delay_i = int(round(delay_i))
    delay_r = int(round(delay_r))
    delay_d = int(round(delay_d))

    # Impose constraints
    assert sim_diff > delay_i
    assert sim_diff > delay_r
    assert sim_diff > delay_d
    assert t1 > 0 and t2 > 0 and t3 > 0 and t4 > 0
    assert t1 < t2 < t3 < t4
    assert delta_t1 > 0 and delta_t2 > 0 and delta_t3 > 0 and delta_t4 > 0
    assert (
        t2 - t1 >= delta_t1
        and t3 - t2 >= delta_t2
        and t4 - t3 >= delta_t3
        and T - t4 >= delta_t4
    )

    # Calculate lambda arrays
    # Lambda0 is the initial contact rate which will be consecutively
    # reduced via the government measures
    sim_lag = sim_diff - 1
    lambd_arr = calc_lambda_array(
        sim_lag,
        lambd0,
        lambd1,
        lambd2,
        lambd3,
        lambd4,
        t1,
        t2,
        t3,
        t4,
        delta_t1,
        delta_t2,
        delta_t3,
        delta_t4,
        T,
    )

    # Initial conditions
    S, E, C, I, R, D = [N - E0], [E0], [0], [0], [0], [0]

    # Containers
    I_news = []
    R_news = []
    D_news = []

    # Reported new cases
    I_data = np.zeros(T)
    R_data = np.zeros(T)
    D_data = np.zeros(T)
    fs_i = np.zeros(T)
    fs_r = np.zeros(T)
    fs_d = np.zeros(T)

    # Simulate T-1 tiemsteps
    for t in range(T + sim_lag):
        # Calculate new exposed cases
        E_new = lambd_arr[t] * ((C[t] + beta * I[t]) / N) * S[t]

        # Remove exposed from susceptible
        S_t = S[t] - E_new

        # Calculate current exposed by adding new exposed and
        # subtracting the exposed becoming carriers.
        E_t = E[t] + E_new - gamma * E[t]

        # Calculate current carriers by adding the new exposed and subtracting
        # those who will develop symptoms and become detected and those who
        # will go through the disease asymptomatically.
        C_t = C[t] + gamma * E[t] - (1 - alpha) * eta * C[t] - alpha * theta * C[t]

        # Calculate current infected by adding the symptomatic carriers and
        # subtracting the dead and recovered. The newly infected are just the
        # carriers who get detected.
        I_t = (
            I[t] + (1 - alpha) * eta * C[t] - (1 - delta) * mu * I[t] - delta * d * I[t]
        )
        I_new = (1 - alpha) * eta * C[t]

        # Calculate current recovered by adding the symptomatic and asymptomatic
        # recovered. The newly recovered are only the detected recovered
        R_t = R[t] + alpha * theta * C[t] + (1 - delta) * mu * I[t]
        R_new = (1 - delta) * mu * I[t]

        # Calculate the current dead
        D_t = D[t] + delta * d * I[t]
        D_new = delta * d * I[t]

        # Ensure some numerical onstraints
        S_t = np.clip(S_t, 0, N)
        E_t = np.clip(E_t, 0, N)
        C_t = np.clip(C_t, 0, N)
        I_t = np.clip(I_t, 0, N)
        R_t = np.clip(R_t, 0, N)
        D_t = np.clip(D_t, 0, N)

        # Keep track of process over time
        S.append(S_t)
        E.append(E_t)
        C.append(C_t)
        I.append(I_t)
        R.append(R_t)
        D.append(D_t)
        I_news.append(I_new)
        R_news.append(R_new)
        D_news.append(D_new)

        # From here, start adding new cases with delay D
        # Note, we assume the same delay
        if t >= sim_lag:
            # Compute lags and add to data arrays
            fs_i[t - sim_lag] = (1 - f_i) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * phi_i))
            )
            fs_r[t - sim_lag] = (1 - f_r) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * phi_r))
            )
            fs_d[t - sim_lag] = (1 - f_d) * (
                1 - np.abs(np.sin((np.pi / 7) * (t - sim_lag) - 0.5 * phi_d))
            )
            I_data[t - sim_lag] = I_news[t - delay_i]
            R_data[t - sim_lag] = R_news[t - delay_r]
            D_data[t - sim_lag] = D_news[t - delay_d]

    # Compute weekly modulation
    I_data = (1 - fs_i) * I_data
    R_data = (1 - fs_r) * R_data
    D_data = (1 - fs_d) * D_data

    # Add noise
    I_data = stats.t(df=4, loc=I_data, scale=np.sqrt(I_data) * scale_I).rvs()
    R_data = stats.t(df=4, loc=R_data, scale=np.sqrt(R_data) * scale_R).rvs()
    D_data = stats.t(df=4, loc=D_data, scale=np.sqrt(D_data) * scale_D).rvs()

    if observation_model:
        return np.stack((I_data, R_data, D_data)).T
    return np.stack((S, E, I, C, R, D)).T


def simulate_given_params(n_sim, params, N, T, sim_diff=21, observation_model=True):
    """Simulated multiple trajectories from the full model given a fixed parameter configuration."""

    x = []
    theta1, theta2 = params[:-7], params[-7:]
    for _ in range(n_sim):
        x_i = non_stationary_SEICR(
            theta1,
            theta2,
            N=N,
            T=T,
            sim_diff=sim_diff,
            observation_model=observation_model,
        )
        x.append(x_i)
    return np.clip(np.array(x), 0, np.inf)

def data_generator(batch_size, T=None, N=None, sim_diff=21, seed=None, scale=1000):
    """
    Runs the forward model 'batch_size' times by first sampling fromt the prior
    theta ~ p(theta) and running x ~ p(x|theta).
    ----------

    Arguments:
    batch_size : int -- the number of samples to draw from the prior
    ----------

    Output:
    forward_dict : dict
        The expected outputs for a BayesFlow pipeline
    """

    if seed is not None:
        np.random.seed(seed)

    # Generate data
    # x is a np.ndarray of shape (batch_size, n_obs, x_dim)
    x = []
    theta = []
    for i in range(batch_size):
        # Reject meaningless simulaitons
        x_i = None
        while x_i is None:
            try:
                theta1 = prior_sir()
                theta2 = prior_secir()
                x_i = non_stationary_SEICR(theta1, theta2, N, T, sim_diff=sim_diff)
            except:
                pass
        # Simulate SECIR
        x.append(x_i)
        theta.append(theta1 + theta2)

    # Clip negative and normalize
    x = np.clip(np.array(x), 0.0, np.inf) / scale
    theta = np.array(theta)

    forward_dict = {"prior_draws": theta, "sim_data": x}
    return forward_dict

np.random.seed(42)
theta1_s = np.array([prior_sir() for _ in range(500)])
theta2_s = np.array([prior_secir() for _ in range(500)])
theta1_mu = np.mean(theta1_s, axis=0, keepdims=True)
theta2_mu = np.mean(theta2_s, axis=0, keepdims=True)
theta1_std = np.std(theta1_s, axis=0, keepdims=True)
theta2_std = np.std(theta2_s, axis=0, keepdims=True)

theta_mu = np.c_[theta1_mu, theta2_mu]
theta_std = np.c_[theta1_std, theta2_std]

class MultiConvLayer(tf.keras.Model):
    """Implements an inception-inspired conv layer using different kernel sizes"""

    def __init__(self, n_filters=32, strides=1):
        super(MultiConvLayer, self).__init__()

        self.convs = [
            tf.keras.layers.Conv1D(
                n_filters // 2,
                kernel_size=f,
                strides=strides,
                padding="causal",
                activation="relu",
                kernel_initializer="glorot_uniform",
            )
            for f in range(2, 8)
        ]
        self.dim_red = tf.keras.layers.Conv1D(
            n_filters, 1, 1, activation="relu", kernel_initializer="glorot_uniform"
        )

    def call(self, x):
        """x is a timeseries of dimensions B timestamps, n_features"""

        out = tf.concat([conv(x) for conv in self.convs], axis=-1)
        out = self.dim_red(out)
        return out


class MultiConvNet(tf.keras.Model):
    """Implements an inception-inspired conv layer using different kernel sizes"""

    def __init__(self, n_layers=3, n_filters=64, strides=1):
        super(MultiConvNet, self).__init__()

        self.net = tf.keras.Sequential(
            [MultiConvLayer(n_filters, strides) for _ in range(n_layers)]
        )

        self.lstm = LSTM(n_filters)

    def call(self, x, **args):
        """x is a timeseries of dimensions B timestamps, n_features"""

        out = self.net(x)
        out = self.lstm(out)
        return out


class SummaryNet(tf.keras.Model):
    def __init__(self, n_summary):
        super(SummaryNet, self).__init__()
        self.net_I = MultiConvNet(n_filters=n_summary // 3)
        self.net_R = MultiConvNet(n_filters=n_summary // 3)
        self.net_D = MultiConvNet(n_filters=n_summary // 3)

    @tf.function
    def call(self, x, **args):
        """x is a timeseries of dimensions B timestamps, n_features"""

        x = tf.split(x, 3, axis=-1)
        x_i = self.net_I(x[0])
        x_r = self.net_R(x[1])
        x_d = self.net_D(x[2])
        return tf.concat([x_i, x_r, x_d], axis=-1)


class MemoryNetwork(tf.keras.Model):
    def __init__(self, meta):
        super(MemoryNetwork, self).__init__()

        self.gru = GRU(meta["n_hidden"], return_sequences=True, return_state=True)
        self.h = meta["n_hidden"]
        self.n_params = meta["n_params"]

    @tf.function
    def call(self, target, condition):
        """Performs a forward pass through the network.

        Params:
        -------
        target    : tf.Tesnor of shape (batch_size, time_stes, dim)
            The time-dependent signal to process.
        condition : tf.Tensor of shape (batch_size, cond_dim)
            The conditional (static) variables, e.g., parameters.
        """
        shift_target = target[:, :-1, :]
        init = tf.zeros((target.shape[0], 1, target.shape[2]))
        inp_teacher = tf.concat([init, shift_target], axis=1)
        inp_teacher_c = tf.concat([inp_teacher, condition], axis=-1)
        out, _ = self.gru(inp_teacher_c)
        return out

    def step_loop(self, target, condition, state):
        out, new_state = self.gru(
            tf.concat([target, condition], axis=-1), initial_state=state
        )
        return out, new_state

class InvertibleNetworkWithMemory(tf.keras.Model):
    """Implements a chain of conditional invertible blocks for Bayesian parameter inference."""

    def __init__(
        self,
        num_params,
        num_coupling_layers=4,
        coupling_settings=None,
        coupling_design="affine",
        permutation="fixed",
        use_act_norm=True,
        act_norm_init=None,
        use_soft_flow=False,
        soft_flow_bounds=(1e-3, 5e-2),
    ):
        """Initializes a custom invertible network with recurrent memory."""

        super().__init__()

        # Create settings dict for coupling layer
        settings = dict(
            latent_dim=num_params,
            coupling_settings=coupling_settings,
            coupling_design=coupling_design,
            permutation=permutation,
            use_act_norm=use_act_norm,
            act_norm_init=act_norm_init,
        )

        # Create sequence of coupling layers and store reference to dimensionality
        self.coupling_layers = [
            CouplingLayer(**settings) for _ in range(num_coupling_layers)
        ]

        # Store attributes
        self.soft_flow = use_soft_flow
        self.soft_low = soft_flow_bounds[0]
        self.soft_high = soft_flow_bounds[1]
        self.use_act_norm = use_act_norm
        self.latent_dim = num_params
        self.dynamic_summary_net = MemoryNetwork({"n_hidden": 256, "n_params": 3})
        self.latent_dim = num_params

    def call(self, targets, condition, inverse=False):
        """Performs one pass through an invertible chain (either inverse or forward).

        Parameters
        ----------
        targets   : tf.Tensor
            The estimation quantities of interest, shape (batch_size, ...)
        condition : tf.Tensor
            The conditional data x, shape (batch_size, summary_dim)
        inverse   : bool, default: False
            Flag indicating whether to run the chain forward or backwards

        Returns
        -------
        (z, log_det_J)  :  tuple(tf.Tensor, tf.Tensor)
            If inverse=False: The transformed input and the corresponding Jacobian of the transformation,
            v shape: (batch_size, ...), log_det_J shape: (batch_size, ...)

        target          :  tf.Tensor
            If inverse=True: The transformed out, shape (batch_size, ...)

        Important
        ---------
        If ``inverse=False``, the return is ``(z, log_det_J)``.\n
        If ``inverse=True``, the return is ``target``.
        """

        if inverse:
            return self.inverse(targets, condition)
        return self.forward(targets, condition)

    @tf.function
    def forward(self, targets, condition, **kwargs):
        """Performs a forward pass though the chain."""

        # Add memory condition
        memory = self.dynamic_summary_net(targets, condition)
        condition = tf.concat([memory, condition], axis=-1)

        z = targets
        log_det_Js = []
        for layer in self.coupling_layers:
            z, log_det_J = layer(z, condition, **kwargs)
            log_det_Js.append(log_det_J)
        # Sum Jacobian determinants for all layers (coupling blocks) to obtain total Jacobian.
        log_det_J = tf.add_n(log_det_Js)
        return z, log_det_J

    @tf.function
    def inverse(self, z, condition, **kwargs):
        """Performs a reverse pass through the chain."""

        target = z
        T = z.shape[1]
        gru_inp = tf.zeros((z.shape[0], 1, z.shape[-1]))
        state = tf.zeros((z.shape[0], self.dynamic_summary_net.h))
        outs = []
        for t in range(T):
            # One step condition
            memory, state = self.dynamic_summary_net.step_loop(
                gru_inp, condition[:, t : t + 1, :], state
            )
            condition_t = tf.concat([memory, condition[:, t : t + 1, :]], axis=-1)
            target_t = target[:, t : t + 1, :]
            for layer in reversed(self.coupling_layers):
                target_t = layer(target_t, condition_t, inverse=True, **kwargs)
            outs.append(target_t)
            gru_inp = target_t
        return tf.concat(outs, axis=1)

def configurator(forward_dict):
    """Customized preprocessing for the Covid simulator."""

    out = {"posterior_inputs": {}, "likelihood_inputs": {}}

    # Extract data
    x = forward_dict["sim_data"].astype(np.float32)
    x_means = np.mean(x, axis=1, keepdims=True)
    x_std = np.std(x, axis=1, keepdims=True)
    x = (x - x_means) / x_std
    log_mu = np.log2(1 + x_means[:, 0, :])
    log_std = np.log2(1 + x_std[:, 0, :])

    # Extract params
    p = forward_dict["prior_draws"].astype(np.float32)
    p = (p - theta_mu) / theta_std

    # Repeat condition
    cond = np.concatenate([p, log_mu, log_std], axis=-1)
    cond = np.stack([cond] * x.shape[1], axis=1)

    # Likelihood inputs
    out["likelihood_inputs"]["observables"] = x.astype(np.float32)
    out["likelihood_inputs"]["conditions"] = np.concatenate([cond], axis=-1).astype(
        np.float32
    )

    # Posterior inputs
    out["posterior_inputs"]["parameters"] = p
    out["posterior_inputs"]["summary_conditions"] = out["likelihood_inputs"][
        "observables"
    ]
    out["posterior_inputs"]["direct_conditions"] = np.concatenate(
        [log_mu, log_std], axis=-1
    )

    return out

real_data = load_data()
generative_model = partial(data_generator, N=real_data["N"], T=real_data["T"])

param_names = [
    r"$t_1$",
    r"$t_2$",
    r"$t_3$",
    r"$t_4$",
    r"$\Delta t_1$",
    r"$\Delta t_2$",
    r"$\Delta t_3$",
    r"$\Delta t_4$",
    r"$\lambda_0$",
    r"$\lambda_1$",
    r"$\lambda_2$",
    r"$\lambda_3$",
    r"$\lambda_4$",
    r"$\mu$",
    r"$A_I$",
    r"$\phi_I$",
    r"$A_R$",
    r"$\phi_R$",
    r"$A_D$",
    r"$\phi_D$",
    r"$L_I$",
    r"$L_R$",
    r"$L_D$",
    r"$E_0$",
    r"$\sigma_I$",
    r"$\sigma_R$",
    r"$\sigma_D$",
    r"$\alpha$",
    r"$\beta$",
    r"$\gamma$",
    r"$\eta$",
    r"$\theta$",
    r"$\delta$",
    r"$d$",
]

coupling_settings = {
    "dense_args": dict(units=128, activation="swish", kernel_regularizer=None),
    "num_dense": 2,
    "dropout": False,
}

likelihood_net = InvertibleNetworkWithMemory(
    num_params=3, num_coupling_layers=8, coupling_settings=coupling_settings
)
posterior_net = InvertibleNetwork(
    num_params=len(param_names),
    num_coupling_layers=6,
    coupling_settings=coupling_settings,
)
summary_net = SummaryNet(n_summary=192)
amortized_posterior = AmortizedPosterior(
    posterior_net, summary_net, summary_loss_fun="MMD"
)
amortized_likelihood = AmortizedLikelihood(likelihood_net)
joint_amortizer = AmortizedPosteriorLikelihood(
    amortized_posterior, amortized_likelihood
)


trainer = Trainer(
    amortizer=joint_amortizer,
    generative_model=generative_model,
    configurator=configurator,
    checkpoint_path="content/temp",
    memory=False,
    max_to_keep=1,
)


print('done')

In [None]:
h = trainer.train_online(epochs=100, iterations_per_epoch=32, batch_size=32, validation_sims=150)

In [None]:
out = generative_model(1)
out