# Inference tutorial

This notebook will go over:

- how to simulate an offline training dataset.
- how to train the approximate likelihood and posterior
- how to run the HMC update procedure

We begin by importing a few packages we will need:

In [None]:
import os
import time
import random
import logging
import torch as t
import numpy as np

from textwrap import wrap

import spt
import spt.config as cfg
import spt.inference.san as san
import spt.modelling.simulation as sim

from spt.types import Tensor
from spt.visualisation import plot_corner, plot_posteriors, ppplot
from spt.load_photometry import get_norm_theta, get_denorm_theta, get_denorm_theta_t, load_simulated_data, load_real_data

Since we're working inside a notebook, we'll change directory to the root of the SPItorch project so that we'll be able to access the example data and datasets in a portable way. We'll also take care of other setup stuff, which will become relevant in later tutorials.

In [None]:
try: # One-time setup
    assert(_SETUP)
except NameError:
    os.chdir(os.path.split(spt.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        !nvidia-smi -L
    else:
        print("CUDA is unavailable; training on CPU.")
        
    def dc(x: Tensor) -> Tensor:
        return x.detach().cpu()
    def dcn(x: Tensor) -> np.ndarray:
        return x.detach().cpu().numpy()
        
    _SETUP = True

## Loading Configurations

We begin by loading configurations from the configuration file (`./spt/config.py`):

In [None]:
ip = cfg.InferenceParams()
fp = cfg.ForwardModelParams()
dt = get_denorm_theta(fp)
dtt = get_denorm_theta_t(fp)

We can inspect the parameters for our Prospector forward model:

In [None]:
print(fp)

## Simulating a Dataset

To simulate a dataset using the Prospector forward model, configure the `SamplingParams` in the configuration file, and provide these to the entrypoint of the `simulation` module (alias `sim`).

In [None]:
sp = cfg.SamplingParams()
print(sp)

In [None]:
sim.main(sp)

## Maximum Likelihood Training of SAN Posterior

We will now use the simulated dataset generated above to train a neural density estimator (here we use the 'v2' variant of our _Sequential Autoregressive Network_) since it performs better.

<img src="https://share.maximerobeyns.com/sanv2.svg" max-width="800px" />

Here, we load the SANv2 parameters from the configuration, and initialise a the neural density estimator:

In [None]:
mp = cfg.SANv2Params()
Q = san.SANv2(mp)
logging.info(f'Initialised {Q}')

In [None]:
# This is not strictly necessary, but is useful for portability
Q.device = device
Q = Q.to(device)

### Load the training data

Before we can proceed with training, we must load up the training dataset that we just simulated above:

In [None]:
train_loader, test_loader = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=Q.params.batch_size,
    phot_transforms=[np.log, t.from_numpy],
    theta_transforms=[get_norm_theta(fp)],
)
logging.info('Created data loaders')

We will also create some convenience methods for later while we're at it:

In [None]:
# Convenience methods
tds = test_loader.dataset
def new_sim_sample() -> tuple[int, int, tuple[Tensor, Tensor]]:
    """Returns index in the test loader, index in the simulated dataset and (x, y) pair"""
    ds_idx = random.randint(0, len(tds)-1) # test loader index
    xys = tds[ds_idx]
    idx = tds.indices[ds_idx]
    sim_xs = t.from_numpy(xys[0]) if isinstance(xys[0], np.ndarray) else xys[0]
    sim_ys = t.from_numpy(xys[1]) if isinstance(xys[1], np.ndarray) else xys[1]
    return ds_idx, idx, (sim_xs.to(device, dtype), sim_ys.to(device, dtype))

def sim_sample_at(ds_idx: int) -> tuple[int, tuple[Tensor, Tensor]]:
    xys = tds[ds_idx]
    idx = tds.indices[ds_idx]
    sim_xs = t.from_numpy(xys[0]) if isinstance(xys[0], np.ndarray) else xys[0]
    sim_ys = t.from_numpy(xys[1]) if isinstance(xys[1], np.ndarray) else xys[1]
    return idx, (sim_xs.to(device, dtype), sim_ys.to(device, dtype))

### Run the Training Procedure

We can now proceed to call the training method as follows:

In [None]:
Q.offline_train(train_loader, ip)
logging.info('ML training of approximate posterior complete.')

As a quick evaluation to see whether the trained model is any good, we can visualise some posteriors:

In [None]:
Q.eval()  # put the model in 'evaluation mode'
ds_idx, idx, (sim_xs, sim_ys) = new_sim_sample()  # pick a random sample

sim_xs, _ = Q.preprocess(sim_xs, sim_ys)
post_samples = dcn(Q.sample(sim_xs, 10000)).squeeze()

plot_corner(samples=post_samples,
            true_params=dcn(sim_ys).squeeze(),
            lims=fp.free_param_lims(normalised=True),
            title=f'$Q(\\theta \\vert x_{{{idx:,}}})$ simulated test point posterior',
            description="\n".join(wrap(str(Q), 160)))
t.cuda.empty_cache()

We can also plot samples from the approximate posteior against the 'ground truth' values for a sample of points:

In [None]:
n, samples = 1000, 1000
test_xs, test_ys = test_loader.dataset[:n]
test_xs = t.from_numpy(test_xs) if isinstance(test_xs, np.ndarray) else test_xs
test_ys = t.from_numpy(test_ys) if isinstance(test_ys, np.ndarray) else test_ys
test_xs, test_ys = Q.preprocess(test_xs, test_ys)

with t.inference_mode():
    Q.eval()
    _ = Q(test_xs, True)
    test_y_hat = Q.likelihood._gmm_from_params(Q.last_params).sample((samples,)).reshape(-1, 6)
    plot_ys = test_ys[None, :].expand((samples, n, 6)).reshape(-1, 6)

plot_posteriors(test_y_hat.cpu().numpy(), plot_ys.cpu().numpy(),
                labels=fp.ordered_free_params, 
                title='Posterior samples for simulated test points', 
                description=f'{samples} samples drawn for {n} test data points, plotted against the true values.')

## Maximum Likelihood training of neural likelihood

We can now repeat a similar procedure to train the neural likelihood. There are a couple of exceptions:
- we must remember to swap the dimensions of the inputs and outputs during preprocessing

  To help with this, the `san.PModel` (or `san.Pmodelv2` for `SANv2`) implements the required preprocessing steps.
  
- we can configure the network to be smaller since the likelihood is a simpler distribution to approximate.

In [None]:
slp = cfg.SANv2LikelihoodParams()
P = san.PModelv2(slp)
ip.ident = "ML_likelihood"

In [None]:
P.offline_train(train_loader, ip)

### Evaluating the Neural Likelihood

To check that the training was reasonable successful, we can plot sampled points against the 'ground truths':

In [None]:
n = 1000
test_xs, test_ys = test_loader.dataset[:n]
test_xs = t.from_numpy(test_xs) if isinstance(test_xs, np.ndarray) else test_xs
test_ys = t.from_numpy(test_ys) if isinstance(test_ys, np.ndarray) else test_ys
test_xs, test_ys = P.preprocess(test_xs, test_ys)
    
with t.inference_mode():
    _ = P(test_xs, True)
    test_y_hat = P.likelihood._gmm_from_params(P.last_params).mean
    plot_ys = test_ys

test_y_hat, plot_ys = map(dcn, (test_y_hat, plot_ys))
plot_posteriors(test_y_hat, plot_ys,
                labels=list(range(mp.cond_dim)), lims=False,
                title='$P_{w}(x \\vert \\theta)$ for simulated test points', 
                description=f'Expected (normalised) flux values for {n} test data points, plotted against the true values.')

# HMC Update Procedure

Here we run the HMC update procedure on the weights of the approximate posterior, with examples for using the simulated and real data.

## HMC update procedure with real data

For real surveys, we will want to run the HMC update procedure on real data from (a subset of) a survey.

We begin by recreating some data loaders, using the HMC update batch size (to allow us to control memory usage).

In [None]:
real_train_loader, real_test_loader = load_real_data(
    path=ip.catalogue_loc, filters=fp.filters, split_ratio=ip.split_ratio,
    batch_size=ip.hmc_update_batch_size, 
    transforms=[t.from_numpy], x_transforms=[np.log],
)

In [None]:
ip.ident = ip.hmc_update_real_ident
Q.hmc_retrain_procedure(real_train_loader, ip, P=P, epochs=ip.hmc_update_real_epochs,
                        K=ip.hmc_update_real_K, lr=3e-4, decay=1e-4)
logging.info('Updated on real data')

## HMC update procedure on simulated data

Alternatively, we can run the update procedure on the simulated data, which will allow us to create evaluation plots against the 'ground truth' values:

In [None]:
hmc_train_loader, hmc_test_loader = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=ip.hmc_update_batch_size,
    phot_transforms=[t.from_numpy, np.log],
    theta_transforms=[get_norm_theta(fp)],
)
logging.info('Created data loaders')

