# Amortized Inference for a NLME Model

## Individual Fit

In [None]:
import os
import pickle
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import pypesto.optimize as pesto_opt
from pypesto import FD, Objective, Problem, store
from scipy.stats import normaltest
from scipy import stats
from tqdm import tqdm

# for amortized inference
from inference.inference_functions import create_boundaries_from_prior
from inference.nlme_objective import get_covariance
from inference.empirical_bayes import ObjectiveFunctionEmpiricalBayes

In [None]:
# specify which model to use
model_name = ['fröhlich-simple', 'fröhlich-detailed', 'fröhlich-sde', 
              'pharmacokinetic_model', 
              'clairon_small_model'][-1]

## Load ODE model


In [None]:
if model_name == 'fröhlich-simple':
    from models.froehlich_model_simple import FroehlichModelSimple, batch_simulator
    individual_model = FroehlichModelSimple(load_best=True)
    n_data = 500
elif model_name == 'fröhlich-detailed':
    from models.froehlich_model_detailed import FroehlichModelDetailed, batch_simulator
    individual_model = FroehlichModelDetailed(load_best=True)
    n_data = 500
elif model_name == 'fröhlich-sde':
    from models.froehlich_model_sde import FroehlichModelSDE, batch_simulator
    individual_model = FroehlichModelSDE(load_best=True)    
    n_data = 500
elif model_name == 'pharmacokinetic_model':
    from models.pharmacokinetic_model import PharmacokineticModel, simulate_single_patient, convert_bf_to_observables
    individual_model = PharmacokineticModel(load_best=True)    
    n_data = 47
elif model_name == 'clairon_small_model':
    from models.clairon_small_model import ClaironSmallModel, simulate_single_patient, convert_bf_to_observables
    individual_model = ClaironSmallModel(load_best=True)
    n_data = 742
else:
    raise NotImplementedError('model not implemented')

trainer = individual_model.build_trainer('../networks/' + individual_model.network_name)

In [None]:
individual_model.plot_example()

In [None]:
obs_data = individual_model.load_data(n_data=n_data)
posterior_draws = individual_model.draw_posterior_samples(data=obs_data, n_samples=100)
print(len(obs_data))

## Single Cell / Patient

In [None]:
individual_idx = 0

# simulate
if 'fröhlich' in model_name:
    t_measurements_full = np.linspace(start=1 / 6, stop=30, num=180, endpoint=True)
    y_sim = batch_simulator(posterior_draws[individual_idx], 
                            with_noise=False).squeeze(axis=2)
    t_measurements = t_measurements_full
    y = obs_data[individual_idx]
elif 'clairon' in model_name or 'pharma' in model_name:
    y_sim = simulate_single_patient(posterior_draws[individual_idx], 
                                    patient_data=obs_data[individual_idx], 
                                    full_trajectory=True,
                                    with_noise=False)
    observations = convert_bf_to_observables(obs_data[individual_idx])
    # this is the same for the clairon and pharma model
    y = observations[0]
    t_measurements = observations[1]
    doses_time_points = observations[2]
    t_measurements_full = np.linspace(0, t_measurements[-1], 100)
else:
    raise NotImplemented

In [None]:
# compute median and 95% percentiles
if 'pharma' in model_name:
    # pharma has two observables
    y_median_1 = np.median(y_sim[:, :, 0], axis=0)
    y_perc_1 = np.percentile(y_sim[:, :, 0], (2.5, 97.5), axis=0)
    y_median_2 = np.median(y_sim[:, :, 1], axis=0)
    y_perc_2 = np.percentile(y_sim[:, :, 1], (2.5, 97.5), axis=0)
    plt.plot(t_measurements_full, y_median_1, color='orange', alpha=0.2)
    plt.fill_between(t_measurements_full, y_perc_1[0], y_perc_1[1], color='orange', alpha=0.2)
    plt.plot(t_measurements_full, y_median_2, color='red', alpha=0.2)
    plt.fill_between(t_measurements_full, y_perc_2[0], y_perc_2[1], color='red', alpha=0.2)
    
    plt.scatter(t_measurements, y[:, 0], label='measurements', color='orange')
    plt.scatter(t_measurements, y[:, 1], label='measurements', color='red')
