# Assorted sampling
This notebook is intended to give users templates of using different samplers as a way to fit their code. It is assumed that you have already read/used emulator_demo before reading this notebook. This notebook is fairly minimal on explanation but each section gives you what is needed to use a particular sampling technique. We cover 4 different sampling techniques here but it should be acknowledged that there are even more techniques out there as well as packages to implement them which should be somewhat easy to implement as long as you know how to call RTFAST and convolve it with a response.

## Nested sampling (dynesty)
This is the technique used in emulator_demo. This section is largely a repeat of that seen in emulator_demo.

We can also explore our parameter space with nested sampling. You can find out more about dynesty and how nested sampling works here: https://dynesty.readthedocs.io/en/stable/. What we care about is how it can help us map the posterior of our fairly complex and degenerate model.

Nested sampling boasts some nice gains, prinicipally being able to derive statistical uncertainties as well as being able to sample from complex multimodal distributions. It also tends to be quite fast in comparison to emcee as well as being able to retrieve the Bayesian evidence - key for model comparison.

## MCMC (Emcee)
X-ray astronomers tend to already be familiar with emcee - they already use it in the various X-ray fitting software available. Rtdist has been used with emcee before, albeit it converges on the order of months. This obviously makes it practically extremely difficult to use this in actual fitting. Emcee doesn't exactly take no time to converge for the emulator either, but this is reduced to the order of hours on a standard laptop. This makes investigation with emcee much more achievable.

In [None]:
from scipy.optimize import minimize

Before we actually run our emcee, we need to establish a basic starting point for our emcee to start. For this, we'll use scipy's minimize optimizer to find a best starting point so that our walkers don't end up getting stuck in low-probability space. We'll then perturb that best starting point to form a Gaussian ball around our best starting point.

In [None]:
#specify number of dimensions (parameters) and number of walkers
ndim, nwalkers = 20, 100

nll = lambda *args: -log_likelihood(*args)
p0 = np.random.rand(ndim)
initial = (p0*(limits[:,1]-limits[:,0])) + limits[:,0]
soln = minimize(nll, np.asarray(nn_pars))

start_pos = soln.x[:,np.newaxis] + np.random.randn(ndim,nwalkers)*0.01
start_pos = start_pos.T

#create emcee sampler
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_likelihood)
print("Sampler started")
max_n = 50000 # maximum iterations, set this larger if not converging

# We'll track how the average autocorrelation time estimate changes
index = 0
autocorr = np.empty(max_n)

# This will be useful to testing convergence
old_tau = np.inf

# Now we'll sample for up to max_n steps
for sample in sampler.sample(start_pos, iterations=max_n, progress=True):
    # Only check convergence every 500 steps
    if sampler.iteration % 500:
        continue

    # Compute the autocorrelation time so far
    # Using tol=0 means that we'll always get an estimate even
    # if it isn't trustworthy
    tau = sampler.get_autocorr_time(tol=0)
    autocorr[index] = np.mean(tau)
    index += 1

    # Check convergence
    converged = np.all(tau * 100 < sampler.iteration)
    converged &= np.all(np.abs(old_tau - tau) / tau < 0.01)
    if converged:
        print("Converged")
        break
    old_tau = tau

We should check the acceptance fraction as well as see if we've actually converged

In [None]:
print(
    "Mean acceptance fraction: {0:.3f}".format(
        np.mean(sampler.acceptance_fraction)
    )
)
print(
    "Mean autocorrelation time: {0:.3f} steps".format(
        np.mean(sampler.get_autocorr_time())
    )
)

As a letter of caution, rtdist tends to have a low acceptance fraction so make sure you actually converge correctly. Let's plot our chains

In [None]:
fig, axes = plt.subplots(len(labels), figsize=(10, len(labels)*3), sharex=True)
samples = sampler.get_chain()
true_pars = np.asarray(nn_pars)
for i in range(ndim):
    ax = axes[i]
    ax.plot(samples[:, :, i], "k", alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])
    ax.yaxis.set_label_coords(-0.1, 0.5)
    ax.axhline(nn_pars[i],ls="--",c="b")

axes[-1].set_xlabel("step number");

Let's throw away 1000 steps to be safe, and thin our posterior by a factor of 5 and plot the posteriors

In [None]:
flat_samples = sampler.get_chain(discard=10000,thin=3,flat=True)
figure = corner.corner(
    flat_samples,
    labels=labels,show_titles=True,
    truths = np.asarray(nn_pars),
    title_kwargs={"fontsize": 12},
)
plt.show()

## Hamilton Monte Carlo (Pyro)
Let's say you're bored of waiting around for a few hours for your fits, or maybe you want to fit multiple objects simultaneously. Your parameter space is suddenly considerably larger and techniques like MCMC and nested sampling begin to struggle to sample your parameter space effectively and the time to converge becomes even LONGER. We can use a variation of Monte Carlo samplers called a Hamilton Monte Carlo (HMC) sampler.

