# Validate BayesFlow Posterior with MCMC

In this notebook we are going to validate the posterior from BayesFlow by comparing it to posteriors generated from MCMC.

In [None]:
import os
from functools import partial
from typing import Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from numba import njit
from pypesto import sample, optimize, visualize, FD, Objective, Problem, store
from tqdm import tqdm

In [None]:
# specify which model to use
model_name = ['fröhlich-simple', 'fröhlich-detailed', 'pharmacokinetic_model', 'clairon_small_model'][-1]
network_idx = 0
load_best_network = True

## Load individual model


In [None]:
if model_name == 'fröhlich-simple':
    from models.froehlich_model_simple import FroehlichModelSimple, batch_simulator
    model = FroehlichModelSimple(network_idx=network_idx, load_best=load_best_network)
    
elif model_name == 'fröhlich-detailed':
    from models.froehlich_model_detailed import FroehlichModelDetailed, batch_simulator
    model = FroehlichModelDetailed(network_idx=network_idx, load_best=load_best_network)

elif model_name == 'pharmacokinetic_model':
    from models.pharmacokinetic_model import PharmacokineticModel, batch_simulator, convert_bf_to_observables
    model = PharmacokineticModel(network_idx=network_idx, load_best=load_best_network)
    
elif model_name == 'clairon_small_model':
    from models.clairon_small_model import ClaironSmallModel, batch_simulator, convert_bf_to_observables
    prior_type = ['normal', 'uniform'][0]
    model = ClaironSmallModel(network_idx=network_idx, load_best=load_best_network, prior_type=prior_type)
else:
    raise NotImplementedError('model not implemented')

# load network
trainer = model.build_trainer('../networks/' + model.network_name)

## Load Data

In [None]:
# load synthetic data for specific model
load_synthetic = False
obs_data = model.load_data(synthetic=load_synthetic)

# chose 10 random individuals/cells
np.random.seed(42)
individual_ids = np.random.randint(0, len(obs_data), size=10)  # obs_data can be list or numpy array
obs_data = [obs_data[i] for i in individual_ids]
    

if load_synthetic:
    # for these model parameters are known
    if model_name == 'fröhlich-sde':
        cell_param_log = pd.read_csv(f'../data/synthetic/synthetic_individual_cell_params_sde_model.csv',
                                     index_col=0, header=0)
    elif model_name == 'fröhlich-detailed':
        cell_param_log = pd.read_csv(f'../data/synthetic/synthetic_individual_cell_params_detailed_model.csv',
                                     index_col=0, header=0)
    else:
        cell_param_log = pd.read_csv(f'../data/synthetic/synthetic_individual_cell_params.csv',
                                     index_col=0, header=0)

## Examine Posterior for a Single Individual/Cell

In [None]:
# use observations to get a first look at the posterior
n_bayesflow_samples = 1000
obs_data_posterior_samples = model.draw_posterior_samples(data=obs_data, n_samples=n_bayesflow_samples)

In [None]:
rows = 4
fig, ax = plt.subplots(rows, int(np.ceil(len(obs_data) / rows)), tight_layout=True, figsize=(10, rows*3),
                       sharex='row', sharey='all')
axis = ax.flatten()
    
for p_id in tqdm(range(len(obs_data))):
    axis[p_id] = model.prepare_plotting(obs_data[p_id], obs_data_posterior_samples[p_id, :100], axis[p_id])
    _, labels = axis[p_id].get_legend_handles_labels()
    
for _ax in axis[len(obs_data):]:
    _ax.remove()

fig.legend(labels, ncol=3, loc='upper center', bbox_to_anchor=(0.5, 1))
plt.show()

## Prepare MCMC Posterior

First we need to define the likelihood and the prior we want to use for MCMC.
Note: BayesFlow works without specifying a likelihood since it is a simulation-based method.

In [None]:
@njit
def log_likelihood_multiplicative_noise(log_measurements: np.ndarray, 
                                        log_simulations: np.ndarray, 
                                        sigmas: Union[float, np.ndarray]) -> float:
    # compute the log-likelihood for multiplicative normal noise (log-normal distribution)
    dif_sum = np.sum(((log_measurements - log_simulations) / sigmas)**2)
    if isinstance(sigmas, float):
        # needed for njit, cannot sum over float
        log_det_sigma = np.log(sigmas**2)
    else:
        log_det_sigma = np.sum(np.log(sigmas**2))
    # log_measurement.size = n_measurements + n_observables, len(log_measurement) = n_measurements
    llh = (-0.5 * log_measurements.size * np.log(2 * np.pi) - 0.5 * len(log_measurements) * log_det_sigma 
           - log_measurements.sum() - 0.5 * dif_sum)
    return llh


