In [None]:
import logging
import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
from torch import tensor
from sbi import analysis as analysis
import scipy.stats
import scipy.stats.mstats

import inference
import inference.priors
import inference.plotting

logging.basicConfig()
logging.getLogger().setLevel(logging.WARNING)

# The relevant files. Make sure the data files exist, and that the directories are set up for the output files
params_file = "../data/inference/params_widen2.feather"
stats_file = "../data/inference/stats_widen2.csv"
samples_file = "../output/inference/posterior_samples.csv"
posterior_file = "../output/inference/posterior.pkl"

## Read in data

You'll need to have the params.feather file (`params_file` above) and the stats.csv file (`stats_file` above).  
Then make sure that the output directories are created (../output/inference and ../plots/inference for the plots).

The stats csv has a column named 'index' which tells you which row of the params file the statistics come from.

In [None]:
stats_df = pd.read_csv(stats_file, index_col="index").sort_values("index").drop(columns=["random_seed"]).dropna()

# Only take the priors that there are summary statistics for in the table above
params_df = pd.read_feather(params_file)
params_df = params_df[params_df.index.isin(stats_df.index)]

# Get the parameters and outputs into torch tensors
theta = tensor(params_df.values, dtype=torch.float32)
x = tensor(stats_df.values, dtype=torch.float32)

# Create joint prior
prior = inference.priors.join_priors()
if not all(np.array(params_df.columns) == np.array(prior.params)):
    raise Exception("Parameter names do not match between params file and prior")
# Save prior samples to a feather file:
# pd.DataFrame(np.array(prior.sample((100000,))), columns=params_df.columns).to_feather("../data/inference/params_new.feather")

num_simulations = len(stats_df.index)
num_params = len(params_df.columns)
num_stats = len(stats_df.columns)
num_samples = num_simulations
print(f'Using {num_simulations} simulations, {num_params} params, {num_stats} summary stats and making {num_samples} samples')

## Training

Create the posterior, this should take
- ~3-4 minutes for NPE
- ~20 minutes for 10 rounds of SNPE

In [None]:
%%time

# posterior = inference.NPE(theta, x, dump_to_file=posterior_file)

posterior = inference.SNPE(theta, x, num_rounds=100, density_estimator='maf', dump_to_file=posterior_file, training_args={'use_combined_loss': True})

# posterior = inference.NLE(theta, x, dump_to_file=posterior_file, mcmc_parameters={
#     "num_chains": 1,
#     "thin": 1,
#     "warmup_steps": 100,
#     "init_strategy": "sir",
#     "sir_batch_size": 1000,
#     "sir_num_batches": 100,
# })

print(posterior)

Alternatively, load the posterior from the pickled file, if training has been done previously:

In [None]:
if 'posterior' in globals() and not 'overwrite_posterior' in globals():
    print('The posterior is already defined, run this cell again to overwrite it')
    overwrite_posterior = True
else:
    posterior = inference.load_posterior(posterior_file)

## Sampling & Plotting
Once you have the posterior, you can run any of these tasks individually

##### **Plot the priors, together with samples from `params_file`**

In [None]:
# Plot the priors with samples from them, and the samples in the feather file
prior_samples = np.array(prior.sample((num_simulations,)))
file_samples = np.array(theta)
fig, axes = inference.plotting.plot_samples_vs_prior(prior, [prior_samples, file_samples], ["Prior samples", "File samples"])
fig.suptitle("Samples drawn from joint prior distribution, compared with parameters from feather file. Red line shows prior distribution.")
plt.show()

##### **Plot marginal samples**

In [None]:
posterior_samples = posterior.sample((num_samples,), x=x[0])
fig, axes = inference.plotting.plot_samples_vs_prior(prior, posterior_samples, "Posterior samples", axsize=4, num_cols=4)
for ax, actual in zip(axes, theta[0]):
    ax.axvline(actual, ymax=1/1.2, lw=1, c=(0, 0, 0, 0.5), label="Actual theta")
    ax.legend()
fig.suptitle(f'{posterior_samples.shape[0]} samples drawn from posterior distributions of each parameter')
plt.savefig("../plots/inference/posterior_samples.jpg")
plt.show()

##### **Pairplot**

In [None]:
posterior_samples = posterior.sample((num_samples,), x=x[0])
_ = analysis.pairplot(posterior_samples, labels=prior.params, figsize=(20, 20))
plt.savefig("../plots/inference/pairplot.jpg", bbox_inches='tight')
plt.show()

##### **Plot means of samples for each parameter, against the actual values**

In [None]:
theta_sample_means = np.array([]).reshape((16, 0))
theta_sample_errors = np.array([]).reshape((16, 2, 0))
num_sample_sets = 100
for i in range(num_sample_sets):
    print(f'Sampling posterior [{i}/{num_sample_sets}]', end='', flush=True)
    samples = np.array(posterior.sample((2000,), x=x[i], show_progress_bars=False)) # (num_params, num_samples)
    
    means = np.mean(samples, axis=0) # (num_params,)
    quantiles = scipy.stats.mstats.mquantiles(samples, prob=[0.025, 0.975], axis=0).T # (num_params, 2)
    errors = np.array([np.abs(q - m) for q, m in zip(quantiles, means)]) # (num_params, 2)
    
    theta_sample_means = np.concatenate((theta_sample_means, means.reshape(num_params, 1)), axis=1)
    theta_sample_errors = np.concatenate((theta_sample_errors, errors.reshape(num_params, 2, 1)), axis=2)
    print("\r                                  \r", end='')
    
print(f'Sampled posterior {num_sample_sets} times.')
# theta_sample_means is of shape (# of params, # of means calculated)
# theta_sample_errors is of shape (# of params, 2, # of means calculated)

In [None]:
fig, axes = inference.plotting.create_params_plot(prior, axsize=4, num_cols=4)
for ax, actual, mean, err, param in zip(axes, np.array(theta[:num_sample_sets]).T, theta_sample_means, theta_sample_errors, prior.params):
    ax.set_xlabel("Actual value")
    ax.set_ylabel("Mean of posterior samples")
    ax.errorbar(actual, mean, yerr=err, marker='.', color=(0.3, 0.3, 1), linestyle='', elinewidth=1, ecolor=(1, 0.3, 0.3, 0.2))
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    lim = (min(xlim[0], ylim[0]), max(xlim[1], ylim[1]))
    ax.plot(lim, lim, ls="--", c=(0, 0, 0, 0.3))
plt.savefig("../plots/inference/samples_vs_actual.jpg", bbox_inches='tight')