else:
    y_median = np.median(y_sim, axis=0)
    y_perc = np.percentile(y_sim, (2.5, 97.5), axis=0)
    plt.plot(t_measurements_full, y_median, color='orange', alpha=0.2)
    plt.fill_between(t_measurements_full, y_perc[0], y_perc[1], color='orange', alpha=0.2)

    plt.scatter(t_measurements, y, label='measurements')
    
if 'clairon' in model_name or 'pharma' in model_name:
    plt.vlines(doses_time_points, 0, np.max(y), color='grey', alpha=0.2, label='doses')
plt.legend()
plt.title(f'Patient {individual_idx}')
plt.show()

## All Individuals

In [None]:
# load pypesto result
covariance_format = ['diag', 'cholesky'][0]
filename_result_population = f'../output/{model_name}-{covariance_format}-n_data_{n_data}.hdf5' 
result_optimization = store.read_result(filename_result_population)
best_res = result_optimization.optimize_result.x[0]

pop_mean = best_res[:individual_model.n_params]
pop_cov = get_covariance(best_res[individual_model.n_params:], 
                         covariance_format=covariance_format, param_dim=individual_model.n_params)

In [None]:
param_bounds = create_boundaries_from_prior(
            prior_mean=individual_model.prior_mean,
            prior_std=individual_model.prior_std,
            prior_type=individual_model.prior_type,
            prior_bounds=individual_model.prior_bounds if hasattr(individual_model, 'prior_bounds') else None,
            boundary_width_from_prior=5,
            covariance_format=covariance_format)
# we are only interested in the mean parameters
lower_bound = param_bounds[0, :individual_model.n_params]
upper_bound = param_bounds[1, :individual_model.n_params]

In [None]:
fixed_indices = [i for i, name in enumerate(individual_model.param_names) if 'error' in name or 'sigma' in name]
free_indices = [i for i in range(individual_model.n_params) if i not in fixed_indices]
fixed_values = pop_mean[fixed_indices]
guess_val = pop_mean[[i for i in range(individual_model.n_params) if i not in fixed_indices]]

In [None]:
def get_empirical_bayes_helper(individual_idx: int, 
                                n_start: int,
                                obs_data: np.ndarray,
                                pop_mean: np.ndarray,
                                pop_cov: np.ndarray,
                                lower_bound: np.ndarray,
                                upper_bound: np.ndarray,
                                fixed_indices: np.ndarray,
                                fixed_values: np.ndarray,
                                verbose: bool = False,
                               get_likelihood: bool = False) -> np.ndarray:
    
    # create objective function
    # prepare batch simulator
    noise_type = 'multiplicative'
    if 'fröhlich' in model_name:
        partial_batch_simulator = partial(batch_simulator, 
                                          with_noise=False)  # only one simulation, so format should be (#imulations)
        y = obs_data[individual_idx].flatten()
        # error covariance
        sigmas = np.exp(pop_mean[-1])
    elif 'clairon' in model_name or 'pharma' in model_name:
        partial_batch_simulator = partial(simulate_single_patient, 
                                          patient_data=obs_data[individual_idx], 
                                          with_noise=False)
        observations = convert_bf_to_observables(obs_data[individual_idx])
        # this is the same for the clairon and pharma model
        y = observations[0]
        
        if 'clairon' in model_name:
            sigmas = np.exp(pop_mean[-2]) + np.exp(pop_mean[-1])*y
            noise_type = 'additive'  # with proportional variance
        else:
            sigmas = np.array([np.exp(pop_mean[8:10])**2] * y.shape[0])
    else:
        raise NotImplemented
    
    
    eb_obj_fun = ObjectiveFunctionEmpiricalBayes(
        data=y,
        pop_mean=pop_mean,
        pop_cov=pop_cov,
        sigmas=sigmas,
        batch_simulator=partial_batch_simulator,
        noise_type=noise_type,
        ignore_conditional=get_likelihood
    )
        
    pesto_objective = FD(obj=Objective(fun=eb_obj_fun,
                                   x_names=individual_model.param_names))
    pesto_problem = Problem(objective=pesto_objective,
                            lb=lower_bound, ub=upper_bound,
                            x_fixed_indices=fixed_indices,
                            x_fixed_vals=fixed_values,
                            x_names=individual_model.param_names,
                            x_scales=['log']*individual_model.n_params,
                            x_guesses=[guess_val]
                            )
    if verbose:
        print(pesto_problem.print_parameter_summary())

    result = pesto_opt.minimize(
        problem=pesto_problem,
        optimizer=pesto_opt.ScipyOptimizer(),
        # engine=engine.MultiProcessEngine(10), # not working due to pickling issues
        n_starts=n_start,
        progress_bar=verbose)
    return result.optimize_result.as_dataframe()

