# Criteria for LNPF

In this notebook we will investigate the inpact of using the ML or the ELBO objective for training members of LNPF.
We will also investigate the effect and/or need of using a lower bound for the standard deviation of the the latent variable and the posterior predictive.


In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import logging
import os
import warnings

import matplotlib.pyplot as plt
import torch

os.chdir("../..")

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")
logging.disable(logging.ERROR)

N_THREADS = 8
IS_FORCE_CPU = False  # Nota Bene : notebooks don't deallocate GPU memory

if IS_FORCE_CPU:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""

torch.set_num_threads(N_THREADS)

## Initialization

Let's load the data, here we will only be working with Gaussian Processes from a single underlying kernel. For more details, see the {doc}`data <Datasets>` notebook.

In [2]:
from utils.ntbks_helpers import get_datasets_single_gp

# DATASET
gp_datasets, gp_test_datasets, gp_valid_datasets = get_datasets_single_gp()

In [3]:
from npf.utils.datasplit import CntxtTrgtGetter, GetRandomIndcs
from utils.data import cntxt_trgt_collate

# CONTEXT TARGET SPLIT
get_cntxt_trgt_1d = cntxt_trgt_collate(
    CntxtTrgtGetter(contexts_getter=GetRandomIndcs(a=0.0, b=50))
)

Let us now make the model. We will make make one model for every member of LNPF. For each we will train them with both losses, with or without lower bound on the the std of the latent distribution, and with or without lower bound on the std of the predictive distribution.
This is a total of 24 models, so we will do in a loop. Note that besides training, the same models are used as in other notebooks.

In [4]:
from functools import partial
from npf import LNP,ConvLNP, AttnLNP
import torch
import torch.nn as nn
import torch.nn.functional as F
from npf.architectures import (
    CNN,
    MLP,
    ResConvBlock,
    SetConv,
    discard_ith_arg,
    merge_flat_input,
)
from utils.helpers import count_parameters

R_DIM = 128
KWARGS = dict(
    XEncoder=partial(MLP, n_hidden_layers=1, hidden_size=R_DIM),
    Decoder=merge_flat_input(  # MLP takes single input but we give x and R so merge them
        partial(MLP, n_hidden_layers=4, hidden_size=R_DIM), is_sum_merge=True,
    ),
    r_dim=R_DIM,
)


def get_std_processing_kwargs(min_sigma_pred=0.01, min_lat=None):
    """Function returning kwarhs for processing std"""
    kwargs = dict(
        p_y_scale_transformer=lambda y_scale: min_sigma_pred
        + (1 - min_sigma_pred) * F.softplus(y_scale)
    )

    if min_lat is not None:
        kwargs["q_z_scale_transformer"] = lambda y_scale: min_lat + (
            1 - min_lat
        ) * F.softplus(y_scale)

    return kwargs


def get_lnp(
    is_mle=True, min_sigma_pred=0.01, min_lat=None,
):

    KWARGS = dict(
        is_q_zCct=not is_mle,  # use MLE instead of ELBO
        n_z_samples_train=32 if is_mle else 1,  # going to be more expensive
        n_z_samples_test=32,
        XEncoder=partial(MLP, n_hidden_layers=1, hidden_size=R_DIM),
        Decoder=merge_flat_input(  # MLP takes single input but we give x and R so merge them
            partial(MLP, n_hidden_layers=4, hidden_size=R_DIM), is_sum_merge=True,
        ),
        r_dim=R_DIM,
        **get_std_processing_kwargs(min_sigma_pred=min_sigma_pred, min_lat=min_lat),
    )

    # 1D case
    model_1d = partial(
        LNP,
        x_dim=1,
        y_dim=1,
        XYEncoder=merge_flat_input(  # MLP takes single input but we give x and y so merge them
            partial(MLP, n_hidden_layers=2, hidden_size=R_DIM * 2), is_sum_merge=True,
        ),
        **KWARGS,
    )
    
    return model_1d


