# Experiment 1
### In this experiment we compare the predictive capabilities of three meta models across four types of dataset. 
We compare the new amortised global inducing point BNN with a vanilla neural process and an amortised MFVI BNN as two baseline models. The test datasets are:
- Noisy sample from squared exponential covarince GP prior
- Noisy sample from Laplacian covariance GP prior
- Noisy sample from periodic covariance GP prior
- Noisy cubic dataset with central gap

For each dataset, a non meta model representing the 'ground truth' is used to compare predictions against. For the GP-generated datasets, the posterior from a hyperparameter-optimised GP with the corresponding covariance function is used, and for the cubic dataset, the vanilla global inducing point BNN is used.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt
import gpytorch
import tqdm.auto as tqdm
import sys

torch.set_default_dtype(torch.float64)

sys.path.append("../")
from bnn_amort_inf.models.bnn import gibnn, mfvi_bnn
from bnn_amort_inf.models import gp, np
from bnn_amort_inf import utils

### Generate metadatasets

In [None]:
num_datasets = 1000
meta_datasets = {}
kernels = ["se", "per", "lap"]

x_min = -5.0
x_max = 5.0

for kernel in kernels:
    meta_datasets[kernel] = utils.dataset_utils.MetaDataset(
        [
            utils.gp_datasets.gp_dataset_generator(
                kernel=kernel, x_min=x_min, x_max=x_max
            )
            for _ in range(num_datasets)
        ]
    )
meta_datasets["saw"] = utils.dataset_utils.MetaDataset(
    [
        utils.dataset_utils.sawtooth_dataset(lower=x_min, upper=x_max)
        for _ in range(num_datasets)
    ]
)

meta_datasets["mix"] = [
    utils.gp_datasets.gp_dataset_generator(
        kernel=kernels[i % 4], x_min=x_min, x_max=x_max
    )
    if i % 4 != 3
    else utils.dataset_utils.sawtooth_dataset(lower=x_min, upper=x_max)
    for i in range(num_datasets)
]

### Generate test datasets

In [None]:
test_datasets = {}

for kernel in kernels:
    test_datasets[kernel] = utils.gp_datasets.gp_dataset_generator(
        min_n=25, max_n=35, kernel=kernel
    )

test_datasets["saw"] = utils.dataset_utils.sawtooth_dataset(min_n=50, max_n=60)

test_datasets["cub"] = utils.dataset_utils.cubic_dataset()

### Define and train models

In [None]:
plot_training_metrics = False

##### Amortised GIBNN

In [None]:
amortised_gibnn = gibnn.AmortisedGIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[50, 50],
    in_hidden_dims=[50, 50],
    noise=1e-1,
    train_noise=False,
)

