# Inference with Approximate Bayesian Computation (ABC)

In [None]:
import argparse
import os
import pickle
from functools import partial
from typing import Union

import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import numpy as np
import pyabc
#from pyabc.sampler import RedisEvalParallelSampler
import scipy.stats as stats
from fitmulticell import model as morpheus_model
from fitmulticell.sumstat import SummaryStatistics

from load_bayesflow_model import load_model, EnsembleTrainer
from plotting_routines import sampling_parameter_cis, plot_posterior_1d, plot_sumstats_distance_stats
from summary_stats import compute_summary_stats, reduce_to_coordinates, span, euclidean_distance

In [None]:
# get the job array id and number of processors
test_id = 2 #int(os.environ.get('SLURM_ARRAY_TASK_ID', 0))
n_procs = 10 # int(os.environ.get('SLURM_CPUS_PER_TASK', 1))
print('test_id', test_id)
on_cluster = False

In [None]:
population_size = 1000

if on_cluster:
    parser = argparse.ArgumentParser(description='Parse necessary arguments')
    parser.add_argument('-pt', '--port', type=str, default="50004",
                        help='Which port should be use?')
    parser.add_argument('-ip', '--ip', type=str,
                        help='Dynamically passed - BW: Login Node 3')
    args = parser.parse_args()

In [None]:
if on_cluster:
    gp = '/home/jarruda_hpc/CellMigration/synth_data_params_bayesflow'
else:
    gp = os.getcwd()

par_map = {
    'gradient_strength': './CellTypes/CellType/Constant[@symbol="gradient_strength"]',
    'move.strength': './CellTypes/CellType/Constant[@symbol="move.strength"]',
    'move.duration.mean': './CellTypes/CellType/Constant[@symbol="move.duration.mean"]',
    'cell_nodes_real': './Global/Constant[@symbol="cell_nodes_real"]',
}

dt = 30
model_path = gp + "/cell_movement_v24.xml"  # time step is 30sec, for inference

# defining the summary statistics function
min_sequence_length = 0
max_sequence_length = 3600 // dt
only_longest_traj_per_cell = True  # mainly to keep the data batchable
cells_in_population = 143

def make_sumstat_dict(data: Union[dict, np.ndarray]) -> dict:
    if isinstance(data, dict):
        # get key
        key = list(data.keys())[0]
        data = data[key]
    data = data[0]  # only one full simulation
    assert data.ndim == 3
    # compute the summary statistics
    msd_list, ta_list, v_list, ad_list = compute_summary_stats(data, dt=dt)
    cleaned_dict = {
        'msd': np.array(msd_list).flatten(),
        'ta': np.array(ta_list).flatten(),
        'vel': np.array(v_list).flatten(),
        'ad': np.array(ad_list).flatten(),
    }
    return cleaned_dict


def prepare_sumstats(output_morpheus_model) -> dict:
    sim_coordinates = reduce_to_coordinates(output_morpheus_model, 
                          minimal_length=min_sequence_length, 
                          maximal_length=max_sequence_length,
                          only_longest_traj_per_cell=only_longest_traj_per_cell,
                          )
    
    # we now do exactly the same as in the BayesFlow workflow, but here we get only one sample at a time
    data_transformed = np.ones((1, cells_in_population, max_sequence_length, 3)) * np.nan
    # each cell is of different length, each with x and y coordinates, make a tensor out of it
    if len(sim_coordinates) != 0:
        # some cells were visible in the simulation
        for c_id, cell_sim in enumerate(sim_coordinates):
            # pre-pad the data with zeros, but first write zeros as nans to compute the mean and std
            data_transformed[0, c_id, -len(cell_sim['x']):, 0] = cell_sim['x']
            data_transformed[0, c_id, -len(cell_sim['y']):, 1] = cell_sim['y']
            data_transformed[0, c_id, -len(cell_sim['t']):, 2] = cell_sim['t']
    
    return {'sim': data_transformed}


sumstat = SummaryStatistics(sum_stat_calculator=prepare_sumstats)                    

if on_cluster:
    # define the model object
    model = morpheus_model.MorpheusModel(
        model_path, par_map=par_map, par_scale="log10",
        show_stdout=False, show_stderr=False,
        executable="ulimit -s unlimited; /home/jarruda_hpc/CellMigration/morpheus-2.3.7",
        clean_simulation=True,
        raise_on_error=False, sumstat=sumstat)

    # note: remember also change tiff path in model.xml!
