# Parameter Inference


This notebook contains various experiments and model comparisons relating to inferring physical galaxy parameters $y$ from photometric observations $x$.

## Table of Contents

- [Data Visualisation](#Data-Distributions)
- [Single-variable experiments](#Gaussians-for-Individual-Parameters)
- [Beyond 1 dimension](#Multivariate-Mixture-Distributions)
- ["Sequential Autoregressive Network"](#"Sequential-Autoregressive-Network"-(SAN))
- [Masked Autoencoder for Distribution Estimation](#MADE) (buggy, do not run)

In [None]:
import os
import time
import random
import corner
import logging
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import agnfinder
import agnfinder.inference as inference
import agnfinder.inference.made as made

from IPython.display import SVG, display
from typing import Type
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.distributions import Normal, MultivariateNormal, StudentT, Laplace, MixtureSameFamily, Categorical

from agnfinder import config as cfg
from agnfinder import nbutils as nbu
from agnfinder.types import ConfigClass, column_order
from agnfinder.inference import CMADE, SAN
from agnfinder.inference.made import MaskedLinear
from agnfinder.inference.base import CVAE, CVAEParams, cvae_t, arch_t
from agnfinder.inference.utils import Squareplus, squareplus_f, load_simulated_data, GalaxyDataset
from agnfinder.simulation.utils import denormalise_theta, normalise_theta

try: # One-time setup
    assert(_SETUP)
except NameError:
    cfg.configure_logging()
    os.chdir(os.path.split(agnfinder.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        # !nvidia-smi
    else:
        print("CUDA is unavailable; training on CPU.")
    _SETUP = True

## Data Distributions

Finding the distribution $p_{\theta}(y \vert x)$ is akin to finding a one-to-many mapping $f: \mathcal{X} \times \Theta \to \mathcal{Y}$ from photometric observations to physical galaxy parameters.

Before doing any inference, it might be wise to visualise the empirical distribution of photometric measurements $\mathcal{X}$ and physical galaxy parameters $\mathcal{Y}$ that we have in our (simulated) training dataset. 

This can help us to tell whether the required mapping $f$ is very complicated and non-linear or whether a linear combination of the photometric measurements will do (I suspect it's the former). We can also see whether the data is nicely normalised and further whether it is nicely distributed within its normalised range (e.g. $[0,1]$).

In [None]:
def normalise_phot_np(x: np.ndarray) -> np.ndarray:
    x_log = np.log(x)
    return (x_log - x_log.mean(0)) / x_log.std(0)

In [None]:
dataset_loc: str = './data/cubes/latest_sample/'
gd = GalaxyDataset(path=dataset_loc, transforms=[transforms.ToTensor()])
fp = cfg.FreeParams()

### Photometry distribution, $\mathcal{X}$

Let's first visualise the outputs of the forward model (and the inputs to our mapping).


In [None]:
xs = gd.get_xs().numpy()
small_xs = xs[:1000]
labels_x = [f'dimension {i}' for i in range(len(xs[0]))]
nbu.plot_corner(small_xs, title='Photometry Distribution', 
                description='Un-normalised outputs from the simulation process',
                labels=labels_x)

We can see that the data points have a huge range, and that they follow a log-scale too. 

Taking the logarithm of the $x$ points, and applying z-score normalisation:

In [None]:
x_norm = normalise_phot_np(xs[:10000])
nbu.plot_corner(x_norm, title='Normalised Photometric Observations', 
                description='x-samples are passed through a logarithm before applying z-score normalisation', 
                labels=labels_x)

### Physical Galaxy Parameter Distribution, $\mathcal{Y}$

Recall that we performed Latin-hypercube sampling to obtain the physical parameter values that we provided to the forward model. Due of this sampling procedure, we would expect their (prior) distribution to be roughly uniform.

In [None]:
ys = gd.get_ys().numpy()[:10000]
nbu.plot_corner(ys, title="Physical Galaxy Parameters",
                description="Normalised values, as obtained from LHS")

Some setup for inference later:

In [None]:
class InferenceParams(ConfigClass):
    epochs: int = 5
    batch_size: int = 512
    split_ratio: float = 0.9
    dtype: t.dtype = dtype
    device: t.device = device
    logging_frequency: int = 10000
    dataset_loc: str = './data/cubes/latest_sample/'
    retrain_model: bool = False  # prefer an existing model over re-training
    overwrite_results: bool = True  # if we do re-train, save it

ip = InferenceParams()

In [None]:
train_loader_128, test_loader_128 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=128,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
train_loader_512, test_loader_512 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=512,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
train_loader_1024, test_loader_1024 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=1024,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
logging.info('Data loading complete.')

# Gaussians for Individual Parameters

As a sanity-check to see whether there is useful signal in the data, here we attempt to maximise a Gaussian likelihood over just one parameter: $p(y_{i} | \mathbf{x}) = \mathcal{N}(y_{i}; \mu, \sigma^2)$. We set the parameters $\theta = \{\mu, \sigma\}$ using a simple feed-forward ANN trained by minimising the NLL.

For reference, available parameters are $y_{i}$ where $i \in \{0, \ldots, 8\}$ = {`redshift`, `log_mass`, `dust2`, `tage`, `log_tau`,
        `log_agn_mass`, `agn_eb_v`, `log_agn_torus_mass`, `inclination`}


In [None]:
class ANN(nn.Module):
    def __init__(self):
        super().__init__()
        net: list[nn.Module] = []
        layer_sizes: list[int] = [8, 16, 2]
        for l1, l2 in zip(layer_sizes, layer_sizes[1:]):
            net.extend([nn.Linear(l1, l2), nn.BatchNorm1d(l2), nn.ReLU()])
        net = net[:-2]  # remove final ReLU and BatchNorm1d
        self.net: nn.Module = nn.Sequential(*net).to(device, dtype)
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        return self.net(x)
    
ann_param_idx = 0  # use the 1st dimension (arbitrarily); here redshift

def train_network(train_loader: DataLoader, ip: InferenceParams) -> ANN:
    savepath: str = './results/nbresults/ann.pt'
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load model from {savepath}')
            net = t.load(savepath)
            logging.info(f'Successfully loaded')
            return net
        except:
            logging.info(f'No model {savepath} found; training...')
    
    net = ANN()
    opt = t.optim.Adam(net.parameters(), lr=1e-3)
    for e in range(ip.epochs):
        for i, (x, y) in enumerate(train_loader):
            x, y = x.to(device, dtype), y.to(device, dtype)
            params = net(x)
            gaussian = Normal(params[:,0], params[:,1]**2)
            loss = -gaussian.log_prob(y[:,ann_param_idx]).mean(0)
            opt.zero_grad()
            loss.backward()
            opt.step()

            if i % ip.logging_frequency == 0 or i == len(train_loader)-1:
                logging.info(
                    "Epoch: {:02d}/{:02d}, Batch: {:03d}/{:d}, Loss {:9.4f}"
                    .format(e+1, ip.epochs, i, len(train_loader)-1, loss.item()))
                
    if ip.overwrite_results:
        t.save(net, savepath)
        logging.info(f'Saved model to {savepath}')
    return net

In [None]:
net = train_network(train_loader_128, ip)

In [None]:
# get the labelscolumn_order
param_label = column_order[ann_param_idx]

xs, true_ys = nbu.new_sample(test_loader_128)
xs = xs.to(device, dtype)
true_ys = true_ys[ann_param_idx]

net.eval()
with t.inference_mode():
    params = net(xs.unsqueeze(0)).squeeze()
    mu, var = params[0].cpu().item(), t.exp(params[1]).cpu().item()
density = t.distributions.Normal(mu, var)
    
plot_xs = t.linspace(-1, 2, 200)
plot_ys = density.log_prob(plot_xs).exp().numpy()
fig, ax = plt.subplots(figsize=(8, 4), dpi=200)
ax.plot(plot_xs, plot_ys, label='log probability')
ax.vlines(true_ys, 0, max(plot_ys), label='True parameter value', color='k')
ax.legend()
ax.set_xlabel(param_label + ' (normalised)')
ax.set_ylabel(f'log probability of {param_label} value')
ax.set_title('Single Gaussian Fit')

### Post mortem

This is clearly a poor model for the job. Nonetheless it does manage to learn adequate parameters for some parameters (the easiest to constrain seems to be the redshift). The predicted variance is often large when the mean is far from the true parameter value, as we might hope.

# Gaussian Density Mixture Network

Perhaps the Gaussian with its light tails is not the best likelihood. Short of using a more robust likelihood such as a StudentT or Laplace distribution, we can modify the above to find the parameters of a mixture of Gaussians.

In [None]:
class GDMN(nn.Module):
    def __init__(self):
        super().__init__()
        # number of mixture components
        K = 8
        net: list[nn.Module] = []
        layer_sizes: list[int] = [8, 16, 32]
        for l1, l2 in zip(layer_sizes, layer_sizes[1:]):
            net.extend([nn.Linear(l1, l2), nn.BatchNorm1d(l2), nn.ReLU()])
        self.net: nn.Module = nn.Sequential(*net)
        self.heads: nn.ModuleList = nn.ModuleList([
            nn.Linear(layer_sizes[-1], K),
            nn.Linear(layer_sizes[-1], K),
            nn.Sequential(*[
                nn.Linear(layer_sizes[-1], K),
                nn.Softmax(dim=-1)
            ])
        ])
        
    def forward(self, x: t.Tensor) -> list[t.Tensor]:
        y = self.net(x)
        return [h(y) for h in self.heads]
    
# still focusing on the redshift.
gdmn_param_idx = 0

def train_gdmn(train_loader: DataLoader, ip: InferenceParams) -> GDMN:
    
    savepath: str = './results/nbresults/gdmn.pt'
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load model from {savepath}')
            gdmn = t.load(savepath).to(ip.device, ip.dtype)
            logging.info(f'Successfully loaded')
            return gdmn
        except:
            logging.info(f'No model {savepath} found; training...')
    
    gdmn = GDMN().to(device, dtype)
    opt = t.optim.Adam(gdmn.parameters(), lr=1e-3)
    for e in range(ip.epochs):
        for i, (x, y) in enumerate(train_loader):
            x, y = x.to(device, dtype), y.to(device, dtype)
            y = y[:,gdmn_param_idx].unsqueeze(-1)
            
            mu, sigma, alpha = gdmn(x)
            normals = t.distributions.Normal(mu, squareplus_f(sigma))
            LP = normals.log_prob(y)
            loss = -(LP + alpha.log()).sum(1).mean(0)
            opt.zero_grad()
            loss.backward()
            opt.step()

            if i % ip.logging_frequency == 0 or i == len(train_loader)-1:
                logging.info(
                    "Epoch: {:02d}/{:02d}, Batch: {:03d}/{:d}, Loss {:9.4f}"
                    .format(e+1, ip.epochs, i, len(train_loader)-1, loss.item()))
                
    if ip.overwrite_results:
        t.save(gdmn, savepath)
        logging.info(f'Saved model to {savepath}')
        
    return gdmn

In [None]:
gdmn = train_gdmn(train_loader_512, ip)

In [None]:
# get the labelscolumn_order
param_label = column_order[gdmn_param_idx]

xs, true_ys = nbu.new_sample(test_loader_512)
xs = xs.to(device, dtype).unsqueeze(0)
true_ys = true_ys[gdmn_param_idx]

gdmn.eval()
with t.inference_mode():
    params = gdmn(xs)
    mu, var, alpha = params[0].squeeze().cpu(), squareplus_f(params[1]).squeeze().cpu(), params[2].cpu()

density = t.distributions.Normal(mu, var)
plot_xs = t.linspace(0, 1, 200).unsqueeze(-1)
plot_ys = (density.log_prob(plot_xs).exp() * alpha).sum(1).cpu().numpy()
fig, ax = plt.subplots(figsize=(8, 4), dpi=200)
ax.plot(plot_xs, plot_ys, label='log probability')
ax.vlines(true_ys, 0, max(plot_ys), label='True parameter value', color='k')
ax.legend()
ax.set_xlabel(param_label +' (normalised)')
ax.set_ylabel(f'log probability of {param_label} value')
ax.set_title('Mixture of Gaussians')

### Post mortem

This is already looking more promising than the single-Gaussian fit. The redshift is still the easiest parameter to constrain.

When the real value is close to the mode of the distribution, the variance is also quite low, and when the true value lies away from the mode of the distribution, the variance is high, which is also fine.

## Variations on the above

To try to understand the data better using our very simple model, we can vary the

- batch size
- epochs
- learning rates
- mixture densities
- activation functions


In [None]:
class ABInferenceParams(ConfigClass):
    epochs: int = 3
    batch_size: int = 1024
    split_ratio: float = 0.9
    dtype: t.dtype = dtype
    device: t.device = device
    logging_frequency: int = 2000
    dataset_loc: str = './data/cubes/latest_sample/'
    retrain_model: bool = False
    overwrite_results: bool = True
ab_ip = ABInferenceParams()

In [None]:
ab_train_loader, ab_test_loader = load_simulated_data(
    path=ab_ip.dataset_loc,
    split_ratio=ab_ip.split_ratio,
    batch_size=ab_ip.batch_size,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
logging.info('Data loading complete.')

In [None]:
class AB_GDMN(nn.Module):
    def __init__(self):
        super().__init__()
        # number of mixture components
        K = 6
        net: list[nn.Module] = []
        layer_sizes: list[int] = [8, 16, 32]
        for l1, l2 in zip(layer_sizes, layer_sizes[1:]):
            net.extend([nn.Linear(l1, l2), nn.BatchNorm1d(l2), nn.ReLU()])
        self.net: nn.Module = nn.Sequential(*net)
        self.heads: nn.ModuleList = nn.ModuleList([
            nn.Linear(layer_sizes[-1], K),
            nn.Linear(layer_sizes[-1], K),
            nn.Sequential(*[
                nn.Linear(layer_sizes[-1], K),
                nn.Softmax(dim=-1)
            ])
        ])
        
    def forward(self, x: t.Tensor) -> list[t.Tensor]:
        y = self.net(x)
        return [h(y) for h in self.heads]
    

# physical parameter index, no greater than 8
# idx 4 is the log_tau parameter
ab_param_idx = 4

def train_ab_gdmn(train_loader: DataLoader, ip: ABInferenceParams) -> AB_GDMN:
                  # epochs: int = ab_ip.epochs, log_every: int = ab_ip.logging_frequency):
    savepath: str = './results/nbresults/ab_gdmn.pt'
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load model from {savepath}')
            ab_gdmn = t.load(savepath).to(ip.device, ip.dtype)
            logging.info(f'Successfully loaded')
            return ab_gdmn
        except:
            logging.info(f'No model {savepath} found; training...')
    
    ab_gdmn = AB_GDMN().to(ip.device, ip.dtype)
    opt = t.optim.Adam(ab_gdmn.parameters(), lr=1e-3)
    ab_gdmn.train()
    for e in range(ip.epochs):
        for i, (x, y) in enumerate(train_loader):
            x, y = x.to(device, dtype), y.to(device, dtype)
            y = y[:, ab_param_idx].unsqueeze(-1)
            
            mu, scale, alpha = ab_gdmn(x)
            normals = t.distributions.StudentT(1., mu, squareplus_f(scale))
            LP = normals.log_prob(y)
            loss = -(LP + alpha).sum(1).mean(0)
            opt.zero_grad()
            loss.backward()
            opt.step()
            if i % ip.logging_frequency == 0 or i == len(train_loader)-1:
                logging.info(
                    "Epoch: {:02d}/{:02d}, Batch: {:03d}/{:d}, Loss {:9.4f}"
                    .format(e+1, ip.epochs, i, len(train_loader)-1, loss.item()))
                
    if ip.overwrite_results:
        t.save(ab_gdmn, savepath)
        logging.info(f'Saved model to {savepath}')
        
    return ab_gdmn

In [None]:
ab_gdmn = train_ab_gdmn(train_loader_1024, ab_ip)

In [None]:
# get the labelscolumn_order
param_label = column_order[ab_param_idx]

xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype).unsqueeze(0)
true_ys = true_ys[ab_param_idx]

ab_gdmn.eval()
with t.inference_mode():
    params = ab_gdmn(xs)
    loc, scale, alpha = params[0].squeeze().cpu(), squareplus_f(params[1]).squeeze().cpu(), params[2].cpu()

# assume 1 dof
density = t.distributions.StudentT(1., loc, scale)

plot_xs = t.linspace(0, 1, 200).unsqueeze(-1)
plot_ys = density.log_prob(plot_xs).exp()
plot_ys = (plot_ys * alpha).sum(1).cpu().numpy()

fig, ax = plt.subplots(figsize=(8, 4), dpi=200)
ax.plot(plot_xs, plot_ys, label='log probability')
ax.vlines(true_ys, 0, max(plot_ys), label='True parameter value', color='k')
ax.legend()
ax.set_xlabel(param_label +' (normalised)')
ax.set_ylabel(f'log probability of {param_label} value')
ax.set_title('Maximising StudentT likelihood for hard-to-constrain parameters')

### Notes

While the `log_tau` parameter, is harder to constrain than the `redshift`, by maximising the likelihood of a heavy tailed distribution like the StudentT, we get stable training and reasonable (if over-simplistic) posterior estimates.

Changing the mixture density from a Gaussian to a more robust distribution like the StudentT helped to bring down the NLL.

# Multivariate Mixture Distributions

With reasonable success treating variables individually, let's learn multiple at the same time. This should be identical to the above when looking at the individual parameters. We begin with the humble factorised Gaussian.

## Mixture of Factorised Gaussians

In [None]:
class FGMM(nn.Module):
    def __init__(self, K: int = 5, D: int = 2) -> None:
        """
        Args:
            K: the number of mixture components to use
            D: the number of dimensions of the individual Gaussians.
        """
        super().__init__()
        self.K = K
        self.D = D
        
        net: list[nn.Module] = []
        layer_sizes: list[int] = [8, 16, 32] # TODO: try something a little bit bigger.
        for l1, l2 in zip(layer_sizes, layer_sizes[1:]):
            net.extend([nn.Linear(l1, l2), nn.BatchNorm1d(l2), nn.ReLU()])
        self.net: nn.Module = nn.Sequential(*net)
        self.heads: nn.ModuleList = nn.ModuleList([
            nn.Linear(layer_sizes[-1], K*D), # mean vectors
            nn.Linear(layer_sizes[-1], K*D), # covariance matrix diagonals
            nn.Sequential(*[nn.Linear(layer_sizes[-1], K),nn.Softmax(dim=-1)]) # mixture weights
        ])
        
    def forward(self, x: t.Tensor) -> list[t.Tensor]:
        y = self.net(x)
        hs = [h(y) for h in self.heads]
        eyes = t.eye(self.D, dtype=dtype, device=device).repeat(x.size(0), self.K, 1)
        diags = eyes * squareplus_f(hs[1]).unsqueeze(-1)
        diags = diags.reshape(-1, self.K, self.D, self.D)
        
        # return mean vectors, lower-triangular matrices, and mixture weights
        return [hs[0].reshape(-1, self.K, self.D), diags, hs[2]]
    
fgmm_param_idxs: list[int] = [0, 1, 2, 3, 4, 5, 6, 7, 8]
K = 3
D = len(fgmm_param_idxs)

def train_fgmm(train_loader: DataLoader, ip: InferenceParams) -> FGMM:
        
    savepath: str = './results/nbresults/fgmm.pt'
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load model from {savepath}')
            fgmm = t.load(savepath).to(ip.device, ip.dtype)
            logging.info(f'Successfully loaded')
            return fgmm
        except:
            logging.info(f'No model {savepath} found; training...')
    
    fgmm = FGMM(K, D).to(ip.device, ip.dtype)
    opt = t.optim.Adam(fgmm.parameters(), lr=1e-3)
    fgmm.train()
    for e in range(ip.epochs):
        for i, (x, y) in enumerate(train_loader):
            
            x, y = x.to(device, dtype), y.to(device, dtype)
            y = y[:, fgmm_param_idxs]
            if D == 1:
                y = y[:,None]
            
            locs, Ls, alpha = fgmm(x)
            normals = MultivariateNormal(locs, scale_tril=Ls)
            P = normals.log_prob(y.unsqueeze(-2))
            loss = -(P + alpha.log()).sum(1).mean(0)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            if i % ip.logging_frequency == 0 or i == len(train_loader)-1:
                logging.info(
                    "Epoch: {:02d}/{:02d}, Batch: {:03d}/{:d}, Loss {:9.4f}"
                    .format(e+1, ip.epochs, i, len(train_loader)-1, loss.item()))
                
    if ip.overwrite_results:
        t.save(fgmm, savepath)
        logging.info(f'Saved model to {savepath}')

    return fgmm

In [None]:
fgmm = train_fgmm(train_loader_1024, ip)

In [None]:
params = [column_order[i] for i in fgmm_param_idxs]

xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype).unsqueeze(0)
true_ys = t.tensor([true_ys[i] for i in fgmm_param_idxs])

fgmm.eval()
with t.inference_mode():
    locs, Ls, alpha = fgmm(xs)
    
mix = Categorical(alpha)
normals = MultivariateNormal(locs, scale_tril=Ls)
mixture = MixtureSameFamily(mix, normals)

n_samples = 10000
samples = mixture.sample((n_samples,)).cpu().squeeze(1)
# samples = normals.sample((n_samples,)).squeeze(1)

lims = np.array([[0.,1.]]).repeat(len(params),0)
nbu.plot_corner(samples.numpy(), true_ys.cpu().numpy(), lims=lims,
               title="Mixture of Factorised Gaussians",
               description="3 mixture components, trained on full dataset, batch size 1024, 5 epochs")

## Post Mortem

The above matches what we would expect:

- the joint 'ellipses' are orthogonal to the axes due to the factorisation
- some parameters (e.g. redshift, masses) are easier to constrain than others (tage, inclination...)

# Full-Covariance Gaussians

Trying to move beyond the factorised case, can our simple model learn joint interactions between the physical parameter values?

In [None]:
class MVGMN(nn.Module):
    def __init__(self, K: int = 5, D: int = 2) -> None:
        """Multivaritate Gaussian mixture network
        
        Args:
            K: the number of mixture components to use
            D: the number of dimensions of the individual Gaussians.
        """
        super().__init__()
        self.K = K
        self.D = D
        
        net: list[nn.Module] = []
        layer_sizes: list[int] = [8, 16, 32]
        for l1, l2 in zip(layer_sizes, layer_sizes[1:]):
            net.extend([nn.Linear(l1, l2), nn.BatchNorm1d(l2), nn.ReLU()])
        self.net: nn.Module = nn.Sequential(*net)
        self.heads: nn.ModuleList = nn.ModuleList([
            nn.Linear(layer_sizes[-1], K*D), # mean vectors
            nn.Linear(layer_sizes[-1], K*D), # covariance matrix diagonals
            nn.Linear(layer_sizes[-1], K*self.tril_size(D)), # terms of the lower-triangular matrix
            nn.Sequential(*[nn.Linear(layer_sizes[-1], K), nn.Softmax(dim=-1)]) # mixture weights
        ])
        
    def forward(self, x: t.Tensor) -> list[t.Tensor]:
        y = self.net(x)
        hs = [h(y) for h in self.heads]
        
        # Create covariance matrix
        # 1. create [N, K] diagonal matrices of side length D with positive covariances
        eyes = t.eye(self.D, dtype=dtype, device=device).repeat(x.size(0), self.K, 1)
        diags = eyes * squareplus_f(hs[1]).unsqueeze(-1)
        diags = diags.reshape(-1, self.K, self.D, self.D)
        
        # 2. fill in the lower-triangular sections if D > 1
        if self.D > 1:
            ti = t.tril_indices(self.D, self.D, -1)
            diags[:,:,ti[0],ti[1]] = hs[2].reshape(x.size(0), self.K, -1)
        
        # return mean vectors, lower-triangular matrices, and mixture weights
        return [hs[0].reshape(-1, self.K, self.D), diags, hs[3]]
    
    def tril_size(self, N: int) -> int:
        return int((N**2 - N)/2)
    
mvgmn_param_idxs: list[int] = [0, 1, 2, 3, 4, 5, 6, 7, 8]
K = 3
D =len(mvgmn_param_idxs)

def train_mvgmn(train_loader: DataLoader, ip: InferenceParams) -> MVGMN:
    
    savepath: str = './results/nbresults/mvgmn.pt'
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load model from {savepath}')
            mvgmm = t.load(savepath).to(ip.device, ip.dtype)
            logging.info(f'Successfully loaded')
            return mvgmm
        except:
            logging.info(f'No model {savepath} found; training...')
    
    mvgmn = MVGMN(K, D).to(ip.device, ip.dtype)
    opt = t.optim.Adam(mvgmn.parameters(), lr=1e-3)
    mvgmn.train()
    for e in range(ip.epochs):
        for i, (x, y) in enumerate(train_loader):
            
            x, y = x.to(device, dtype), y.to(device, dtype)
            y = y[:, mvgmn_param_idxs]
            if D == 1:
                y = y[:,None]
            
            locs, Ls, alpha = mvgmn(x)
            normals = MultivariateNormal(locs, scale_tril=Ls)
            P = normals.log_prob(y.unsqueeze(-2))
            loss = -(P + alpha.log()).sum(1).mean(0)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            if i % ip.logging_frequency == 0 or i == len(train_loader)-1:
                logging.info(
                    "Epoch: {:02d}/{:02d}, Batch: {:03d}/{:d}, Loss {:9.4f}"
                    .format(e+1, ip.epochs, i, len(train_loader)-1, loss.item()))
            
    if ip.overwrite_results:
        t.save(mvgmn, savepath)
        logging.info(f'Saved model to {savepath}')
        
    return mvgmn

In [None]:
mvgmn = train_mvgmn(train_loader_1024, ip)

In [None]:
params = [column_order[i] for i in mvgmn_param_idxs]

xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype).unsqueeze(0)
true_ys = t.tensor([true_ys[i] for i in mvgmn_param_idxs])

mvgmn.eval()
with t.inference_mode():
    locs, Ls, alpha = mvgmn(xs)
normals = MultivariateNormal(locs, scale_tril=Ls)

samples = normals.sample((10000,)).squeeze(1).mean(1).cpu()

lims = np.array([[0.,1.]]).repeat(len(params),0)
nbu.plot_corner(samples.numpy(), true_ys.cpu().numpy(), lims=lims,
               title="Mixture of Gaussians",
               description="3 mixture components, trained on full dataset, batch size 1024, 10 epochs")

### Post Mortem

The accurcy of the model seems to have suffered somewhat compared to the factorised Gaussian. This is perhaps due to the larger number of parameters we now need to learn, and so training this for longer could provide a rudimentary solution.

Not much was gained since the joints mostly remain independent apart from a few sporadic exceptions. The above took about half an hour to train and despite this, the marginals are overly-simplistic and not very accurate.

This is probably the extent of the usefulness of simple feed-forward networks parametrising mixture distributions.

# "Sequential Autoregressive Network" (SAN)

If $\mathbf{y} \in \mathbb{R}^{D}$ and $\mathbf{x} \in \mathbb{R}^{E}$ then we can write down the joint as
\begin{align*}
p(\mathbf{x}, \mathbf{y}) &= p(y_{1} \vert \mathbf{x})p(y_{2}\vert y_{1}, \mathbf{x})\cdots p(y_{D}\vert y_{1}, \ldots, y_{D-1}, \mathbf{x}) \\
&= \prod^{D}_{d=1}p(y_{d} \vert \mathbf{y}_{<d}, \mathbf{x}).
\end{align*}

This is the _autoregressive property_ that is exploited by models like MADE or Masked Autoregressive Flows. Any architecture satisfying this autoregressive property can straightforwardly be trained by minimising the resulting NLL:

$$
\ell(\mathbf{y}; \mathbf{x}) = - \sum^{D}_{d=1} \log p(y_d \vert \mathbf{y}_{<d}, \mathbf{x}).
$$

The following architecture satisfies this property by explicitly modelling each $y_{d}, d \in \{1, \ldots, D\}$ sequentially and conditioning appropriately. It is formed of $D$ "_sequential blocks_", which are identical layers (which do not share weights) and are intended to learn features which allow the parameters of the $d$'th parameter, $\theta_{d}$ to be determined; $y_{d} \sim p(y_{d}; \theta_{d})$.

Note that the likelihood needn't be Gaussian. We can choose to output an arbitrary number of parameters $\theta_{d}$ from each sequential block, and use these in arbitrary density functions. For instance, we could parametrise a mixture distribution.

In [None]:
display(SVG(filename='./notebooks/inference/san.svg'))

In [None]:
class SANParams(ConfigClass):
    cond_dim: int = 8  # dimensions of conditioning info (e.g. photometry)
    data_dim: int = 9  # dimensions of data of interest (e.g. physical params)
    module_shape: list[int] = [16, 16]  # shape of the network 'modules'
    batch_norm: bool = True  # use batch normalisation in network?
sp = SANParams() 

In [None]:
def train_san(model: SAN, sp: SANParams, ip: InferenceParams) -> SAN:
    san = model(sp.cond_dim, sp.data_dim, sp.module_shape, 
                       device=ip.device, dtype=ip.dtype)

    savepath: str = san.fpath()
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load {san.likelihood_name} SAN model from {savepath}')
            san = t.load(savepath).to(ip.device, ip.dtype)
            logging.info(f'Successfully loaded')
            return san.cuda() if ip.device == t.device('cuda') else san
        except:
            logging.info(f'No model {savepath} found; training...')
            
    san.trainmodel(train_loader_1024, ip.epochs, ip.logging_frequency)
    logging.info(f'Trained {san.likelihood_name} SAN model')

    t.save(san, san.fpath())
    logging.info(f'Saved {san.likelihood_name} SAN model as: {san.fpath()}')
    return san.cuda() if ip.device == t.device('cuda') else san

## Gaussian SAN

A sequential autoregressive network with Gaussian likelihoods: $p(y_{d} \vert \mathbf{y}_{<d}, \mathbf{x}) = \mathcal{N}(y_{d}; \mu_{d}, \sigma^{2}_{d})$

In [None]:
gsan = train_san(Gaussian_SAN, sp, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)
# true_ys = t.tensor([true_ys[i] for i in fgmm_param_idxs])

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = gsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'{gsan} trained for {ip.epochs} epochs (batch size {ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='Gaussian "Sequential Autoregressive Network"',
                description=description)

## Post Mortem

Training this model takes longer than e.g. MADE, however the sampling procedure is fairly quick (e.g. 10,000 samples in 2.3ms). This model becomes relatively expensive for higher dimensional data, since the 'sequential block' must be repeated for each new output dimension; adding additional parameters to the model, which must be evaluated sequentially (i.e. the network becomes deeper, not wider, and this sequential computation is bad for parallelisation).

The conditional distributions additionally seem quite concentrated in instances where they perhaps should not be. This may be due to the light tails on the Gaussian likelihood. Below we attempt the same thing, this time using a Laplace likelihood:

## Laplace SAN

In [None]:
lsan = train_san(Laplace_SAN, sp, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)
# true_ys = t.tensor([true_ys[i] for i in fgmm_param_idxs])

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = lsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'{lsan} trained for {ip.epochs} epochs (batch size {ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='Laplace "Sequential Autoregressive Network"',
                description=description)

# MADE

**The model used in the following cells has a (known) bug to do with the mask sampling.**

This means that some dimensions do not satisfy the autoregressive property (i.e. are conditioned on 'too much'), and others are not conditioned on anything! The only value that these cells bring at the moment is to compare the speed of training and inference against e.g. the "sequential autoregressive network" above.

First, we try MADE with a single mask / input ordering.

In [None]:
class MADEInferenceParameters(ConfigClass):
    epochs: int = 5
    batch_size: int = 1024
    split_ratio: float = 0.9
    dtype: t.dtype = dtype
    device: t.device = device
    logging_frequency: int = 10000
    dataset_loc: str = './data/cubes/40M_shuffled.hdf5'
    retrain_model: bool = False
    overwrite_results: bool = True
made_ip = MADEInferenceParameters()

class MADEParams(ConfigClass):
    cond_dim: int = 8  # x; dimensions of photometry
    data_dim: int = 9  # y; dimensions of physical parameters to be estimated
    hidden_sizes: list[int] = [128, 128]
    likelihood: Type[made.MADE_Likelihood] = made.Gaussian
    likelihood_kwargs = None

    # maximum number of different masks / orderings for connectivity / order agnostic training
    num_masks: int = 128
    
    conditional_all: bool = False

    # The number of samples of masks to average parameters over for each 
    # training iteration. 
    samples: int = 16

    # whether to factorise the joint data distribution in the same order as the
    # dimensions are naturally given.
    natural_ordering: bool = False
mp = MADEParams()

In [None]:
mademodel = CMADE(cond_dim=mp.cond_dim, data_dim=mp.data_dim,
                  hidden_sizes=mp.hidden_sizes,
                  likelihood=mp.likelihood,
                  likelihood_kwargs=mp.likelihood_kwargs,
                  num_masks=mp.num_masks, natural_ordering=mp.natural_ordering,
                  device=made_ip.device, dtype=made_ip.dtype)
if made_ip.device == t.device('cuda'):
    mademodel = mademodel.cuda()

In [None]:
made_train_loader, made_test_loader = load_simulated_data(
    path=made_ip.dataset_loc,
    split_ratio=made_ip.split_ratio,
    batch_size=made_ip.batch_size,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
logging.info('Data loading complete')

In [None]:
mademodel.trainmodel(made_train_loader, made_ip.epochs, mp.samples, made_ip.logging_frequency)
logging.info("Trained MADE model")

t.save(mademodel, mademodel.fpath())
logging.info(f'Saved MADE model as: {mademodel.fpath()}')

In [None]:
mademodel = t.load(mademodel.fpath())

In [None]:
xs, true_ys = nbu.new_sample(made_test_loader)
xs = xs.to(device, dtype).unsqueeze(0)
true_ys = true_ys.to(device, dtype)

mask_idxs = t.randperm(mademodel.num_masks)[:1]
n_samples = 10000
mademodel.eval()
with t.no_grad():
    start = time.time()
    samples = mademodel.sample(xs, n_samples=n_samples, mask_idxs=mask_idxs).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'Trained {mademodel} for {made_ip.epochs} epochs (batch size {made_ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title="MADE",
                description=description)

In [None]:
xs, true_ys = nbu.new_sample(made_test_loader, 10000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

mask_idxs = t.randperm(made.num_masks)[:1]
mask_idxs = None

# samples per posterior
N = 1000

samples = t.empty((0, 9)).to(dtype=dtype)
sample_list: list[t.Tensor] = []

with t.inference_mode():
    sample_list = [made.sample(x.unsqueeze(0), n_samples=N, mask_idxs=mask_idxs).cpu() for x in xs]
samples = t.cat(sample_list, 0)

description = f'Trained {made} for {made_ip.epochs} epochs (batch size {made_ip.batch_size})'
        
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="MADE",
                   description=description)

### Post Mortem

This is clearly finding some signal in the data, although perhaps not as much as we would like. What is the effect of ensembling over more masks?

In [None]:
class MADEParams2(ConfigClass):
    cond_dim: int = 8  # x; dimensions of photometry
    data_dim: int = 9  # y; dimensions of physical parameters to be estimated
    hidden_sizes: list[int] = [32, 32]

    # maximum number of different masks / orderings for connectivity / order agnostic training
    num_masks: int = 128

    # The number of samples of masks to average parameters over for each 
    # training iteration. 
    samples: int = 16

    # whether to factorise the joint data distribution in the same order as the
    # dimensions are naturally given.
    natural_ordering: bool = False
mp2 = MADEParams2()

made2 = CMADE(cond_dim=mp2.cond_dim, data_dim=mp2.data_dim,
             hidden_sizes=mp2.hidden_sizes, out_size=2*mp2.data_dim,
             num_masks=mp2.num_masks, natural_ordering=mp2.natural_ordering,
             device=made_ip.device, dtype=made_ip.dtype)
if made_ip.device == t.device('cuda'):
    made2 = made2.cuda()

In [None]:
made2.trainmodel(made_train_loader, made_ip.epochs, mp2.samples, made_ip.logging_frequency)
logging.info("Trained MADE model")

t.save(made2, made2.fpath())
logging.info(f'Saved MADE model as: {made2.fpath()}')

In [None]:
made2 = t.load(made2.fpath())

In [None]:
xs, true_ys = nbu.new_sample(made_test_loader)
xs = xs.to(device, dtype).unsqueeze(0)
true_ys = true_ys.to(device, dtype)

mask_idxs = t.randperm(made2.num_masks)[:30]
# mask_idxs = None

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = made2.sample(xs, n_samples=n_samples, mask_idxs=mask_idxs).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'Trained {made2} for {made_ip.epochs} epochs (batch size {made_ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
# lims = None
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title="MADE with Mask Ensemble",
                description=description)

Looking at this is not very instructive because this is an ensemble of faulty masks - which makes the plot above, to put it nicely, garbage.

In [None]:
xs, true_ys = nbu.new_sample(made_test_loader, 10000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

# mask_idxs = t.randperm(made.num_masks)
mask_idxs = None

# samples per posterior
N = 1000

samples = t.empty((0, 9)).to(dtype=dtype)
sample_list: list[t.Tensor] = []

with t.inference_mode():
    sample_list = [made2.sample(x.unsqueeze(0), n_samples=N, mask_idxs=mask_idxs).cpu() for x in xs]
samples = t.cat(sample_list, 0)

description = f'Trained {made2} for {made_ip.epochs} epochs (batch size {made_ip.batch_size})'
        
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="MADE",
                   description=description)