def get_attnlnp(
    is_mle=True, min_sigma_pred=0.01, min_lat=None,
):

    KWARGS = dict(
        is_q_zCct=not is_mle,  # use MLE instead of ELBO
        n_z_samples_train=8 if is_mle else 1,  # going to be more expensive
        n_z_samples_test=8,
        r_dim=R_DIM,
        attention="transformer",
        **get_std_processing_kwargs(min_sigma_pred=min_sigma_pred, min_lat=min_lat),
    )

    # 1D case
    model_1d = partial(
        AttnLNP,
        x_dim=1,
        y_dim=1,
        XYEncoder=merge_flat_input(  # MLP takes single input but we give x and y so merge them
            partial(MLP, n_hidden_layers=2, hidden_size=R_DIM), is_sum_merge=True,
        ),
        is_self_attn=False,
        **KWARGS,
    )

    return model_1d


def get_convlnp(
    is_mle=True, min_sigma_pred=0.01, min_lat=None, z_dim=None
):
    KWARGS = dict(
        is_q_zCct=not is_mle,  # use MLE instead of ELBO
        n_z_samples_train=16 if is_mle else 1, # going to be more expensive
        n_z_samples_test=32,
        r_dim=R_DIM,
        Decoder=discard_ith_arg(
            torch.nn.Linear, i=0
        ),  # use small decoder because already went through CNN
        z_dim=16, #! NPVI requires smaller number of latent channels due to the KL
        **get_std_processing_kwargs(min_sigma_pred=min_sigma_pred, min_lat=min_lat),
    )

    CNN_KWARGS = dict(
        ConvBlock=ResConvBlock,
        is_chan_last=True,  # all computations are done with channel last in our code
        n_conv_layers=2,
        n_blocks=4,
    )

    # 1D case
    model_1d = partial(
        ConvLNP,
        x_dim=1,
        y_dim=1,
        CNN=partial(
            CNN,
            Conv=torch.nn.Conv1d,
            Normalization=torch.nn.BatchNorm1d,
            kernel_size=19,
            **CNN_KWARGS,
        ),
        density_induced=64,  # size of discretization
        is_global=False, #! Global representation does not work well with NPVI
        **KWARGS,
    )

    return model_1d


lnpf_getters = dict(LNP=get_lnp, AttnLNP=get_attnlnp, ConvLNP=get_convlnp)


def get_name(lnpf, is_elbo, is_lat_LB, is_sigma_LB):
    return f"{lnpf}_ELBO{str(is_elbo)}_LatLB{str(is_lat_LB)}_SigLB{str(is_sigma_LB)}"

models = {
    get_name(lnpf, is_elbo, is_lat_LB, is_sigma_LB): lnpf_getters[
        lnpf
    ](
        is_mle=not is_elbo,
        min_sigma_pred=0.01 if is_sigma_LB else 1e-4,
        min_lat=None if is_lat_LB else 1e-4,
    )
    for lnpf in ["LNP", "AttnLNP", "ConvLNP"]
    for is_elbo in [True, False]
    for is_sigma_LB in [True, False]
    for is_lat_LB in [True, False]
}

In [5]:
len(models)

24

### Training

The main function for training is `train_models` which trains a dictionary of models on a dictionary of datasets and returns all the trained models.
See its docstring for possible parameters.

In [None]:
import skorch
from npf import NLLLossLNPF, ELBOLossLNPF
from utils.ntbks_helpers import add_y_dim
from utils.train import train_models

KWARGS = dict(
    is_retrain=False, # whether to load precomputed model or retrain
    chckpnt_dirname="results/pretrained/",
    device=None,  # use GPU if available
    batch_size=32,
    lr=1e-3,
    decay_lr=10,  # decrease learning rate by 10 during training
    seed=123,
    test_datasets=gp_test_datasets,
    train_split=None,  # No need for validation as the training data is generated on the fly
    iterator_train__collate_fn=get_cntxt_trgt_1d,
    iterator_valid__collate_fn=get_cntxt_trgt_1d,
    max_epochs=100,
)

