# Criteria for LNPF

In this notebook we will investigate the inpact of using the ML or the ELBO objective for training members of LNPF.
We will also investigate the effect and/or need of using a lower bound for the standard deviation of the the latent variable and the posterior predictive.


In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import logging
import os
import warnings

import matplotlib.pyplot as plt
import torch

os.chdir("../..")

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")
logging.disable(logging.ERROR)

N_THREADS = 8
IS_FORCE_CPU = False  # Nota Bene : notebooks don't deallocate GPU memory

if IS_FORCE_CPU:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""

torch.set_num_threads(N_THREADS)

## Initialization

Let's load the data, here we will only be working with Gaussian Processes from a single underlying kernel. For more details, see the {doc}`data <Datasets>` notebook.

In [2]:
from utils.ntbks_helpers import get_datasets_single_gp

# DATASET
gp_datasets, gp_test_datasets, gp_valid_datasets = get_datasets_single_gp()

In [3]:
from npf.utils.datasplit import CntxtTrgtGetter, GetRandomIndcs
from utils.data import cntxt_trgt_collate

# CONTEXT TARGET SPLIT
get_cntxt_trgt_1d = cntxt_trgt_collate(
    CntxtTrgtGetter(contexts_getter=GetRandomIndcs(a=0.0, b=50))
)

Let us now make the model. We will make make one model for every member of LNPF. For each we will train them with both losses, with or without lower bound on the the std of the latent distribution, and with or without lower bound on the std of the predictive distribution.
This is a total of 24 models, so we will do in a loop. Note that besides training, the same models are used as in other notebooks.

In [4]:
from functools import partial
from npf import LNP,ConvLNP, AttnLNP
import torch
import torch.nn as nn
import torch.nn.functional as F
from npf.architectures import (
    CNN,
    MLP,
    ResConvBlock,
    SetConv,
    discard_ith_arg,
    merge_flat_input,
)
from utils.helpers import count_parameters

R_DIM = 128
KWARGS = dict(
    XEncoder=partial(MLP, n_hidden_layers=1, hidden_size=R_DIM),
    Decoder=merge_flat_input(  # MLP takes single input but we give x and R so merge them
        partial(MLP, n_hidden_layers=4, hidden_size=R_DIM), is_sum_merge=True,
    ),
    r_dim=R_DIM,
)


def get_std_processing_kwargs(min_sigma_pred=0.01, min_lat=None):
    """Function returning kwarhs for processing std"""
    kwargs = dict(
        p_y_scale_transformer=lambda y_scale: min_sigma_pred
        + (1 - min_sigma_pred) * F.softplus(y_scale)
    )

    if min_lat is not None:
        kwargs["q_z_scale_transformer"] = lambda y_scale: min_lat + (
            1 - min_lat
        ) * F.softplus(y_scale)

    return kwargs


def get_lnp(
    is_mle=True, min_sigma_pred=0.01, min_lat=None,
):

    KWARGS = dict(
        is_q_zCct=not is_mle,  # use MLE instead of ELBO
        n_z_samples_train=32 if is_mle else 1,  # going to be more expensive
        n_z_samples_test=32,
        XEncoder=partial(MLP, n_hidden_layers=1, hidden_size=R_DIM),
        Decoder=merge_flat_input(  # MLP takes single input but we give x and R so merge them
            partial(MLP, n_hidden_layers=4, hidden_size=R_DIM), is_sum_merge=True,
        ),
        r_dim=R_DIM,
        **get_std_processing_kwargs(min_sigma_pred=min_sigma_pred, min_lat=min_lat),
    )

    # 1D case
    model_1d = partial(
        LNP,
        x_dim=1,
        y_dim=1,
        XYEncoder=merge_flat_input(  # MLP takes single input but we give x and y so merge them
            partial(MLP, n_hidden_layers=2, hidden_size=R_DIM * 2), is_sum_merge=True,
        ),
        **KWARGS,
    )
    
    return model_1d


