In [None]:
import logging
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle

import torch
import torch.distributions
from torch import tensor
from sbi.inference import SNLE, SNPE, prepare_for_sbi
from sbi.utils.user_input_checks import process_prior
from sbi import utils as utils
from sbi import analysis as analysis
import scipy.stats
import scipy.stats.mstats

import inference.priors as priors

logging.basicConfig()
logging.getLogger().setLevel(logging.WARNING)

# The relevant files. Make sure the data files exist, and that the directories are set up for the output files
params_file = "../data/inference/params.feather"
stats_file = "../data/inference/stats.csv"
samples_file = "../output/inference/posterior_samples.csv"
posterior_file = "../output/inference/posterior.pkl"

# Can scale priors to similar sizes. Doesn't seem to help, so leave as False
use_normalised_priors = False

## Read in data

You'll need to have the params.feather file (`params_file` above) and the stats.csv file (`stats_file` above).  
Then make sure that the output directories are created (../output/inference and ../plots/inference for the plots).

The stats csv has a column named 'index' which tells you which row of the params file the statistics come from.

In [None]:
stats_df = pd.read_csv(stats_file, index_col="index").sort_values("index").drop(columns=["random_seed"]).dropna()

# Only take the priors that there are summary statistics for in the table above
params_df = pd.read_feather(params_file)
params_df = params_df[params_df.index.isin(stats_df.index)]

# Scale params table accordingly
if use_normalised_priors:
    for param in params_df:
        params_df[param] -= priors.transforms[param]['loc']
        params_df[param] /= priors.transforms[param]['scale']

# Get the parameters and outputs into torch tensors
theta = tensor(params_df.values, dtype=torch.float32)
x = tensor(stats_df.values, dtype=torch.float32)

# Create joint prior
prior = priors.join_priors(normalise=use_normalised_priors)
if not all(np.array(params_df.columns) == np.array(prior.params)):
    raise Exception("Parameter names do not match between params file and prior")

num_simulations = len(stats_df.index)
num_params = len(params_df.columns)
num_stats = len(stats_df.columns)
num_samples = num_simulations
print(f'Using {num_simulations} simulations, {num_params} params ({"" if use_normalised_priors else "un"}normalised), {num_stats} summary stats and making {num_samples} samples')

In [None]:
# Plot the priors with samples from them, and the samples in the feather file
prior_samples = np.array(prior.sample((num_simulations,)))
file_samples = params_df.values
fig, axes = priors.plot_samples_vs_prior(prior, [prior_samples, file_samples], ["Prior samples", "File samples"])
fig.suptitle("Samples drawn from joint prior distribution, compared with parameters from feather file. Red line shows prior distribution.")
plt.show()

## Training

Create the posterior, this should take ~3-4 minutes

In [None]:
%%time
# Calculate the posterior
inference = SNPE(prior=process_prior(prior)[0], density_estimator='maf')
inference = inference.append_simulations(theta, x)
density_estimator = inference.train(show_train_summary=True)
posterior = inference.build_posterior(density_estimator)
with open(posterior_file, "wb") as f:
    pickle.dump(posterior, f)
    print(f'Dumped posterior object to "{posterior_file}"')
print(posterior)

Alternatively, load the posterior from the pickled file, if training has been done already:

In [None]:
# Read the posterior from the pkl file if the training hasn't been run
try:
    posterior
except:
    with open(posterior_file, "rb") as f:
        posterior = pickle.load(f)
        print(f'Loaded posterior from "{posterior_file}"')
        print(posterior)

## Sampling & Plotting
Once you have the posterior, you can run any of these tasks individually

**Take `num_samples` samples from the posterior, save to a csv and plot**

In [None]:
# posterior.set_mcmc_method('slice')
# posterior.set_mcmc_parameters({'thin': 1, 'num_chains': 1})
posterior_samples = posterior.sample((num_samples,), x=x[0], sample_with_mcmc=False)
print(f'Done, saving to "{samples_file}"')
pd.DataFrame(np.array(posterior_samples), columns=prior.params).to_csv(samples_file, index=False)

In [None]:
fig, axes = priors.plot_samples_vs_prior(prior, posterior_samples, "Posterior samples", axsize=4, num_cols=4)
fig.suptitle(f'{posterior_samples.shape[0]} samples drawn from posterior distributions of each parameter')
plt.savefig("../plots/inference/posterior_samples.jpg")
plt.show()

_ = analysis.pairplot(posterior_samples, labels=prior.params, figsize=(20, 20))
plt.show()
plt.savefig("../plots/inference/pairplot.jpg", bbox_inches='tight')

**Plot means of samples for each parameter, against the actual values**

In [None]:
theta_sample_means = np.array([]).reshape((16, 0))
theta_sample_errors = np.array([]).reshape((16, 2, 0))
num_sample_sets = 20
for i in range(num_sample_sets):
    print(f'Sampling posterior [{i}/{num_sample_sets}]', end='', flush=True)
    samples = np.array(posterior.sample((num_samples,), x=x[i], sample_with_mcmc=False, show_progress_bars=False)) # (num_params, num_samples)
    
    means = np.mean(samples, axis=0) # (num_params,)
    quantiles = scipy.stats.mstats.mquantiles(samples, prob=[0.025, 0.975], axis=0).T # (num_params, 2)
    errors = np.array([np.abs(q - m) for q, m in zip(quantiles, means)]) # (num_params, 2)
    
    theta_sample_means = np.concatenate((theta_sample_means, means.reshape(num_params, 1)), axis=1)
    theta_sample_errors = np.concatenate((theta_sample_errors, errors.reshape(num_params, 2, 1)), axis=2)
    print("\r                                  \r", end='')
    
print(f'Sampled posterior {num_sample_sets} times.')
# theta_sample_means is of shape (# of params, # of means calculated)
# theta_sample_errors is of shape (# of params, 2, # of means calculated)

In [None]:
fig, axes = priors.create_params_plot(prior, axsize=4, num_cols=4)
for ax, actual, mean, err, param in zip(axes, np.array(theta[:num_sample_sets]).T, theta_sample_means, theta_sample_errors, prior.params):
    ax.set_xlabel("Actual value")
    ax.set_ylabel("Mean of posterior samples")
    ax.errorbar(actual, mean, yerr=err, marker='x', linestyle='')
    ax.set_xlim(*ax.get_ylim())
plt.savefig("../plots/inference/samples_vs_actual.jpg", bbox_inches='tight')