In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

from collections import defaultdict

import torch
import matplotlib.pyplot as plt
import numpy as np

from tqdm.auto import tqdm

sys.path.append("../")
import bnn_amort_inf
from bnn_amort_inf import models, utils

torch.set_default_dtype(torch.float64)

# Generate GP datasets for training.

In [None]:
num_datasets = 1000
train_datasets = []

for _ in range(num_datasets):
    train_datasets.append(utils.gp_datasets.gp_dataset_generator())

meta_dataset = utils.dataset_utils.MetaDataset(train_datasets)

# Define training loop for all models.

## Define and train the amortised GIBNN

In [None]:
amortised_gibnn = models.gibnn.amortised_gibnn.AmortisedGIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    in_hidden_dims=[20, 20],
    noise=1e-1,
    train_noise=True,
)

agibnn_tracker = utils.training_utils.train_metamodel(
    amortised_gibnn,
    meta_dataset,
)

## Plot metrics throughout training

In [None]:
fig, axes = plt.subplots(
    len(agibnn_tracker.keys()),
    1,
    figsize=(8, len(agibnn_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

# Define and train CNP

In [None]:
cnp = models.np.CNP(
    x_dim=1,
    y_dim=1,
    embedded_dim=64,
    encoder_hidden_dims=[32, 32],
    decoder_hidden_dims=[32, 32],
)
#     noise=1e-1,
#     train_noise=False,
#     decoder_activation=torch.nn.Identity()
# )

# above is to see how it behaves when constrained to tiny noise

cnp_tracker = utils.training_utils.train_metamodel(
    cnp,
    meta_dataset,
    neural_process=True,
    lr=1e-3,
    max_iters=50_000,
    min_es_iters=10_000,
    ref_es_iters=1_000,
    smooth_es_iters=500,
)

In [None]:
fig, axes = plt.subplots(
    len(cnp_tracker.keys()),
    1,
    figsize=(8, len(cnp_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, cnp_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

### Generate predictions for each model

In [None]:
xs = torch.linspace(-4, 4, 200).unsqueeze(-1)

num_test_datasets = 5
test_datasets = []
for _ in range(num_test_datasets):
    test_datasets.append(utils.gp_datasets.gp_dataset_generator(min_n=5, max_n=10))

## Amortised GIBNN predictions.

#### Training data

In [None]:
num_plots = min(len(train_datasets), 4)  # limit to 4 plots
fig, axes = plt.subplots(num_plots, 1, figsize=(8, 4 * num_plots), sharex=True)

for ax, (x, y) in zip(axes, train_datasets[:num_plots]):

    ys_preds = amortised_gibnn(x, y, x_test=xs, num_samples=100)[-1]
    for ys_pred in ys_preds[:-1]:
        ax.plot(xs, ys_pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
    ax.plot(
        xs,
        ys_preds[-1].detach().numpy(),
        color="C0",
        alpha=0.1,
        zorder=0,
        label="Prediction samples",
    )

    ax.plot(
        xs,
        ys_preds.detach().mean(0).numpy(),
        color="C0",
        alpha=1.0,
        ls="--",
        zorder=0,
        label="Mean prediction",
    )

    ax.scatter(x, y, color="C1", marker="x", label="Datapoints", zorder=1)

    ax.grid()
    ax.legend()
    ax.set_xlim([-4.0, 4.0])
    ax.set_ylim([-5.0, 5.0])

plt.show()

#### Test datasets

In [None]:
fig, axes = plt.subplots(
    len(test_datasets), 1, figsize=(8, 4 * len(test_datasets)), sharex=True
)

for ax, (x, y) in zip(axes, test_datasets):

    ys_preds = amortised_gibnn(x, y, x_test=xs, num_samples=100)[-1]
    for ys_pred in ys_preds[:-1]:
        ax.plot(xs, ys_pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
    ax.plot(
        xs,
        ys_preds[-1].detach().numpy(),
        color="C0",
        alpha=0.1,
        zorder=0,
        label="Prediction samples",
    )

    ax.plot(
        xs,
        ys_preds.detach().mean(0).numpy(),
        color="C0",
        alpha=1.0,
        ls="--",
        zorder=0,
        label="Mean prediction",
    )

    ax.scatter(x, y, color="C1", label="Datapoints", zorder=1)

    ax.grid()
    ax.legend()
    ax.set_xlim([-4.0, 4.0])
    ax.set_ylim([-5.0, 5.0])

plt.show()

## CNP Predictions

#### Train datasets

In [None]:
num_plots = min(len(train_datasets), 4)  # limit to 4 plots
fig, axes = plt.subplots(num_plots, 1, figsize=(8, 4 * num_plots), sharex=True)

for ax, (x, y) in zip(axes, train_datasets[:num_plots]):

    ys = cnp(x, y, x_t=xs)
    ys_pred = ys.loc.detach().numpy()
    ax.plot(
        xs,
        ys_pred,
        color="C0",
        alpha=1.0,
        ls="--",
        label="Mean prediction",
        zorder=0,
    )
    ys_std = ys.scale.detach().numpy()
    ax.fill_between(
        xs.squeeze(),
        (ys_pred + 2 * ys_std).squeeze(),
        (ys_pred - 2 * ys_std).squeeze(),
        alpha=0.3,
        label="95% confidence interval",
    )
    ax.scatter(x, y, color="C1", label="Datapoints", zorder=1)

    ax.grid()
    ax.legend()
    ax.set_xlim([-4.0, 4.0])
    ax.set_ylim([-5.0, 5.0])

plt.show()

#### Test datasets

In [None]:
fig, axes = plt.subplots(
    len(test_datasets), 1, figsize=(8, 4 * len(test_datasets)), sharex=True
)

for ax, (x, y) in zip(axes, test_datasets):

    ys = cnp(x, y, x_t=xs)
    ys_pred = ys.loc.detach().numpy()
    ax.plot(
        xs,
        ys_pred,
        color="C0",
        alpha=1.0,
        ls="--",
        label="Mean prediction",
        zorder=0,
    )
    ys_std = ys.scale.detach().numpy()
    ax.fill_between(
        xs.squeeze(),
        (ys_pred + 2 * ys_std).squeeze(),
        (ys_pred - 2 * ys_std).squeeze(),
        alpha=0.3,
        label="95% confidence interval",
    )
    ax.scatter(x, y, color="C1", label="Datapoints", zorder=1)

    ax.grid()
    ax.legend()
    ax.set_xlim([-4.0, 4.0])
    ax.set_ylim([-5.0, 5.0])

plt.show()