In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt
import gpytorch
import tqdm.auto as tqdm
import sys
import wbml.plot

torch.set_default_dtype(torch.float64)

sys.path.append("../")
from bnn_amort_inf.models.bnn import gibnn, mfvi_bnn
from bnn_amort_inf.models.np import cnp, convcnp, mlg_convcnp, makora_convcnp
from bnn_amort_inf.models import gp
from bnn_amort_inf import utils
from bnn_amort_inf.models.likelihoods.normal import (
    NormalLikelihood,
    HeteroscedasticNormalLikelihood,
)

# Generate metadatasets

In [None]:
num_datasets = 1_000
meta_datasets = {}
kernels = ["se", "per", "lap"]

x_min = -2.0
x_max = 2.0

for kernel in kernels:
    meta_datasets[kernel] = [
        utils.gp_datasets.gp_dataset_generator(
            kernel=kernel, x_min=x_min, x_max=x_max, min_n=10, max_n=20
        )
        for _ in range(num_datasets)
    ]
meta_datasets["saw"] = utils.dataset_utils.MetaDataset(
    [
        utils.dataset_utils.sawtooth_dataset(lower=x_min, upper=x_max)
        for _ in range(num_datasets)
    ]
)

meta_datasets["mix"] = [
    utils.gp_datasets.gp_dataset_generator(
        kernel=kernels[i % 4],
        x_min=x_min,
        x_max=x_max,
        min_n=10,
        max_n=20,
    )
    if i % 4 != 3
    else utils.dataset_utils.sawtooth_dataset(
        lower=x_min, upper=x_max, min_n=10, max_n=20
    )
    for i in range(num_datasets)
]

In [None]:
test_datasets = {}

x_min = -2.0
x_max = 2.0

for kernel in kernels:
    test_datasets[kernel] = utils.gp_datasets.gp_dataset_generator(
        min_n=10,
        max_n=20,
        kernel=kernel,
        x_min=x_min,
        x_max=x_max,
    )

test_datasets["saw"] = utils.dataset_utils.sawtooth_dataset(min_n=50, max_n=60)
test_datasets["cub"] = utils.dataset_utils.cubic_dataset()

x_test = torch.linspace(-2.25, 2.25, 500).unsqueeze(-1)

# Amortised GIBNN

In [None]:
amortised_gibnn = gibnn.AmortisedGIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    in_hidden_dims=[20, 20],
    likelihood=NormalLikelihood(noise=0.01, train_noise=True),
)

agibnn_tracker = utils.training_utils.train_metamodel(
    amortised_gibnn,
    meta_datasets["se"],
    loss_fn="loss",
    lr=1e-3,
    min_es_iters=2_000,
    smooth_es_iters=500,
    batch_size=1,
)

fig, axes = plt.subplots(
    len(agibnn_tracker.keys()),
    1,
    figsize=(8, len(agibnn_tracker.keys()) * 2),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

In [None]:
x, y = test_datasets["se"]
agibnn_preds = amortised_gibnn(x, y, x_test=x_test, num_samples=100)[-1]

fig = plt.figure(figsize=(8, 3), dpi=200)

plt.plot(
    x_test,
    agibnn_preds[-1].detach().numpy(),
    style="pred",
    alpha=0.1,
    zorder=0,
    label="Pred samples",
)
for pred in agibnn_preds[:-1]:
    plt.plot(x_test, pred.detach().numpy(), style="pred", alpha=0.1, zorder=0)

plt.plot(
    x_test,
    agibnn_preds.detach().mean(0).numpy(),
    style="pred",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)

plt.scatter(x, y, style="train", label="Data", zorder=1)
plt.ylim([-1.5, 1.5])
wbml.plot.tweak()
plt.show()

# Amortised MFVI BNN

In [None]:
amortised_mfbnn = mfvi_bnn.AmortisedMFVIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    in_hidden_dims=[20, 20],
    likelihood=NormalLikelihood(noise=0.01, train_noise=True),
)

amortised_mfbnn_tracker = utils.training_utils.train_metamodel(
    amortised_mfbnn,
    meta_datasets["mix"],
    loss_fn="loss",
    lr=1e-3,
    min_es_iters=2_000,
    smooth_es_iters=500,
    batch_size=3,
)

