In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt
import gpytorch
import tqdm.auto as tqdm
import sys

torch.set_default_dtype(torch.float64)

sys.path.append("../")
from bnn_amort_inf.models.bnn import gibnn
from bnn_amort_inf.models import gp
from bnn_amort_inf import utils

### Generate meta dataset for amortised bnn

In [None]:
num_datasets = 1000
train_datasets = []

for _ in range(num_datasets):
    train_datasets.append(utils.gp_datasets.gp_dataset_generator(noise=0.06))

meta_dataset = utils.dataset_utils.MetaDataset(train_datasets)

In [None]:
amort_model = gibnn.AmortisedGIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    in_hidden_dims=[20, 20],
    noise=1e-1,
    train_noise=True,
)

In [None]:
agibnn_tracker = utils.training_utils.train_metamodel(
    amort_model,
    meta_dataset,
)

In [None]:
fig, axes = plt.subplots(
    len(agibnn_tracker.keys()),
    1,
    figsize=(8, len(agibnn_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

### Generate test dataset

In [None]:
xs = torch.linspace(-2.5, 2.5, 100).unsqueeze(1)
x, y = utils.gp_datasets.gp_dataset_generator(
    x_min=-2.0, x_max=2.0, min_n=20, max_n=30, noise=0.06
)
prediction_samps = amort_model(x, y, x_test=xs, num_samples=100)[-1].squeeze().T

dataset = torch.utils.data.TensorDataset(x, y)

In [None]:
dataset_size = len(x)
num_induce = dataset_size // 2
rand_perm = torch.randperm(dataset_size)[:num_induce]
inducing_points = x[rand_perm]

gip_model = gibnn.GIBNN(
    1,
    1,
    [20, 20],
    num_induce,
    inducing_points,
    train_noise=True,
)

gibnn_tracker = utils.training_utils.train_model(
    gip_model,
    dataset,
    batch_size=128,
    lr=1e-2,
)

fig, axes = plt.subplots(
    len(gibnn_tracker.keys()),
    1,
    figsize=(8, len(gibnn_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, gibnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

In [None]:
gip_prediction_samps = gip_model(xs, num_samples=100)[0].squeeze(-1).T

In [None]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
gp_model = gp.GPModel(x.squeeze(), y.squeeze(), likelihood)

gp_model.train()
likelihood.train()
opt = torch.optim.Adam(gp_model.parameters(), lr=1e-1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp_model)
mll_evo = []

epoch_iter = tqdm.tqdm(range(200), "Epoch")
for epoch in epoch_iter:
    opt.zero_grad()
    output = gp_model(x.squeeze())
    loss = -mll(output, y.squeeze()).sum()
    mll_evo.append(loss.item())
    loss.backward()
    opt.step()
    epoch_iter.set_postfix({"loss": loss})

plt.plot(mll_evo)
plt.ylabel("loss")
plt.xlabel("epoch")
plt.title("Marginal Log Likelihood")
plt.show()

In [None]:
gp_model.eval()
gp_prediction_samps = (
    gp_model(xs)
    .sample(
        torch.Size(
            [
                100,
            ]
        )
    )
    .T
)

In [None]:
fig, axs = plt.subplots(1, 3)
fig.set_size_inches(18, 5)

axs[0].plot(
    xs,
    prediction_samps.detach().numpy()[:, :-1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
)
axs[0].plot(
    xs,
    prediction_samps.detach().numpy()[:, -1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
    label="Predictive Sample",
)
axs[0].set_title("Amortised Model Prediction Samples")
axs[0].set_xlabel("Input Variable")
axs[0].set_ylabel("Output Variable")
axs[0].scatter(
    x,
    y,
    marker="x",
    label="Training Data",
    color="red",
    linewidth=0.5,
)

axs[0].legend()
axs[0].set_ylim(-4.0, 4.0)
axs[0].set_xlim(-2.5, 2.5)

axs[1].plot(
    xs,
    gip_prediction_samps.detach().numpy()[:, :-1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
)
axs[1].plot(
    xs,
    gip_prediction_samps.detach().numpy()[:, -1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
    label="Predictive Sample",
)
axs[1].set_title("Global Inducing Point Model Prediction Samples")
axs[1].set_xlabel("Input Variable")
axs[1].set_ylabel("Output Variable")
axs[1].scatter(
    x,
    y,
    marker="x",
    label="Training Data",
    color="red",
    linewidth=0.5,
)

axs[1].legend()
axs[1].set_ylim(-4.0, 4.0)
axs[1].set_xlim(-2.5, 2.5)

axs[2].plot(
    xs,
    gp_prediction_samps.detach().numpy()[:, :-1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
)
axs[2].plot(
    xs,
    gp_prediction_samps.detach().numpy()[:, -1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
    label="Predictive Sample",
)
# axs[2].plot(xs, gp_prediction_samps.detach().numpy())
axs[2].set_title("Gaussian Process Model Prediction Samples")
axs[2].set_xlabel("Input Variable")
axs[2].set_ylabel("Output Variable")
axs[2].scatter(
    x,
    y,
    marker="x",
    label="Training Data",
    color="red",
    linewidth=0.5,
)

axs[2].legend()
axs[2].set_ylim(-4.0, 4.0)
axs[2].set_xlim(-2.5, 2.5)

plt.show()