agibnn_tracker = utils.training_utils.train_metamodel(
    amortised_gibnn,
    meta_datasets["mix"],
    min_es_iters=2_000,
    smooth_es_iters=500,
    batch_size=1,
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(agibnn_tracker.keys()),
        1,
        figsize=(8, len(agibnn_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

Amortised GIBNN with NP loss

In [None]:
np_amortised_gibnn = gibnn.AmortisedGIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[50, 50],
    in_hidden_dims=[50, 50],
    noise=5e-2,
    train_noise=False,
)

npagibnn_tracker = utils.training_utils.train_metamodel(
    np_amortised_gibnn,
    meta_datasets["mix"],
    np_loss=True,
    min_es_iters=2_000,
    smooth_es_iters=500,
    batch_size=1,
)

if True:
    fig, axes = plt.subplots(
        len(npagibnn_tracker.keys()),
        1,
        figsize=(8, len(npagibnn_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, npagibnn_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

##### Amortised MFVI BNN

In [None]:
amortised_mfvibnn = mfvi_bnn.AmortisedMFVIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    in_hidden_dims=[20, 20],
    noise=1e-2,
    train_noise=True,
)

agibnn_tracker = utils.training_utils.train_metamodel(
    amortised_mfvibnn,
    meta_datasets["mix"],
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(agibnn_tracker.keys()),
        1,
        figsize=(8, len(agibnn_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

##### Neural Process

In [None]:
cnp = np.CNP(
    x_dim=1,
    y_dim=1,
    embedded_dim=64,
    encoder_hidden_dims=[128, 128],
    decoder_hidden_dims=[128, 128],
    train_noise=True,
)

cnp_tracker = utils.training_utils.train_metamodel(
    cnp,
    meta_datasets["mix"],
    neural_process=True,
    lr=1e-3,
    max_iters=50_000,
    batch_size=10,
    min_es_iters=10_000,
    ref_es_iters=1_000,
    smooth_es_iters=500,
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(cnp_tracker.keys()),
        1,
        figsize=(8, len(cnp_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, cnp_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

##### SE GP

In [None]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
x, y = test_datasets["se"]
dataset = torch.utils.data.TensorDataset(x, y)
se_gp_model = gp.GPModel(x.squeeze(), y.squeeze(), likelihood, kernel="se")

se_gp_tracker = utils.training_utils.train_gp(
    se_gp_model,
    likelihood,
    dataset,
    dataset_size=x.shape[0],
)

if plot_training_metrics:
    plt.plot(se_gp_tracker["loss"])
    plt.ylabel("marginal likelihood")
    plt.grid()
    plt.show()

##### Periodic GP

In [None]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
x, y = test_datasets["per"]
dataset = torch.utils.data.TensorDataset(x, y)
per_gp_model = gp.GPModel(x.squeeze(), y.squeeze(), likelihood, kernel="per")

per_gp_tracker = utils.training_utils.train_gp(
    per_gp_model,
    likelihood,
    dataset,
    dataset_size=x.shape[0],
)

if plot_training_metrics:
    plt.plot(per_gp_tracker["loss"])
    plt.ylabel("marginal likelihood")
    plt.grid()
    plt.show()

##### Laplacian GP

In [None]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
x, y = test_datasets["lap"]
dataset = torch.utils.data.TensorDataset(x, y)
lap_gp_model = gp.GPModel(x.squeeze(), y.squeeze(), likelihood, kernel="lap")

likelihood.noise_covar.noise = 1e-2

lap_gp_tracker = utils.training_utils.train_gp(
    lap_gp_model,
    likelihood,
    dataset,
    dataset_size=x.shape[0],
)

if plot_training_metrics:
    plt.plot(lap_gp_tracker["loss"])
    plt.ylabel("marginal likelihood")
    plt.grid()
    plt.show()

##### Periodic GP for sawtooth

In [None]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
x, y = test_datasets["saw"]
dataset = torch.utils.data.TensorDataset(x, y)
saw_gp_model = gp.GPModel(x.squeeze(), y.squeeze(), likelihood, kernel="saw")

saw_gp_tracker = utils.training_utils.train_gp(
    saw_gp_model,
    likelihood,
    dataset,
    dataset_size=x.shape[0],
)

if plot_training_metrics:
    plt.plot(saw_gp_tracker["loss"])
    plt.ylabel("marginal likelihood")
    plt.grid()
    plt.show()

##### GIBNN 

In [None]:
x, y = test_datasets["cub"]
dataset = torch.utils.data.TensorDataset(x, y)
n = x.shape[0]
num_inducing = 10
rand_perm = torch.randperm(n)[:num_inducing]
inducing_points = x[rand_perm]

v_gibnn = gibnn.GIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    num_inducing=num_inducing,
    inducing_points=inducing_points,
    train_noise=True,
)

v_gibnn_tracker = utils.training_utils.train_model(
    v_gibnn,
    dataset,
    batch_size=128,
    lr=1e-2,
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(v_gibnn_tracker.keys()),
        1,
        figsize=(8, len(v_gibnn_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, v_gibnn_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

### Compare predictions

In [None]:
x_test = torch.linspace(-5.0, 5.0, 1000).unsqueeze(-1)
num_models = 5

##### SE covariance GP-generated test dataset

In [None]:
fig, axes = plt.subplots(num_models, 1, figsize=(6, 3 * num_models), sharex=True)
x, y = test_datasets["se"]

# generate predictions
agibnn_preds = amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
npagibnn_preds = np_amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
amfvibnn_preds = amortised_mfvibnn(x, y, x_test=x_test, num_samples=100)[-1]
cnp_preds = cnp(x, y, x_t=x_test)
se_gp_model.eval()
se_gp_preds = se_gp_model(x_test)

# AGIBNN
axes[0].plot(
    x_test,
    agibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in agibnn_preds[:-1]:
    axes[0].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[0].plot(
    x_test,
    agibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[0].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[0].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[0].set_title("Amortised GIBNN", fontsize=18)
axes[0].set_ylim([-5.0, 5.0])
axes[0].set_xlim([-5.0, 5.0])
axes[0].set_xticklabels([])
axes[0].set_xticks([], [])
axes[0].set_yticklabels([])
axes[0].set_yticks([], [])
# axes[0].legend()

# MFVIBNN
axes[1].plot(
    x_test,
    amfvibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in amfvibnn_preds[:-1]:
    axes[1].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[1].plot(
    x_test,
    amfvibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[1].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[1].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[1].set_title("Amortised MFVIBNN", fontsize=18)
axes[1].set_ylim([-5.0, 5.0])
axes[1].set_xlim([-5.0, 5.0])
axes[1].set_xticklabels([])
axes[1].set_xticks([], [])
axes[1].set_yticklabels([])
axes[1].set_yticks([], [])
# axes[1].legend()

# CNP
pred_mean = cnp_preds.loc.detach().numpy()
pred_std = cnp_preds.scale.detach().numpy()
axes[2].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean rediction",
)
axes[2].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
# axes[2].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[2].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[2].set_title("Vanilla CNP", fontsize=18)
axes[2].set_ylim([-5.0, 5.0])
axes[2].set_xlim([-5.0, 5.0])
axes[2].set_xticklabels([])
axes[2].set_xticks([], [])
axes[2].set_yticklabels([])
axes[2].set_yticks([], [])
# axes[2].legend()

# GP
pred_mean = se_gp_preds.mean.detach().numpy()
pred_std = torch.sqrt(se_gp_preds.variance.detach()).numpy()
axes[3].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
axes[3].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
axes[3].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[3].set_title("SE covariance GP", fontsize=18)
axes[3].set_ylim([-5.0, 5.0])
axes[3].set_xlim([-5.0, 5.0])
axes[3].set_xticklabels([])
axes[3].set_xticks([], [])
axes[3].set_yticklabels([])
axes[3].set_yticks([], [])
# axes[3].legend()

# NP loss AGIBNN
axes[4].plot(
    x_test,
    npagibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in npagibnn_preds[:-1]:
    axes[4].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[4].plot(
    x_test,
    npagibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[4].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[4].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[4].set_title("Amortised GIBNN, NP Loss", fontsize=18)
axes[4].set_ylim([-5.0, 5.0])
axes[4].set_xlim([-5.0, 5.0])
axes[4].set_xticklabels([])
axes[4].set_xticks([], [])
axes[4].set_yticklabels([])
axes[4].set_yticks([], [])
# axes[4].legend()

plt.show()

##### Periodic covariance GP-generated test dataset

In [None]:
fig, axes = plt.subplots(num_models, 1, figsize=(6, 3 * num_models), sharex=True)
x, y = test_datasets["per"]

# generate predictions
agibnn_preds = amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
npagibnn_preds = np_amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
amfvibnn_preds = amortised_mfvibnn(x, y, x_test=x_test, num_samples=100)[-1]
cnp_preds = cnp(x, y, x_t=x_test)
per_gp_model.eval()
per_gp_preds = per_gp_model(x_test)

# AGIBNN
axes[0].plot(
    x_test,
    agibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in agibnn_preds[:-1]:
    axes[0].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[0].plot(
    x_test,
    agibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[0].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[0].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[0].set_title("Amortised GIBNN", fontsize=18)
axes[0].set_ylim([-5.0, 5.0])
axes[0].set_xlim([-5.0, 5.0])
axes[0].set_xticklabels([])
axes[0].set_xticks([], [])
axes[0].set_yticklabels([])
axes[0].set_yticks([], [])
# axes[0].legend()

# MFVIBNN
axes[1].plot(
    x_test,
    amfvibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in amfvibnn_preds[:-1]:
    axes[1].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[1].plot(
    x_test,
    amfvibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[1].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[1].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[1].set_title("Amortised MFVIBNN", fontsize=18)
axes[1].set_ylim([-5.0, 5.0])
axes[1].set_xlim([-5.0, 5.0])
axes[1].set_xticklabels([])
axes[1].set_xticks([], [])
axes[1].set_yticklabels([])
axes[1].set_yticks([], [])
# axes[1].legend()

# CNP
pred_mean = cnp_preds.loc.detach().numpy()
pred_std = cnp_preds.scale.detach().numpy()
axes[2].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean rediction",
)
axes[2].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
# axes[2].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[2].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[2].set_title("Vanilla CNP", fontsize=18)
axes[2].set_ylim([-5.0, 5.0])
axes[2].set_xlim([-5.0, 5.0])
axes[2].set_xticklabels([])
axes[2].set_xticks([], [])
axes[2].set_yticklabels([])
axes[2].set_yticks([], [])
# axes[2].legend()

# GP
pred_mean = per_gp_preds.mean.detach().numpy()
pred_std = torch.sqrt(per_gp_preds.variance.detach()).numpy()
axes[3].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
axes[3].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
axes[3].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[3].set_title("Periodic covariance GP", fontsize=18)
axes[3].set_ylim([-5.0, 5.0])
axes[3].set_xlim([-5.0, 5.0])
axes[3].set_xticklabels([])
axes[3].set_xticks([], [])
axes[3].set_yticklabels([])
axes[3].set_yticks([], [])
# axes[3].legend()

# NP loss AGIBNN
axes[4].plot(
    x_test,
    npagibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in npagibnn_preds[:-1]:
    axes[4].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[4].plot(
    x_test,
    npagibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[4].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[4].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[4].set_title("Amortised GIBNN, NP Loss", fontsize=18)
axes[4].set_ylim([-5.0, 5.0])
axes[4].set_xlim([-5.0, 5.0])
axes[4].set_xticklabels([])
axes[4].set_xticks([], [])
axes[4].set_yticklabels([])
axes[4].set_yticks([], [])
# axes[4].legend()

plt.show()

##### Laplacian covariance GP-generated test dataset

In [None]:
fig, axes = plt.subplots(num_models, 1, figsize=(6, 3 * num_models), sharex=True)
x, y = test_datasets["lap"]

# generate predictions
agibnn_preds = amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
npagibnn_preds = np_amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
amfvibnn_preds = amortised_mfvibnn(x, y, x_test=x_test, num_samples=100)[-1]
cnp_preds = cnp(x, y, x_t=x_test)
lap_gp_model.eval()
lap_gp_preds = lap_gp_model(x_test)

# AGIBNN
axes[0].plot(
    x_test,
    agibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in agibnn_preds[:-1]:
    axes[0].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[0].plot(
    x_test,
    agibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[0].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[0].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[0].set_title("Amortised GIBNN", fontsize=18)
axes[0].set_ylim([-5.0, 5.0])
axes[0].set_xlim([-5.0, 5.0])
axes[0].set_xticklabels([])
axes[0].set_xticks([], [])
axes[0].set_yticklabels([])
axes[0].set_yticks([], [])
# axes[0].legend()

# MFVIBNN
axes[1].plot(
    x_test,
    amfvibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in amfvibnn_preds[:-1]:
    axes[1].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[1].plot(
    x_test,
    amfvibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[1].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[1].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[1].set_title("Amortised MFVIBNN", fontsize=18)
axes[1].set_ylim([-5.0, 5.0])
axes[1].set_xlim([-5.0, 5.0])
axes[1].set_xticklabels([])
axes[1].set_xticks([], [])
axes[1].set_yticklabels([])
axes[1].set_yticks([], [])
# axes[1].legend()

# CNP
pred_mean = cnp_preds.loc.detach().numpy()
pred_std = cnp_preds.scale.detach().numpy()
axes[2].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean rediction",
)
axes[2].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
# axes[2].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[2].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[2].set_title("Vanilla CNP", fontsize=18)
axes[2].set_ylim([-5.0, 5.0])
axes[2].set_xlim([-5.0, 5.0])
axes[2].set_xticklabels([])
axes[2].set_xticks([], [])
axes[2].set_yticklabels([])
axes[2].set_yticks([], [])
# axes[2].legend()

# GP
pred_mean = lap_gp_preds.mean.detach().numpy()
pred_std = torch.sqrt(lap_gp_preds.variance.detach()).numpy()
axes[3].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
axes[3].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
axes[3].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[3].set_title("Laplacian covariance GP", fontsize=18)
axes[3].set_ylim([-5.0, 5.0])
axes[3].set_xlim([-5.0, 5.0])
axes[3].set_xticklabels([])
axes[3].set_xticks([], [])
axes[3].set_yticklabels([])
axes[3].set_yticks([], [])
# axes[3].legend()

# NP loss AGIBNN
axes[4].plot(
    x_test,
    npagibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in npagibnn_preds[:-1]:
    axes[4].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[4].plot(
    x_test,
    npagibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[4].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[4].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[4].set_title("Amortised GIBNN, NP Loss", fontsize=18)
axes[4].set_ylim([-5.0, 5.0])
axes[4].set_xlim([-5.0, 5.0])
axes[4].set_xticklabels([])
axes[4].set_xticks([], [])
axes[4].set_yticklabels([])
axes[4].set_yticks([], [])
# axes[4].legend()

plt.show()

##### Sawtooth dataset

In [None]:
fig, axes = plt.subplots(num_models, 1, figsize=(6, 3 * num_models), sharex=True)
x, y = test_datasets["saw"]

# generate predictions
agibnn_preds = amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
npagibnn_preds = np_amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
amfvibnn_preds = amortised_mfvibnn(x, y, x_test=x_test, num_samples=100)[-1]
cnp_preds = cnp(x, y, x_t=x_test)
saw_gp_model.eval()
saw_gp_preds = saw_gp_model(x_test)

# AGIBNN
axes[0].plot(
    x_test,
    agibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in agibnn_preds[:-1]:
    axes[0].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[0].plot(
    x_test,
    agibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[0].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[0].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[0].set_title("Amortised GIBNN")
axes[0].set_ylim([-5.0, 5.0])
axes[0].set_xlim([-5.0, 5.0])
axes[0].set_xticklabels([])
axes[0].set_xticks([], [])
axes[0].set_yticklabels([])
axes[0].set_yticks([], [])
# axes[0].legend()

# MFVIBNN
axes[1].plot(
    x_test,
    amfvibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in amfvibnn_preds[:-1]:
    axes[1].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[1].plot(
    x_test,
    amfvibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[1].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[1].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[1].set_title("Amortised MFVIBNN")
axes[1].set_ylim([-5.0, 5.0])
axes[1].set_xlim([-5.0, 5.0])
axes[1].set_xticklabels([])
axes[1].set_xticks([], [])
axes[1].set_yticklabels([])
axes[1].set_yticks([], [])
# axes[1].legend()

# CNP
pred_mean = cnp_preds.loc.detach().numpy()
pred_std = cnp_preds.scale.detach().numpy()
axes[2].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean rediction",
)
axes[2].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
# axes[2].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[2].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[2].set_title("Vanilla CNP")
axes[2].set_ylim([-5.0, 5.0])
axes[2].set_xlim([-5.0, 5.0])
axes[2].set_xticklabels([])
axes[2].set_xticks([], [])
axes[2].set_yticklabels([])
axes[2].set_yticks([], [])
# axes[2].legend()

# GP
pred_mean = saw_gp_preds.mean.detach().numpy()
pred_std = torch.sqrt(saw_gp_preds.variance.detach()).numpy()
axes[3].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
axes[3].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
axes[3].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[3].set_title("Periodic covariance GP")
axes[3].set_ylim([-5.0, 5.0])
axes[3].set_xlim([-5.0, 5.0])
axes[3].set_xticklabels([])
axes[3].set_xticks([], [])
axes[3].set_yticklabels([])
axes[3].set_yticks([], [])
# axes[3].legend()

# NP loss AGIBNN
axes[4].plot(
    x_test,
    npagibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in npagibnn_preds[:-1]:
    axes[4].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[4].plot(
    x_test,
    npagibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[4].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[4].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[4].set_title("Amortised GIBNN, NP Loss", fontsize=18)
axes[4].set_ylim([-5.0, 5.0])
axes[4].set_xlim([-5.0, 5.0])
axes[4].set_xticklabels([])
axes[4].set_xticks([], [])
axes[4].set_yticklabels([])
axes[4].set_yticks([], [])
# axes[4].legend()

plt.show()

##### Cubic dataset

In [None]:
fig, axes = plt.subplots(num_models, 1, figsize=(6, 3 * num_models), sharex=True)
x, y = test_datasets["cub"]

# generate predictions
agibnn_preds = amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
npagibnn_preds = np_amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]
amfvibnn_preds = amortised_mfvibnn(x, y, x_test=x_test, num_samples=100)[-1]
cnp_preds = cnp(x, y, x_t=x_test)
v_gibnn_preds = v_gibnn(x_test, num_samples=100)[0]

# AGIBNN
axes[0].plot(
    x_test,
    agibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in agibnn_preds[:-1]:
    axes[0].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[0].plot(
    x_test,
    agibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[0].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[0].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[0].set_title("Amortised GIBNN", fontsize=18)
axes[0].set_ylim([-5.0, 5.0])
axes[0].set_xlim([-5.0, 5.0])
axes[0].set_xticklabels([])
axes[0].set_xticks([], [])
axes[0].set_yticklabels([])
axes[0].set_yticks([], [])
# axes[0].legend()

# MFVIBNN
axes[1].plot(
    x_test,
    amfvibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in amfvibnn_preds[:-1]:
    axes[1].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[1].plot(
    x_test,
    amfvibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[1].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[1].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[1].set_title("Amortised MFVIBNN", fontsize=18)
axes[1].set_ylim([-5.0, 5.0])
axes[1].set_xlim([-5.0, 5.0])
axes[1].set_xticklabels([])
axes[1].set_xticks([], [])
axes[1].set_yticklabels([])
axes[1].set_yticks([], [])
# axes[1].legend()

# CNP
pred_mean = cnp_preds.loc.detach().numpy()
pred_std = cnp_preds.scale.detach().numpy()
axes[2].plot(
    x_test,
    pred_mean,
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean rediction",
)
axes[2].fill_between(
    x_test.squeeze(),
    (pred_mean + 2 * pred_std).squeeze(),
    (pred_mean - 2 * pred_std).squeeze(),
    # color="C0",
    alpha=0.3,
    label="95% confidence interval",
)
# axes[2].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[2].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[2].set_title("Vanilla CNP", fontsize=18)
axes[2].set_ylim([-5.0, 5.0])
axes[2].set_xlim([-5.0, 5.0])
axes[2].set_xticklabels([])
axes[2].set_xticks([], [])
axes[2].set_yticklabels([])
axes[2].set_yticks([], [])
# axes[2].legend()

# Vanilla GIBNN
axes[3].plot(
    x_test,
    v_gibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in v_gibnn_preds[:-1]:
    axes[3].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[3].plot(
    x_test,
    v_gibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)

axes[3].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[3].set_title("Vanilla GIBNN", fontsize=18)
axes[3].set_ylim([-5.0, 5.0])
axes[3].set_xlim([-5.0, 5.0])
axes[3].set_xticklabels([])
axes[3].set_xticks([], [])
axes[3].set_yticklabels([])
axes[3].set_yticks([], [])
# axes[3].legend()

# NP loss AGIBNN
axes[4].plot(
    x_test,
    npagibnn_preds[-1].detach().numpy(),
    color="C0",
    alpha=0.1,
    zorder=0,
    label="Prediction samples",
)
for pred in npagibnn_preds[:-1]:
    axes[4].plot(x_test, pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
axes[4].plot(
    x_test,
    npagibnn_preds.detach().mean(0).numpy(),
    color="C0",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)
# axes[4].fill_between([x_min, 3], 5, -5, color="grey", alpha=0.2)
axes[4].scatter(x, y, color="C1", label="Datapoints", zorder=1, s=10)
axes[4].set_title("Amortised GIBNN, NP Loss", fontsize=18)
axes[4].set_ylim([-5.0, 5.0])
axes[4].set_xlim([-5.0, 5.0])
axes[4].set_xticklabels([])
axes[4].set_xticks([], [])
axes[4].set_yticklabels([])
axes[4].set_yticks([], [])
# axes[4].legend()

plt.show()