# Experiment 2
### In this experiment we compare different metamodels on an image completion task framed as a 2D regression task 
We compare the amortised GI BNN with a convolutional conditional neural process and an amortised MFVI BNN as two baseline models. The models will be evaluated on datasets which will be treated as metadatasets in which each image will be subsampled and treated as a 2D regression dataset. The following datasets will be considered:
- MNIST
- CelebA, including an out-of-distribution qualitative evaluation on the Ellen Oscars Selfie
- CIFAR10

For each test case, a linear interpolator will be used as a lower benchmark


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys

torch.set_default_dtype(torch.float64)

sys.path.append("../")
from bnn_amort_inf.models.bnn import gibnn, mfvi_bnn
from bnn_amort_inf.models.np import GridConvCNP, CNP
from bnn_amort_inf import utils

### Handle Data and Generate Metadatasets

In [None]:
num_datasets = 1000  # note cannot be greater than 9000
meta_datasets = {}
ratio = 0.5  # proportion of context pixels in training images

mnist = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)
mnist_iter = iter(torch.utils.data.DataLoader(mnist, shuffle=True))
img_samp = next(mnist_iter)[0]

I, M_c = (
    next(mnist_iter)[0].squeeze(0),
    utils.dataset_utils.random_mask(ratio, img_samp)[1],
)

meta_datasets["mnist"] = utils.dataset_utils.MetaDataset(
    [
        (
            next(mnist_iter)[0].squeeze(0),
            utils.dataset_utils.random_mask(ratio, img_samp)[1],
        )
        for _ in range(num_datasets)
    ]
)

### Define and Train Models

In [None]:
plot_training_metrics = True

In [None]:
unetconvcnp = GridConvCNP(
    x_dim=2,
    y_dim=1,
    embedded_dim=128,
    conv_kernel_size=5,
    cnn_kernel_size=3,
    unet=True,
    num_unet_layers=6,  # 2**((num_unet_layers / 2) - 1) must divide image width (/height) in pixels
    unet_starting_chans=64,
    pool="max",
    target_only_loss=False,  # only the pixels not in context set are used in loss if True
)

unetconvcnp_tracker = utils.training_utils.train_metamodel(
    unetconvcnp,
    meta_datasets["mnist"],
    gridconv=True,
    lr=2e-4,
    max_iters=5_000,
    batch_size=16,
    min_es_iters=150,
    ref_es_iters=50,
    smooth_es_iters=50,
    es=True,
    man_thresh=1600,  # manual threshold that only stops training loop if ll>man_thresh... very hacky but fine for now
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(unetconvcnp_tracker.keys()),
        1,
        figsize=(8, len(unetconvcnp_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, unetconvcnp_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

In [None]:
resconvcnp = GridConvCNP(
    x_dim=2,
    y_dim=1,
    embedded_dim=128,
    cnn_chans=[128, 128, 128],
    conv_kernel_size=5,
    cnn_kernel_size=3,
    res=True,
)

resconvcnp_tracker = utils.training_utils.train_metamodel(
    resconvcnp,
    meta_datasets["mnist"],
    gridconv=True,
    lr=2e-4,
    max_iters=5_000,
    batch_size=16,
    min_es_iters=150,
    ref_es_iters=50,
    smooth_es_iters=50,
    es=True,
    man_thresh=1600,  # manual threshold that only stops training loop if ll>man_thresh... very hacky but fine for now
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(resconvcnp_tracker.keys()),
        1,
        figsize=(8, len(resconvcnp_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, resconvcnp_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

### Generate Test Datasets

In [None]:
num_test = 4
num_models = 2
test_datasets = [
    (
        next(mnist_iter)[0].squeeze(0),
        utils.dataset_utils.random_mask(0.2, img_samp)[1],
    )
    for _ in range(num_test)
]

In [None]:
plot_std = True

for dataset in test_datasets:
    I, M_c = dataset
    ctx_img = utils.dataset_utils.vis_ctxt_img(M_c, I)
    unet_preds = unetconvcnp(I, M_c)
    res_preds = resconvcnp(I, M_c)
    unet_pred_img = unet_preds.loc.detach().permute(1, 2, 0).numpy()
    unet_pred_std = unet_preds.scale.detach().permute(1, 2, 0).numpy()
    res_pred_img = res_preds.loc.detach().permute(1, 2, 0).numpy()
    res_pred_std = res_preds.scale.detach().permute(1, 2, 0).numpy()
    num_plots = 2 + num_models * (1 + int(plot_std))
    fig, axes = plt.subplots(1, num_plots)
    axes[0].imshow(I.permute(1, 2, 0).numpy(), cmap="gray")
    axes[0].axis(False)
    axes[1].imshow(ctx_img)
    axes[1].axis(False)
    axes[2].imshow(unet_pred_img, cmap="gray")
    axes[2].axis(False)
    if plot_std:
        axes[3].imshow(unet_pred_std, cmap="viridis")
        axes[3].axis(False)
        axes[4].imshow(res_pred_img, cmap="gray")
        axes[4].axis(False)
        axes[5].imshow(res_pred_std, cmap="viridis")
        axes[5].axis(False)
    else:
        axes[3].imshow(res_pred_img, cmap="gray")
        axes[3].axis(False)

    plt.show()