fig, axes = plt.subplots(
    len(amortised_mfbnn_tracker.keys()),
    1,
    figsize=(8, len(amortised_mfbnn_tracker.keys()) * 2),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, amortised_mfbnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()
    wbml.plot.tweak()

plt.show()

In [None]:
x, y = test_datasets["se"]
mfbnn_preds = amortised_mfbnn(x, y, x_test=x_test, num_samples=100)[-1]

fig = plt.figure(figsize=(8, 3), dpi=200)

plt.plot(
    x_test,
    mfbnn_preds[-1].detach().numpy(),
    style="pred",
    alpha=0.1,
    zorder=0,
    label="Pred samples",
)
for pred in mfbnn_preds[:-1]:
    plt.plot(x_test, pred.detach().numpy(), style="pred", alpha=0.1, zorder=0)

plt.plot(
    x_test,
    mfbnn_preds.detach().mean(0).numpy(),
    style="pred",
    alpha=1.0,
    ls="--",
    zorder=0,
    label="Mean prediction",
)

plt.scatter(x, y, style="train", label="Data", zorder=1)
wbml.plot.tweak()
plt.show()

# ConvCNP

In [None]:
gran = 64

In [None]:
convcnp_s = convcnp.ConvCNP(
    x_dim=1,
    y_dim=1,
    embedded_dim=16,
    likelihood=HeteroscedasticNormalLikelihood(),
    cnn_chans=[3, 16, 32, 16],
    kernel_size=5,
    granularity=gran,  # discretized points per unit
    encoder_lengthscale=10.0 / gran,
    decoder_lengthscale=1.0 / gran,
)

print("Small ConvCNP parameters: ", sum(p.numel() for p in convcnp_s.parameters()))

convcnp_xl = convcnp.ConvCNP(
    x_dim=1,
    y_dim=1,
    embedded_dim=128,
    likelihood=HeteroscedasticNormalLikelihood(),
    granularity=gran,
    kernel_size=5,
    unet=True,
    num_unet_layers=12,
    unet_starting_chans=16,  # set to 64 for ConvCNP paper version
)

print("ConvCNP XL parameters: ", sum(p.numel() for p in convcnp_xl.parameters()))

convcnp_tracker = utils.training_utils.train_metamodel(
    convcnp_s,
    meta_datasets["se"],
    loss_fn="npml_loss",
    lr=1e-3,
    max_iters=10_000,
    batch_size=5,
    min_es_iters=1_000,
    ref_es_iters=500,
    smooth_es_iters=200,
)

fig, axes = plt.subplots(
    len(convcnp_tracker.keys()),
    1,
    figsize=(8, len(convcnp_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, convcnp_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()
    wbml.plot.tweak()

plt.show()

In [None]:
x_test = torch.linspace(-2.25, 2.25, 500).unsqueeze(-1)

In [None]:
x, y = test_datasets["se"]
convcnp_preds = convcnp_s(x, y, x_t=x_test)
convcnp_pred_loc = convcnp_preds.loc.detach().numpy()
convcnp_pred_std = convcnp_preds.scale.detach().numpy()

fig = plt.figure(figsize=(8, 3), dpi=200)

plt.plot(
    x_test,
    convcnp_preds.loc.detach().numpy(),
    style="pred",
    alpha=1.0,
    zorder=1,
    label="Mean prediction",
)

convcnp_pred_upper = (convcnp_preds.loc + 1.96 * convcnp_preds.scale).detach().numpy()
convcnp_pred_lower = (convcnp_preds.loc - 1.96 * convcnp_preds.scale).detach().numpy()
plt.fill_between(
    x_test.squeeze(),
    convcnp_pred_upper.squeeze(),
    convcnp_pred_lower.squeeze(),
    style="pred",
    alpha=0.3,
    zorder=0,
    label="95% confidence interval",
)

plt.scatter(x, y, style="train", label="Data", zorder=1)
wbml.plot.tweak()
plt.show()

# MLG ConvCNP implementation

In [None]:
# mlg_convcnp_model = mlg_convcnp.ConvCNP(
#     rho=mlg_convcnp.UNet(),
#     points_per_unit=64,
# )

# mlg_convcnp_tracker = utils.training_utils.train_metamodel(
#     mlg_convcnp_model,
#     meta_datasets["se"],
#     loss_fn="npml_loss",
#     lr=1e-3,
#     max_iters=10_000,
#     batch_size=5,
#     min_es_iters=1_000,
#     ref_es_iters=500,
#     smooth_es_iters=200,
# )

# fig, axes = plt.subplots(
#     len(mlg_convcnp_tracker.keys()),
#     1,
#     figsize=(8, len(mlg_convcnp_tracker.keys()) * 4),
#     dpi=100,
#     sharex=True,
# )

# for ax, (key, vals) in zip(axes, mlg_convcnp_tracker.items()):
#     ax.plot(vals)
#     ax.set_ylabel(key)
#     ax.grid()
#     wbml.plot.tweak()

# plt.show()

In [None]:
# x, y = test_datasets["se"]
# mlg_convcnp_preds = mlg_convcnp_model(
#     x.unsqueeze(0), y.unsqueeze(0), x_out=x_test.unsqueeze(0)
# )
# mlg_convcnp_pred_loc = mlg_convcnp_preds.loc.squeeze(0).detach().numpy()
# mlg_convcnp_pred_std = mlg_convcnp_preds.scale.squeeze(0).detach().numpy()

# fig = plt.figure(figsize=(8, 3), dpi=200)

# plt.plot(
#     x_test,
#     mlg_convcnp_preds.loc.detach().numpy(),
#     style="pred",
#     alpha=1.0,
#     zorder=1,
#     label="Mean prediction",
# )

# mlg_convcnp_pred_upper = (
#     (mlg_convcnp_preds.loc + 1.96 * mlg_convcnp_preds.scale).detach().numpy()
# )
# mlg_convcnp_pred_lower = (
#     (mlg_convcnp_preds.loc - 1.96 * mlg_convcnp_preds.scale).detach().numpy()
# )
# plt.fill_between(
#     x_test.squeeze(),
#     mlg_convcnp_pred_upper.squeeze(),
#     mlg_convcnp_pred_lower.squeeze(),
#     style="pred",
#     alpha=0.3,
#     zorder=0,
#     label="95% confidence interval",
# )

# plt.scatter(x, y, style="train", label="Data", zorder=1)
# wbml.plot.tweak()
# plt.show()