@njit
def log_likelihood_additive_noise(measurements: np.ndarray, 
                                  simulations: np.ndarray, 
                                  sigmas: Union[float, np.ndarray]) -> float:
    # compute the log-likelihood for additive normal noise, proportionality might be captured in sigma already
    # normal distribution
    dif_sum = np.sum(((measurements - simulations) / sigmas)**2)
    log_det_sigma = np.sum(np.log(sigmas**2))
    llh = -0.5 * measurements.size * np.log(2 * np.pi) - 0.5*len(measurements)*log_det_sigma  - 0.5 * dif_sum
    return llh

In [None]:
@njit
def log_prior_density_normal(log_param: np.ndarray, 
                             mean: np.ndarray,
                             inv_cov_matrix: np.ndarray, 
                             prior_constant: float) -> float:
    # compute the log normal density of the prior
    dif = log_param - mean
    return prior_constant - 0.5 * dif.dot(inv_cov_matrix).dot(dif.T) - log_param.sum()


@njit
def log_prior_density_uniform(log_param: np.ndarray, 
                              prior_constant: float) -> float:
    # compute the log uniform density of the prior
    return prior_constant - log_param.sum()


if model.prior_type == 'normal':
    log_prior_constant = -0.5 * model.n_params * np.log(2 * np.pi) -0.5* np.linalg.slogdet(model.prior_cov).logabsdet
    inv_cov = np.linalg.inv(model.prior_cov)
if model.prior_type == 'uniform':
    log_prior_constant =  -np.log(np.diff(model.prior_bounds).prod())

In [None]:
individual_id = 1  # patient 5 for pharma, fro-detailed 0
obs_data_indv = obs_data[individual_id]

In [None]:
# prepare simulator accordingly to the model
if 'Froehlich' in model.name :
    # prepare simulator, data should be on log-scale
    simulator = partial(batch_simulator, 
                                n_obs=180,
                                with_noise=False)
    noise_model = 'multiplicative'  # additive on log-scale 
    index_sigma = -1  # index of sigma in parameter vector
    obs_data_indv_prepared = obs_data_indv.flatten()  # just one measurement per time point, already on log-scale
elif 'Pharma' in model.name:
    # prepare simulator, data should be on log-scale
    obs_data_indv_prepared, t_measurement, doses_time_points, dos, wt = convert_bf_to_observables(obs_data_indv)
    simulator = partial(batch_simulator,
                       t_measurement=t_measurement,
                       t_doses=doses_time_points,
                       wt=wt,
                       dos=dos,
                       with_noise=False,
                       convert_to_bf_batch=False)
    noise_model = 'multiplicative'  # additive on log-scale
    index_sigma = [-3, -2]  # index of sigmas in parameter vector
elif 'Clairon' in model.name:
    # prepare simulator, data should be on linear scale
    obs_data_indv_prepared, t_measurements, dose_amount, doses_time_points = convert_bf_to_observables(obs_data_indv)
    simulator = partial(batch_simulator,
                        t_measurements=t_measurements,
                        t_doses=doses_time_points,
                        with_noise=False,
                        convert_to_bf_batch=False)    
    noise_model = 'proportional'   # additive on linear scale
    index_sigma = [-2, -1]  # index of a, b in parameter vector of y+(a+by)*e
else:
    raise NotImplementedError('model not implemented')

assert simulator(model.prior_mean).shape == obs_data_indv_prepared.shape, 'simulator output shape does not match data shape' 

In [None]:
def neg_log_prop_posterior(log_param: np.ndarray):
    y_sim = simulator(log_param)  
    if noise_model == 'multiplicative':
        llh = log_likelihood_multiplicative_noise(log_measurements=obs_data_indv_prepared,
                                                  log_simulations=y_sim,
                                                  sigmas=np.exp(log_param[index_sigma]))
    else:  # noise_model == 'proportional':
        prop_sigma =  np.exp(log_param[index_sigma[0]]) + obs_data_indv_prepared * np.exp(log_param[index_sigma[1]])
        llh = log_likelihood_additive_noise(measurements=obs_data_indv_prepared,
                                            simulations=y_sim,
                                            sigmas=prop_sigma)
        
    if model.prior_type == 'normal':
        log_prior = log_prior_density_normal(log_param=log_param, mean=model.prior_mean, inv_cov_matrix=inv_cov,
                                             prior_constant=log_prior_constant)
    else:
        log_prior = log_prior_density_uniform(log_param=log_param, prior_constant=log_prior_constant)
    return -(llh + log_prior)