else:
    # define the model object
    model = morpheus_model.MorpheusModel(
        model_path, par_map=par_map, par_scale="log10",
        show_stdout=False, show_stderr=False,
        clean_simulation=True,
        raise_on_error=False, sumstat=sumstat)

# parameter values used to generate the synthetic data
obs_pars = {
    'gradient_strength': 100.,  # strength of the gradient of chemotaxis
    'move.strength': 10.,  # strength of directed motion
    'move.duration.mean': 0.1,  # mean of exponential distribution (seconds)
    'cell_nodes_real': 50.,  # area of the cell (\mu m^2), macrophages have a volume of 4990\mu m^3 -> radius of 17 if they would are sphere
}


obs_pars_log = {key: np.log10(val) for key, val in obs_pars.items()}
limits = {'gradient_strength': (1, 10000), #(10 ** 4, 10 ** 8),
          'move.strength': (1, 100),
          'move.duration.mean': (1e-4, 30), #(math.log10((10 ** -2) * 30), math.log10((10 ** 4))), # smallest time step in simulation 5
          'cell_nodes_real': (1, 300)}
limits_log = {key: (np.log10(val[0]), np.log10(val[1])) for key, val in limits.items()}


prior = pyabc.Distribution(**{key: pyabc.RV("uniform", loc=lb, scale=ub-lb)
                              for key, (lb, ub) in limits_log.items()})

param_names = ['$m_{\\text{dir}}$', '$m_{\\text{rand}}$', '$w$', '$a$']
log_param_names = ['$\log_{10}(m_{\\text{dir}})$', '$\log_{10}(m_{\\text{rand}})$',
                   '$\log_{10}(w)$', '$\log_{10}(a)$']
print(obs_pars)
print(limits_log)

In [None]:
sigma0 = 550
space_x0 = 1173/2
space_y0 = 1500/1.31/2
x0, y0 = 1173/2, (1500+1500/2+270)/1.31
u1 = lambda space_x, space_y: 7/(2*np.pi*(sigma0**2)) *np.exp(-1/2*(((space_x)-(x0))**2+ ((space_y)-(y0))**2)/(sigma0**2))

# plot the function
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x = np.linspace(0, 1173 , 100)
y = np.linspace(0, 2500 , 100)
X, Y = np.meshgrid(x, y)
Z = u1(X, Y)
ax.plot_surface(X, Y, Z, alpha=0.5)
# plot start points
ax.scatter(space_x0, space_y0, u1(space_x0, space_y0), color='r', s=100)

ax.set_xlabel('space_x')
ax.set_ylabel('space_y')
plt.show()

space_x0, space_y0, u1(space_x0, space_y0), u1(x0, y0)

In [None]:
# load test data
np.random.seed(42+test_id)
test_params = np.array(list(prior.rvs().values()))
if not os.path.exists(os.path.join(gp, f'test_sim_{test_id}.npy')):
    raise FileNotFoundError('Test data not found')
else:
    test_sim = np.load(os.path.join(gp, f'test_sim_{test_id}.npy'))
    test_sim_full = {'sim_data': test_sim}
results_path = f'abc_results_{test_id}'
test_sim.shape

In [None]:
prior_draws = np.array([list(prior.rvs().values()) for _ in range(1000)])

## ABC with Wasserstein distance

In [None]:
def obj_func_wass_helper(sim: dict, obs: dict, key: str) -> float:
    x, y = np.array(sim[key]), np.array(obs[key])
    if x.size == 0 or y.size == 0:
        return np.inf
    return stats.wasserstein_distance(x, y)

distances = {
    'msd': pyabc.distance.FunctionDistance(partial(obj_func_wass_helper, key='msd')),
    'ta': pyabc.distance.FunctionDistance(partial(obj_func_wass_helper, key='ta')),
    'vel': pyabc.distance.FunctionDistance(partial(obj_func_wass_helper, key='vel')),
    'ad': pyabc.distance.FunctionDistance(partial(obj_func_wass_helper, key='ad')),
}

# adaptive distance
log_file_weights = f"{results_path}/adaptive_distance_log_{test_id}.txt"


adaptive_wasserstein_distance = pyabc.distance.AdaptiveAggregatedDistance(
    distances=list(distances.values()),
    scale_function=span,
    log_file=log_file_weights
)

In [None]:
#redis_sampler = RedisEvalParallelSampler(host=args.ip, port=args.port,
#                                         adapt_look_ahead_proposal=False,
#                                         look_ahead=False)

