# Experiment 2
### In this experiment we compare different metamodels on an image completion task framed as a 2D regression task 
We compare the amortised GI BNN with a convolutional conditional neural process and an amortised MFVI BNN as two baseline models. The models will be evaluated on datasets which will be treated as metadatasets in which each image will be subsampled and treated as a 2D regression dataset. The following datasets will be considered:
- MNIST
- CelebA, including an out-of-distribution qualitative evaluation on the Ellen Oscars Selfie
- CIFAR10

For each test case, a linear interpolator will be used as a lower benchmark


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import wbml.experiment
import wbml.plot
from tqdm import tqdm
import sys

torch.set_default_dtype(torch.float64)

sys.path.append("../")
import bnn_amort_inf
from bnn_amort_inf.models.likelihoods.normal import (
    NormalLikelihood,
    HeteroscedasticNormalLikelihood,
)
from bnn_amort_inf.models.likelihoods.bernoulli import BernoulliLikelihood
from bnn_amort_inf.models.bnn import gibnn, mfvi_bnn
from bnn_amort_inf.models.np.convcnp import GridConvCNP
from bnn_amort_inf import utils

torch.set_default_dtype(torch.float64)

In [None]:
wd = wbml.experiment.WorkingDirectory("./experiment_two", log=None, override=False)

### Handle Data and Generate Metadatasets

In [None]:
num_datasets = 9000  # note cannot be greater than 9000
meta_datasets = {}
ratio = 0.25  # proportion of context pixels in training images

mnist = datasets.MNIST(
    root=wd.root,
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)
mnist_iter = iter(torch.utils.data.DataLoader(mnist, shuffle=True))

m_d = []
for _ in range(num_datasets):
    img = next(mnist_iter)[0].squeeze(0)

    # Convert to binary.
    bin_img = img.round()
    mask = utils.dataset_utils.random_mask(ratio, img)[1]
    m_d.append((img, bin_img, mask))

meta_datasets["mnist"] = utils.dataset_utils.MetaDataset(m_d)

### Define and Train Models

In [None]:
plot_training_metrics = True

In [None]:
convcnp_lik = BernoulliLikelihood()

convcnp = GridConvCNP(
    x_dim=2,
    y_dim=1,
    embedded_dim=256,
    cnn_chans=[128, 128, 128],
    conv_kernel_size=5,
    cnn_kernel_size=3,
    res=True,
    likelihood=convcnp_lik,
)