In [None]:
neg_log_prop_posterior(model.prior_mean)

## Run MCMC

In [None]:
n_chains = 10
n_samples = 1e6
filename = f'sampling_results/mcmc_{model.name}_individual_{individual_id}.hdf5'

# create objective function
pesto_objective = FD(obj=Objective(fun=neg_log_prop_posterior),
                     x_names=model.log_param_names)

lb = model.prior_mean - 5 * model.prior_std
ub = model.prior_mean + 5 * model.prior_std

# create pypesto problem
pesto_problem = Problem(objective=pesto_objective,
                        lb=lb, ub=ub,
                        x_names=model.log_param_names,
                        x_scales=['log']*len(model.log_param_names))
pesto_problem.print_parameter_summary()

In [None]:
# check if file exists, if not run optimization
if os.path.exists(filename):
    result = store.read_result(filename)
else:
    result = optimize.minimize(problem=pesto_problem,
                           optimizer=optimize.ScipyOptimizer(),
                           n_starts=n_chains*10)

In [None]:
visualize.parameters(result)
visualize.waterfall(result)
print(neg_log_prop_posterior(result.optimize_result.x[0]))

In [None]:
fig, ax = plt.subplots(2, 3, tight_layout=True, figsize=(10, 5),
                       sharex='row', sharey='all')
axis = ax.flatten()
    
for p_id in tqdm(range(axis.size)):
    axis[p_id] = model.prepare_plotting(obs_data_indv, result.optimize_result.x[p_id], axis[p_id])
    _, labels = axis[p_id].get_legend_handles_labels()

fig.legend(labels, ncol=3, loc='upper center', bbox_to_anchor=(0.5, 1))
axis[0].set_title(f'best {axis.size} fits')
plt.show()

In [None]:
sampler = sample.AdaptiveParallelTemperingSampler(
    internal_sampler=sample.AdaptiveMetropolisSampler(), n_chains=n_chains,
)

In [None]:
if not os.path.exists(filename):
    result = sample.sample(
            pesto_problem, n_samples=n_samples, sampler=sampler,
            x0=list(result.optimize_result.x)[:n_chains],
            result=result)

In [None]:
geweke_test = sample.geweke_test(result)
print('geweke_test', geweke_test)

auto_correlation = sample.auto_correlation(result)
print('auto_correlation', auto_correlation)

effective_sample_size = sample.effective_sample_size(result)
print('effective_sample_size', effective_sample_size)

In [None]:
visualize.sampling_parameter_traces(result, use_problem_bounds=True);

In [None]:
visualize.sampling_fval_traces(result);

In [None]:
if not os.path.exists(filename):
    store.write_result(
            result=result,
            filename=filename,
            problem=True,
            optimize=True,
            sample=True,
    )

In [None]:
pesto_samples = result.sample_result.trace_x[0]
print(pesto_samples.shape)

In [None]:
burn_in = result.sample_result.burn_in
pesto_samples_adjusted = pesto_samples[burn_in:, :]
thinned_samples = pesto_samples_adjusted[::int(auto_correlation), :]
print(pesto_samples_adjusted.shape)
print(thinned_samples.shape)

In [None]:
MAP_idx = np.argmin(result.sample_result.trace_neglogpost[0,burn_in:])
MAP = result.sample_result.trace_x[0,burn_in+MAP_idx,:]
print('MAP (optimizing)', neg_log_prop_posterior(result.optimize_result.x[0]))
print('MAP (sampling)', neg_log_prop_posterior(MAP))

if model_name == 'fröhlich-simple':
    # it is known, that this model's posterior should have two modes (in the first two parameters)
    other_MAP = MAP.copy()
    other_MAP[[0,1]] = other_MAP[[1,0]]
    print('MAP-2 (sampling)', neg_log_prop_posterior(other_MAP))

In [None]:
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=(10, 5))

ax = model.prepare_plotting(obs_data_indv, MAP, ax)
_, labels = ax.get_legend_handles_labels()
fig.legend(labels, ncol=4, loc='lower center', bbox_to_anchor=(0.5, 1))

ax.set_title(f'MAP fit')
plt.show()

