# Dark Energy Survey (DES) Colours Test

In this notebook, we use the DES catalogue and calculate 'colours' (pairwise differences between filter values ('maggies') as inputs to the network.

Colours are calculated as $\mathcal{C} = \big\{ (f_{i} - f_{j}) | i, j \in [1, N], i < j \big\}$, for filter values $\mathbf{f} = [f_{1}, \ldots, f_{N}]$.

There are good reasons to believe that this won't do anything much (other than slow down the network). Since the magnitudes ('maggies') are already normalised as $\hat{\mathbf{f}} = \log \big(\frac{\mathbf{f} - \overline{\mathbf{f}}}{\text{std}(\mathbf{f})}\big)$, the benefit of inputting the relative distance between filter values is minimal. Worse, the conditioning information is now $N(N-1)/2$ instead of $N$; growing as $O(N^2)$ for $N$ the number of filters, and this significantly slows down the network. Owing to the depth of the SAN, and the alternating widths / bottlenecks in the architecture, it is likely that the network learns good representations on its own, without us needing to compute the colours by hand.

In [None]:
import os
import time
import logging
import numpy as np
import torch as t
import matplotlib.pyplot as plt

from matplotlib.colors import to_rgb, to_rgba
from typing import Type, Any
from torchvision import transforms

# TODO clean up these imports!
import agnfinder
import agnfinder.inference.san as san
import agnfinder.inference.inference as inf

from agnfinder import config as cfg
from agnfinder import nbutils as nbu
from agnfinder.types import ConfigClass, column_order
from agnfinder.inference import SAN
from agnfinder.inference.utils import load_simulated_data, normalise_phot_np, maggies_to_colours_np, get_colours_length

try: # One-time setup
    assert(_SETUP)
except NameError:
    cfg.configure_logging()
    os.chdir(os.path.split(agnfinder.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        # !nvidia-smi
    else:
        print("CUDA is unavailable; training on CPU.")
    _SETUP = True

In [None]:
class InferenceParams(inf.InferenceParams):
    model: inf.model_t = san.SAN
    logging_frequency: int = 10000
    dataset_loc: str = './data/cubes/des_sample/photometry_simulation_40000000n_z_0p0000_to_6p0000.hdf5'
    retrain_model: bool = False
    overwrite_results: bool = True
    
ip = InferenceParams()
fp = cfg.FreeParams()

In [None]:
train_loader_1024, test_loader_1024 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=1024,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()],
    x_transforms=[maggies_to_colours_np]
)

In [None]:
train_loader_10000, test_loader_10000 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=10000,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()],
    x_transforms=[maggies_to_colours_np]
)
logging.info('Data loading complete.')

## SAN with Mixture of Gaussians

In [None]:
class MoGSANParams(san.SANParams):
    epochs: int = 2
    batch_size: int = 1024
    dtype: t.dtype = t.float32
    
    # cond_dim: int = 7  # x; dimension of photometry / colours
    cond_dim: int = get_colours_length(7)  # x; number of colours
    data_dim: int = 9  # dimensions of data of interest (e.g. physical params)
    # module_shape: list[int] = [512, 512]  # shape of the network 'modules'
    module_shape: list[int] = [64, 64]  # shape of the network 'modules'
    sequence_features: int = 8  # features passed between sequential blocks
    likelihood: Type[san.SAN_Likelihood] = san.MoG
    likelihood_kwargs: dict[str, Any] = {'K': 10}
    batch_norm: bool = True  # use batch normalisation in network?

mgsp = MoGSANParams() 

In [None]:
mogsan = san.SAN(mgsp)
mogsan.trainmodel(train_loader_1024, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
true_ys = true_ys.to(dtype=dtype)

n_samples = 10000
with t.inference_mode():
    start = time.time()
    samples = mogsan.sample(xs, n_samples=n_samples).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='SAN on DES Catalogue with Colours',
                description=str(mogsan))

In [None]:
xs, true_ys = nbu.new_sample(test_loader_10000, 10000)
xs = xs.to(device, dtype)
true_ys = true_ys.to(device, dtype)

# samples per posterior
N = 100

with t.inference_mode():
    xs, _ = mogsan.preprocess(xs, t.empty(xs.shape))
    samples = mogsan.forward(xs.repeat_interleave(N, 0))
logging.info('Finished sampling. Plotting')
    
true_ys = true_ys.repeat_interleave(N, 0).cpu().numpy()
nbu.plot_posteriors(samples.cpu().numpy(), true_ys, title="SAN on DES Catalogue with Colours",
                    description=f'{mogsan}')