# Effect of Amortisation on Aquired ELBO 

In [None]:
import sys

sys.path.append("../")

# import argparse
# import json
import torch
import matplotlib.pyplot as plt
import wbml.plot
import wbml.experiment

import bnn_amort_inf

torch.set_default_dtype(torch.float64)

In [None]:
# params to be changed in experiment
num_datasets = 1
seed = 0

# data params
kernel = "se"
x_min = -2.0
x_max = 2.0
n_min_train = 10
n_max_train = 50
n_min_test = 10
n_max_test = 50
num_test_datasets = 5

# model params
hidden_dims = [32, 32]
in_hidden_dims = [50, 50]
min_num_inducing = 20

# training params
loss_fn = "loss"
amort_lr = 1e-3
reg_lr = 5e-3
max_iters = 15_000
batch_size = 5
min_es_iters = 1000
ref_es_iters = 500
smooth_es_iters = 500

include_gibnn = True

### Generate Test Datasets (fixed seed)

In [None]:
torch.manual_seed(0)

test_datasets = bnn_amort_inf.utils.dataset_utils.gen_datasets(
    num_datasets=num_test_datasets,
    kernel=kernel,
    x_min=x_min,
    x_max=x_max,
    n_min=n_min_test,
    n_max=n_max_test,
)

### Generate Training Datasets (variable seed)

In [None]:
torch.manual_seed(seed)

train_metadataset = bnn_amort_inf.utils.dataset_utils.gen_datasets(
    num_datasets=num_datasets,
    kernel=kernel,
    x_min=x_min,
    x_max=x_max,
    n_min=n_min_train,
    n_max=n_max_train,
)

### Construct Metamodel

In [None]:
agibnn = bnn_amort_inf.models.bnn.gibnn.AmortisedGIBNN(
    x_dim=1,
    y_dim=1,
    hidden_dims=hidden_dims,
    in_hidden_dims=in_hidden_dims,
    likelihood=bnn_amort_inf.models.likelihoods.normal.NormalLikelihood(
        noise=0.05, train_noise=False
    ),
)

### Train Metamodel on Metadataset

In [None]:
if num_datasets > 0:
    agibnn_tracker = bnn_amort_inf.utils.training_utils.train_metamodel(
        agibnn,
        dataset=train_metadataset,
        loss_fn=loss_fn,
        lr=amort_lr,
        max_iters=max_iters,
        batch_size=batch_size,
        min_es_iters=min_es_iters,
        ref_es_iters=ref_es_iters,
        smooth_es_iters=smooth_es_iters,
    )

    for k, v in agibnn_tracker.items():
        # Plot results and save.
        plt.figure(figsize=(4, 1.5), dpi=200)

        plt.plot(v)
        plt.ylabel(k)
        wbml.plot.tweak()
        plt.show()

### Evaluate on Test Datasets

