In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt
import sys
import tqdm.auto as tqdm

torch.set_default_dtype(torch.float64)

sys.path.append("../")
from bnn_amort_inf.models.bnn import gibnn
from bnn_amort_inf import utils
from bnn_amort_inf.models.likelihoods.normal import NormalLikelihood

### Dataset stuff here:

In [None]:
noise_std = torch.tensor(4.0)
n = 100

x_neg = torch.rand(n // 2, 1) * (-2) - 2
x_pos = torch.rand(n // 2, 1) * (-2) + 4

x = torch.cat((x_neg, x_pos), dim=0)
y = x**3 + noise_std * torch.randn_like(x)

x = (x - x.mean()) / x.std()
y = (y - y.mean()) / y.std()

dataset = torch.utils.data.TensorDataset(x, y)

In [None]:
fig = plt.figure(figsize=(8, 6), dpi=100)

plt.scatter(x, y, marker="x")

plt.title("Toy Dataset")
plt.xlabel("Input Variable")
plt.ylabel("Output Variable")
plt.xlim(-2, 2)
plt.ylim(-3, 3)
plt.grid()

plt.show()

### Do some experiments

In [None]:
num_inducing = 30
rand_perm = torch.randperm(n)[:num_inducing]
inducing_points = x[rand_perm]

gibnn_model = gibnn.GIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=[20, 20],
    num_inducing=num_inducing,
    inducing_points=inducing_points,
    likelihood=NormalLikelihood(noise=1.0, train_noise=True),
    learn_final_layer_mu=True,
    learn_final_layer_prec=True,
)

In [None]:
gibnn_tracker = utils.training_utils.train_model(
    gibnn_model,
    dataset,
    batch_size=128,
    num_samples=5,
    lr=1e-2,
    min_es_iters=1_000,
    ref_es_iters=300,
    smooth_es_iters=100,
)

In [None]:
fig, axes = plt.subplots(
    len(gibnn_tracker.keys()),
    1,
    figsize=(8, len(gibnn_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, gibnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

In [None]:
xs = torch.linspace(-2.5, 2.5, 100).unsqueeze(1)
ys_preds = gibnn_model(xs, num_samples=100)[0].detach().numpy().squeeze(-1)
ys_pred_mean = ys_preds.mean(0)
ys_preds_std = ys_preds.std(0)

In [None]:
plot_distribution = True

plt.figure(figsize=(10, 4), dpi=100)

if plot_distribution:
    plt.plot(
        xs.numpy(),
        ys_pred_mean,
        color="C0",
        # linewidth=1.0,
        label="Predictive mean",
    )

    plt.fill_between(
        xs.squeeze().numpy(),
        ys_pred_mean + 1.96 * ys_preds_std,
        ys_pred_mean - 1.96 * ys_preds_std,
        color="C0",
        alpha=0.3,
        label="95% Confidence",
    )

else:

    for ys_pred in ys_preds[:-1]:
        plt.plot(
            xs.detach().numpy(),
            ys_pred,
            color="C0",
            linewidth=1.0,
            alpha=0.1,
        )

    plt.plot(
        xs.numpy(),
        ys_preds[-1],
        color="C0",
        linewidth=1.0,
        alpha=0.1,
        label="Predictive samples",
    )

# plt.title("Model Prediction Samples")
plt.scatter(x, y, marker="2", label="Training Data", color="red", linewidth=1.0)
plt.legend(fontsize=15)
plt.xlim(-2.5, 2.5)
plt.ylim(-3.5, 3.5)
plt.grid()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.show()