# Joint Model

In [None]:
import numpy as np
from scipy.special import expit  # sigmoid aka logistic function
import matplotlib.pyplot as plt

In [None]:
insidenz = lambda t: np.sin(t*2*np.pi/100)/100 + 2/100 

In [None]:
time = np.linspace(0, 100, 100)
plt.plot(time, insidenz(time), label='insidenz')
plt.title('Joint Model')
plt.legend()
plt.show()

# Estimation of the hazard function

In [None]:
import itertools
from typing import Optional
from functools import partial
from bayesflow.simulation import Simulator
from inference.base_nlme_model import NlmeBaseAmortizer
from bayesflow import diagnostics

In [None]:
class linearModel(NlmeBaseAmortizer):
    def __init__(self, name: str = 'myModel'):
        # define names of parameters
        param_names = ['intercept', 'slope', 'error-variance']

        # define prior values (for log-parameters)
        prior_mean = np.array([np.log(2500), -3, -2])
        prior_cov = np.diag([1, 1, 0.1])
        self.prior_type = 'normal'

        super().__init__(name=name,
                         param_names=param_names,
                         prior_mean=prior_mean,
                         prior_cov=prior_cov)

        self.simulator = Simulator(batch_simulator_fun=partial(self.batch_simulator))
        
    # define simulator
    @staticmethod
    def batch_simulator(param_batch: np.ndarray, time: Optional[np.ndarray] = None, 
                        maxtime: int = 100, with_noise: bool = True) -> np.ndarray:
        param_batch = np.exp(param_batch)

        if param_batch.ndim == 1:  # so not (batch_size, params)
            # just a single parameter set
            param_batch = param_batch[np.newaxis, :]

        if time is None:
            n_measurements = np.random.randint(1, 5)
            time = np.random.uniform(1, maxtime, n_measurements)
            time.sort()

        # simulate data
        y = np.zeros((param_batch.shape[0], time.size, 2))
        for p_id, params in enumerate(param_batch):
            # model plus noise
            y[p_id, :, 0] = params[0] * np.exp(-params[1] * time)
            if with_noise:
                y[p_id, :, 0] *= ( 1 + params[2] * np.random.normal(loc=0, scale=1, size=time.size))
            y[p_id, :, 1] = time / maxtime

        # add censoring
        y[:, :, 0] = np.log(np.clip(y[:, :, 0], a_min=0.001, a_max=2500))
        return y
        
    def load_amortizer_configuration(self, model_idx: int = 0, load_best: bool = False) -> str:
        self.n_epochs = 10
        bidirectional_LSTM = [False]
        n_coupling_layers = [2, 3]
        n_dense_layers_in_coupling = [2]
        coupling_design = ['affine']
        summary_network_type = ['sequence']

        combinations = list(itertools.product(bidirectional_LSTM, n_coupling_layers,
                                              n_dense_layers_in_coupling, coupling_design, summary_network_type))

        (self.bidirectional_LSTM,
         self.n_coupling_layers,
         self.n_dense_layers_in_coupling,
         self.coupling_design,
         self.summary_network_type) = combinations[model_idx]

        model_name = f'amortizer-toyModel-{self.prior_type}' \
                     f'-{self.summary_network_type}-summary' \
                     f'-{"Bi-LSTM" if self.bidirectional_LSTM else "LSTM"}' \
                     f'-{self.n_coupling_layers}layers' \
                     f'-{self.n_dense_layers_in_coupling}coupling-{self.coupling_design}' \
                     f'-{self.n_epochs}epochs'
        return model_name
        
    def plot_example(self, params: Optional[np.ndarray] = None) -> None:
        raise NotImplementedError('Not implemented yet.')
        
    def prepare_plotting(self, data: np.ndarray, params: np.ndarray, ax: Optional[plt.Axes] = None) -> plt.Axes:
        raise NotImplementedError('Not implemented yet.')

In [None]:
toy_model = linearModel()

In [None]:
for i in range(10):
    prior_sample = toy_model.prior(1)['prior_draws']
    yt = toy_model.simulator(prior_sample)['sim_data']
    plt.plot(yt[0, :, 1], np.exp(yt[0, :, 0]))