abc = pyabc.ABCSMC(model, prior,
                   distance_function=adaptive_wasserstein_distance,
                   summary_statistics=make_sumstat_dict,
                   population_size=population_size,
                   sampler=pyabc.sampler.MulticoreEvalParallelSampler(n_procs=n_procs)
                   #sampler=redis_sampler
                   )

db_path = os.path.join(gp, f"{results_path}/synthetic_{test_id}_test_wasserstein_sumstats_adaptive.db")
if not os.path.exists(db_path):
    history_abc = abc.new("sqlite:///" + db_path, make_sumstat_dict(test_sim))

    # start the abc fitting
    abc.run(min_acceptance_rate=1e-2, max_nr_populations=15)
    print('Done!')
    adaptive_weights = pyabc.storage.load_dict_from_json(log_file_weights)[history_abc.max_t]
else:
    history_abc = abc.load("sqlite:///" + db_path)
    if len(history_abc.all_runs()) > 1:
        history_abc = abc.load("sqlite:///" + db_path, abc_id=len(history_abc.all_runs()))
    #adaptive_weights = list(adaptive_weights.values())
    adaptive_weights = pyabc.storage.load_dict_from_json(log_file_weights)[history_abc.max_t]

## ABC with neural network summary statistics

In [None]:
if os.path.exists(os.path.join(gp, 'validation_data.pickle')):
    with open(os.path.join(gp, 'validation_data.pickle'), 'rb') as f:
        valid_data = pickle.load(f)
else:
    raise FileNotFoundError('Validation data not found')

x_mean = np.nanmean(valid_data['sim_data'], axis=(0, 1, 2))
x_std = np.nanstd(valid_data['sim_data'], axis=(0, 1, 2))
p_mean = np.mean(valid_data['prior_draws'], axis=0)
p_std = np.std(valid_data['prior_draws'], axis=0)
print('Mean and std of data:', x_mean, x_std)
print('Mean and std of parameters:', p_mean, p_std)

In [None]:
# use trained neural net as summary statistics
def make_sumstat_dict_nn(data: Union[dict, np.ndarray], use_npe_summaries: bool = True) -> dict:
    if use_npe_summaries:
        model_id = 0
    else:
        model_id = 10
    if isinstance(data, dict):
        # get key
        key = list(data.keys())[0]
        data = data[key]

    trainer = load_model(
        model_id=model_id,
        x_mean=x_mean,
        x_std=x_std,
        p_mean=p_mean,
        p_std=p_std,
    )

    # configures the input for the network
    config_input = trainer.configurator({"sim_data": data})
    # get the summary statistics
    if isinstance(trainer, EnsembleTrainer):
        out_dict = {
            'summary_net': trainer.amortizer.summary_net(config_input).flatten()
        }
    else:
        out_dict = {
            'summary_net': trainer.amortizer.summary_net(config_input['summary_conditions']).numpy().flatten()
        }
    if model_id == 10:
        # renormalize the parameters
        out_dict['summary_net'] = out_dict['summary_net'] * p_std + p_mean

    del trainer
    return out_dict


if on_cluster:
    # define the model object
    model_nn = morpheus_model.MorpheusModel(
        model_path, par_map=par_map, par_scale="log10",
        show_stdout=False, show_stderr=False,
        executable="ulimit -s unlimited; /home/jarruda_hpc/CellMigration/morpheus-2.3.7",
        clean_simulation=True,
        raise_on_error=False, sumstat=sumstat)

    # note: remember also change tiff path in model.xml!
else:
    # define the model object
    model_nn = morpheus_model.MorpheusModel(
        model_path, par_map=par_map, par_scale="log10",
        show_stdout=False, show_stderr=False,
        clean_simulation=True,
        raise_on_error=False, sumstat=sumstat)

In [None]:
# %%time
# print(make_sumstat_dict_nn(test_sim), test_params)
#
# p = 1
# summary_error = (np.abs(make_sumstat_dict_nn(test_sim, use_npe_summaries=False)['summary_net']-test_params)**p).sum() ** (1 / p)
# print(make_sumstat_dict_nn(test_sim, use_npe_summaries=False), test_params)
# print('error:', summary_error)

In [None]:
# abc with summary net trained on posterior mean
abc_nn = pyabc.ABCSMC(model_nn, prior, # here we use now the Euclidean distance, Wasserstein distance is not possible
                      population_size=population_size,
                      summary_statistics=partial(make_sumstat_dict_nn, use_npe_summaries=False),
                      sampler=pyabc.sampler.MulticoreEvalParallelSampler(n_procs=n_procs)
                      #sampler=redis_sampler
                      )