HMCs have a key difference: they use gradients of your parameter space in respect to the likelihood to better sample the posterior and increase the acceptance rate of your evaluations. There are also several implementations of HMCs that are GPU compatible in Python. As our neural network is GPU compatible, we can load it onto the GPU and perform our fitting on GPUs which greatly speeds up model evaluation as well as achieving effortless parallelisation.

Thus, we achieve two things: parallelisation of sample evaluations meaning we can increase the number of walkers in our sampler considerably (leading to a better coverage of the posterior) as well as speeding up both our evaluations and needing to perform fewer evaluations to converge.

HMCs scale effectively into higher dimensions and allow us to start building hierarchical models that allow us to constrain population statistics and other things that every source should "know" about each other. This opens up a lot of interesting new avenues of research that we couldn't really access by hacking x-ray fitting software such as Xspec.

We'll use the package called Pyro in this notebook as it is based off pytorch that our emulator is built on.

In [None]:
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro.infer import MCMC, NUTS
from network import RTFAST
from pyro.nn import PyroModule

Let's write a function that uses pyro's structure to sample.

In [None]:
emulator = SpectralEmulator()
pyro_emulator = PyroModule(emulator)

def pyro_model(data):
    h   = pyro.sample("h",   dist.Uniform(np.log10(1.5), np.log10(10)))
    a   = pyro.sample("a",   dist.Uniform(0.5, 0.998))
    inc = pyro.sample("inc", dist.Uniform(np.log10(30), np.log10(80)))
    rin = pyro.sample("rin", dist.Uniform(np.log10(1), np.log10(20)))
    gam = pyro.sample("gam", dist.Uniform(2, 2.75))
    dis = pyro.sample("dis", dist.Uniform(np.log10(3.5e5), np.log10(5e7)))
    afe = pyro.sample("afe", dist.Uniform(np.log10(0.5), np.log10(3)))
    lNe = pyro.sample("lNe", dist.Uniform(15,20))
    nH  = pyro.sample("nH",  dist.Uniform(np.log10(1e-3), np.log10(1)))
    Ano = pyro.sample("Ano", dist.Uniform(np.log10(1e-4),np.log10(1e-1)))
    
    rout = torch.log10(torch.Tensor([2e4]))[0].double()
    z = torch.Tensor([0.024917])[0].double()
    kte = torch.log10(torch.Tensor([50]))[0].double()
    bst = torch.log10(torch.Tensor([1]))[0].double()
    mas = torch.log10(torch.Tensor([3e6]))[0].double()
    hor = torch.Tensor([0.02])[0].double()
    b1 = torch.Tensor([0])[0].double()
    b2 = torch.Tensor([0])[0].double()
    pAB = torch.Tensor([-0.8])[0].double()
    g = torch.Tensor([0.3])[0].double()
    # Stack parameters to create the input to the neural network
    params = torch.stack([h,a,inc,rin,rout,z,gam,dis,afe,lNe,kte,nH,bst,mas,hor,b1,b2,pAB,g,Ano])
    
    #call emulator and convolve with instrument response
    pred = pyro_emulator(params)
    convolved = torch.matmul(pred,resp.resp_matrix.double())
    pred = torch.transpose(convolved,0,-1)
    
    pred = torch.max(pred, torch.tensor(1e-11))  # Add a small epsilon to avoid zeros
    
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Poisson(pred), obs=data)


Now that we've defined our model in pyro, let's run a NUTS sampler to retrieve our posterior.

In [None]:
nuts_kernel = NUTS(pyro_model)

# Set up MCMC with the NUTS sampler
mcmc = MCMC(nuts_kernel, num_samples=10, warmup_steps=400, num_chains=1,)

# Run MCMC to sample from the posterior
mcmc.run(torch.Tensor(pois_obs))

# Get samples from the posterior
posterior_samples = mcmc.get_samples()

Let's check if the sampler has converged by calculating the Gelman-Rubin statistic. Our sampler has converged if our GR is less than 1.

In [None]:
import arviz as az

# Convert Pyro samples to InferenceData
inference_data = az.from_pyro(mcmc)
# Calculate Gelman-Rubin statistic (PSRF)
rhat = az.rhat(inference_data)
print(rhat)

Once we're happy that the sampler has converged, we can plot our posteriors.

In [None]:
figure = corner.corner(posterior_samples,show_titles=True,
    truths = np.asarray(nn_pars_hmc),
    title_kwargs={"fontsize": 12},
)
plt.show()

Finally, let's plot draws from our posteriors to see if our model outputs are reasonable compared to the data.

In [None]:
hmc_samples = []
names = []
for name, samples in posterior_samples.items():
    hmc_samples.append(samples)
    names.append(name)
hmc_samples = np.asarray(hmc_samples).T
#The following line rearranges the pyro inference object in the same order of parameters as RTFAST inputs
hmc_samples = hmc_samples[:,[5,1,6,-1,4,3,2,-3,-2,0]]