plt.xlabel('time')
plt.ylabel('antibody')
plt.yscale('log')
plt.title('Simulated data')
plt.ylim(0.001, 2500)
plt.show()

In [None]:
trainer = toy_model.build_trainer('../networks/' + toy_model.network_name)

In [None]:
history = trainer.train_online(epochs=10,
                               iterations_per_epoch=100,
                               batch_size=128,
                               early_stopping=True,
                               validation_sims=100)

In [None]:
new_sims = toy_model.generate_simulations_from_prior(trainer=trainer, n_samples=2500)
posterior_draws = toy_model.draw_posterior_samples(data=new_sims['summary_conditions'], n_samples=100)

In [None]:
fig_sbc = diagnostics.plot_sbc_histograms(post_samples=posterior_draws,
                                          prior_samples=new_sims['parameters'],
                                          param_names=toy_model.log_param_names)

In [None]:
from inference.helper_functions import create_mixed_effect_model_param_names

np.random.seed(42)
toy_samples = toy_model.prior(100)['prior_draws']
toy_samples[:, -2:] = toy_model.prior_mean[-2:]
toy_data = toy_model.simulator(toy_samples)['sim_data']

mixed_effect_params_names = create_mixed_effect_model_param_names(toy_model.param_names, 
                                                                  cov_type='diag')

In [None]:
fixed_indices = np.array([4,5])
fixed_values = np.array([-np.log(0.001), -np.log(0.001)])

In [None]:
from inference.inference_functions import run_population_optimization

result_optimization, obj_fun_amortized = run_population_optimization(
    individual_model=toy_model,
    data=toy_data,
    param_names=mixed_effect_params_names,
    cov_type='diag',
    n_multi_starts=10,
    n_samples_opt=100,
    covariates_bounds=None,
    covariates=None,
    n_covariates_params=0,
    covariate_mapping=None,
    x_fixed_indices=fixed_indices,
    x_fixed_vals=fixed_values,
    file_name=None, # f'../output/{model_name}-{cov_type}-n_data_{n_data}.hdf5',
    verbose=True,
    trace_record=True,
    pesto_multi_processes=10,
    result=None #result_optimization
    )

In [None]:
result_optimization.optimize_result.x[0]

In [None]:
toy_samples.mean(axis=0), -np.log(toy_samples.var(axis=0))

In [None]:
# todo: then try to fit the hazard function with the joint model

In [None]:
# generate data
np.random.seed(42)

linear_pars = 2500, 0.03
hazard_pars = 10, 0.1
random_variance = 1
error_variance = 0.1

def hazard_function(beta, t, log_ad_t):
    return beta[0] * insidenz(t) * ( 1 - expit(beta[1] * log_ad_t ))

In [None]:
def generate_data():
    re = 1 #np.exp(np.random.normal(0, random_variance))
    log_params = np.log(np.array([linear_pars[0]*re, linear_pars[1], error_variance]))
    random_start = np.random.uniform(0, 50)
    time = np.linspace(random_start, 100, int(100-random_start))
    obs_data = toy_model.batch_simulator(log_params, time=time-random_start)[0]  # only return antibody data
    
    # get hazard function
    h_obs = hazard_function(hazard_pars, time, obs_data[:, 0])
    # sample infection event
    infection_event = np.random.uniform(0, 1, len(h_obs)) < h_obs
    infection_event_index = next((i for i, x in enumerate(infection_event) if x), None)
    if infection_event_index is None:
        event_is_censored = True
        infection_event_index = len(infection_event) - 1
    else:
        event_is_censored = False
    # cut of data
    obs_data = obs_data[:infection_event_index]
    # maximum of 4 measurements
    m_index = np.random.choice(len(time[:infection_event_index]), min(4, len(time[:infection_event_index])), replace=False)
    m_index.sort()
    return log_params, obs_data[m_index], time[m_index], time[infection_event_index], event_is_censored, random_start

