# Does Sampling Order Matter?

In autoregressive models, we generate each dimension of the multivariate posterior distribution sequentially, with each new dimension depending on previous samples.

We might therefore suspect that the order with which we perform this sampling maters. Is this the case with the SAN model?

<img src="https://maximerobeyns.github.io/agnfinder/_images/san.svg" width="80%" />

There are good reasons why the sampling order might not matter: unlike usual autoregressive models, this one has 'sequence features' passed along the dimensions, in addition to the samples themselves.

In [None]:
import os
import time
import logging
import numpy as np
import torch as t
import matplotlib.pyplot as plt

from matplotlib.colors import to_rgb, to_rgba
from typing import Type, Any
from torchvision import transforms

# TODO clean up these imports!
import agnfinder
import agnfinder.inference.san as san
import agnfinder.inference.inference as inf

from agnfinder import config as cfg
from agnfinder import nbutils as nbu
from agnfinder.types import ConfigClass, column_order
from agnfinder.inference import SAN
from agnfinder.inference.utils import load_simulated_data, normalise_phot_np

try: # One-time setup
    assert(_SETUP)
except NameError:
    cfg.configure_logging()
    os.chdir(os.path.split(agnfinder.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        # !nvidia-smi
    else:
        print("CUDA is unavailable; training on CPU.")
    _SETUP = True

In [None]:
class InferenceParams(inf.InferenceParams):
    model: inf.model_t = san.SAN
    logging_frequency: int = 10000
    dataset_loc: str = './data/cubes/des_sample/photometry_simulation_4000000n_z_0p0000_to_1p0000.hdf5'
    retrain_model: bool = False  # Don't re-train an identical (existing) model
    overwrite_results: bool = False  # If we do re-train an identical model, save it uniquely
    
ip = InferenceParams()
fp = cfg.FreeParams()

In [None]:
train_loader_1024, test_loader_1024 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=1024,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)

In [None]:
train_loader_10000, test_loader_10000 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=10000,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
logging.info('Data loading complete.')

In [None]:
class MoGSANParams(san.SANParams):
    epochs: int = 20
    batch_size: int = 1024
    dtype: t.dtype = t.float32
    
    cond_dim: int = 7  # dimensions of conditioning info (e.g. photometry)
    data_dim: int = 9  # dimensions of data of interest (e.g. physical params)
    module_shape: list[int] = [512, 512]  # shape of the network 'modules'
    sequence_features: int = 8  # features passed between sequential blocks
    likelihood: Type[san.SAN_Likelihood] = san.MoG
    likelihood_kwargs: dict[str, Any] = {'K': 10}
    batch_norm: bool = True  # use batch normalisation in network?

mgsp = MoGSANParams() 

## Order 1

This is the same order as we have been training the models.

In [None]:
mogsan = san.SAN(mgsp)
mogsan.trainmodel(train_loader_1024, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(dtype=dtype)

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = mogsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'Column order: {agnfinder.types.column_order}'
lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='SAN: order 1',
                description=description)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_10000, 10000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

# samples per posterior
N = 100

with t.inference_mode():
    xs, _ = mogsan.preprocess(xs, t.empty(xs.shape))
    samples = mogsan.forward(xs.repeat_interleave(N, 0))
logging.info('Finished sampling. Plotting')
    
description = f'Column order: {agnfinder.types.column_order}'
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="Order 1",
                    description=description)

## Order 2

We now order the colums according to how easily parameters appear to be constrained, beginning with the easiest.

In [None]:
mogsan = san.SAN(mgsp)
mogsan.trainmodel(train_loader_1024, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(dtype=dtype)

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = mogsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'Column order: {agnfinder.types.column_order}'
lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='SAN: order 2',
                description=description)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_10000, 10000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

# samples per posterior
N = 100

with t.inference_mode():
    xs, _ = mogsan.preprocess(xs, t.empty(xs.shape))
    samples = mogsan.forward(xs.repeat_interleave(N, 0))
logging.info('Finished sampling. Plotting')
    
description = f'Column order: {agnfinder.types.column_order}'
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="Order 2",
                    description=description)