In [None]:
ip.ident = ip.hmc_update_sim_ident
Q.hmc_retrain_procedure(hmc_train_loader, ip, P=P, epochs=ip.hmc_update_sim_epochs, 
                        K=ip.hmc_update_sim_K, lr=3e-4, decay=1e-4)

# Evaluations

In [None]:
# We re-use this prospector instance to make plotting a little faster later
dummy_obs = spt.load_observation()
p = spt.Prospector(dummy_obs)

## Setup baseline for comparison

In [None]:
ip = cfg.InferenceParams()
mp = cfg.SANv2Params()
Q_base = san.SANv2(mp)
Q_base.offline_train(train_loader, ip)

In [None]:
Q_base.device = device
Q_base = Q_base.to(device)

In [None]:
Q.eval()
Q_base.eval()

ds_idx, idx, (sim_xs, sim_ys) = new_sim_sample()
logging.info(f'ds_idx: {ds_idx}')

sim_xs, sim_ys = Q.preprocess(sim_xs, sim_ys)
start = time.time()
with t.inference_mode():
    san_mode = Q.mode(sim_xs, 10000)
    base_mode = Q_base.mode(sim_xs, 10000)
san_mode = dt(san_mode.cpu().squeeze().numpy())
base_mode = dt(base_mode.cpu().squeeze().numpy())
true_ys = dt(sim_ys.cpu().numpy())