db_path = os.path.join(gp, f"{results_path}/synthetic_{test_id}_test_nn_sumstats_posterior_mean.db")

if not os.path.exists(db_path):
    history_nn = abc_nn.new("sqlite:///" + db_path, make_sumstat_dict_nn(test_sim))

    # start the abc fitting
    abc_nn.run(min_acceptance_rate=1e-2, max_nr_populations=15)
    print('Done!')
else:
    history_nn = abc_nn.load("sqlite:///" + db_path)
    if len(history_nn.all_runs()) > 1:
        history_nn = abc_nn.load("sqlite:///" + db_path, abc_id=len(history_nn.all_runs()))

In [None]:
#redis_sampler = RedisEvalParallelSampler(host=args.ip, port=args.port,
#                                         adapt_look_ahead_proposal=False,
#                                         look_ahead=False)

# abc with summary net trained with NPE
abc_npe = pyabc.ABCSMC(model_nn, prior, # here we use now the Euclidean distance, Wasserstein distance is not possible
                      population_size=population_size,
                      summary_statistics=partial(make_sumstat_dict_nn, use_npe_summaries=True),
                      sampler=pyabc.sampler.MulticoreEvalParallelSampler(n_procs=n_procs)
                      #sampler=redis_sampler
                      )

db_path = os.path.join(gp, f"{results_path}/synthetic_{test_id}_test_nn_sumstats.db")
if not os.path.exists(db_path):
    history_npe = abc_npe.new("sqlite:///" + db_path,
                              make_sumstat_dict_nn(test_sim, use_npe_summaries=True))

    # start the abc fitting
    abc_npe.run(min_acceptance_rate=1e-2, max_nr_populations=15)
    print('Done!')
else:
    history_npe = abc_npe.load("sqlite:///" + db_path)
    if len(history_npe.all_runs()) > 1:
        # first run failed
        history_npe = abc_npe.load("sqlite:///" + db_path, abc_id=len(history_npe.all_runs()))

In [None]:
for hist, name in zip([history_abc, history_nn, history_npe], ['abc', 'abc_mean', 'abc_npe']):
    if hist is None:
        continue
    print(name, 'Generations:', hist.max_t)
    fig, ax = plt.subplots(1, len(param_names), tight_layout=True, figsize=(12, 4))
    for i, param in enumerate(limits.keys()):
        for t in range(hist.max_t + 1):
            df, w = hist.get_distribution(m=0, t=t)
            pyabc.visualization.plot_kde_1d(
                df,
                w,
                xmin=limits_log[param][0],
                xmax=limits_log[param][1],
                x=param,
                xname=log_param_names[i],
                ax=ax[i],
                label=f"PDF t={t}",
            )
        ax[i].set_xlim((limits_log[param][0]-0.2, limits_log[param][1]+0.2))
    plt.savefig(os.path.join(gp, f'{results_path}/synthetic_{test_id}_population_kdes_{name}.pdf'), bbox_inches='tight')
    plt.show()

    fig, arr_ax = plt.subplots(1, 5, figsize=(12, 3), tight_layout=True)
    arr_ax = arr_ax.flatten()
    pyabc.visualization.plot_sample_numbers(hist, ax=arr_ax[0])
    arr_ax[0].get_legend().remove()
    pyabc.visualization.plot_walltime(hist, ax=arr_ax[1], unit='h')
    arr_ax[1].get_legend().remove()
    pyabc.visualization.plot_epsilons(hist, ax=arr_ax[2])
    pyabc.visualization.plot_effective_sample_sizes(hist, ax=arr_ax[3])
    pyabc.visualization.plot_acceptance_rates_trajectory(hist, ax=arr_ax[4])
    plt.savefig(os.path.join(gp, f'{results_path}/synthetic_{test_id}_diagnostics_{name}.pdf'), bbox_inches='tight')
    plt.show()

# Synthetic Tests

## Compare Posterior Samples

In [None]:
# get posterior samples
posterior_samples = {}
for hist, name in zip([history_abc, history_nn, history_npe], ['abc', 'abc_mean', 'abc_npe']):
    if hist is None:
        continue
    abc_df, abc_w = hist.get_distribution()
    posterior_samples[name] = pyabc.resample(abc_df[limits.keys()].values, abc_w, n=1000)