get_empirical_bayes = partial(get_empirical_bayes_helper, n_start=10,
                                    obs_data=obs_data,
                                    pop_mean=pop_mean,
                                    pop_cov=pop_cov,
                                    lower_bound=lower_bound,
                                    upper_bound=upper_bound,
                                    fixed_indices=fixed_indices,
                                    fixed_values=fixed_values)

get_likelihoods = partial(get_empirical_bayes_helper, n_start=10,
                                    obs_data=obs_data,
                                    pop_mean=pop_mean,
                                    pop_cov=pop_cov,
                                    lower_bound=lower_bound,
                                    upper_bound=upper_bound,
                                    fixed_indices=fixed_indices,
                                    fixed_values=fixed_values,
                                    get_likelihood=True)

In [None]:
%%time
individual_idx = 10
result = get_empirical_bayes(individual_idx, n_start=1, verbose=True)

In [None]:
posterior_samples = individual_model.draw_posterior_samples(data=obs_data[individual_idx][np.newaxis, :], n_samples=50)
posterior_median = np.median(posterior_samples, axis=0)

print('empirical bayes ', result.x[0][free_indices])
print('posterior       ', posterior_median[free_indices])
print('population      ', pop_mean[free_indices])
print('error population', pop_mean[fixed_indices])

In [None]:
%%time
filename = f'../output/empirical_bayes-{model_name}-{covariance_format}-n_data_{n_data}.pkl' 

if not os.path.exists(filename):
    empirical_bayes_res_full = []
    for i in range(2): #n_data):
        empirical_bayes_res_full.append(get_empirical_bayes(i))
       
    with open(filename, 'wb') as f:
        pickle.dump(empirical_bayes_res_full, f)
     
else:
    with open(filename, 'rb') as f:
        empirical_bayes_res_full = pickle.load(f)

In [None]:
%%time
filename_l = f'../output/likelihoods-{model_name}-{covariance_format}-n_data_{n_data}.pkl' 

if not os.path.exists(filename_l):
    likelihoods_res_full = []
    for i in range(2): #n_data):
        likelihoods_res_full.append(get_likelihoods(i))
       
    with open(filename_l, 'wb') as f:
        pickle.dump(likelihoods_res_full, f)
     
else:
    with open(filename_l, 'rb') as f:
        likelihoods_res_full = pickle.load(f)

In [None]:
# reduce dataframes to get only the estimates
empirical_bayes_res = np.array([np.array([i for  i in entry['x']]) for entry in empirical_bayes_res_full])

In [None]:
print(np.array(individual_model.param_names)[free_indices])
random_vars = empirical_bayes_res[:, 0, free_indices] - pop_mean[free_indices]
print(np.mean(random_vars, axis=0))
print(pop_cov[free_indices, :].diagonal())
print(np.cov(random_vars.T).diagonal())

In [None]:
normaltest(random_vars, axis=0).pvalue >= 0.05 # null hypothesis cannot be rejected

In [None]:
n_samples = 50
with_noise = False
obs = []
pred_median = []
posterior_median = []
posterior_all_samples = []
pred_empirical_bayes = []
pred_likelihood = []
pred_map = []
pred_monolix = []
pred_pop = []

