# VAE Inference

In [None]:
import os
import time
import random
import corner
import logging
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import agnfinder
import agnfinder.inference as inference
import agnfinder.inference.san as san

from typing import Type, Any
from IPython.display import SVG, display
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.distributions import Normal, MultivariateNormal, StudentT, Laplace

from agnfinder import config as cfg
from agnfinder import nbutils as nbu
from agnfinder.types import ConfigClass, column_order
from agnfinder.inference import CMADE, SAN
from agnfinder.inference.base import CVAE, CVAEParams, cvae_t, arch_t
from agnfinder.inference.utils import Squareplus, squareplus_f, load_simulated_data, GalaxyDataset, normalise_phot_np
from agnfinder.simulation.utils import denormalise_theta, normalise_theta

try: # One-time setup
    assert(_SETUP)
except NameError:
    cfg.configure_logging()
    os.chdir(os.path.split(agnfinder.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        # !nvidia-smi
    else:
        print("CUDA is unavailable; training on CPU.")
    _SETUP = True

In [None]:
class InferenceParams(ConfigClass):
    epochs: int = 5
    batch_size: int = 1024
    split_ratio: float = 0.9
    dtype: t.dtype = dtype
    device: t.device = device
    logging_frequency: int = 10000
    dataset_loc: str = './data/cubes/40M_shuffled.hdf5'
    # dataset_loc: str = './data/cubes/latest_sample/'
    retrain_model: bool = False  # prefer an existing model over re-training
    overwrite_results: bool = True  # if we do re-train, save it

ip = InferenceParams()

In [None]:
train_loader_1024, test_loader_1024 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=1024,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
train_loader_10000, test_loader_10000 = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=10000,
    normalise_phot=normalise_phot_np,
    transforms=[transforms.ToTensor()]
)
logging.info('Data loading complete.')

In [None]:
class GCVAEParams(ConfigClass, CVAEParams):
    cond_dim = 8  # x; dimension of photometry
    data_dim = 9  # y; len(FreeParameters()); dimensions of physical params
    latent_dim = 2  # z
    adam_lr = 1e-3

    # (conditional) Gaussian prior network p_{theta}(z | x)
    # prior = inference.StandardGaussianPrior
    # prior_arch = None
    prior = inference.FactorisedGaussianPrior
    prior_arch = arch_t(
        layer_sizes=[cond_dim, 16],
        activations=nn.SiLU(),
        head_sizes=[latent_dim, latent_dim],
        head_activations=[None, Squareplus()],
        batch_norm=False)

    # Gaussian recognition model q_{phi}(z | y, x)
    encoder = inference.FactorisedGaussianEncoder
    enc_arch = arch_t(
        layer_sizes=[data_dim + cond_dim, 32, 16],
        activations=nn.ReLU(),
        head_sizes=[latent_dim, latent_dim], # mean and log_std
        head_activations=[None, Squareplus()],
        batch_norm=False)


    # Gaussian generator network arch: p_{theta}(y | z, x)
    decoder = inference.FactorisedGaussianDecoder
    dec_arch = arch_t(
        layer_sizes=[latent_dim + cond_dim, 32, 16],
        head_sizes=[data_dim, data_dim],
        activations=nn.ReLU(),
        head_activations=[None, Squareplus()],
        batch_norm=False)
    
gcp = GCVAEParams()

In [None]:
def train_cvae(cp: CVAEParams, ip: InferenceParams) -> CVAE:
    cvae = CVAE(gcp, device, dtype)
    
    savepath: str = cvae.fpath()
    if not ip.retrain_model:
        try:
            logging.info(f'Attempting to load CVAE from {savepath}')
            cvae = t.load(savepath).to(ip.device, ip.dtype)
            logging.info(f'Successfully loaded')
            return cvae.cuda() if ip.device == t.device('cuda') else cvae
        except:
            logging.info(f'No model {savepath} found; training...')
            
    cvae.trainmodel(train_loader_1024, ip.epochs, ip.logging_frequency)
    logging.info(f'Trained CVAE model')
    
    t.save(cvae, cvae.fpath())
    logging.info(f'Saved CVAE model as {cvae.fpath()}')
    return cvae

In [None]:
cvae = CVAE(gcp, device, dtype, True)

gcvae = train_cvae(gcp, ip)

In [None]:
xs, true_ys = nbu.new_sample(test_loader_1024)
xs = xs.to(device, dtype)
if xs.dim() == 1:
    xs = xs.unsqueeze(0)
true_ys = true_ys.to(device, dtype)

n_samples = 1000
n_post = 100
cvae.eval()
with t.inference_mode():
    start = time.time()
    zs = gcvae.prior(xs).sample((n_samples,)).squeeze(1)
    xs_e = xs.expand(n_samples, xs.size(-1))
    samples = gcvae.decoder(zs, xs_e).sample((n_post,)).reshape(-1, 9).cpu()
    sampling_time = (time.time() - start) * 1e3
logging.info(f'Finished drawing {n_post * n_samples:,} samples in {sampling_time:.4f}ms.')
logging.info('Plotting results...')

description = f'CVAE trained for {ip.epochs} epochs (batch size {ip.batch_size})'

lims = np.array([[0.,1.]]).repeat(len(column_order),0)
nbu.plot_corner(samples=samples.numpy(), true_params=true_ys.cpu().numpy(), 
                lims=lims, labels=column_order, title='Gaussian "Sequential Autoregressive Network"',
                description=description)