In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt
import tqdm.auto as tqdm

torch.set_default_dtype(torch.float64)

import sys

sys.path.append("../")
from bnn_amort_inf.models.bnn import gibnn
from bnn_amort_inf import utils

# Dataset here

In [None]:
noise_std = torch.tensor(4.0)
dataset_size = 100

x_neg, x_pos = torch.zeros(dataset_size // 2), torch.zeros(dataset_size // 2)
x_neg, x_pos = x_neg.uniform_(-4, -2), x_pos.uniform_(2, 4)
x = torch.cat((x_neg, x_pos))

y = x**3 + noise_std * torch.normal(
    torch.zeros(dataset_size), torch.ones(dataset_size)
)

x = (x - x.mean()) / x.std()
y = (y - y.mean()) / y.std()

plt.scatter(x, y, marker="x")
plt.title("Toy Dataset")
plt.xlabel("Input Variable")
plt.ylabel("Output Variable")
plt.xlim(-2, 2)
plt.ylim(-3, 3)
plt.show()

dataset = torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1))

In [None]:
amort_model = gibnn.AmortisedGIBNN(1, 1, [20, 20], [20, 20], train_noise=True)
print(amort_model)

opt = torch.optim.Adam(amort_model.parameters(), lr=1e-2)

In [None]:
abnn_tracker = utils.training_utils.train_model(
    amort_model,
    dataset,
    batch_size=128,
    lr=1e-2,
)

In [None]:
fig, axes = plt.subplots(
    len(abnn_tracker.keys()),
    1,
    figsize=(8, len(abnn_tracker.keys()) * 4),
    dpi=100,
    sharex=True,
)

for ax, (key, vals) in zip(axes, abnn_tracker.items()):
    ax.plot(vals)
    ax.set_ylabel(key)
    ax.grid()

plt.show()

In [None]:
xs = torch.linspace(-2.5, 2.5, 100).unsqueeze(1)
prediction_samps = (
    amort_model(x.unsqueeze(1), y.unsqueeze(1), x_test=xs, num_samples=100)[-1]
    .squeeze()
    .T
)

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(
    xs,
    prediction_samps.detach().numpy()[:, :-1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
)
plt.plot(
    xs,
    prediction_samps.detach().numpy()[:, -1],
    color="blue",
    linewidth=0.5,
    alpha=0.15,
    label="Predictive Sample",
)
plt.title("Model Prediction Samples")
plt.xlabel("Input Variable")
plt.ylabel("Output Variable")
plt.scatter(x, y, marker="x", label="Training Data", color="red", linewidth=0.5)

plt.legend()
plt.xlim(-2.5, 2.5)
plt.ylim(-3.5, 3.5)
plt.show()