for i in tqdm(range(len(obs_data))):
    y, t_measurements, dose_amount, doses_time_points = convert_to_observables(obs_data[i])    
    obs.append(y)
    
    # compute individual posterior-median of simulations
    posterior_samples_i = model.draw_posterior_samples(data=obs_data[i][np.newaxis, :],
                                                       n_samples=n_samples)
    posterior_median_i = np.median(posterior_samples_i, axis=0)
    posterior_median.append(posterior_median_i)
    posterior_all_samples.append(posterior_samples_i)
    sim_data_i = batch_simulator(posterior_samples_i,
                                 t_measurements=t_measurements,
                                 t_doses=doses_time_points,
                                 dose_amount=dose_amount,
                                 with_noise=with_noise,
                                 convert_to_bf_batch=False)
    pred_median.append(np.median(sim_data_i, axis=0))
    
    # simulate with the best empirical bayes
    sim_data_eb = batch_simulator(empirical_bayes_res[i, 0][np.newaxis, :],
                                t_measurements=t_measurements,
                                t_doses=doses_time_points,
                                dose_amount=dose_amount,
                                with_noise=with_noise,
                                 convert_to_bf_batch=False)
    pred_empirical_bayes.append(sim_data_eb)
    
    # simulate with likelihood estimate
    sim_data_l = batch_simulator(likelihood_res[i][np.newaxis, :],
                                t_measurements=t_measurements,
                                t_doses=doses_time_points,
                                dose_amount=dose_amount,
                                with_noise=with_noise,
                                 convert_to_bf_batch=False)
    pred_likelihood.append(sim_data_l)
    
    # simulate with map
    sim_data_map = batch_simulator(map_res[i][np.newaxis, :],
                                t_measurements=t_measurements,
                                t_doses=doses_time_points,
                                dose_amount=dose_amount,
                                with_noise=with_noise,
                                 convert_to_bf_batch=False)
    pred_map.append(sim_data_map)
    
    # simulate with monolix
    sim_data_monolix = batch_simulator(monolix_res[i][np.newaxis, :],
                                t_measurements=t_measurements,
                                t_doses=doses_time_points,
                                dose_amount=dose_amount,
                                with_noise=with_noise,
                                 convert_to_bf_batch=False)
    pred_monolix.append(sim_data_monolix)
    
    # simulate with covariates but no random effects (population)
    sim_data_pop = batch_simulator(pop_params[np.newaxis, :],
                                t_measurements=t_measurements,
                                t_doses=doses_time_points,
                                dose_amount=dose_amount,
                                with_noise=with_noise,
                                 convert_to_bf_batch=False)
    pred_pop.append(sim_data_pop)
    
posterior_median = np.array(posterior_median)

In [None]:
np.random.seed(40)
rand_p_ids = np.random.choice(range(len(obs_data)), size=15, replace=False)

In [None]:
# chose random patients and plots
rows = 3
fig, ax = plt.subplots(rows, int(np.ceil(len(rand_p_ids) / rows)), sharex='all', sharey='all',
                       tight_layout=True, figsize=(10, rows*3))
axis = ax.flatten()

for ax_i, p_id in tqdm(enumerate(rand_p_ids), total=len(rand_p_ids)):
    y, t_measurements, dose_amount, doses_time_points = convert_to_observables(obs_data[p_id])
    t_measurements_full = np.linspace(0, 700, 100)
    for j_start in reversed(range(empirical_bayes_res.shape[1])):
        sim_data = batch_simulator(empirical_bayes_res[p_id, j_start][np.newaxis, :],
                                   t_measurements=t_measurements_full,
                                   t_doses=doses_time_points,
                                   dose_amount=dose_amount,
                                   with_noise=False, 
                                   convert_to_bf_batch=False)
        
        eb_handle, = axis[ax_i].plot(t_measurements_full, sim_data, 'b', 
                                     linestyle='--' if j_start > 0 else '-',
                                     alpha=0.2 if j_start > 0 else 1,
                                     label='Empirical Bayes')
    data_handle = axis[ax_i].scatter(t_measurements, y, color='g', label='data')
    axis[ax_i].vlines(doses_time_points, 0, 2500, color='grey', linestyles='--', alpha=0.5)
    
    # empirical bayes based on monolix
    sim_data = batch_simulator(monolix_res[p_id][np.newaxis, :],
                               t_measurements=t_measurements_full,
                               t_doses=doses_time_points,
                               dose_amount=dose_amount,
                               with_noise=False, 
                               convert_to_bf_batch=False)

    monolix_handle, = axis[ax_i].plot(t_measurements_full, sim_data, 'r', label='EB - Monolix')
