In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pypesto import store

from inference.nlme_objective import ObjectiveFunctionNLME
from inference.helper_functions import compute_error_estimate, compute_variance_estimate

In [None]:
# specify which model to use
model_name = ['fröhlich-simple', 'fröhlich-detailed', 'fröhlich-sde'][2]

In [None]:
if model_name == 'fröhlich-simple':
    from models.froehlich_model_simple import FroehlichModelSimple
    individual_model = FroehlichModelSimple(load_best=True)
elif model_name == 'fröhlich-detailed':
    from models.froehlich_model_detailed import FroehlichModelDetailed
    individual_model = FroehlichModelDetailed(load_best=True)
elif model_name == 'fröhlich-sde':
    from models.froehlich_model_sde import FroehlichModelSDE
    individual_model = FroehlichModelSDE(load_best=True)    
elif model_name == 'pharmacokinetic_model':
    from models.pharmacokinetic_model import PharmacokineticModel
    individual_model = PharmacokineticModel(load_best=True)    
elif model_name == 'clairon_small_model':
    from models.clairon_small_model import ClaironSmallModel
    individual_model = ClaironSmallModel(load_best=True)
else:
    raise NotImplementedError('model not implemented')

# assemble simulator and prior
trainer = individual_model.build_trainer('../networks/' + individual_model.network_name)
individual_model.plot_example()

In [None]:
obj_fun_amortized = ObjectiveFunctionNLME(model_name=individual_model.name,
                                          param_samples=np.empty((1,1,1)),
                                          prior_mean=individual_model.prior_mean,
                                          prior_std=individual_model.prior_std,
                                          covariance_format='diag',
                                          covariates=None,
                                          covariate_mapping=None,
                                          prior_type=individual_model.prior_type,
                                          prior_bounds=individual_model.prior_bounds if hasattr(individual_model, 'prior_bounds') else None,  # for uniform prior
                                          )

In [None]:
compute_relative_error = False  # relative to true parameter values

In [None]:
test_n_cells = [50, 100, 500, 5000, 10000]
n_samples_opt_list = [10, 50, 100, 500]
n_runs = 104
time_opt = np.ones((len(test_n_cells), len(n_samples_opt_list), n_runs)) * np.nan
amortized_error = np.ones((len(test_n_cells), len(n_samples_opt_list), n_runs)) * np.nan
amortized_var = np.ones((len(test_n_cells), len(n_samples_opt_list))) * np.nan
    
for nc, n_cells in enumerate(test_n_cells):
    for ns, n_samples in enumerate(n_samples_opt_list):
        # load results
        filename = f'synthetic_results_amortized/{individual_model.name}_cells_{n_cells}_samples_{n_samples}.hd5'
        result_optimization = store.read_result(filename)
        results_params = np.array(result_optimization.optimize_result.x)
        #assert results_params.shape[0] == n_runs, f'number of runs ({n_runs}) does not match number of results ({results_params.shape[0]})'
                
        # load true population parameters
        true_pop_parameters = individual_model.load_synthetic_parameter(n_data=n_cells)
        # set very small variances to 0.001
        true_pop_parameters[individual_model.n_params:][true_pop_parameters[individual_model.n_params:] < 0.001] = 0.001
        
        estimated_params_full = []
        for i_r, res in enumerate(results_params):
            # transform results
            estimated_beta = res[:individual_model.n_params]
            estimated_var = np.exp(-res[individual_model.n_params:individual_model.n_params*2])
            estimated_params = np.concatenate((estimated_beta, estimated_var))
                         
            # compute relative error of parameter estimated as minimum over multi_starts
            amortized_error[nc, ns, i_r] = compute_error_estimate(estimated_params,
                                                            true_pop_parameters,
                                                            bi_modal=True if 'Simple' in individual_model.name else False,
                                                            relative_error=compute_relative_error)
            estimated_params_full.append(estimated_params)
            
        amortized_var[nc, ns] = compute_variance_estimate(np.array(estimated_params_full))
             
        # get duration of optimization procedure (in seconds)
        time_opt[nc, ns, :results_params.shape[0]] = np.array(result_optimization.optimize_result.time) / 60 / 60
#amortized_error.sort(axis=-1)

In [None]:
# read results from monolix
if 'simple' in model_name:
    reorder_monolix_params = [0,1,2,3,4,10,5,6,7,8,9]
elif 'detailed' in model_name:
    reorder_monolix_params = [0,1,2,3,4,5,6,7,8,9,20,10,11,12,13,14,15,16,17,18,19]
else:
    raise NotImplementedError('model not implemented')

monolix_errors = np.ones(len(test_n_cells)) * np.nan
monolix_var = np.ones(len(test_n_cells)) * np.nan
timing_monolix = np.ones(len(test_n_cells)) * np.nan