# add bayesflow posterior samples
posterior_samples['npe'] = np.load(f'abc_results_{test_id}/posterior_samples_npe.npy')

In [None]:
labels_colors = {
    'abc': ('ABC with hand-crafted summaries', '#9AB8D7'),
    'abc_mean': ('ABC with posterior mean summaries', '#C4B7D4'),
    'abc_npe': ('ABC with inference-tailored summaries', '#EEBC88'),
    'npe': ('NPE with jointly learned summaries', '#A7CE97')
}

colors = [labels_colors[name][1] for name in posterior_samples.keys()]
labels = [labels_colors[name][0] for name in posterior_samples.keys()]

In [None]:
fig = plot_posterior_1d(
    posterior_samples=posterior_samples,
    prior_draws=prior_draws,
    log_param_names=log_param_names,
    test_sim=test_sim,
    test_params=test_params,
    labels_colors=labels_colors,
    make_sumstat_dict_nn=make_sumstat_dict_nn,
    save_path=os.path.join(gp, f'{results_path}/synthetic_posterior_all_rows.pdf')
)

In [None]:
prior_mean = np.mean(prior_draws, axis=0)
prior_std = np.std(prior_draws, axis=0)

def compute_z_score(posterior_mean):
    return (posterior_mean - prior_mean) / prior_std

def compute_contraction(posterior_std):
    return 1. - (posterior_std / prior_std)

posterior_stats = {}
for name, ps in posterior_samples.items():
    posterior_stats[name] = {
        'mean': np.mean(ps, axis=0),
        'std': np.std(ps, axis=0),
        'median': np.median(ps, axis=0),
        'z_score': compute_z_score(np.mean(ps, axis=0)),
        'contraction': compute_contraction(np.std(ps, axis=0))
    }

In [None]:
z_scores = [posterior_stats[name]['z_score'] for name in posterior_samples.keys()]
contractions = [posterior_stats[name]['contraction'] for name in posterior_samples.keys()]

# Plotting Z-scores and contractions for both methods
fig, ax1 = plt.subplots(figsize=(8, 4), tight_layout=True)

# Z-Scores for both methods
for i, z in enumerate(z_scores):
    ax1.bar(np.arange(len(param_names)) + 0.2 * i, z, width=0.15, align='center', color=colors[i],
            label=labels[i])
ax1.set_ylabel('Z-Score')
ax1.tick_params(axis='y')
ax1.set_xticks(np.arange(len(param_names)) + 0.3)
ax1.set_xticklabels(log_param_names)
ax1.axhline(0, color='gray', linestyle='--', linewidth=0.8)
ax1.grid()

# Plot Contractions for both methods on secondary axis
ax2 = ax1.twinx()
for i, c in enumerate(contractions):
    ax2.plot(np.arange(0.3, len(param_names)), c, color=colors[i], marker='o', linestyle='--')
ax2.set_ylabel('--●--  Contraction')
ax2.tick_params(axis='y')
# get maximal y-limit of ax1
max_y = max(ax1.get_ylim()[1], ax2.get_ylim()[1])
min_y = min(ax1.get_ylim()[0], ax2.get_ylim()[0])
ax1.set_ylim(min_y, max_y)
ax2.set_ylim(min_y, max_y)

# Combine legends
handles1, labels1 = ax1.get_legend_handles_labels()
fig.legend(handles1, labels1,
           loc='lower center', bbox_to_anchor=(0.5, -0.12), ncol=2)
fig.savefig(os.path.join(gp, f'{results_path}/synthetic_z_scores_contraction.pdf'), bbox_inches='tight')
plt.show()

In [None]:
#ordering = [0,4,1,5,2,6,3,7]
ordering = np.concatenate([[i,i+4, i+8, i+12] for i in range(4)])
all_params = np.concatenate((posterior_samples['abc'],
                             posterior_samples['abc_mean'],
                             posterior_samples['abc_npe'],
                             posterior_samples['npe']), axis=-1)
log_param_names_plot = np.array(
    [f'{n} $\qquad$ ABC hand-crafted' for n in log_param_names] +
    [f'ABC posterior mean' for n in log_param_names] +
    [f'ABC inference tailored' for n in log_param_names] +
    [f'NPE' for n in log_param_names]
)[ordering]
param_names_plot = np.array(
    [f'{n} $\qquad$ ABC hand-crafted' for n in param_names] +
    [f'ABC posterior mean' for n in param_names] +
    [f'ABC inference tailored' for n in param_names] +
    [f'NPE' for n in param_names]
)[ordering]
color_list = colors*len(param_names)