convcnp_tracker = utils.training_utils.train_metamodel(
    convcnp,
    meta_datasets["mnist"],
    image=True,
    gridconv=True,
    lr=5e-4,
    max_iters=5_000,
    batch_size=12,
    min_es_iters=1000,
    ref_es_iters=400,
    smooth_es_iters=100,
    es=True,
    # man_thresh=(
    #     "ll",
    #     2000.0,
    # ),  # manual threshold that only stops training loop if ll>man_thresh... very hacky but fine for now
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(convcnp_tracker.keys()),
        1,
        figsize=(8, len(convcnp_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, convcnp_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

In [None]:
agibnn_lik = BernoulliLikelihood()

agibnn = gibnn.AmortisedGIBNN(
    x_dim=2,
    y_dim=1,
    hidden_dims=[64, 64],
    in_hidden_dims=[100, 100],
    likelihood=agibnn_lik,
)

agibnn_tracker = utils.training_utils.train_metamodel(
    agibnn,
    meta_datasets["mnist"],
    image=True,
    loss_fn="npml_loss",  # options are "npml_loss", "npvi_loss", "loss"
    lr=1e-3,
    max_iters=10_000,
    batch_size=3,
    min_es_iters=500,
    ref_es_iters=200,
    smooth_es_iters=100,
    es=True,
    # man_thresh=("elbo", -160),
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(agibnn_tracker.keys()),
        1,
        figsize=(8, len(agibnn_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

### Generate Test Datasets

In [None]:
num_test = 5
test_datasets = []
for _ in range(num_test):
    img = next(mnist_iter)[0].squeeze(0)
    bin_img = img.round()
    mask = utils.dataset_utils.random_mask(0.25, img)[1]
    test_datasets.append((img, bin_img, mask))

In [None]:
plot_std = True
round_preds = False
round_orig_img = False

for dataset in test_datasets:
    img, bin_img, mask = dataset
    x_c, y_c, x_t, y_t = utils.dataset_utils.img_for_reg(bin_img, mask)
    ctx_img = utils.dataset_utils.vis_ctxt_img(mask, bin_img)
    convcnp_preds = convcnp(bin_img, mask)
    if isinstance(convcnp_lik, BernoulliLikelihood):
        a = convcnp_preds.sample(torch.Size([2]))
        convcnp_pred_samps = convcnp_preds.sample(
            sample_shape=torch.Size([1000])
        ).detach()
        convcnp_pred_img = convcnp_pred_samps.mean(0).reshape((28, 28, 1))
        convcnp_pred_std = convcnp_pred_samps.std(0).reshape((28, 28, 1))
    else:
        convcnp_pred_img = convcnp_preds.loc.detach().reshape((28, 28, 1)).numpy()
        convcnp_pred_std = convcnp_preds.scale.detach().reshape((28, 28, 1)).numpy()

    x_test = utils.dataset_utils.test_grid(img.shape[-2:])

    preds = agibnn(x_c, y_c, x_test=x_test, num_samples=10)[-1]
    pred_dists = agibnn.likelihood(preds)
    if isinstance(agibnn_lik, BernoulliLikelihood):
        pred_probs = pred_dists.probs.detach()
        pred_img = pred_probs.mean(0).reshape((28, 28, 1))
        pred_std = pred_probs.std(0).reshape((28, 28, 1))
    else:
        pred_probs = pred_dists.mean.detach()
        pred_img = pred_probs.mean(0).reshape((28, 28, 1))
        pred_std = pred_probs.std(0).reshape((28, 28, 1))

    if round_preds:
        pred_img = pred_img.round()
        convcnp_pred_img = convcnp_pred_img.round()
    if round_orig_img:
        img = img.round()

    num_plots = 2 + 2 * (1 + int(plot_std))
    fig, axes = plt.subplots(1, num_plots)
    axes[0].imshow(img.permute(1, 2, 0).numpy(), cmap="gray")
    axes[0].axis(False)
    axes[1].imshow(ctx_img)
    axes[1].axis(False)
    axes[2].imshow(pred_img, cmap="gray")
    axes[2].axis(False)
    if plot_std:
        axes[3].imshow(pred_std, cmap="viridis")
        axes[3].axis(False)
        axes[4].imshow(convcnp_pred_img, cmap="gray")
        axes[4].axis(False)
        axes[5].imshow(convcnp_pred_std, cmap="viridis")
        axes[5].axis(False)
    else:
        axes[3].imshow(convcnp_pred_img, cmap="gray")
        axes[3].axis(False)

    plt.show()

# ConvCNP

In [None]:
# from functools import partial

# import sys

# sys.path.append("../../Neural-Process-Family")

# from npf import ConvCNP, GridConvCNP, CNPFLoss
# from npf.architectures import CNN, MLP, ResConvBlock, SetConv, discard_ith_arg
# from npf.utils.helpers import CircularPad2d, make_abs_conv, make_padded_conv
# from utils.ntbks_helpers import get_img_datasets
# from utils.helpers import count_parameters
# from utils.data import cntxt_trgt_collate, get_test_upscale_factor
# from npf.utils.datasplit import GridCntxtTrgtGetter, RandomMasker, no_masker

# img_datasets, img_test_datasets = get_img_datasets(["mnist"])

# # Random masker, masking between 0% and 30% of input points.
# get_cntxt_trgt_2d = cntxt_trgt_collate(
#     GridCntxtTrgtGetter(
#         context_masker=RandomMasker(a=0.0, b=0.3),
#         target_masker=no_masker,
#     ),
#     is_return_masks=True,  # will be using grid conv CNP => can work directly with mask
# )

In [None]:
# r_dim = 32
# model_kwargs = dict(
#     r_dim=r_dim,
#     Decoder=discard_ith_arg(  # disregards the target features to be translation equivariant
#         partial(MLP, n_hidden_layers=4, hidden_size=r_dim), i=0
#     ),
# )


# cnn_kwargs = dict(
#     ConvBlock=ResConvBlock,
#     is_chan_last=True,  # all computations are done with channel last in our code
#     n_conv_layers=2,  # layers per block
# )


# # on the grid
# model_2d = partial(
#     GridConvCNP,
#     x_dim=1,  # for gridded conv it's the mask shape
#     CNN=partial(
#         CNN,
#         Conv=torch.nn.Conv2d,
#         Normalization=torch.nn.BatchNorm2d,
#         n_blocks=3,
#         kernel_size=5,
#         **cnn_kwargs,
#     ),
#     y_dim=img_datasets["mnist"].shape[
#         0
#     ],  # seems to just be the number of output channels
#     **model_kwargs,
# )

# n_params_2d = count_parameters(model_2d())
# print(f"Number Parameters (2D): {n_params_2d:,d}")

# Train ConvCNP

In [None]:
# import skorch
# from npf import CNPFLoss
# from utils.ntbks_helpers import add_y_dim, get_img_datasets
# from utils.train import train_models
# from utils.data import cntxt_trgt_collate, get_test_upscale_factor
# from npf.utils.datasplit import GridCntxtTrgtGetter, RandomMasker, no_masker

# img_datasets, img_test_datasets = get_img_datasets(["mnist"])

# train_kwargs = dict(
#     criterion=CNPFLoss,
#     is_retrain=True,
#     device="cpu",
#     lr=1e-3,
#     decay_lr=10,
#     seed=123,
#     batch_size=32,
# )

# # 2D
# trainers_2d = train_models(
#     img_datasets,
#     {"mnist": model_2d},
#     test_datasets=img_test_datasets,
#     train_split=skorch.dataset.CVSplit(0.1),  # use 10% of training for valdiation
#     iterator_train__collate_fn=get_cntxt_trgt_2d,
#     iterator_valid__collate_fn=get_cntxt_trgt_2d,
#     max_epochs=50,
#     **train_kwargs
# )