for cell_idx, n_cells in enumerate(test_n_cells):
    if 'detailed' in model_name and n_cells == 200: continue 
    
    estimates_monolix = pd.read_csv(f'synthetic_results_monolix/{model_name}/estimates/synthetic_{n_cells}_poppars.csv',
                                    index_col=0, header=0)
    
    true_pop_parameters = individual_model.load_synthetic_parameter(n_data=n_cells)
    # set very small variances to 0.001
    true_pop_parameters[individual_model.n_params:][true_pop_parameters[individual_model.n_params:] < 0.001] = 0.001

    results_to_compare = []
    for col in estimates_monolix.columns:
        temp_res = estimates_monolix[col].values[reorder_monolix_params]
        temp_res[individual_model.n_params-1] = np.log(temp_res[individual_model.n_params-1]) # standard deviation is not on log-scale
        temp_res = np.concatenate((temp_res, [0.001]))  # add variance of noise
        results_to_compare.append(temp_res)
    error_mono = compute_error_estimate(np.array(results_to_compare), 
                                        true_pop_parameters, 
                                        bi_modal=True if 'Simple' in individual_model.name else False,
                                        relative_error=compute_relative_error)
    # take min over multi-starts
    error_mono.sort()
    monolix_errors[cell_idx] = np.median(error_mono)
    monolix_var[cell_idx] = compute_variance_estimate(np.array(results_to_compare))

    # get timing    
    if 'detailed' in model_name:
        # likelihood were not always available, results are sorted
        best_runs = pd.read_csv(f'synthetic_results_monolix/{model_name}/estimates/synthetic_{n_cells}_complete_likelihoods.csv', 
                                index_col=0, header=0)['run']#[:10]
    else:
        # results are sorted
        best_runs = pd.read_csv(f'synthetic_results_monolix/{model_name}/estimates/synthetic_{n_cells}_likelihoods.csv', 
                                index_col=0, header=0)['run']#[:10]
        
    timing_monolix_df = pd.read_csv(f'synthetic_results_monolix/{model_name}/optimization_times/synthetic_{n_cells}_timings.csv', 
                                 header=0)
    timing_monolix[cell_idx] = np.median(timing_monolix_df.saem) / 60 / 60  # in hours


In [None]:
figure, axis = plt.subplots(nrows=1, ncols=len(n_samples_opt_list), tight_layout=True,
                            sharex='col', sharey='row', figsize=(15, 5))

for j, n_samples_opt in enumerate(n_samples_opt_list):
    axis[j].errorbar(np.array(test_n_cells), np.median(amortized_error[:, j], axis=1), np.sqrt(amortized_var[:, j]), alpha=0.5,
                  linestyle='None', marker='x', capsize=3, label=f'#posterior samples: {n_samples_opt}')

    axis[j].errorbar(test_n_cells, monolix_errors, np.sqrt(monolix_var), label='Monolix',
                     linestyle='None', marker='x', capsize=3)

    axis[j].set_xscale('log')
    axis[j].set_xlabel('#cells')
    axis[j].set_xticks(ticks=test_n_cells, labels=test_n_cells, rotation=60)
    axis[j].legend()
axis[0].set_ylabel('Relative mean squared error' if compute_relative_error else 'Mean squared error')
axis[len(n_samples_opt_list)//2].set_title('Error (compared to true population parameters)')
plt.show()

In [None]:
if 'simple' in model_name:
    average_training_time = 6.11
elif 'detailed' in model_name:
    average_training_time = 5.83 + 10.56
elif 'sde' in model_name:
    average_training_time = 2.16 + 5.12
else:
    raise NotImplementedError('model not implemented')

In [None]:
figure, axis = plt.subplots(nrows=1, ncols=len(n_samples_opt_list), sharey='row', figsize=(15, 5))

for j, n_samples_opt in enumerate(n_samples_opt_list):
    axis[j].hlines(average_training_time, xmin=test_n_cells[0], xmax=test_n_cells[-1], color='grey', linestyle='--',
               label=f'average training time of BayesFlow')
    axis[j].plot(test_n_cells, np.median(time_opt[:, j], axis=-1) / 60 / 60, label=f'#posterior samples: {n_samples_opt}')

    axis[j].plot(test_n_cells, timing_monolix, label=f'baseline')
    
    axis[j].set_xscale('log')
    axis[j].set_yscale('log')
    axis[j].set_title('Optimization Time For a New Data Set')
    axis[j].legend()
    axis[j].set_xlabel('#cells')
    axis[j].set_xticks(ticks=test_n_cells, labels=test_n_cells, rotation=60)
axis[0].set_ylabel('$t\,[h]$')
plt.tight_layout()
plt.show()