fig.legend(handles=[data_handle, eb_handle, monolix_handle], loc='lower center',
           bbox_to_anchor=(0.5, -0.05), ncol=3)
#plt.savefig(f'plots/comparison_fits_clairon.png')
plt.show()

In [None]:
# plot individual fits
fig, ax = plt.subplots(1, 3, figsize=(15,5), sharey='all', sharex='all', tight_layout=True)
ax[0].scatter(pred_median, obs, color='orange', label=f'Median of Individual Posterior')
ax[1].scatter(pred_empirical_bayes, obs, color='orange', label=f'Empirical Bayes')
ax[2].scatter(pred_pop, obs, color='orange', label=f'Population')

ax[0].set_ylabel('Measurements')
ax[0].set_xlabel('Simulation')
ax[1].set_xlabel('Simulation')
ax[2].set_xlabel('Simulation')
ax[0].set_title('Median of Individual Posterior')
ax[1].set_title('Empirical Bayes')
ax[2].set_title('Population (no random effect)')
ax[0].set_aspect('equal', 'box')
ax[1].set_aspect('equal', 'box')
ax[2].set_aspect('equal', 'box')
#plt.savefig(f'plots/empirical_bayes_clairon{"_no" if obj_fun_amortized.covariance_format == "diag" else ""}_corr.png')
plt.show()

In [None]:
# compute shrinkage for every parameter (1 - variance of eta / variance of population), i.e. 1 is full shrinkage
shrinkage = []
shrinkage_posterior = []
shrinkage_monolix = []
var_random_effects = pop_cov.diagonal()
var_random_effects_monolix = pop_cov_monolix.diagonal()

for p_i in range(model.n_params):
    # from empirical bayes
    eta_i = empirical_bayes_res[:, 0, p_i] - pop_params[p_i]
    shrinkage_i = 1 - np.var(eta_i) / var_random_effects[p_i]
    shrinkage.append(shrinkage_i)
    
    # from posterior
    eta_pos_i = posterior_median[:, p_i] - pop_params[p_i]
    shrinkage_i = 1 - np.var(eta_pos_i) / var_random_effects[p_i]
    shrinkage_posterior.append(shrinkage_i)
    
    # for monolix
    eta_monolix_i = monolix_res[:, p_i] - pop_mean_monolix[p_i]
    shrinkage_i = 1 - np.var(eta_monolix_i) / var_random_effects_monolix[p_i]
    shrinkage_monolix.append(shrinkage_i)

In [None]:
# compute correlation between random effects and covariates
correlation_age = []
correlation_gender = []
correlation_age_posterior = []
correlation_gender_posterior = []
for p_i in range(model.n_params):
    # from empirical bayes
    eta_i = empirical_bayes_res[:, 0, p_i] - pop_params[p_i]
    corr_i_age = np.corrcoef(eta_i, covariates[:, 0])[0, 1]
    corr_i_gender = np.corrcoef(eta_i, covariates[:, 1])[0, 1]
    correlation_age.append(corr_i_age)
    correlation_gender.append(corr_i_gender)
    
    # from posterior
    eta_pos_i = posterior_median[:, p_i] - pop_params[p_i]
    corr_i_age = np.corrcoef(eta_pos_i, covariates[:, 0])[0, 1]
    corr_i_gender = np.corrcoef(eta_pos_i, covariates[:, 1])[0, 1]
    correlation_age_posterior.append(corr_i_age)
    correlation_gender_posterior.append(corr_i_gender)
    
correlation = [correlation_age, correlation_gender]
correlation_posterior = [correlation_age_posterior, correlation_gender_posterior]

