# Sequential Autoregressive Network results

This notebook contains variations on the first 'SAN' models; in particular we use a larger (wider, but not deeper) network which gives better NLL scores during training, as well as seemingly improved predictive performance.

The diagram below illustrates the architecture of the _Sequential Autoregressive Network_, with Gaussian mixtures for $p(y_{d} \vert \mathbf{y}_{<d}, \mathbf{x})$.

<img src="https://share.maximerobeyns.com/san.svg" width="80%" />

In [None]:
import os
import time
import logging
import numpy as np
import torch as t
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb, to_rgba

import agnfinder
import agnfinder.inference.san as san

from typing import Type, Any
from torchvision import transforms

from agnfinder import config as cfg
from agnfinder import nbutils as nbu
from agnfinder.types import ConfigClass, column_order
from agnfinder.simulation import Simulator_f
from agnfinder.inference import SAN
from agnfinder.inference.utils import load_simulated_data, normalise_phot_np

try: # One-time setup
    assert(_SETUP)
except NameError:
    cfg.configure_logging()
    os.chdir(os.path.split(agnfinder.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        # !nvidia-smi
    else:
        print("CUDA is unavailable; training on CPU.")
    _SETUP = True

In [None]:
class InferenceParams(ConfigClass):
    epochs: int = 20
    batch_size: int = 1024
    split_ratio: float = 0.9
    dtype: t.dtype = dtype
    device: t.device = device
    logging_frequency: int = 10000
    dataset_loc: str = './data/cubes/40M_shuffled.hdf5'
    retrain_model: bool = False  # Don't re-train an identical (existing) model
    overwrite_results: bool = True  # If we do re-train an identical model, save it

ip = InferenceParams()
fp = cfg.FreeParams()

In [None]:
train_loader_1024, test_loader_1024 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=1024,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)

In [None]:
train_loader_10000, test_loader_10000 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=10000,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
logging.info('Data loading complete.')

In [None]:
def train_san(sp: ConfigClass, ip: InferenceParams) -> SAN:
    san = SAN(cond_dim=sp.cond_dim, data_dim=sp.data_dim, 
              module_shape=sp.module_shape, 
              sequence_features=sp.sequence_features, 
              likelihood=sp.likelihood, 
              likelihood_kwargs=sp.likelihood_kwargs, 
              batch_norm=sp.batch_norm, device=ip.device, 
              dtype=ip.dtype)
    
    savepath: str = san.fpath()
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load {san.likelihood.name()} SAN model from {savepath}')
            san.load_state_dict(t.load(savepath))
            logging.info(f'Successfully loaded')
            return san.cuda() if ip.device == t.device('cuda') else san
        except:
            logging.info(f'Could not load model {savepath}; training...')
            
    san.trainmodel(train_loader_1024, ip.epochs, ip.logging_frequency)
    logging.info(f'Trained {san.likelihood.name()} SAN model')

    t.save(san.state_dict(), san.fpath())
    logging.info(f'Saved {san.likelihood.name()} SAN model as: {san.fpath()}')
    return san.cuda() if ip.device == t.device('cuda') else san

## SAN with Mixture of Gaussians

In [None]:
class MoGSANParams(ConfigClass):
    cond_dim: int = 8  # dimensions of conditioning info (e.g. photometry)
    data_dim: int = 9  # dimensions of data of interest (e.g. physical params)
    # TODO vary the width and depth
    module_shape: list[int] = [512, 512]  # shape of the network 'modules'
    sequence_features: int = 8  # features passed between sequential blocks
    likelihood_kwargs: dict[str, Any] = {'K': 10}
    likelihood: Type[san.SAN_Likelihood] = san.MoG
    lparams: int = 2  # number of parameers for likelihood p(y_d | y_<d, x)
    batch_norm: bool = True  # use batch normalisation in network?

mgsp = MoGSANParams() 

In [None]:
mogsan = train_san(mgsp, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(dtype=dtype)

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = mogsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'{mogsan} trained for {ip.epochs} epochs (batch size {ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='Gaussian Mixture "Sequential Autoregressive Network"',
                description=description)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_10000, 10000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

# samples per posterior
N = 100

with t.inference_mode():
    xs, _ = mogsan.preprocess(xs, t.empty(xs.shape))
    samples = mogsan.forward(xs.repeat_interleave(N, 0))
logging.info('Finished sampling. Plotting')
    
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="Sequential Autoregressive Network",
                    description=(f'{mogsan} \ntrained for {ip.epochs} epochs, '
                                 f'batch size {ip.batch_size} and '
                                 f'{mgsp.likelihood_kwargs["K"]} mixture components.'))

## Posterior Plots

Here we plot the simulated photometry from posterior parameter samples.

Now, we need to sample (a single) galaxy at random from the testing dataset, draw N ($\approx 100$) samples from the posterior distribution over the y parameters resulting from SAN (keep hold of the parameters of this Gaussian mixture distribution), pass these through the simulator, and plot them in an opacity that corresponds to the likelihood of the original y sample under the GMM with the true $\mathbf{x}$ values (the photometry) picked from the test dataset plotted in a dictinctive colour.

In [None]:
xs, ys = nbu.new_sample(test_loader_1024, 1)
xs = xs.to(device, dtype).unsqueeze(0)

N = 100

with t.inference_mode():
    xs, _ = mogsan.preprocess(xs, t.empty(xs.shape))
    samples = mogsan.forward(xs.repeat_interleave(N, 0))
    ll = mogsan.likelihood.log_prob(samples, mogsan.last_params).exp()

def _scatter(ax: plt.Axes, xs: np.ndarray, ys: np.ndarray, alpha: np.ndarray):
    r, g, b = to_rgb('#35b4db')
    color = np.array([(r, g, b, a) for a in alpha])
    ax.scatter(xs, ys, c=color, label='posterior samples')
    
plot_xs = t.arange(9).expand(N, -1).numpy().flatten()
plot_ys = samples.cpu().numpy().flatten()
plot_alpha = ((ll - ll.min()) / (ll.max() - ll.min())).cpu().numpy().flatten()

fig, ax = plt.subplots(figsize=(6, 3), dpi=200)
_scatter(ax, plot_xs, plot_ys, plot_alpha)
ax.scatter(np.arange(0, 9), ys.cpu().numpy(), c='r', label='true values')
ax.set_ylim(0, 1)
plt.xticks(np.arange(0, 9), fp.raw_members, rotation=45)
ax.set_xlabel('Parameters')
ax.set_ylabel('Normalised Values')
ax.set_title('Single Galaxy Parameter Samples')