# Compare BayesFlow and MCMC

In [None]:
# reduce to same number of samples
n_samples_umap = min(obs_data_posterior_samples[individual_id].shape[0], pesto_samples_adjusted.shape[0])
bayes_flow_samples = obs_data_posterior_samples[individual_id, :n_samples_umap]
mcmc_smaples = pesto_samples_adjusted[np.random.choice(range(pesto_samples_adjusted.shape[0]),
                                                       n_samples_umap, replace=False)]

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=int(np.ceil(model.n_params/2)), tight_layout=True, figsize=(16,12))
axis = ax.flatten()
bins = 40
for i, name in enumerate(model.param_names):
    axis[i].set_title('log '+name)
    #axis[i].hist(bayes_flow_samples[:, i], bins=bins, density=True, label='BayesFlow', color='blue')

    axis[i].hist(mcmc_smaples[:, i], bins=bins, density=True, label='MCMC', alpha=0.6, color='red')
    axis[i].legend()

for _ax in axis[model.n_params:]:
    _ax.remove()
#plt.savefig(f'../plots/mcmc/posterior_validation_{model.name}_individual_{individual_id}.png', dpi=600)
plt.show()

fig, ax = plt.subplots(nrows=2, ncols=int(np.ceil(model.n_params/2)), tight_layout=True, figsize=(16,12))
axis = ax.flatten()
for i, name in enumerate(model.param_names):
    axis[i].set_title(name)
    axis[i].hist(np.exp(bayes_flow_samples[:, i]), bins=bins, density=True, label='BayesFlow', color='blue')

    axis[i].hist(np.exp(mcmc_smaples[:, i]), bins=bins, density=True, label='MCMC', alpha=0.6, color='red')
    axis[i].legend()

for _ax in axis[model.n_params:]:
    _ax.remove()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, tight_layout=True, figsize=(16, 6),
                       sharex='row', sharey='all')
    
ax[0] = model.prepare_plotting(obs_data_indv, obs_data_posterior_samples[individual_id], ax[0])
ax[1] = model.prepare_plotting(obs_data_indv, thinned_samples[:n_bayesflow_samples], ax[1])
_, labels = ax[0].get_legend_handles_labels()
ax[1].set_ylabel('')

fig.legend(labels, ncol=3, loc='lower center', bbox_to_anchor=(0.5, -0.01))
ax[0].set_title('BayesFlow Posterior Predictive')
ax[1].set_title('MCMC Posterior Predictive')
#plt.savefig(f'../plots/mcmc/posterior_simulation_{model.name}_individual_{individual_id}.png', dpi=600)
plt.show()

In [None]:
import ot

In [None]:
# compute wasserstein distance on original samples
m = ot.dist(bayes_flow_samples, mcmc_smaples)
sample_weights_bf = np.ones(bayes_flow_samples.shape[0]) / bayes_flow_samples.shape[0]  # uniform
sample_weights_mcmc = np.ones(mcmc_smaples.shape[0]) / mcmc_smaples.shape[0]  # uniform
w_dist = ot.emd2(sample_weights_bf, sample_weights_mcmc, m)

print(f'Wasserstein distance between posteriors {w_dist}')

## Dimensionality Reduction

To see visually if samples differ, we map the posterior samples in a two-dimensional space using a UMAP. 

In [None]:
import umap
from sklearn.preprocessing import StandardScaler

In [None]:
# normalize samples
all_samples = np.concatenate((bayes_flow_samples, mcmc_smaples), axis=0)
scaled_samples = StandardScaler().fit_transform(all_samples)

# create umap
reducer = umap.UMAP(random_state=42, n_jobs=1,   # for reproducibility 
                    #densmap=True,  # preserve local density
                    ) 
umap_embedding = reducer.fit_transform(scaled_samples)

In [None]:
fig = plt.figure(tight_layout=True, figsize=(8, 6))
plt.scatter(
    umap_embedding[:n_samples_umap, 0],
    umap_embedding[:n_samples_umap, 1], label='BayesFlow', alpha=0.7, color='blue')
plt.scatter(
    umap_embedding[n_samples_umap:, 0],
    umap_embedding[n_samples_umap:, 1], label='MCMC', alpha=0.7, color='red')
plt.legend()
plt.title('Umap Based Representation of Posterior Distributions')

#plt.savefig(f'../plots/mcmc/posterior_umap_{model.name}_individual_{individual_id}.png', dpi=600)
plt.show()