# Variations on the 'Sequential Autoregressive Network'

This notebook is for trying out different architectural choices as well as different likelihood choices for the sequential autoregressive network (SAN) model.

Variations include:

- Removing the reparametrised sampling for the marginal likelihoods
- Providing the conditioning data $\mathbf{x}$ directly to the inputs of the sequential blocks.
- Using mixture distributions for the marginals

The diagram below illustrates this new architecture, with Gaussian mixtures for $p(y_{d} \vert \mathbf{y}_{<d}, \mathbf{x})$.

<img src="https://share.maximerobeyns.com/san.svg" width="80%" />

In [None]:
import os
import time
import logging
import numpy as np
import torch as t

import agnfinder
import agnfinder.inference.san as san

from typing import Type, Any
from torchvision import transforms

from agnfinder import config as cfg
from agnfinder import nbutils as nbu
from agnfinder.types import ConfigClass, column_order
from agnfinder.inference import SAN
from agnfinder.inference.utils import load_simulated_data, normalise_phot_np

try: # One-time setup
    assert(_SETUP)
except NameError:
    cfg.configure_logging()
    os.chdir(os.path.split(agnfinder.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        # !nvidia-smi
    else:
        print("CUDA is unavailable; training on CPU.")
    _SETUP = True

In [None]:
class InferenceParams(ConfigClass):
    epochs: int = 20
    batch_size: int = 1024
    split_ratio: float = 0.9
    dtype: t.dtype = dtype
    device: t.device = device
    logging_frequency: int = 10000
    dataset_loc: str = './data/cubes/40M_shuffled.hdf5'
    retrain_model: bool = False  # Don't re-train an identical (existing) model
    overwrite_results: bool = True  # If we do re-train an identical model, save it

ip = InferenceParams()

In [None]:
train_loader_1024, test_loader_1024 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=1024,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)

In [None]:
train_loader_10000, test_loader_10000 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=10000,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
logging.info('Data loading complete.')

In [None]:
def train_san(sp: ConfigClass, ip: InferenceParams) -> SAN:
    san = SAN(cond_dim=sp.cond_dim, data_dim=sp.data_dim, 
              module_shape=sp.module_shape, 
              sequence_features=sp.sequence_features, 
              likelihood=sp.likelihood, 
              likelihood_kwargs=sp.likelihood_kwargs, 
              batch_norm=sp.batch_norm, device=ip.device, 
              dtype=ip.dtype)
    
    savepath: str = san.fpath()
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load {san.likelihood.name()} SAN model from {savepath}')
            san = t.load(savepath).to(ip.device, ip.dtype)
            logging.info(f'Successfully loaded')
            return san.cuda() if ip.device == t.device('cuda') else san
        except:
            logging.info(f'No model {savepath} found; training...')
            
    san.trainmodel(train_loader_1024, ip.epochs, ip.logging_frequency)
    logging.info(f'Trained {san.likelihood.name()} SAN model')

    t.save(san, san.fpath())
    logging.info(f'Saved {san.likelihood.name()} SAN model as: {san.fpath()}')
    return san.cuda() if ip.device == t.device('cuda') else san

## SAN with Mixture of Gaussians

In [None]:
class MoGSANParams(ConfigClass):
    cond_dim: int = 8  # dimensions of conditioning info (e.g. photometry)
    data_dim: int = 9  # dimensions of data of interest (e.g. physical params)
    # TODO vary the width and depth
    module_shape: list[int] = [32, 64, 32]  # shape of the network 'modules'
    sequence_features: int = 8  # features passed between sequential blocks
    likelihood_kwargs: dict[str, Any] = {'K': 10}
    likelihood: Type[san.SAN_Likelihood] = san.MoG
    lparams: int = 2  # number of parameers for likelihood p(y_d | y_<d, x)
    batch_norm: bool = True  # use batch normalisation in network?

mgsp = MoGSANParams() 

In [None]:
mogsan = train_san(mgsp, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(dtype=dtype)

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = mogsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'{mogsan} trained for {ip.epochs} epochs (batch size {ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='Gaussian Mixture "Sequential Autoregressive Network"',
                description=description)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_10000, 8000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

# samples per posterior
N = 1000

with t.inference_mode():
    xs, _ = mogsan.preprocess(xs, t.empty(xs.shape))
    samples = mogsan.forward(xs.repeat_interleave(N, 0))
    
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="Gaussian Mixture 'Sequential Autoregressive Network'",
                    description=f'{mogsan} trained for 5 epochs, batch size 1024.')

## Gaussian Baseline (without reparametrised sampling)

In [None]:
class SANParams(ConfigClass):
    cond_dim: int = 8  # dimensions of conditioning info (e.g. photometry)
    data_dim: int = 9  # dimensions of data of interest (e.g. physical params)
    module_shape: list[int] = [16, 32]  # shape of the network 'modules'
    sequence_features: int = 4  # features passed between sequential blocks
    likelihood: Type[san.SAN_Likelihood] = san.Gaussian
    likelihood_kwargs = None
    lparams: int = 2  # number of parameers for likelihood p(y_d | y_<d, x)
    batch_norm: bool = True  # use batch normalisation in network?

sp = SANParams() 

In [None]:
gsan = train_san(sp, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)
# true_ys = t.tensor([true_ys[i] for i in fgmm_param_idxs])

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = gsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'{gsan} trained for {ip.epochs} epochs (batch size {ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='Gaussian "Sequential Autoregressive Network"',
                description=description)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_10000, 10000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

# samples per posterior
N = 1000

with t.inference_mode():
    xs, _ = gsan.preprocess(xs, t.empty(xs.shape))
    samples = gsan.forward(xs.repeat_interleave(N, 0))
    
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="Gaussian 'Sequential Autoregressive Network'",
                   description=f'{gsan} trained for 5 epochs, batch size 1024.')

## Mixture of StudentT Distributions

In [None]:
class MoSTSANParams(ConfigClass):
    cond_dim: int = 8  # dimensions of conditioning info (e.g. photometry)
    data_dim: int = 9  # dimensions of data of interest (e.g. physical params)
    module_shape: list[int] = [16, 32]  # shape of the network 'modules'
    sequence_features: int = 4  # features passed between sequential blocks
    likelihood_kwargs: dict[str, Any] = {'K': 5}
    likelihood: Type[san.SAN_Likelihood] = san.MoST
    batch_norm: bool = True  # use batch normalisation in network?

mostsp = MoSTSANParams() 

In [None]:
mostsan = train_san(mostsp, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)
# true_ys = t.tensor([true_ys[i] for i in fgmm_param_idxs])

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = mostsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'{mostsan} trained for {ip.epochs} epochs (batch size {ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='StudentT Mixture "Sequential Autoregressive Network"',
                description=description)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024, 1000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

# samples per posterior
N = 100

with t.inference_mode():
    xs, _ = mostsan.preprocess(xs, t.empty(xs.shape))
    samples = mostsan.forward(xs.repeat_interleave(N, 0))
    
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="StudentT Mixture 'Sequential Autoregressive Network'",
                   description=f'{mostsan} trained for 5 epochs, batch size 1024.')