In [None]:
#np.random.seed(42)
log_params_i, obs_data, time, infection_event, event_is_censored, vaccine_date = generate_data()
log_params_i, obs_data, time, infection_event, vaccine_date, event_is_censored

In [None]:
plt.scatter(time, np.exp(obs_data[:, 0]), label='antibody')
plt.xlim(0,100)
plt.ylim(0,2500)
plt.vlines(infection_event, ymin=0, ymax=2500, color='r')
plt.show()

In [None]:
import scipy.integrate as integrate

In [None]:
def log_hazard_term(evet_time_i, event_i_is_censored, hazard_pars, log_mixed_pars, batch_simulator):
    
    h_func = lambda t: float(hazard_function(hazard_pars, t, 
                                             log_ad_t=batch_simulator(log_mixed_pars, 
                                                             time=np.array([t]), 
                                                             with_noise=False)[0, :, 0]
                                             )
                             )
    integrand, err = integrate.quad(h_func, 0, evet_time_i)
    s = np.exp(-integrand)
    
    if event_i_is_censored:
        return np.log(s)
    else:
        return np.log(s * float(h_func(evet_time_i)))

In [130]:
param_samples = toy_model.draw_posterior_samples(data=obs_data, n_samples=100)
hazard_pars = 10, 0.1
hazard_pars, infection_event, event_is_censored

((10, 0.1), 39.52704858306626, False)

In [131]:
joint_model_obs = [{
    'obs_data': obs_data,
    'time': time,
    'infection_event': infection_event,
    'event_is_censored': event_is_censored,
    'vaccine_date': vaccine_date
}]

In [134]:
def compute_hazard(param_samples: np.ndarray, hazard_pars: np.ndarray) -> np.ndarray:
    n_sim, n_samples, _ = param_samples.shape
    res = np.zeros((n_sim, n_samples))
    for ns in range(n_sim):
        infection_event_time_i = joint_model_obs[ns]['infection_event']
        event_i_is_censored = joint_model_obs[ns]['event_is_censored']
        for ps in range(n_samples):
            
            h_func = lambda t: hazard_pars[0] * insidenz(t) * ( 1 - expit(hazard_pars[1] * toy_model.batch_simulator(
                param_batch=param_samples[ns, ps],
                time=np.array([t]), 
                with_noise=False)[0, :, 0] ))
            integrand, err = integrate.quad(h_func, 0, infection_event_time_i)
            s = np.exp(-integrand)
            
            if event_i_is_censored:
                res[ns, ps] = np.log(s)
            else:
                res[ns, ps] = np.log(s * float(h_func(infection_event_time_i)))
    return res

In [135]:
%%time
res = compute_hazard(param_samples[np.newaxis, :, :], hazard_pars)
res

CPU times: user 415 ms, sys: 5.19 ms, total: 420 ms
Wall time: 440 ms


array([[-5.90630652, -5.98826905, -5.87694083, -5.87669564, -5.8843968 ,
        -5.95653943, -5.92233104, -5.90107245, -5.95696995, -5.91473662,
        -6.00671019, -5.89935448, -5.91628525, -5.93252014, -5.87360933,
        -5.87677577, -6.02840375, -5.87214631, -5.8827012 , -5.90082876,
        -5.92050639, -5.89905767, -6.16098295, -5.87694083, -5.87211389,
        -5.88830355, -5.94718302, -5.93563017, -5.86188786, -5.94920743,
        -5.9657974 , -5.86020858, -6.01727707, -5.86662148, -6.00087552,
        -5.8935193 , -5.92293933, -5.97608958, -5.87694083, -5.87893309,
        -5.93592609, -5.94936971, -5.86801958, -5.90482503, -5.90392522,
        -5.87694083, -5.93851572, -5.87694083, -6.31997226, -6.24351385,
        -5.96603058, -5.9574572 , -5.89211134, -5.97635919, -5.93336535,
        -5.94645165, -5.87694083, -6.05585599, -5.97503831, -5.8919526 ,
        -5.96120727, -5.87694083, -5.86610219, -5.93092544, -5.90985912,
        -5.94383757, -5.88156107, -5.89233174, -5.8