# NPVI
trainers_1d_NPVI = train_models(
    gp_datasets,
    {k:v for k,v in models.items() if "ELBOTrue" in k},
    criterion=ELBOLossLNPF,  # NPVI
    **KWARGS
)

#NPML
trainers_1d_NPML = train_models(
    gp_datasets,
    {k:v for k,v in models.items() if "ELBOTrue" not in k},
    criterion=NLLLossLNPF,  # NPML
    **KWARGS
)

trainers_1d = {**trainers_1d_NPML, **trainers_1d_NPVI}


--- Training RBF_Kernel/LNP_ELBOTrue_LatLBTrue_SigLBTrue/run_0 ---



HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m177.9034[0m     +  99.9683


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      2      [36m176.0549[0m     +  103.8027


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      3      [36m175.7554[0m     +  103.4337


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      4      [36m173.0566[0m     +  103.6515


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      5      [36m171.7896[0m     +  104.9342


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      7      [36m167.6460[0m     +  103.3266


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      8      167.8569     +  101.0704


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     11      [36m158.6918[0m     +  98.3054


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     12      [36m157.9744[0m     +  97.5705


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     13      [36m157.0392[0m     +  97.8102


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     14      [36m154.9269[0m     +  98.0612


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     15      [36m153.2298[0m     +  97.5336


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     16      [36m152.5868[0m     +  97.4834


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     17      153.0487     +  97.3626


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     18      [36m152.5024[0m     +  97.7235


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     19      [36m151.1178[0m     +  97.9192


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     20      [36m150.9641[0m     +  98.3273


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     21      151.3775     +  97.3894


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     22      [36m149.3304[0m     +  98.1378


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     23      [36m148.4318[0m     +  97.7226


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     24      [36m147.7837[0m     +  98.0750


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     25      148.1965     +  97.5027


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     26      [36m146.1609[0m     +  97.9366


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     27      146.9731     +  97.5539


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     28      [36m145.2895[0m     +  97.9362


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     29      [36m145.0878[0m     +  98.3813


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     30      145.3347     +  97.8217


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     31      [36m144.1234[0m     +  97.7468


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     32      [36m142.0720[0m     +  97.6152


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     33      142.2372     +  98.0121


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     34      142.6986     +  97.4348


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     35      [36m139.1740[0m     +  98.6267


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     36      141.6484     +  97.2340


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     37      140.7108     +  97.6357


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     38      140.0815     +  98.0206


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     39      139.4058     +  98.1340


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     40      140.4177     +  97.6486


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     41      [36m136.3484[0m     +  97.5496


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     42      139.8401     +  97.5934


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     43      137.8045     +  98.5111


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     44      139.4644     +  98.0663


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     45      139.0457     +  98.0865


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     46      139.7116     +  97.2322


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     47      139.5585     +  97.4667


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     48      138.8185     +  97.6057


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     49      138.7101     +  97.5344


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     50      [36m135.9002[0m     +  98.2025


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     51      137.6867     +  97.7650


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     52      [36m134.7001[0m     +  98.3819


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     53      135.8487     +  97.4162


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     54      135.7094     +  97.4410


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     55      [36m134.5240[0m     +  97.6732


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     56      134.9099     +  98.2643


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     57      135.1254     +  98.5580


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     58      134.6373     +  98.4074


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     59      [36m133.8455[0m     +  97.4992


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     60      [36m133.8302[0m     +  97.6191


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     61      134.2900     +  97.8006


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     62      [36m133.4884[0m     +  97.6579


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     63      134.8651     +  97.2369


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     64      [36m132.5709[0m     +  98.0840


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     65      132.6529     +  98.2662


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     66      133.3572     +  97.5864


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     67      133.7605     +  97.4355


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     68      133.1031     +  98.2409


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     69      [36m131.0208[0m     +  97.4424


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     70      132.0883     +  97.8003


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     71      132.9164     +  97.5532


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     72      131.5082     +  97.7186


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     73      131.0979     +  98.0441


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     74      131.5093     +  97.3941


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     75      132.6045     +  98.1931


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     76      132.8279     +  97.6116


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     77      131.5157     +  97.9979


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     78      [36m130.2629[0m     +  98.5640


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     79      132.3553     +  97.4284


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     80      130.7610     +  98.5992


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     81      131.5630     +  98.6281


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     82      130.7517     +  98.8304


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     83      [36m130.1373[0m     +  97.0826


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     84      [36m127.6356[0m     +  98.5816


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     85      129.1720     +  98.4943


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     86      127.7872     +  98.1833


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     87      129.8711     +  97.8563


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     88      128.2503     +  98.0189


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     89      [36m126.4321[0m     +  98.4900


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     90      128.8310     +  98.3424


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     91      129.0092     +  98.3268


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     92      127.6994     +  98.8018


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     93      127.4778     +  98.2814


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     94      128.2189     +  98.3001


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     95      127.2165     +  98.9056


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     96      127.0985     +  97.8566


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     97      126.7246     +  98.2600


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     98      128.0278     +  98.3087


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     99      127.4678     +  98.3036


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

    100      127.1846     +  60.9133
Re-initializing module.
Re-initializing optimizer.


HBox(children=(FloatProgress(value=0.0, max=157.0), HTML(value='')))

RBF_Kernel/LNP_ELBOTrue_LatLBTrue_SigLBTrue/run_0 | best epoch: None | train loss: 126.4321 | valid loss: None | test log likelihood: -99.6728

--- Training RBF_Kernel/LNP_ELBOTrue_LatLBFalse_SigLBTrue/run_0 ---



HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m177.5167[0m     +  47.1347


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      2      [36m176.1950[0m     +  47.0247


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      3      176.2243     +  121.7584


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      4      [36m175.6326[0m     +  121.2091


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      5      [36m173.0347[0m     +  119.9323


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      6      [36m171.3825[0m     +  129.5171


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      7      [36m169.2744[0m     +  122.9550


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      8      [36m167.4270[0m     +  120.7581


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      9      [36m166.5722[0m     +  118.9969


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     10      [36m165.9699[0m     +  121.4509


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     11      [36m163.1006[0m     +  122.4282


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     12      [36m162.0550[0m     +  122.5874


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     13      [36m160.8986[0m     +  125.4537


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     14      [36m158.1635[0m     +  121.6101


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     15      [36m158.0399[0m     +  125.9440


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     16      [36m157.3807[0m     +  126.9261


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     17      [36m157.2003[0m     +  119.8089


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     18      [36m156.3180[0m     +  124.8057


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     19      [36m153.5280[0m     +  122.0002


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     20      [36m152.8159[0m     +  120.7129


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     21      153.2944     +  124.8109


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     22      [36m151.1410[0m     +  124.5688


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     23      151.9573     +  122.8012


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     24      151.7100     +  120.8507


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     25      152.5642     +  128.8707


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     26      [36m150.7401[0m     +  122.1747


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     27      151.4953     +  121.1391


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     28      [36m150.1921[0m     +  124.2655


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     29      [36m149.2672[0m     +  122.6661


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     30      150.5048     +  118.7845


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     31      149.7712     +  115.7306


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     32      [36m148.8626[0m     +  132.9415


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     33      [36m148.3325[0m     +  125.8767


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     34      [36m148.0805[0m     +  121.6976


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     35      [36m144.9834[0m     +  124.1851


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     36      146.5789     +  124.8723


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     37      145.6047     +  117.9410


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     38      145.1558     +  129.2036


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     39      [36m144.8601[0m     +  124.0504


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     40      145.8569     +  120.0572


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     41      [36m142.1505[0m     +  126.6578


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     43      143.2517     +  119.8431


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     44      144.8645     +  123.6119


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     46      144.8875     +  127.7962


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     47      144.0867     +  123.9012


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     48      143.9844     +  132.1709


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     49      143.6388     +  125.5236


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     50      [36m141.3484[0m     +  124.7762


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     51      141.5272     +  121.1404


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     52      [36m139.2386[0m     +  123.1780


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     53      140.4608     +  124.2100


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     54      139.8948     +  123.0784


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     55      139.4725     +  122.5045


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     56      139.9162     +  126.6443


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     57      139.4686     +  121.9294


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     58      139.2390     +  124.3668


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     64      [36m136.2612[0m     +  121.8139


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     65      [36m135.5048[0m     +  123.8501


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     71      134.4895     +  115.4791


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     72      133.4933     +  125.0601


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     75      134.1131     +  121.2255


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     76      134.4727     +  123.8040


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     77      133.5413     +  119.5893


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     78      [36m131.7764[0m     +  131.9032


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     79      133.9863     +  119.5108


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     80      132.6986     +  119.3395


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     81      133.3597     +  116.0368


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     82      132.4482     +  129.0467


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     83      132.3939     +  129.6232


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     84      [36m130.1123[0m     +  126.2507


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     85      131.7504     +  125.0865


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     86      130.9726     +  126.6415


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     87      133.1359     +  123.3736


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     88      132.0433     +  122.8515


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     89      130.1766     +  124.5150


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     90      132.3507     +  122.4752


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     91      132.4591     +  125.7490


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     92      131.2000     +  124.3235


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     93      131.1157     +  119.5062


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     94      131.4404     +  124.6957


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     95      131.3519     +  116.7177


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     96      130.3248     +  120.4392


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     97      130.7883     +  120.2955


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     98      131.7250     +  125.7427


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     99      131.4555     +  114.5677


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

    100      130.8418     +  123.9908
Re-initializing module.
Re-initializing optimizer.


HBox(children=(FloatProgress(value=0.0, max=157.0), HTML(value='')))

RBF_Kernel/LNP_ELBOTrue_LatLBFalse_SigLBTrue/run_0 | best epoch: None | train loss: 130.1123 | valid loss: None | test log likelihood: -106.0453

--- Training RBF_Kernel/LNP_ELBOTrue_LatLBTrue_SigLBFalse/run_0 ---



HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

  epoch    train_loss    cp       dur
-------  ------------  ----  --------
      1      [36m177.2794[0m     +  118.9804


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      2      [36m176.1856[0m     +  123.0297


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      3      [36m176.1592[0m     +  124.5061


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      4      [36m175.6906[0m     +  116.8558


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      5      [36m173.0330[0m     +  116.9929


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      6      [36m171.3944[0m     +  126.6705


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      7      [36m170.5599[0m     +  120.1905


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      8      [36m166.8541[0m     +  121.4413


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      9      [36m163.1285[0m     +  120.3632


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     10      [36m162.3892[0m     +  126.5041


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     11      [36m160.3659[0m     +  122.4341


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     12      [36m158.6204[0m     +  124.1829


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     13      [36m158.0072[0m     +  123.1067


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     14      [36m156.4918[0m     +  127.1771


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     15      [36m156.1746[0m     +  117.7971


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     16      [36m153.6619[0m     +  124.4471


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     17      [36m153.0137[0m     +  123.9071


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     18      [36m152.7521[0m     +  121.5588


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     19      [36m151.5968[0m     +  124.5699


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     20      [36m151.4086[0m     +  114.3636


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     21      152.1705     +  122.0490


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     22      [36m150.0752[0m     +  124.8225


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     23      150.9014     +  121.7228


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     24      150.6777     +  120.0460


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     25      151.3387     +  125.9888


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     26      [36m149.5504[0m     +  125.4678


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     27      150.5301     +  124.1062


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     28      [36m149.1941[0m     +  127.8418


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     29      [36m148.5228[0m     +  124.5722


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     30      149.6296     +  121.8310


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     31      148.6274     +  126.6713


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     32      [36m146.5755[0m     +  128.4150


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     33      [36m145.9512[0m     +  124.3861


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     34      145.9749     +  120.2166


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     35      [36m143.1122[0m     +  120.7755


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     36      145.0419     +  121.2655


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     37      143.7465     +  119.8937


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     38      [36m142.7146[0m     +  123.1037


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     39      [36m141.5364[0m     +  128.5324


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     40      142.0796     +  126.2781


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     41      [36m137.7383[0m     +  118.4869


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     42      140.6690     +  119.9360


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     43      138.6252     +  117.0899


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     44      140.2514     +  122.6779


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     47      139.5466     +  131.5968


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     48      139.1531     +  121.5369


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     49      138.7080     +  120.3405


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     50      [36m136.6153[0m     +  120.9868


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     51      138.2546     +  118.4245


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     52      [36m136.5814[0m     +  127.1860


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     53      137.9850     +  120.7419


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     54      137.5678     +  138.7421


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     55      137.2607     +  132.4058


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     56      137.3235     +  128.5715


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     57      [36m136.4699[0m     +  117.1797


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     58      [36m135.4029[0m     +  123.3717


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     59      [36m134.1692[0m     +  119.8564


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     60      134.4049     +  120.2216


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     61      134.9325     +  121.7918


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     62      [36m133.8455[0m     +  125.2576


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     63      134.9473     +  126.4617


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     64      [36m132.8049[0m     +  121.4981


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     65      [36m132.7600[0m     +  123.3299


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     66      133.3386     +  121.1755


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     67      133.7423     +  126.1595


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     68      133.2933     +  120.0755


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     69      [36m131.0707[0m     +  118.3386


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     70      132.0340     +  120.3853


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     71      132.7145     +  119.4634


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     72      131.5445     +  122.9022


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     73      131.1901     +  119.5370


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     74      131.4614     +  113.1507


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     75      132.4998     +  124.9866


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     76      133.0701     +  122.0474


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     77      131.8148     +  123.8133


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     78      [36m130.2228[0m     +  125.0314


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     79      132.3617     +  128.3347


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     80      131.1387     +  124.9052


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     81      131.8412     +  119.8158


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     82      130.8733     +  124.9982


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     83      130.9141     +  123.5878


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     84      [36m128.6382[0m     +  127.2600


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     85      130.3762     +  137.4278


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     86      129.6734     +  123.8823


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     87      131.7930     +  130.1953


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     88      130.6604     +  123.6005


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     89      128.7408     +  126.5042


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     90      131.0632     +  118.9503


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     91      131.0526     +  127.4573


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     92      129.8431     +  119.0735


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     93      129.6796     +  126.6794


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     94      130.3082     +  118.4387


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     95      129.8932     +  122.4460


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     96      128.9455     +  125.6986


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     97      129.3637     +  125.0290


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     98      130.5691     +  126.9080


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     99      130.1566     +  120.7102


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

    100      129.6066     +  130.3810
Re-initializing module.
Re-initializing optimizer.


HBox(children=(FloatProgress(value=0.0, max=157.0), HTML(value='')))

RBF_Kernel/LNP_ELBOTrue_LatLBTrue_SigLBFalse/run_0 | best epoch: None | train loss: 128.6382 | valid loss: None | test log likelihood: -104.5312

--- Training RBF_Kernel/LNP_ELBOTrue_LatLBFalse_SigLBFalse/run_0 ---



HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

  epoch    train_loss    cp       dur
-------  ------------  ----  --------
      1      [36m177.3238[0m     +  122.6260


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      2      [36m176.2274[0m     +  117.2256


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



      8      [36m169.6941[0m     +  125.0885


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

      9      [36m167.3747[0m     +  119.0108


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     15      [36m158.3354[0m     +  120.9263


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     16      [36m157.5610[0m     +  118.3842


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     22      [36m151.8063[0m     +  118.8181


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     23      152.2116     +  125.5918


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     29      [36m149.2996[0m     +  120.2549


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     60      136.2734     +  118.6461


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     61      136.3960     +  127.9309


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     63      136.1954     +  127.5010


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     64      [36m134.0120[0m     +  121.9638


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     66      134.5598     +  119.1033


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     69      [36m132.1590[0m     +  123.5208


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     16       45.3207     +  141.5777


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     17        8.4905     +  139.4558


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     23      -21.8435     +  138.5721


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     24      -39.7917     +  143.5775


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     30      -67.9619     +  136.5317


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     36      -82.4276     +  145.4545


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     37      -81.9875     +  142.1072


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     68     -122.7620     +  132.6887


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     74     -132.8603     +  142.9228


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     75     -118.5824     +  137.5915


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     81     -130.6337     +  131.1396


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     82     -130.5481     +  141.4557


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     88     -143.0017     +  142.0758


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     89     -148.6532     +  148.6397


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     19      [36m-41.9811[0m     +  139.4701


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     22      [36m-67.8781[0m     +  140.6351


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     23      -56.6119     +  142.5643


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     24      -67.3616     +  132.7415


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     25      -45.2894     +  135.4896


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     26      [36m-68.1793[0m     +  144.8964


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     27      [36m-68.5434[0m     +  142.2838


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     28      [36m-74.4836[0m     +  141.9397


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     29      [36m-81.5417[0m     +  136.5605


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     30      -81.1983     +  143.9262


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     31      -81.1302     +  139.8843


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     32      [36m-85.5258[0m     +  139.0885


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     33      -84.7346     +  141.5345


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     34      [36m-86.6957[0m     +  141.7420


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     35      -55.5498     +  136.5411


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     36      -47.6318     +  151.5500


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

     37      -75.2834     +  139.7237


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))

### Plots

Let's visualize how well the model performs in different settings.

#### GPs Dataset

Let's define a plotting function that we will use in this section. We'll reuse the same function defined in {doc}`CNP notebook <CNP>`, but will use `n_samples = 20` to plot multiple posterior predictives conditioned on different latent samples.

In [None]:
from utils.ntbks_helpers import PRETTY_RENAMER, plot_multi_posterior_samples_1d
from utils.visualize import giffify


def multi_posterior_gp_gif(filename, trainers, datasets, seed=123, **kwargs):
    giffify(
        save_filename=f"jupyter/gifs/{filename}.gif",
        gen_single_fig=plot_multi_posterior_samples_1d,  # core plotting
        sweep_parameter="n_cntxt",  # param over which to sweep
        sweep_values=[0, 2, 5, 7, 10, 15, 20, 30, 50, 100],
        fps=0.5,  # gif speed
        # PLOTTING KWARGS
        trainers=trainers,
        datasets=datasets,
        is_plot_generator=True,  # plot underlying GP
        is_plot_real=False,  # don't plot sampled / underlying function
        is_plot_std=True,  # plot the predictive std
        is_fill_generator_std=False,  # do not fill predictive of GP
        pretty_renamer=PRETTY_RENAMER,  # pretiffy names of modulte + data
        # Fix formatting for coherent GIF
        plot_config_kwargs=dict(
            set_kwargs=dict(ylim=[-3, 3]), rc={"legend.loc": "upper right"}
        ),
        seed=seed,
        **kwargs,
    )

Let us visualize the CNP when it is trained on samples from a single GP.

In [None]:
def filter_npf(d, lnpf, is_elbo, is_lat_LB, is_sigma_LB):
    """Select only data form single GP."""
    return {k: v for k, v in d.items() if "/"+get_name(lnpf, is_elbo, is_lat_LB, is_sigma_LB) in k}

for lnpf in ["LNP", "AttnLNP", "ConvLNP"]:
    for is_sigma_LB in [True, False]:
        for is_lat_LB in [True, False]:
            multi_posterior_gp_gif(
                f"singlegp_{lnpf}_LatLB{str(is_lat_LB)}_SigLB{str(is_sigma_LB)}",
                trainers=filter_npf(trainers_1d, lnpf, is_elbo=False, is_lat_LB=is_lat_LB, is_sigma_LB=is_sigma_LB),
                trainers_compare=filter_npf(trainers_1d, lnpf, is_elbo=True, is_lat_LB=is_lat_LB, is_sigma_LB=is_sigma_LB),
                datasets=gp_test_datasets,
                n_samples=20,  # 20 samples from the latent
                title="{model_name} | {data_name} | C={n_cntxt}",
                imgsize=(6, 3),
            )

Let's now visualize all of these plots.

### LNP

#### No Lower bounds

```{figure} ../gifs/singlegp_LNP_LatLBFalse_SigLBFalse.gif
---
width: 60em
name: singlegp_LNP_LatLBFalse_SigLBFalse
---
```

#### Lower bounded std of latent

```{figure} ../gifs/singlegp_LNP_LatLBTrue_SigLBFalse.gif
---
width: 60em
name: singlegp_LNP_LatLBTrue_SigLBFalse
---
```

#### Lower bounded std of predictive

```{figure} ../gifs/singlegp_LNP_LatLBFalse_SigLBTrue.gif
---
width: 60em
name: singlegp_LNP_LatLBFalse_SigLBTrue
---
```

#### Both Lower Bounds


```{figure} ../gifs/singlegp_LNP_LatLBTrue_SigLBTrue.gif
---
width: 60em
name: singlegp_LNP_LatLBTrue_SigLBTrue
---
```

### AttnLNP

#### No Lower bounds

```{figure} ../gifs/singlegp_AttnLNP_LatLBFalse_SigLBFalse.gif
---
width: 60em
name: singlegp_AttnLNP_LatLBFalse_SigLBFalse
---
```

#### Lower bounded std of latent

```{figure} ../gifs/singlegp_AttnLNP_LatLBTrue_SigLBFalse.gif
---
width: 60em
name: singlegp_AttnLNP_LatLBTrue_SigLBFalse
---
```

#### Lower bounded std of predictive

```{figure} ../gifs/singlegp_AttnLNP_LatLBFalse_SigLBTrue.gif
---
width: 60em
name: singlegp_AttnLNP_LatLBFalse_SigLBTrue
---
```

#### Both Lower Bounds


```{figure} ../gifs/singlegp_AttnLNP_LatLBTrue_SigLBTrue.gif
---
width: 60em
name: singlegp_AttnLNP_LatLBTrue_SigLBTrue
---
```

### ConvLNP


```{warning} 

For NPVI to train with ConvLNP we had to remove the global representation and decrease the number of channels to `z_dim=16`.
The models for NPVI and NPML are thus slighlty different.
```


#### No Lower bounds

```{figure} ../gifs/singlegp_ConvLNP_LatLBFalse_SigLBFalse.gif
---
width: 60em
name: singlegp_ConvLNP_LatLBFalse_SigLBFalse
---
```

#### Lower bounded std of latent

```{figure} ../gifs/singlegp_ConvLNP_LatLBTrue_SigLBFalse.gif
---
width: 60em
name: singlegp_ConvLNP_LatLBTrue_SigLBFalse
---
```

#### Lower bounded std of predictive

```{figure} ../gifs/singlegp_ConvLNP_LatLBFalse_SigLBTrue.gif
---
width: 60em
name: singlegp_ConvLNP_LatLBFalse_SigLBTrue
---
```

#### Both Lower Bounds


```{figure} ../gifs/singlegp_ConvLNP_LatLBTrue_SigLBTrue.gif
---
width: 60em
name: singlegp_ConvLNP_LatLBTrue_SigLBTrue
---
```


In [11]:
###### ADDITIONAL 1D PLOTS ######

#TO Chose