# Experiment 2.ii
### In this experiment we compare different metamodels on an image completion task framed as a 2D regression task 
We compare the amortised GI BNN with a convolutional conditional neural process. The models will be evaluated on datasets which will be treated as metadatasets in which each image will be subsampled and treated as a 2D regression dataset. In this sub-experiment, we evalute the models on the CelebA dataset.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import wbml.experiment
import wbml.plot
from tqdm import tqdm
import sys

torch.set_default_dtype(torch.float64)

sys.path.append("../")
import bnn_amort_inf
from bnn_amort_inf.models.likelihoods.normal import (
    NormalLikelihood,
    BernoulliLikelihood,
)
from bnn_amort_inf.models.bnn import gibnn, mfvi_bnn
from bnn_amort_inf.models.np import GridConvCNP, CNP
from bnn_amort_inf import utils

torch.set_default_dtype(torch.float64)

In [None]:
wd = wbml.experiment.WorkingDirectory("./experiment_two_ii", log=None, override=False)

In [None]:
num_datasets = 1000  # note cannot be greater than 202,599
meta_datasets = {}
ratio = 0.2  # proportion of context pixels in training images

celeba = datasets.CelebA(
    root=wd.root,
    split="all",
    download=True,
    transform=torchvision.transforms.ToTensor(),
)
celeba_iter = iter(torch.utils.data.DataLoader(celeba, shuffle=True))
img_samp = next(celeba_iter)[0]

I, M_c = (
    next(celeba_iter)[0].squeeze(0),
    utils.dataset_utils.random_mask(ratio, img_samp)[1],
)

m_d = []
for _ in tqdm(range(num_datasets)):
    img = next(celeba_iter)[0].squeeze(0)
    mask = utils.dataset_utils.random_mask(ratio, img_samp)[1]
    x_c, y_c, x_t, y_t = utils.dataset_utils.img_for_reg(img, mask)
    m_d.append((img, mask, x_c, y_c, x_t, y_t))

meta_datasets["celeba"] = utils.dataset_utils.MetaDataset(m_d)

In [None]:
len(meta_datasets["celeba"].datasets)

### Define and Train Models

In [None]:
plot_training_metrics = True

In [None]:
convcnp = GridConvCNP(
    x_dim=2,
    y_dim=3,
    embedded_dim=128,
    cnn_chans=[64, 64],
    conv_kernel_size=13,
    cnn_kernel_size=9,
    res=True,
)

convcnp_tracker = utils.training_utils.train_metamodel(
    convcnp,
    meta_datasets["celeba"],
    image=True,
    gridconv=True,
    lr=1e-4,
    max_iters=5_000,
    batch_size=1,
    min_es_iters=150,
    ref_es_iters=50,
    smooth_es_iters=50,
    es=True,
    # man_thresh=(
    #     "ll",
    #     2000.0,
    # ),  # manual threshold that only stops training loop if ll>man_thresh... very hacky but fine for now
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(convcnp_tracker.keys()),
        1,
        figsize=(8, len(convcnp_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, convcnp_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

In [None]:
agibnn = gibnn.AmortisedGIBNN(
    x_dim=2,
    y_dim=3,
    hidden_dims=[50, 50],
    in_hidden_dims=[100, 100],
    likelihood=NormalLikelihood(noise=0.01, train_noise=False),
)

agibnn_tracker = utils.training_utils.train_metamodel(
    agibnn,
    meta_datasets["celeba"],
    image=True,
    np_loss=True,
    np_kl=False,
    lr=5e-3,
    max_iters=5_000,
    batch_size=1,
    min_es_iters=150,
    ref_es_iters=50,
    smooth_es_iters=50,
    es=True,
    # man_thresh=("elbo", -2.0e10),
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(agibnn_tracker.keys()),
        1,
        figsize=(8, len(agibnn_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

### Generate Test Datasets

In [None]:
num_test = 4
num_models = 2
test_datasets = []
for _ in range(num_test):
    img = next(celeba_iter)[0].squeeze(0)
    mask = utils.dataset_utils.random_mask(0.2, img_samp)[1]
    x_c, y_c, x_t, y_t = utils.dataset_utils.img_for_reg(img, mask)
    test_datasets.append((img, mask, x_c, y_c, x_t, y_t))

In [None]:
plot_std = True

for dataset in test_datasets:
    I, M_c, x_c, y_c, x_t, y_t = dataset
    ctx_img = utils.dataset_utils.vis_ctxt_img(M_c, I)

    convcnp_pred_dist = convcnp(
        I,
        M_c,
    )
    convcnp_pred_img = convcnp_pred_dist.loc.reshape((218, 178, 3)).detach()
    convcnp_pred_std = convcnp_pred_dist.scale.reshape((218, 178, 3)).detach()

    agibnn_pred_samps = agibnn(x_c, y_c, x_test=x_t, num_samples=3)[-1].detach()
    pred_img = agibnn_pred_samps.mean(0).reshape((218, 178, 3))
    pred_std = agibnn_pred_samps.std(0).reshape((218, 178, 3))

    num_plots = 2 + num_models * (1 + int(plot_std))
    fig, axes = plt.subplots(1, num_plots)
    axes[0].imshow(I.permute(1, 2, 0))
    axes[0].axis(False)
    axes[1].imshow(ctx_img)
    axes[1].axis(False)
    axes[2].imshow(pred_img)
    axes[2].axis(False)
    axes[4].imshow(convcnp_pred_img)
    axes[4].axis(False)
    if plot_std:
        axes[3].imshow(pred_std)
        axes[3].axis(False)
        axes[5].imshow(convcnp_pred_std)
        axes[5].axis(False)

    plt.show()