In [None]:
fig, ax = plt.subplots(7, model.n_params-2, figsize=(15,20), sharex='col',  sharey='row', 
                       tight_layout=True)
for p_id in range(model.n_params-2):
    ax[0, p_id].hist(monolix_res[:, p_id], density=True, bins=20, color='red', alpha=0.5)
    ax[0, p_id].set_title(f'$\log$ {param_names[p_id]}\nshrinkage {shrinkage_monolix[p_id]:.2f}')
    
    ax[1, p_id].hist(empirical_bayes_res[:, 0, p_id], density=True, bins=20, color='blue', alpha=0.5)
    ax[1, p_id].set_title(f'shrinkage {shrinkage[p_id]:.2f}')
    
    ax[2, p_id].hist(empirical_bayes_res[:, :, p_id].flatten(), 
                     density=True, bins=20, color='blue', alpha=0.5)
    
    ax[3, p_id].hist(likelihood_res[:, p_id], density=True, bins=20, color='blue', alpha=0.5)
    #ax[3, p_id].set_title(f'$\log$ {param_names[p_id]}\n{shrinkage[p_id]:.2f}')
    
    # ax[4, p_id].hist(map_res[:, p_id], density=True, bins=20, color='blue', alpha=0.5) # todo: still old prior
    #ax[4, p_id].set_title(f'$\log$ {param_names[p_id]}\n{shrinkage[p_id]:.2f}')
    
    ax[5, p_id].hist(posterior_median[:, p_id], density=True, bins=20, color='blue', alpha=0.5)
    
    temp = np.concatenate([s[:, p_id] for s in posterior_all_samples])
    ax[6, p_id].hist(temp, density=True, bins=20, color='blue', alpha=0.5)
    #ax[6, p_id].set_title(f'$\log$ {param_names[p_id]}\n{shrinkage[p_id]:.2f}')
        
    # plot expected distribution
    x = np.linspace(pop_mean[p_id] - 2.58*pop_cov.diagonal()[p_id], 
                    pop_mean[p_id] + 2.58*pop_cov.diagonal()[p_id], 100)
    for i in range(1, ax.shape[0]):
        pop_handle, = ax[i, p_id].plot(x, stats.norm.pdf(x, pop_mean[p_id], np.sqrt(pop_cov.diagonal()[p_id])),
                      color='blue', label='Estimated Population')
    # plot pior
    if model.prior_type == 'gaussian':
        prior_handle, = ax[6, p_id].plot(x, stats.norm.pdf(x, prior_mean[p_id], prior_std[p_id]),
                         color='orange', label='Individual Prior', linestyle='--')
        prior_handle, = ax[4, p_id].plot(x, stats.norm.pdf(x, prior_mean[p_id], prior_std[p_id]),
                         color='orange', label='Individual Prior', linestyle='--')
    
    # monolix
    x = np.linspace(pop_mean_monolix[p_id] - 5*pop_cov_monolix.diagonal()[p_id], 
                    pop_mean_monolix[p_id] + 5*pop_cov_monolix.diagonal()[p_id], 100)
    monolix_handle, = ax[0, p_id].plot(x, stats.norm.pdf(x, pop_mean_monolix[p_id], np.sqrt(pop_cov_monolix.diagonal()[p_id])),
                     color='red', label='Monolix Population')
    #ax[0, p_id].set_xlim(lower_bound[p_id], upper_bound[p_id])
    
# plot bounds for empirical bayes estimates
ax0_ylim = ax[0, 0].get_ylim()
ax1_ylim = ax[1, 0].get_ylim()
ax5_ylim = ax[-1, 0].get_ylim()
for p_id in range(model.n_params-2):
    bound_handle = ax[0, p_id].vlines(lower_bound[p_id], ax0_ylim[0], ax0_ylim[1], color='green', linestyle='--', label='Bounds')
    ax[0, p_id].vlines(upper_bound[p_id], ax0_ylim[0], ax0_ylim[1], color='green', linestyle='--')
    ax[1, p_id].vlines(lower_bound[p_id], ax1_ylim[0], ax1_ylim[1], color='green', linestyle='--')
    ax[1, p_id].vlines(upper_bound[p_id], ax1_ylim[0], ax1_ylim[1], color='green', linestyle='--')

    if model.prior_type == 'uniform':
        prior_handle = ax[6, p_id].vlines(model.prior_bounds[p_id, 0], ax5_ylim[0], ax5_ylim[1],
                             color='orange', label='Individual Prior', linestyle='--')
        prior_handle = ax[6, p_id].vlines(model.prior_bounds[p_id, 1], ax5_ylim[0], ax5_ylim[1],
                             color='orange', label='Individual Prior', linestyle='--')
    
