In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt
import tqdm.auto as tqdm

torch.set_default_dtype(torch.float64)

import sys

sys.path.append("../")
from bnn_amort_inf.models.gibnn import amortised_gibnn

# Dataset here

In [None]:
noise_std = torch.tensor(4.0)
dataset_size = 100

x_neg, x_pos = torch.zeros(dataset_size // 2), torch.zeros(dataset_size // 2)
x_neg, x_pos = x_neg.uniform_(-4, -2), x_pos.uniform_(2, 4)
x = torch.cat((x_neg, x_pos))

y = x**3 + noise_std * torch.normal(
    torch.zeros(dataset_size), torch.ones(dataset_size)
)

x = (x - x.mean()) / x.std()
y = (y - y.mean()) / y.std()

plt.scatter(x, y, marker="x")
plt.title("Toy Dataset")
plt.xlabel("Input Variable")
plt.ylabel("Output Variable")
plt.xlim(-2, 2)
plt.ylim(-3, 3)
plt.show()

In [None]:
amort_model = amortised_gibnn.AmortisedGIBNN(1, 1, [20, 20], [20, 20], train_noise=True)
print(amort_model)

opt = torch.optim.Adam(amort_model.parameters(), lr=1e-2)

In [None]:
loss_evo = []
ll_evo = []
kl_evo = []

epoch_iter = tqdm.tqdm(range(1000), "Epoch")
for epoch in epoch_iter:
    opt.zero_grad()

    loss, metrics = amort_model.loss(x.unsqueeze(1), y.unsqueeze(1), num_samples=1)
    loss_evo.append(loss.item())
    ll_evo.append(metrics["exp_ll"])
    kl_evo.append(metrics["kl"])

    loss.backward()
    opt.step()

    epoch_iter.set_postfix({"loss": loss, "ll": metrics["exp_ll"], "kl": metrics["kl"]})

In [None]:
plt.plot(loss_evo)
plt.ylabel("ELBO loss")
plt.xlabel("epoch")
plt.title("Loss Evolution")
plt.show()

plt.plot(ll_evo)
plt.ylabel("expected log likelihood")
plt.xlabel("epoch")
plt.title("Expected Log Likelihood Evolution")
plt.show()

plt.plot(kl_evo)
plt.ylabel("kl")
plt.xlabel("epoch")
plt.title("KL Evolution")
plt.show()

In [None]:
xs = torch.linspace(-2.5, 2.5, 100).unsqueeze(1)
prediction_samps = (
    amort_model(x.unsqueeze(1), y.unsqueeze(1), x_test=xs, num_samples=100)[-1]
    .squeeze()
    .T
)

In [None]:
print(prediction_samps.shape)

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(
    xs,
    prediction_samps.detach().numpy()[:, :-1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
)
plt.plot(
    xs,
    prediction_samps.detach().numpy()[:, -1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
    label="Predictive Sample",
)
plt.title("Model Prediction Samples")
plt.xlabel("Input Variable")
plt.ylabel("Output Variable")
plt.scatter(x, y, marker="x", label="Training Data", color="red", linewidth=0.5)

plt.legend()
plt.xlim(-2.5, 2.5)
plt.ylim(-3.5, 3.5)
plt.show()