In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

from collections import defaultdict

import torch
import matplotlib.pyplot as plt
import numpy as np

from tqdm.auto import tqdm

sys.path.append("../")
from bnn_amort_inf.models.bnn import mfvi_bnn
from bnn_amort_inf import utils

# Generate GP datasets for training.

In [None]:
num_datasets = 1000
train_datasets = []

for _ in range(num_datasets):
    train_datasets.append(utils.gp_datasets.gp_dataset_generator(min_n=10, max_n=20))

meta_dataset = utils.dataset_utils.MetaDataset(train_datasets)

# Define and train model

In [None]:
amortised_mfvibnn = mfvi_bnn.AmortisedMFVIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    in_hidden_dims=[20, 20],
    noise=1e-2,
    train_noise=True,
)

agibnn_tracker = utils.training_utils.train_metamodel(
    amortised_mfvibnn,
    meta_dataset,
    lr=1e-3,
    num_samples=1,
    min_es_iters=3_000,
    es_thresh=1e-3,
    smooth_es_iters=200,
    batch_size=10,
)

# Plot metrics during training

In [None]:
fig, axes = plt.subplots(
    len(agibnn_tracker.keys()),
    1,
    figsize=(8, len(agibnn_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

# Generate test datasets

In [None]:
num_test_datasets = 5
test_datasets = []
for _ in range(num_test_datasets):
    test_datasets.append(utils.gp_datasets.gp_dataset_generator(min_n=5, max_n=10))

xs = torch.linspace(-4, 4, 200).unsqueeze(-1)

# Predictions

### Training data

In [None]:
num_plots = min(len(train_datasets), 4)  # limit to 4 plots
fig, axes = plt.subplots(num_plots, 1, figsize=(8, 4 * num_plots), sharex=True)

for ax, (x, y) in zip(axes, train_datasets[:num_plots]):

    ys_preds = amortised_mfvibnn(x, y, x_test=xs, num_samples=100)[-1]
    for ys_pred in ys_preds[:-1]:
        ax.plot(xs, ys_pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
    ax.plot(
        xs,
        ys_preds[-1].detach().numpy(),
        color="C0",
        alpha=0.1,
        zorder=0,
        label="Prediction samples",
    )

    ax.plot(
        xs,
        ys_preds.detach().mean(0).numpy(),
        color="C0",
        alpha=1.0,
        ls="--",
        zorder=0,
        label="Mean prediction",
    )

    ax.scatter(x, y, color="C1", marker="x", label="Datapoints", zorder=1)

    ax.grid()
    ax.legend()
    ax.set_xlim([-4.0, 4.0])
    ax.set_ylim([-5.0, 5.0])

plt.show()

### Test datasets

In [None]:
fig, axes = plt.subplots(
    len(test_datasets), 1, figsize=(8, 4 * len(test_datasets)), sharex=True
)

for ax, (x, y) in zip(axes, test_datasets):

    ys_preds = amortised_mfvibnn(x, y, x_test=xs, num_samples=100)[-1]
    for ys_pred in ys_preds[:-1]:
        ax.plot(xs, ys_pred.detach().numpy(), color="C0", alpha=0.1, zorder=0)
    ax.plot(
        xs,
        ys_preds[-1].detach().numpy(),
        color="C0",
        alpha=0.1,
        zorder=0,
        label="Prediction samples",
    )

    ax.plot(
        xs,
        ys_preds.detach().mean(0).numpy(),
        color="C0",
        alpha=1.0,
        ls="--",
        zorder=0,
        label="Mean prediction",
    )

    ax.scatter(x, y, color="C1", label="Datapoints", zorder=1)

    ax.grid()
    ax.legend()
    ax.set_xlim([-4.0, 4.0])
    ax.set_ylim([-5.0, 5.0])

plt.show()