ax[0, 0].set_ylabel(f'EB - Monolix')
ax[1, 0].set_ylabel(f'Empirical Bayes')
ax[2, 0].set_ylabel(f'Empirical Bayes\nAll Starts')
ax[3, 0].set_ylabel(f'Individual Likelihood')
ax[4, 0].set_ylabel(f'Individual Map')
ax[5, 0].set_ylabel(f'Individual\nPosterior Median')
ax[6, 0].set_ylabel(f'Combined\nIndividual Posterior\nSamples')
fig.legend(handles=[pop_handle, monolix_handle, prior_handle, bound_handle], 
           loc='lower center', bbox_to_anchor=(0.5, -0.02), ncols=4)

#plt.savefig(f'plots/comparison_histograms_clairon.png')
plt.show()

In [None]:
fig, ax = plt.subplots(2, 
                       model.n_params-2, figsize=(15,10), sharex='col', sharey='row', 
                       tight_layout=True)
for p_id in range(model.n_params-2):
    ax[0, p_id].hist(monolix_res[:, p_id], density=True, bins=50, color='red', alpha=0.5)
    ax[0, p_id].set_title(f'$\log$ {param_names[p_id]}\nshrinkage {shrinkage_monolix[p_id]:.2f}')
    
    ax[1, p_id].hist(empirical_bayes_res[:, :, p_id].flatten(), 
                     density=True, bins=50, color='blue', alpha=0.5)
    ax[1, p_id].set_title(f'shrinkage {shrinkage[p_id]:.2f}')
    
    # plot expected distribution
    x = np.linspace(pop_mean[p_id] - 2.58*pop_cov.diagonal()[p_id], 
                    pop_mean[p_id] + 2.58*pop_cov.diagonal()[p_id], 100)
    for i in range(1, ax.shape[0]):
        pop_handle, = ax[i, p_id].plot(x, stats.norm.pdf(x, pop_mean[p_id], np.sqrt(pop_cov.diagonal()[p_id])),
                      color='blue', label='Estimated Population')

    # monolix
    x = np.linspace(pop_mean_monolix[p_id] - 5*pop_cov_monolix.diagonal()[p_id], 
                    pop_mean_monolix[p_id] + 5*pop_cov_monolix.diagonal()[p_id], 100)
    monolix_handle, = ax[0, p_id].plot(x, stats.norm.pdf(x, pop_mean_monolix[p_id], np.sqrt(pop_cov_monolix.diagonal()[p_id])),
                     color='red', label='Monolix Population')
    
ax[0, 0].set_ylabel(f'EB - Monolix')
ax[1, 0].set_ylabel(f'Empirical Bayes')

#fig.legend(handles=[pop_handle, monolix_handle, prior_handle, bound_handle], 
#           loc='lower center', bbox_to_anchor=(0.5, -0.02), ncols=4)

#plt.savefig(f'plots/comparison_histograms_clairon.png')
plt.show()

In [None]:
# plotting all samples
test = np.concatenate(np.array(posterior_all_samples), axis=0)

