In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt
import sys
import tqdm.auto as tqdm

torch.set_default_dtype(torch.float64)

sys.path.append("../")
from bnn_amort_inf.models.bnn import gibnn
from bnn_amort_inf import utils

### Dataset stuff here:

In [None]:
noise_std = torch.tensor(4.0)
n = 100

x_neg = torch.rand(n // 2, 1) * (-2) - 2
x_pos = torch.rand(n // 2, 1) * (-2) + 4

x = torch.cat((x_neg, x_pos), dim=0)
y = x**3 + noise_std * torch.randn_like(x)

x = (x - x.mean()) / x.std()
y = (y - y.mean()) / y.std()

dataset = torch.utils.data.TensorDataset(x, y)

In [None]:
fig = plt.figure(figsize=(8, 6), dpi=100)

plt.scatter(x, y, marker="x")

plt.title("Toy Dataset")
plt.xlabel("Input Variable")
plt.ylabel("Output Variable")
plt.xlim(-2, 2)
plt.ylim(-3, 3)
plt.grid()

plt.show()

### Do some experiments

In [None]:
num_inducing = 10
rand_perm = torch.randperm(n)[:num_inducing]
inducing_points = x[rand_perm]

gibnn_model = gibnn.GIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    num_inducing=num_inducing,
    inducing_points=inducing_points,
    train_noise=True,
)

In [None]:
gibnn_tracker = utils.training_utils.train_model(
    gibnn_model,
    dataset,
    batch_size=128,
    lr=1e-2,
)

In [None]:
fig, axes = plt.subplots(
    len(gibnn_tracker.keys()),
    1,
    figsize=(8, len(gibnn_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, gibnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

In [None]:
xs = torch.linspace(-2.5, 2.5, 100).unsqueeze(1)
ys_preds = gibnn_model(xs, num_samples=100)[0]

In [None]:
plt.figure(figsize=(8, 6), dpi=100)

for ys_pred in ys_preds[:-1]:
    plt.plot(
        xs.detach().numpy(),
        ys_pred.detach().numpy().squeeze(-1),
        color="C0",
        linewidth=1.0,
        alpha=0.1,
    )

plt.plot(
    xs.detach().numpy(),
    ys_preds[-1].detach().numpy().squeeze(-1),
    color="C0",
    linewidth=1.0,
    alpha=0.1,
    label="predictive samples",
)

plt.title("Model Prediction Samples")
plt.scatter(x, y, marker="x", label="Training Data", color="red", linewidth=0.5)
plt.legend()
plt.xlim(-2.5, 2.5)
plt.ylim(-3.5, 3.5)
plt.grid()

plt.show()