# Additional Results

This notebook is for generating plots that require different models. It will not be detailed.

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import logging
import os
import warnings

import matplotlib.pyplot as plt
import torch

os.chdir("../..")

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")
logging.disable(logging.ERROR)

N_THREADS = 8
IS_FORCE_CPU = False  # Nota Bene : notebooks don't deallocate GPU memory

if IS_FORCE_CPU:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""

torch.set_num_threads(N_THREADS)

In [2]:
from npf.utils.datasplit import (
    CntxtTrgtGetter,
    GetRandomIndcs,
    GridCntxtTrgtGetter,
    RandomMasker,
    get_all_indcs,
    no_masker,
)
from utils.data import cntxt_trgt_collate, get_test_upscale_factor
from utils.ntbks_helpers import get_all_gp_datasets, get_img_datasets

# DATASETS
# merges : get_datasets_single_gp, get_datasets_varying_hyp_gp, get_datasets_varying_kernel_gp
gp_datasets, gp_test_datasets, gp_valid_datasets = get_all_gp_datasets()


# CONTEXT TARGET SPLIT
get_cntxt_trgt_1d = cntxt_trgt_collate(
    CntxtTrgtGetter(
        contexts_getter=GetRandomIndcs(a=0.0, b=50), targets_getter=get_all_indcs,
    )
)

In [3]:
from functools import partial

from npf import CNP, AttnCNP
from npf.architectures import MLP, merge_flat_input
from utils.helpers import count_parameters




R_DIM = 128
KWARGS = dict(
    XEncoder=partial(MLP, n_hidden_layers=1, hidden_size=R_DIM),
    r_dim=R_DIM,
    x_dim=1,
    y_dim=1,
)


model_1d = dict()

# 1D case
model_1d["CNP"] = partial(
    CNP,
    XYEncoder=merge_flat_input(  # MLP takes single input but we give x and y so merge them
        partial(MLP, n_hidden_layers=2, hidden_size=R_DIM * 2), is_sum_merge=True,
    ),
    **KWARGS,
)


# 1D case
model_1d["AttnCNP"] = partial(
    AttnCNP,
    attention="transformer",  # multi headed attention with normalization and skip connections
    is_self_attn=False,
    XYEncoder=merge_flat_input(  # MLP takes single input but we give x and y so merge them
        partial(MLP, n_hidden_layers=2, hidden_size=R_DIM), is_sum_merge=True,
    ),
    **KWARGS,
)

In [4]:
import skorch
from npf import CNPFLoss
from utils.ntbks_helpers import add_y_dim
from utils.train import train_models

KWARGS = dict(
    is_retrain=False,  # whether to load precomputed model or retrain
    criterion=CNPFLoss,  # (approx) conditional ELBO Loss
    chckpnt_dirname="results/pretrained/",
    device=None,  # use GPU if available
    batch_size=32,
    lr=1e-3,
    decay_lr=10,  # decrease learning rate by 10 during training
    seed=123,
)


# 1D
trainers_1d = train_models(
    gp_datasets,
    model_1d,
    test_datasets=gp_test_datasets,
    iterator_train__collate_fn=get_cntxt_trgt_1d,
    iterator_valid__collate_fn=get_cntxt_trgt_1d,
    max_epochs=100,
    **KWARGS
)


--- Loading RBF_Kernel/CNP/run_0 ---

RBF_Kernel/CNP/run_0 | best epoch: None | train loss: 8.4633 | valid loss: None | test log likelihood: -16.1129

--- Loading RBF_Kernel/AttnCNP/run_0 ---

RBF_Kernel/AttnCNP/run_0 | best epoch: None | train loss: -157.9417 | valid loss: None | test log likelihood: 149.158

--- Loading Periodic_Kernel/CNP/run_0 ---

Periodic_Kernel/CNP/run_0 | best epoch: None | train loss: 129.0426 | valid loss: None | test log likelihood: -126.4177

--- Loading Periodic_Kernel/AttnCNP/run_0 ---

Periodic_Kernel/AttnCNP/run_0 | best epoch: None | train loss: 21.2395 | valid loss: None | test log likelihood: -25.4617

--- Loading Noisy_Matern_Kernel/CNP/run_0 ---

Noisy_Matern_Kernel/CNP/run_0 | best epoch: None | train loss: 111.3382 | valid loss: None | test log likelihood: -115.7692

--- Loading Noisy_Matern_Kernel/AttnCNP/run_0 ---

Noisy_Matern_Kernel/AttnCNP/run_0 | best epoch: None | train loss: 87.5571 | valid loss: None | test log likelihood: -91.5147

---

In [5]:
from utils.ntbks_helpers import PRETTY_RENAMER, plot_multi_posterior_samples_1d
from utils.visualize import giffify


def multi_posterior_gp_gif(filename, trainers, datasets, seed=123, **kwargs):
    giffify(
        save_filename=f"jupyter/gifs/{filename}.gif",
        gen_single_fig=plot_multi_posterior_samples_1d,  # core plotting
        sweep_parameter="n_cntxt",  # param over which to sweep
        sweep_values=[0, 2, 5, 7, 10, 15, 20, 30, 50, 100],
        fps=1.2,  # gif speed
        # PLOTTING KWARGS
        trainers=trainers,
        datasets=datasets,
        is_plot_generator=True,  # plot underlying GP
        is_plot_real=False,  # don't plot sampled / underlying function
        is_plot_std=True,  # plot the predictive std
        is_fill_generator_std=False,  # do not fill predictive of GP
        pretty_renamer=PRETTY_RENAMER,  # pretiffy names of modulte + data
        # Fix formatting for coherent GIF
        plot_config_kwargs=dict(
            set_kwargs=dict(ylim=[-3, 3]), rc={"legend.loc": "upper right"}
        ),
        seed=seed,
        **kwargs,
    )

In [6]:
def filter_rbf(d):
    """Select only data form RBF."""
    return {k: v for k, v in d.items() if ("RBF" in k)}

multi_posterior_gp_gif(
    "CNP_AttnCNP_rbf_extrap",
    trainers=filter_rbf(trainers_1d),
    datasets=filter_rbf(gp_test_datasets),
    left_extrap=-2,  # shift signal 2 to the right for extrapolation
    right_extrap=2,  # shift signal 2 to the right for extrapolation
)