ax = sampling_parameter_cis(
    all_params[:, ordering],
    true_param=np.concatenate((test_params, test_params, test_params, test_params))[ordering] if test_params is not None else None,
    prior_bounds=limits_log.values() if test_params is not None else None,
    param_names=log_param_names_plot,
    alpha=[95, 90 , 80],
    color_list=color_list,
    show_median=False if test_params is not None else True,
    size=(7, 4),
    legend_bbox_to_anchor=(0.45,1) if test_params is not None else (0.31,1)
)
plt.savefig(os.path.join(gp, f'{results_path}/synthetic_posterior_credible_intervals_log.pdf'))
plt.show()

all_params = np.power(10, all_params)
ax = sampling_parameter_cis(
    all_params[:, ordering],
    true_param=np.power(10, np.concatenate((test_params, test_params, test_params, test_params))[ordering]) if test_params is not None else None,
    param_names=param_names_plot,
    alpha=[95, 90 , 80],
    color_list=color_list,
    show_median=False if test_params is not None else True,
    size=(7, 4),
    legend_bbox_to_anchor=(0.99,0.35)
)
#plt.savefig(os.path.join(gp, f'{results_path}/synthetic_posterior_credible_intervals.pdf'))
plt.show()

## Compare Simulations from Posterior Samples

In [None]:
file_name = os.path.join(gp, f'{results_path}/synthetic_posterior_simulations.pickle')
if os.path.exists(file_name):
    with open(file_name, 'rb') as f:
        posterior_simulations = pickle.load(f)
else:
    posterior_simulations = {}
    for name, ps in posterior_samples.items():
        @delayed
        def wrapper_fun(sample_i):
            _sim_dict = {key: p for key, p in zip(obs_pars.keys(), ps[sample_i])}
            _posterior_sim = model(_sim_dict)
            return _posterior_sim['sim']

        sim_list = Parallel(n_jobs=n_procs, verbose=1)(wrapper_fun(i) for i in range(10))
        posterior_simulations[name] = np.concatenate(sim_list)

    with open(file_name, 'wb') as f:
        pickle.dump(posterior_simulations, f)

simulation_sumstats = {}
for name, ps in posterior_simulations.items():
    simulation_sumstats[name] = [make_sumstat_dict(p_sim[np.newaxis]) for p_sim in ps]
    for i in range(len(simulation_sumstats[name])):
        simulation_sumstats[name][i]['nn'] = make_sumstat_dict_nn(ps[i][np.newaxis], use_npe_summaries=True)['summary_net']

In [None]:
def obj_func_comparison(sim: dict, obs: dict, return_marginal: bool = False, weights: Union[dict, list] = None) -> Union[float, np.ndarray]:
    total = np.zeros(len(sim.keys()))
    for k_i, key in enumerate(sim):
        if key == 'nn':
            # for the neural network summary statistics we use the Euclidean distance
            total[k_i] = euclidean_distance(sim['nn'], obs['nn'])
            continue # no weights applied
        else:
            total[k_i] = distances[key](sim, obs)
        if weights is not None:
            if isinstance(weights, dict):
                total[k_i] = total[k_i] * weights[key]
            elif isinstance(weights, list):
                total[k_i] = total[k_i] * weights[k_i]
            else:
                raise ValueError('Weights must be a list or a dictionary')
    if return_marginal:
        return total
    return total.sum()

In [None]:
test_sim_dict = make_sumstat_dict(test_sim)
test_sim_dict['nn'] = make_sumstat_dict_nn(test_sim, use_npe_summaries=True)['summary_net']

In [None]:
plot_sumstats_distance_stats(obj_func_comparison,
                             test_sim_dict,
                             [ps for ps in simulation_sumstats.values()],
                             labels=labels,
                             title='Wasserstein Distance',
                             colors=colors,
                             ylog_scale=True,
                             path=os.path.join(gp, f'{results_path}/synthetic_sumstats_comparison.pdf')
                             )

print(*test_sim_dict.keys())
print(adaptive_weights)
plot_sumstats_distance_stats(partial(obj_func_comparison, weights=adaptive_weights),
                             test_sim_dict,
                             [ps for ps in simulation_sumstats.values()],
                             labels=labels,
                             colors=colors,
                             ylog_scale=True,
                             #path=os.path.join(gp, f'{results_path}/synthetic_sumstats_comparison.pdf')
                             )