inds = np.random.randint(len(hmc_samples), size=100)
model_draws = []
for ind in inds:
    sample = hmc_samples[ind]
    model_eval = convolve_sim_fixed(sample)
    model_draws.append(model_eval)
    plt.plot(emid, model_eval, "C1", alpha=0.1)

model_draws = np.asarray(model_draws)
plt.plot(emid,np.mean(model_draws,axis=0),"r",label="Mean model")
plt.plot(emid, convolved, "k", label="truth",ls="--")
plt.legend(fontsize=14)
plt.xlabel("Energy/KeV")
plt.ylabel("Photons")
plt.xscale("log")
plt.yscale("log")
plt.ylim(1)
plt.show()

## Simulation based inference (SBI)
Yet another alternative fitting procedure is simulation based inference. Simulation based inference functions by training a neural network to learn what the posterior will look like for a simulator given a set of data. This method can be really fast, but it should be cautioned that there have been papers that have shown sbi to give optimistic posterior evaluations in comparison to MCMC: https://arxiv.org/abs/2110.06581. SBI also works a little differently than other traditional methods - by simulating the instrumental and underlying effects (Poissonian noise) of the source.

Let's start by importing the necessary packages for using sbi

In [None]:
import torch

from sbi.analysis import pairplot
from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils import BoxUniform
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)

In this case, we'll use the convolved simulation function that we used in nested sampling as our simulator- as that simulates what our final data looks like. Let's define our priors in SBI compatible form:

In [None]:
def simulator_sbi(theta):
    pred = convolve_sim_fixed(theta)
    return np.random.poisson(pred)

We then need to redefine our priors using sbi's prior formats.

In [None]:
num_dim = 3

height_range = [np.log10(1.5),np.log10(10)]
spin_range = [0.5,0.998]
inclination_range = [np.log10(30),np.log10(80)]
r_inner_range = [np.log10(1),np.log10(20)]
Gamma_range = [2,2.75]
distance_range = [np.log10(3.5e5),np.log10(5e7)]
Afe_range = [np.log10(0.5),np.log10(3)]
logNe_range = [15,20]
nH_range = [np.log10(1e-3),np.log10(1)]
anorm_range = [np.log10(1e-4),np.log10(1e-1)]

prior_ranges = [Gamma_range,nH_range,anorm_range]

prior_ranges = np.asarray(prior_ranges)

prior = BoxUniform(low=torch.Tensor(prior_ranges[:,0]), high=torch.Tensor(prior_ranges[:,1]))

We then need to simulate some observations for our network to train on

In [None]:
prior, num_parameters, prior_returns_numpy = process_prior(prior)
simulator_sbi = process_simulator(simulator_sbi, prior, prior_returns_numpy)
check_sbi_inputs(simulator_sbi, prior)

inference = SNPE(prior)

posteriors = []
proposal = prior

num_rounds = 2

for _ in range(num_rounds):
    print(_)
    theta, x = simulate_for_sbi(simulator_sbi, proposal, num_simulations=10000)
    density_estimator = inference.append_simulations(
        theta, x, proposal=proposal
    ).train()
    posterior = inference.build_posterior(density_estimator)
    posteriors.append(posterior)
    proposal = posterior.set_default_x(pois_obs)

Let's plot the resulting posteriors.

In [None]:
for posterior in posteriors:
    samples = posterior.sample((10000,), x=pois_obs)
    _ = pairplot(samples,points=np.asarray([2.45,np.log10(5e-2),np.log10(4e-4)]),
             labels=[r"$\Gamma$",r"$N_H$",r"$A_{norm}$"],
             quantiles=[0.03,99.7])
    plt.show()

And plot some posterior draws.

In [None]:
samples = posteriors[0].sample((10000,), x=pois_obs)
fig, axs = plt.subplots(2,sharex=True,figsize=(8,8))
inds = np.random.randint(0,samples.shape[0],size=500)
model_draws = []
for ind in inds:
    sample = samples[ind]
    model_eval = convolve_sim_fixed(sample)
    model_draws.append(model_eval)
    axs[0].plot(emid, model_eval, "C1", alpha=0.1)

model_draws = np.asarray(model_draws)
axs[0].plot(emid,np.mean(model_draws,axis=0),"r",label="Mean model")
axs[0].plot(emid, pois_obs, "k", label="observation",ls="--")
axs[0].legend(fontsize=14)
fig.supxlabel("Energy/KeV")
axs[0].set_ylabel("Photons")
axs[0].set_xscale("log")
axs[0].set_yscale("log")
axs[0].set_ylim(1)
axs[1].scatter(emid,(pois_obs-np.mean(model_draws,axis=0))/np.sqrt(pois_obs),s=1,marker="+")
axs[1].set_ylabel("(data-model)/data")
plt.tight_layout()
plt.show()