In [None]:
# Evaluate performance on test datasets.
x_test = torch.linspace(x_min, x_max, 1000).unsqueeze(-1)
for i, dataset in enumerate(test_datasets):
    x, y = dataset

    # Train metamodel on just test datasets if num_datasets == 0
    if num_datasets == 0:
        agibnn = bnn_amort_inf.models.bnn.gibnn.AmortisedGIBNN(
            x_dim=1,
            y_dim=1,
            hidden_dims=hidden_dims,
            in_hidden_dims=in_hidden_dims,
            likelihood=bnn_amort_inf.models.likelihoods.normal.NormalLikelihood(
                noise=0.05, train_noise=False
            ),
        )

        agibnn_tracker = bnn_amort_inf.utils.training_utils.train_model(
            agibnn,
            dataset=torch.utils.data.TensorDataset(x, y),
            batch_size=x.shape[0],
            lr=amort_lr,
            max_iters=max_iters,
            min_es_iters=min_es_iters,
            ref_es_iters=ref_es_iters,
            smooth_es_iters=smooth_es_iters,
        )

        for k, v in agibnn_tracker.items():
            # Plot results and save.
            plt.figure(figsize=(4, 1.5), dpi=200)

            plt.plot(v)
            plt.ylabel(k)
            wbml.plot.tweak()
            plt.show()

    # Evaluate amortised BNN on test datasets
    with torch.no_grad():
        agibnn_pred_samples = agibnn(x, y, x_test=x_test, num_samples=100)[-1]

        agibnn_loss, _ = agibnn.loss(x, y, num_samples=100)
        agibnn_elbo = (-agibnn_loss) * x.shape[0]
        agibnn_metrics = {"elbo": agibnn_elbo.item()}

    agibnn_pred_loc = agibnn_pred_samples.mean(0)
    agibnn_pred_std = agibnn_pred_samples.std(0)

    # Train regular BNN on test datasets
    if include_gibnn:
        inducing_points = x.clone()

        if min_num_inducing > x.shape[0]:
            inducing_points = torch.cat(
                [
                    inducing_points,
                    torch.linspace(
                        x_min, x_max, min_num_inducing - x.shape[0]
                    ).unsqueeze(-1),
                ],
                dim=0,
            )

        gibnn = bnn_amort_inf.models.bnn.gibnn.GIBNN(
            x_dim=1,
            y_dim=1,
            hidden_dims=hidden_dims,
            num_inducing=len(inducing_points),
            inducing_points=inducing_points,
            likelihood=bnn_amort_inf.models.likelihoods.normal.NormalLikelihood(
                noise=0.05, train_noise=False
            ),
            final_layer_prec=1e0,
        )

        gibnn_tracker = bnn_amort_inf.utils.training_utils.train_model(
            gibnn,
            dataset=torch.utils.data.TensorDataset(x, y),
            batch_size=x.shape[0],
            lr=reg_lr,
            max_iters=max_iters,
            min_es_iters=min_es_iters,
            ref_es_iters=ref_es_iters,
            smooth_es_iters=smooth_es_iters,
        )

        for k, v in gibnn_tracker.items():
            # Plot results and save.
            plt.figure(figsize=(4, 1.5), dpi=200)

            plt.plot(v)
            plt.ylabel(k)
            wbml.plot.tweak()
            plt.show()

        with torch.no_grad():
            gibnn_loss, _ = gibnn.loss(x, y, num_samples=100)
            gibnn_elbo = (-gibnn_loss) * x.shape[0]
            gibnn_metrics = {"elbo": gibnn_elbo.item()}

            gibnn_pred_samples = gibnn(x_test, num_samples=100)[0]

        gibnn_pred_loc = gibnn_pred_samples.mean(0)
        gibnn_pred_std = gibnn_pred_samples.std(0)
    else:
        gibnn_metrics = None
        gibnn_pred_loc = None
        gibnn_pred_std = None

    print(
        "kernel: ",
        kernel,
        ", test dataset: {}".format(i),
        "gibnn metrics: ",
        gibnn_metrics,
        "agibnn_metrics: ",
        agibnn_metrics,
    )

    # Plot results and save.
    plt.figure(figsize=(4, 1.5), dpi=200)
    plt.scatter(x, y, style="train", zorder=2, s=10)

    if agibnn_pred_loc is not None:
        plt.plot(x_test, agibnn_pred_loc, style="pred", ls="-", lw=2, zorder=1)

    if agibnn_pred_std is not None:
        pred_upper = agibnn_pred_loc + 1.96 * agibnn_pred_std
        pred_lower = agibnn_pred_loc - 1.96 * agibnn_pred_std
        plt.fill_between(
            x_test.squeeze(),
            pred_upper.squeeze(),
            pred_lower.squeeze(),
            style="pred",
            alpha=0.2,
            zorder=1,
        )

    if gibnn_pred_loc is not None:
        plt.plot(x_test, gibnn_pred_loc, style="pred", ls="-", lw=2, zorder=1)

    if gibnn_pred_std is not None:
        pred_upper = gibnn_pred_loc + 1.96 * gibnn_pred_std
        pred_lower = gibnn_pred_loc - 1.96 * gibnn_pred_std
        plt.fill_between(
            x_test.squeeze(),
            pred_upper.squeeze(),
            pred_lower.squeeze(),
            style="pred",
            alpha=0.2,
            zorder=1,
        )

    # if pred_samples is not None:
    #     for sample in pred_samples:
    #         plt.plot(x_test, sample, style="pred", ls="-", alpha=0.1, zorder=0)

    plt.ylim([-2.5, 2.5])
    wbml.plot.tweak()
    plt.show()