fig, ax = plt.subplots(1, model.n_params-2, figsize=(15,5), tight_layout=True)
for p_id in range(model.n_params-2):
    handle_pos = ax[p_id].hist(test[:, p_id], density=True, label='all individual posterior samples')
    ax[p_id].set_title(f'$\log$ {param_names[p_id]}')
        
    # plot expected distribution
    x = np.linspace(pop_mean[p_id] - 1*pop_cov.diagonal()[p_id], 
                    pop_mean[p_id] + 1*pop_cov.diagonal()[p_id], 100)
    handle_pop, = ax[p_id].plot(x, stats.norm.pdf(x, pop_mean[p_id], np.sqrt(pop_cov.diagonal()[p_id])),
                  color='blue', label='Estimated Population')
    
    handle_prior = ax[p_id].vlines(model.prior_bounds[p_id, 0], ax5_ylim[0], ax5_ylim[1],
                             color='orange', label='Individual Prior', linestyle='--')
    ax[p_id].vlines(model.prior_bounds[p_id, 1], ax5_ylim[0], ax5_ylim[1],
                             color='orange', label='Individual Prior', linestyle='--')
        
    #ax[p_id].set_xlim(lower_bound[p_id], upper_bound[p_id])
ax[0].set_ylabel(f'Combined\nIndividual Posterior\nSamples')
fig.legend(['all individual posterior samples', 'Estimated Population', 'Individual Prior'], loc='lower center', bbox_to_anchor=(0.5, -0.06), ncols=3)
plt.show()

In [None]:
# plot covariates against individual parameters
covariate_names = ['age', 'gender']

for covariate_i, c in enumerate(covariate_names):
    #y_axis_scale = lambda x: np.exp(x)
    y_axis_scale = lambda x: x
    
    fig, ax = plt.subplots(model.n_params, 2, figsize=(15,15), sharey='all', sharex='all', tight_layout=True)
    for p_id in range(model.n_params):
        ax[p_id, 0].scatter(covariates[:, covariate_i], 
                            y_axis_scale(empirical_bayes_res[:, p_id])-y_axis_scale(pop_params[p_id]), 
                            color='orange')
        ax[p_id, 1].scatter(covariates[:, covariate_i], 
                            y_axis_scale(posterior_median[:, p_id])-y_axis_scale(pop_params[p_id]), 
                            color='blue')
        ax[p_id, 0].set_ylabel(f'$\log$ {param_names[p_id]}')
        # add text upper right corner
        ax[p_id, 0].text(0.6, 0.5, f'shrinkage {shrinkage[p_id]:.2f}\n'
                                     f'correlation {correlation[covariate_i][p_id]:.2f}',
                         transform=ax[p_id, 0].transAxes)
        ax[p_id, 1].text(0.6, 0.5, f'shrinkage {shrinkage_posterior[p_id]:.2f}\n'
                                        f'correlation {correlation_posterior[covariate_i][p_id]:.2f}',
                         transform=ax[p_id, 1].transAxes)
        # ax[p_id, 0].hlines(y_axis_scale(lower_bound[p_id])-y_axis_scale(pop_params[p_id]), 
        #                   np.min(covariates[:, covariate_i]), np.max(covariates[:, covariate_i]),
        #                   color='red', linestyles='--')
        # ax[p_id, 1].hlines(y_axis_scale(lower_bound[p_id])-y_axis_scale(pop_params[p_id]), 
        #                   np.min(covariates[:, covariate_i]), np.max(covariates[:, covariate_i]),
        #                   color='red', linestyles='--')
        # ax[p_id, 0].hlines(y_axis_scale(upper_bound[p_id])-y_axis_scale(pop_params[p_id]), 
        #                   np.min(covariates[:, covariate_i]), np.max(covariates[:, covariate_i]),
        #                   color='red', linestyles='--')
        # ax[p_id, 1].hlines(y_axis_scale(upper_bound[p_id])-y_axis_scale(pop_params[p_id]), 
        #                   np.min(covariates[:, covariate_i]), np.max(covariates[:, covariate_i]),
        #                   color='red', linestyles='--')
    
    ax[-1, 0].set_xlabel(f'{covariate_names[covariate_i]}')
    ax[-1, 1].set_xlabel(f'{covariate_names[covariate_i]}')
    ax[0, 0].set_title(f'Empirical Bayes')
    ax[0, 1].set_title(f'Individual Posterior Median\n(without population information)')
    
    #plt.savefig(f'plots/covariates_clairon{"_no" if obj_fun_amortized.covariance_format == "diag" else ""}_corr_{covariate_names[covariate_i]}.png')
    plt.show()