In [None]:
import logging
from os.path import exists
import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
from torch import tensor
from sbi import analysis as analysis
import scipy.stats
import scipy.stats.mstats

import allel
from sim.model import GenotypeData
from sim.sum_stats import simple_sum

import inference
import inference.priors
import inference.analysis

logging.basicConfig()
logging.getLogger().setLevel(logging.WARNING)

# The relevant files. Make sure the data files exist, and that the directories are set up for the output files
FILE_SUFFIX = "_real"
params_file = "../data/inference/params_widen2mean.feather"
stats_file = "../data/inference/stats_widen2mean.csv"
x_observed_file = "../data/inference/stats_observed.csv"
samples_file = f'../output/inference/posterior_samples{FILE_SUFFIX}.csv'
posterior_file = f'../output/inference/posterior{FILE_SUFFIX}.pkl'

## Read in data

You'll need to have the params.feather file (`params_file` above) and the stats.csv file (`stats_file` above).  
Then make sure that the output directories are created (../output/inference and ../plots/inference for the plots).

The stats csv has a column named 'index' which tells you which row of the params file the statistics come from.

In [None]:
stats_df = pd.read_csv(stats_file, index_col="index").sort_values("index").drop(columns=["random_seed"]).dropna()

# Only take the priors that there are summary statistics for in the table above
params_df = pd.read_feather(params_file)
params_df = params_df[params_df.index.isin(stats_df.index)]
batch_size = int(math.floor(len(params_df.values) / 2))

# Get the parameters and outputs into torch tensors
theta = tensor(params_df.values[:batch_size], dtype=torch.float32)
x = tensor(stats_df.values[:batch_size], dtype=torch.float32)

theta_test = tensor(params_df.values[batch_size:batch_size * 2], dtype=torch.float32)
x_test = tensor(stats_df.values[batch_size:batch_size * 2], dtype=torch.float32)

# Create joint prior
prior = inference.priors.join_priors()
if not all(np.array(params_df.columns) == np.array(prior.params)):
    raise Exception("Parameter names do not match between params file and prior")
# Save prior samples to a feather file:
# pd.DataFrame(np.array(prior.sample((100000,))), columns=params_df.columns).to_feather("../data/inference/params_new.feather")

num_simulations = theta.shape[0]
num_tests = theta_test.shape[0]
num_params = len(params_df.columns)
num_stats = len(stats_df.columns)
num_samples = num_tests
print(f'Using {num_simulations} simulations, {num_tests} test samples, {num_params} params, {num_stats} summary stats and making {num_samples} samples')

In [None]:
def read_vcf(vcf_file, info_file):
    callset = allel.read_vcf(vcf_file)
    pop = np.genfromtxt(info_file, dtype="str", usecols=1, skip_header=1)
    subpops = {pop_name: np.where(pop == pop_name)[0] for pop_name in np.unique(pop)}
    changekeys = [("CAP", "captive"), ("DOM", "domestic"), ("WILD", "wild")]
    for keys in changekeys:
        subpops[keys[1]] = subpops.pop(keys[0])
    subpops["all_pops"] = np.arange(len(pop))
    data = GenotypeData(callset=callset, subpops=subpops, seq_length=64340295)
    return subpops, data

try:
    x_observed = np.loadtxt(x_observed_file, delimiter=',')
    print(f'Loaded summary statistics from "{x_observed_file}"')
except:
    print("Reading vcf...")
    geno_pops, geno_data = read_vcf("../data/E1/E1.vcf", "../data/E1/SampleInfo.txt")
    print("Calculating summary statistics...")
    x_observed = np.array(list(simple_sum(geno_data).values()))
    np.savetxt(x_observed_file, x_observed, delimiter=',')

# Standard deviations away from mean of each statistic
mu = x.mean(0)
sigma = x.std(0)
x_observed_err = (tensor(x_observed) - mu) / sigma
print(f'{len(x_observed_err[torch.abs(x_observed_err) < 3])}/{len(x_observed_err)} statistics within 3 standard devs of simulated statistics')

## Training

Create the posterior, this should take
- ~3-4 minutes for NPE
- ~20 minutes for 10 rounds of SNPE

In [None]:
%%time

if exists(posterior_file) and not 'overwrite_posterior_file' in globals():
    print(f'The posterior file already exists ({posterior_file}), run this cell again to overwrite it')
    overwrite_posterior_file = True