phot_obs = np.exp(sim_xs.squeeze().cpu().numpy())
obs = spt.load_photometry.sim_observation(fp.filters, phot_obs, index=idx, dset=ip.dataset_loc)
p.set_new_obs(obs)

p.visualise_model(theta=[san_mode, base_mode, #true_ys
                        ], theta_labels=["SAN (HMC update)",  "Baseline SAN", #"True"
                                        ],
                  show=True, save=False, title=f'Forward Model Predictions (simulated point, index {ds_idx})')

In [None]:
sim_ds = spt.load_photometry.InMemoryObsDataset(
    ip.dataset_loc,  
    phot_transforms=[lambda x: t.from_numpy(np.log(x))],  
    theta_transforms=[get_norm_theta(fp)])
tmp_xs = sim_ds.get_xs()[:10000]
tmp_ys = sim_ds.get_ys()[:10000]
sim_xs = t.from_numpy(tmp_xs) if isinstance(tmp_xs, np.ndarray) else tmp_xs
sim_ys = t.from_numpy(tmp_ys) if isinstance(tmp_ys, np.ndarray) else tmp_ys

sim_xs, sim_ys = Q.preprocess(sim_xs, sim_ys)

In [None]:
with t.inference_mode():
    Q.eval()
    _ = Q(sim_xs, True)
    sim_y_hat = Q.likelihood._gmm_from_params(Q.last_params).sample((100,)).reshape(-1, 6)
    plot_ys = sim_ys[None, :].expand((100, 10000, 6)).reshape(-1, 6)

In [None]:
ppplot(sim_y_hat.cpu().numpy(), plot_ys.cpu().numpy(),
       labels=fp.ordered_free_params,
       title='Probability-Probability plot',
       description='"True" Simulated CDF vs Prediction CDF')