def get_attnlnp(
    is_mle=True, min_sigma_pred=0.01, min_lat=None,
):

    KWARGS = dict(
        is_q_zCct=not is_mle,  # use MLE instead of ELBO
        n_z_samples_train=8 if is_mle else 1,  # going to be more expensive
        n_z_samples_test=8,
        r_dim=R_DIM,
        attention="transformer",
        **get_std_processing_kwargs(min_sigma_pred=min_sigma_pred, min_lat=min_lat),
    )

    # 1D case
    model_1d = partial(
        AttnLNP,
        x_dim=1,
        y_dim=1,
        XYEncoder=merge_flat_input(  # MLP takes single input but we give x and y so merge them
            partial(MLP, n_hidden_layers=2, hidden_size=R_DIM), is_sum_merge=True,
        ),
        is_self_attn=False,
        **KWARGS,
    )

    return model_1d


def get_convlnp(
    is_mle=True, min_sigma_pred=0.01, min_lat=None, z_dim=None
):
    KWARGS = dict(
        is_q_zCct=not is_mle,  # use MLE instead of ELBO
        n_z_samples_train=16 if is_mle else 1, # going to be more expensive
        n_z_samples_test=32,
        r_dim=R_DIM,
        Decoder=discard_ith_arg(
            torch.nn.Linear, i=0
        ),  # use small decoder because already went through CNN
        z_dim=16, #! NPVI requires smaller number of latent channels due to the KL
        **get_std_processing_kwargs(min_sigma_pred=min_sigma_pred, min_lat=min_lat),
    )

    CNN_KWARGS = dict(
        ConvBlock=ResConvBlock,
        is_chan_last=True,  # all computations are done with channel last in our code
        n_conv_layers=2,
        n_blocks=4,
    )

    # 1D case
    model_1d = partial(
        ConvLNP,
        x_dim=1,
        y_dim=1,
        CNN=partial(
            CNN,
            Conv=torch.nn.Conv1d,
            Normalization=torch.nn.BatchNorm1d,
            kernel_size=19,
            **CNN_KWARGS,
        ),
        density_induced=64,  # size of discretization
        is_global=False, #! Global representation does not work well with NPVI
        **KWARGS,
    )

    return model_1d


lnpf_getters = dict(LNP=get_lnp, AttnLNP=get_attnlnp, ConvLNP=get_convlnp)


def get_name(lnpf, is_elbo, is_lat_LB, is_sigma_LB):
    return f"{lnpf}_ELBO{str(is_elbo)}_LatLB{str(is_lat_LB)}_SigLB{str(is_sigma_LB)}"

models = {
    get_name(lnpf, is_elbo, is_lat_LB, is_sigma_LB): lnpf_getters[
        lnpf
    ](
        is_mle=not is_elbo,
        min_sigma_pred=0.01 if is_sigma_LB else 1e-4,
        min_lat=None if is_lat_LB else 1e-4,
    )
    for lnpf in ["LNP", "AttnLNP", "ConvLNP"]
    for is_elbo in [True, False]
    for is_sigma_LB in [True, False]
    for is_lat_LB in [True, False]
}

In [5]:
len(models)

8

### Training

The main function for training is `train_models` which trains a dictionary of models on a dictionary of datasets and returns all the trained models.
See its docstring for possible parameters.

In [6]:
import skorch
from npf import NLLLossLNPF, ELBOLossLNPF
from utils.ntbks_helpers import add_y_dim
from utils.train import train_models

KWARGS = dict(
    is_retrain=True, # whether to load precomputed model or retrain
    chckpnt_dirname="results/pretrained/",
    device=None,  # use GPU if available
    batch_size=32,
    lr=1e-3,
    decay_lr=10,  # decrease learning rate by 10 during training
    seed=123,
    test_datasets=gp_test_datasets,
    train_split=None,  # No need for validation as the training data is generated on the fly
    iterator_train__collate_fn=get_cntxt_trgt_1d,
    iterator_valid__collate_fn=get_cntxt_trgt_1d,
    max_epochs=100,
)