else:
    # posterior = inference.NPE(theta, x, dump_to_file=posterior_file)

    posterior = inference.SNPE(theta, x, x_o=x_observed, num_rounds=100, density_estimator='maf', dump_to_file=posterior_file, training_args={'use_combined_loss': True})

    # posterior = inference.NLE(theta, x, dump_to_file=posterior_file, mcmc_parameters={
    #     "num_chains": 1,
    #     "thin": 1,
    #     "warmup_steps": 100,
    #     "init_strategy": "sir",
    #     "sir_batch_size": 1000,
    #     "sir_num_batches": 100,
    # })

    print(posterior)
    posterior_samples = posterior.sample((num_samples,), x=x_test[0])
    print(f'Taken {num_samples} samples, saving to "{samples_file}"')
    pd.DataFrame(np.array(posterior_samples), columns=prior.params).to_csv(samples_file, index=False)

Alternatively, load the posterior from the pickled file, if training has been done previously:

In [None]:
if 'posterior' in globals() and not 'overwrite_posterior' in globals():
    print('The posterior is already defined, run this cell again to overwrite it')
    overwrite_posterior = True
else:
    posterior = inference.load_posterior(posterior_file)

## Sampling & Plotting
Once you have the posterior, you can run any of these tasks individually

##### **Plot the priors, together with samples from `params_file`**

In [None]:
# Plot the priors with samples from them, and the samples in the feather file
prior_samples = np.array(prior.sample((num_simulations,)))
file_samples = np.array(theta)
fig, axes = inference.analysis.plot_samples_vs_prior(prior, [prior_samples, file_samples], ["Prior samples", "File samples"])
fig.suptitle("Samples drawn from joint prior distribution, compared with parameters from feather file. Red line shows prior distribution.")
plt.show()

##### **Plot means of samples for each parameter, against the actual values** (using test data)

In [None]:
fig, axes = inference.analysis.plot_means_against_theta(prior, posterior, theta_test[:100], x_test[:100])
plt.savefig(f'../plots/inference/samples_vs_actual{FILE_SUFFIX}.jpg', bbox_inches='tight')
plt.show()

##### **Plot marginal samples** (using test data)

In [None]:
posterior_samples = posterior.sample((num_samples,), x=x_test[0])
fig, axes = inference.analysis.plot_samples_vs_prior(prior, posterior_samples, "Posterior samples", axsize=4, num_cols=4)
for ax, actual in zip(axes, theta_test[0]):
    ax.axvline(actual, ymax=1/1.2, lw=1, c=(0, 0, 0, 0.5), label="Actual theta")
    ax.legend()
fig.suptitle(f'{posterior_samples.shape[0]} samples drawn from posterior distributions of each parameter')
plt.savefig(f'../plots/inference/posterior_samples{FILE_SUFFIX}.jpg')
plt.show()

##### **Pairplot** (using test data)

In [None]:
posterior_samples = posterior.sample((num_samples,), x=x_test[0])
_ = analysis.pairplot(posterior_samples, labels=prior.params, figsize=(20, 20))
plt.savefig(f'../plots/inference/pairplot_{FILE_SUFFIX}.jpg', bbox_inches='tight')
plt.show()

##### **KL Divergence** (using real data)

In [None]:
# kl = []
# num_calcs = 50
# for i in range(num_calcs):
#     print(f'[{i}/{num_calcs}]', end='', flush=True)
#     divergence = inference.analysis.kl_divergence(prior, posterior, x=x_test[i], num_samples=2000, base=2)
#     kl.append(divergence)
#     print("\r                                  \r", end='', flush=False)
# print(f'{min(kl)=}, {max(kl)=}, {np.mean(kl)=}, {np.std(kl)=}')

divergence = inference.analysis.kl_divergence(prior, posterior, x=x_observed, num_samples=num_samples, base=2)
divergence

##### **Two NN Intrinsic Dimension Estimation** (using real data)

In [None]:
posterior_samples = posterior.sample((num_samples,), x=x_observed)
dimensions = inference.analysis.twonn_dimension(posterior_samples)
print(dimensions)

##### **Plot marginal samples** (using real data)

In [None]:
posterior_samples = posterior.sample((num_samples,), x=x_observed)
fig, axes = inference.analysis.plot_samples_vs_prior(prior, posterior_samples, "Posterior samples", axsize=4, num_cols=4)
fig.suptitle(f'{posterior_samples.shape[0]} samples drawn from posterior distributions of each parameter')
plt.savefig(f'../plots/inference/posterior_samples_real_{FILE_SUFFIX}.jpg')
plt.show()

##### **Pairplot** (using real data)

In [None]:
posterior_samples = posterior.sample((num_samples,), x=x_observed)
_ = analysis.pairplot(posterior_samples, labels=prior.params, figsize=(20, 20))
plt.savefig(f'../plots/inference/pairplot_real_{FILE_SUFFIX}.jpg', bbox_inches='tight')
plt.show()