# NPVI
trainers_1d_NPVI = train_models(
    gp_datasets,
    {k:v for k,v in models.items() if "ELBOTrue" in k},
    criterion=ELBOLossLNPF,  # NPVI
    **KWARGS
)

#NPML
trainers_1d_NPML = train_models(
    gp_datasets,
    {k:v for k,v in models.items() if "ELBOTrue" not in k},
    criterion=NLLLossLNPF,  # NPML
    **KWARGS
)

trainers_1d = {**trainers_1d_NPML, **trainers_1d_NPVI}


--- Training RBF_Kernel/ConvLNP_ELBOTrue_LatLBTrue_SigLBTrue/run_0 ---



HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

  epoch    train_loss    cp       dur
-------  ------------  ----  --------
      1      [36m129.7038[0m     +  239.3549


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      2      [36m101.4122[0m     +  238.1240


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      3       [36m97.4390[0m     +  238.3368


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      4       [36m92.0397[0m     +  238.1451


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      5       [36m87.2336[0m     +  238.0833


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      6       [36m84.0536[0m     +  238.4104


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      7       [36m79.8557[0m     +  238.3659


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      8       83.7127     +  238.3537


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      9       81.0838     +  238.0101


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     10       80.6303     +  238.3044


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     11       [36m76.1241[0m     +  238.4895


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     12       [36m75.7715[0m     +  238.7090


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     13       [36m72.1939[0m     +  238.3308


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     14       [36m57.8714[0m     +  238.4733


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     15       [36m54.2789[0m     +  238.5238


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     16       [36m44.9809[0m     +  238.2171


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     17       [36m40.3735[0m     +  238.2151


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     18       [36m37.7643[0m     +  238.8333


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     19       [36m32.3731[0m     +  238.4594


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     20       [36m26.4459[0m     +  238.3471


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     21       30.2063     +  238.1728


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     22       [36m21.5765[0m     +  238.5040


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     23       29.4292     +  238.0256


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     24       26.3641     +  238.4433


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     25       30.0697     +  238.1362


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     26       [36m18.7951[0m     +  238.4706


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     27       27.3368     +  238.1425


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     28       19.0851     +  238.8775


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     29       [36m16.5148[0m     +  238.7939


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     30       20.4330     +  244.5869


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     31       18.0716     +  242.0035


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     32       [36m14.7289[0m     +  238.3282


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     33       18.4413     +  238.1761


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     34       23.1330     +  238.7662


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     35        [36m7.4369[0m     +  238.5179


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     36       22.0328     +  238.4506


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     37       18.8443     +  238.4743


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     38       12.9400     +  238.5806


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     39       12.1659     +  238.3955


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     40       16.9035     +  238.0642


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     41       [36m-1.1496[0m     +  238.0272


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     42       15.4954     +  238.4320


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     43        5.0167     +  238.4243


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     44       15.3910     +  238.3112


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     45       11.9266     +  237.9933


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     46       17.7535     +  238.7897


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     47       15.8736     +  238.8321


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     48       13.3900     +  238.4870


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     49       12.9759     +  238.4963


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     50        0.9926     +  238.4289


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     51       10.8183     +  238.2757


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     52        0.6480     +  238.7333


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     53       10.0863     +  238.6328


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     54        8.7167     +  238.8326


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     55        6.4785     +  238.6568


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     56        6.8276     +  238.2139


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     57        8.8167     +  238.5518


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     58        8.1628     +  238.4619


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     59        4.0275     +  238.5637


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     60        6.7623     +  238.2706


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     61        5.9025     +  238.2616


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     62        5.2453     +  238.3772


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     63        9.3041     +  238.3914


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     64        0.3898     +  238.4141


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     65        0.4236     +  238.8593


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     66        3.1711     +  238.4839


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     67        7.8289     +  238.6293


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     68        4.1811     +  238.3139


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     69       [36m-5.6372[0m     +  238.2651


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     70       -1.3111     +  238.5241


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     71        4.3642     +  238.3230


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     72       -2.0597     +  238.5121


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     73       -4.1778     +  238.6974


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     74       -3.6988     +  238.4022


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     75        2.4047     +  238.4325


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     76        5.9431     +  238.6175


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     77       -2.3961     +  238.3699


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     78       [36m-7.9538[0m     +  238.2562


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     79        2.3153     +  238.4240


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     80       -4.7170     +  238.5504


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     81       -0.3072     +  238.5983


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     82       -4.7691     +  238.5618


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     83       -3.7401     +  238.2080


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     84      [36m-15.8007[0m     +  238.4962


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     85       -5.2551     +  238.7403


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     86      -10.2924     +  238.5412


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     87       -0.8481     +  243.8790


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     88       -6.0228     +  245.8084


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     89      -13.6405     +  237.6706


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     90       -2.3882     +  238.7418


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     91       -2.2633     +  238.7118


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     92       -7.2702     +  238.5521


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     93       -8.4949     +  238.5709


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     94       -4.8103     +  238.1701


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     95       -8.1157     +  238.0129


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     96      -10.1068     +  237.8133


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     97       -9.7320     +  238.2391


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     98       -3.6467     +  238.4105


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     99       -6.8139     +  237.9917


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

    100       -8.0791     +  238.7962
Re-initializing module.
Re-initializing optimizer.


HBox(children=(FloatProgress(value=0.0, max=157.0), HTML(value='')))

RuntimeError: CUDA out of memory. Tried to allocate 24.00 GiB (GPU 0; 10.92 GiB total capacity; 420.09 MiB already allocated; 1.43 GiB free; 1.17 GiB reserved in total by PyTorch)

### Plots

Let's visualize how well the model performs in different settings.

#### GPs Dataset

Let's define a plotting function that we will use in this section. We'll reuse the same function defined in {doc}`CNP notebook <CNP>`, but will use `n_samples = 20` to plot multiple posterior predictives conditioned on different latent samples.

In [10]:
from utils.ntbks_helpers import PRETTY_RENAMER, plot_multi_posterior_samples_1d
from utils.visualize import giffify


def multi_posterior_gp_gif(filename, trainers, datasets, seed=123, **kwargs):
    giffify(
        save_filename=f"jupyter/gifs/{filename}.gif",
        gen_single_fig=plot_multi_posterior_samples_1d,  # core plotting
        sweep_parameter="n_cntxt",  # param over which to sweep
        sweep_values=[0, 2, 5, 7, 10, 15, 20, 30, 50, 100],
        fps=0.5,  # gif speed
        # PLOTTING KWARGS
        trainers=trainers,
        datasets=datasets,
        is_plot_generator=True,  # plot underlying GP
        is_plot_real=False,  # don't plot sampled / underlying function
        is_plot_std=True,  # plot the predictive std
        is_fill_generator_std=False,  # do not fill predictive of GP
        pretty_renamer=PRETTY_RENAMER,  # pretiffy names of modulte + data
        # Fix formatting for coherent GIF
        plot_config_kwargs=dict(
            set_kwargs=dict(ylim=[-3, 3]), rc={"legend.loc": "upper right"}
        ),
        seed=seed,
        **kwargs,
    )

Let us visualize the CNP when it is trained on samples from a single GP.

In [11]:
def filter_npf(d, lnpf, is_elbo, is_lat_LB, is_sigma_LB):
    """Select only data form single GP."""
    return {k: v for k, v in d.items() if "/"+get_name(lnpf, is_elbo, is_lat_LB, is_sigma_LB) in k}

for lnpf in ["LNP", "AttnLNP", "ConvLNP"]:
    for is_sigma_LB in [True, False]:
        for is_lat_LB in [True, False]:
            multi_posterior_gp_gif(
                f"singlegp_{lnpf}_LatLB{str(is_lat_LB)}_SigLB{str(is_sigma_LB)}",
                trainers=filter_npf(trainers_1d, lnpf, is_elbo=False, is_lat_LB=is_lat_LB, is_sigma_LB=is_sigma_LB),
                trainers_compare=filter_npf(trainers_1d, lnpf, is_elbo=True, is_lat_LB=is_lat_LB, is_sigma_LB=is_sigma_LB),
                datasets=gp_test_datasets,
                n_samples=20,  # 20 samples from the latent
                title="{model_name} | {data_name} | C={n_cntxt}",
                imgsize=(6, 3),
            )

Let's now visualize all of these plots.

### LNP

#### No Lower bounds

```{figure} ../gifs/singlegp_LNP_LatLBFalse_SigLBFalse.gif
---
width: 60em
name: singlegp_LNP_LatLBFalse_SigLBFalse
---
```

#### Lower bounded std of latent

```{figure} ../gifs/singlegp_LNP_LatLBTrue_SigLBFalse.gif
---
width: 60em
name: singlegp_LNP_LatLBTrue_SigLBFalse
---
```

#### Lower bounded std of predictive

```{figure} ../gifs/singlegp_LNP_LatLBFalse_SigLBTrue.gif
---
width: 60em
name: singlegp_LNP_LatLBFalse_SigLBTrue
---
```

#### Both Lower Bounds


```{figure} ../gifs/singlegp_LNP_LatLBTrue_SigLBTrue.gif
---
width: 60em
name: singlegp_LNP_LatLBTrue_SigLBTrue
---
```

### AttnLNP

#### No Lower bounds

```{figure} ../gifs/singlegp_AttnLNP_LatLBFalse_SigLBFalse.gif
---
width: 60em
name: singlegp_AttnLNP_LatLBFalse_SigLBFalse
---
```

#### Lower bounded std of latent

```{figure} ../gifs/singlegp_AttnLNP_LatLBTrue_SigLBFalse.gif
---
width: 60em
name: singlegp_AttnLNP_LatLBTrue_SigLBFalse
---
```

#### Lower bounded std of predictive

```{figure} ../gifs/singlegp_AttnLNP_LatLBFalse_SigLBTrue.gif
---
width: 60em
name: singlegp_AttnLNP_LatLBFalse_SigLBTrue
---
```

#### Both Lower Bounds


```{figure} ../gifs/singlegp_AttnLNP_LatLBTrue_SigLBTrue.gif
---
width: 60em
name: singlegp_AttnLNP_LatLBTrue_SigLBTrue
---
```

### ConvLNP


```{warning} 

For NPVI to train with ConvLNP we had to remove the global representation and decrease the number of channels to `z_dim=16`.
The models for NPVI and NPML are thus slighlty different.
```


#### No Lower bounds

```{figure} ../gifs/singlegp_ConvLNP_LatLBFalse_SigLBFalse.gif
---
width: 60em
name: singlegp_ConvLNP_LatLBFalse_SigLBFalse
---
```

#### Lower bounded std of latent

```{figure} ../gifs/singlegp_ConvLNP_LatLBTrue_SigLBFalse.gif
---
width: 60em
name: singlegp_ConvLNP_LatLBTrue_SigLBFalse
---
```

#### Lower bounded std of predictive

```{figure} ../gifs/singlegp_ConvLNP_LatLBFalse_SigLBTrue.gif
---
width: 60em
name: singlegp_ConvLNP_LatLBFalse_SigLBTrue
---
```

#### Both Lower Bounds


```{figure} ../gifs/singlegp_ConvLNP_LatLBTrue_SigLBTrue.gif
---
width: 60em
name: singlegp_ConvLNP_LatLBTrue_SigLBTrue
---
```


In [11]:
###### ADDITIONAL 1D PLOTS ######

#TO Chose