# Experiment 2.ii
### In this experiment we compare different metamodels on an image completion task framed as a 2D regression task 
We compare the amortised GI BNN with a convolutional conditional neural process. The models will be evaluated on datasets which will be treated as metadatasets in which each image will be subsampled and treated as a 2D regression dataset. In this sub-experiment, we evalute the models on the CelebA dataset.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import wbml.experiment
import wbml.plot
from tqdm import tqdm
import sys
import os
import random
from PIL import Image

torch.set_default_dtype(torch.float64)

sys.path.append("../")
import bnn_amort_inf
from bnn_amort_inf.models.likelihoods.normal import (
    NormalLikelihood,
    HeteroscedasticNormalLikelihood,
)
from bnn_amort_inf.models.bnn import gibnn, mfvi_bnn
from bnn_amort_inf.models.np.convcnp import GridConvCNP
from bnn_amort_inf import utils

torch.set_default_dtype(torch.float64)

In [None]:
wd = wbml.experiment.WorkingDirectory("./experiment_two_ii", log=None, override=False)

### Generate training metadataset for CelebA 64x64

In [None]:
num_train = 10000  # note cannot be greater than 50,000
tr_ratio = 0.5
ratio_range = [0.05, 0.95]

celeba_dir = "./CelebA64x64/50k/"
img_file_paths = [
    os.path.join(celeba_dir, f) for f in os.listdir(celeba_dir) if f.endswith(".jpg")
]
random.shuffle(img_file_paths)

train_paths = img_file_paths[:num_train]
transform = torchvision.transforms.ToTensor()

tr_md = []
for file_path in tqdm(train_paths):
    img = transform(Image.open(file_path))
    mask = utils.dataset_utils.random_mask(ratio_range, img)[1]
    tr_md.append((img, mask))
train_metadataset = utils.dataset_utils.MetaDataset(tr_md)

### Define and Train Models

In [None]:
plot_training_metrics = True

In [None]:
convcnp = GridConvCNP(
    x_dim=2,
    y_dim=3,
    embedded_dim=512,
    conv_kernel_size=5,
    cnn_kernel_size=5,
    unet=True,
    num_unet_layers=8,
    unet_starting_chans=32,
    likelihood=HeteroscedasticNormalLikelihood(image=True),
)

print("ConvCNP parameters: ", sum(p.numel() for p in convcnp.parameters()))

convcnp_tracker = utils.training_utils.train_metamodel(
    convcnp,
    train_metadataset,
    image=True,
    gridconv=True,
    lr=5e-4,
    max_iters=10_000,
    batch_size=3,
    min_es_iters=2000,
    ref_es_iters=500,
    smooth_es_iters=200,
    es=True,
    # man_thresh=(
    #     "ll",
    #     2000.0,
    # ),  # manual threshold that only stops training loop if ll>man_thresh... very hacky but fine for now
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(convcnp_tracker.keys()),
        1,
        figsize=(8, len(convcnp_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, convcnp_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

In [None]:
num_test = 4
test_ratio = 0.3
test_paths = img_file_paths[num_train : num_train + num_test]

test_metadataset = []
for file_path in test_paths:
    img = transform(Image.open(file_path))
    mask = utils.dataset_utils.random_mask(test_ratio, img)[1]
    test_metadataset.append((img, mask))

In [None]:
plot_std = True

for dataset in test_metadataset:
    img, mask = dataset
    x_c, y_c, x_t, y_t = utils.dataset_utils.img_for_reg(img, mask)
    ctx_img = utils.dataset_utils.vis_ctxt_img(mask, img)
    x_test = utils.dataset_utils.test_grid(img.shape[-2:])

    convcnp_preds = convcnp(img, mask)
    convcnp_pred_img = convcnp_preds.loc.detach().permute(1, 2, 0).numpy()
    convcnp_pred_std = convcnp_preds.scale.detach().permute(1, 2, 0).numpy()
    linear_interpolation = utils.dataset_utils.linearly_interpolate(img, mask)

    num_plots = 3 + (1 + int(plot_std))
    fig, axes = plt.subplots(1, num_plots)
    axes[0].imshow(img.permute(1, 2, 0).numpy())
    axes[0].axis(False)
    axes[1].imshow(ctx_img)
    axes[1].axis(False)
    axes[2].imshow(convcnp_pred_img)
    axes[2].axis(False)
    if plot_std:
        axes[3].imshow(convcnp_pred_std.mean(-1), cmap="viridis")
        axes[3].axis(False)

    axes[-1].imshow(linear_interpolation.numpy())
    axes[-1].axis(False)

    plt.show()

In [None]:
agibnn = gibnn.AmortisedGIBNN(
    x_dim=2,
    y_dim=3,
    hidden_dims=[150, 150, 150],
    in_hidden_dims=[200, 200, 200],
    likelihood=NormalLikelihood(noise=0.05, train_noise=False, image=True),
)

print("Amortised BNN parameters: ", sum(p.numel() for p in agibnn.parameters()))

agibnn_tracker = utils.training_utils.train_metamodel(
    agibnn,
    train_metadataset,
    image=True,
    loss_fn="npml_loss",  # options are "npml_loss", "npvi_loss", "loss"
    lr=1e-5,
    max_iters=10_000,
    batch_size=3,
    min_es_iters=2000,
    ref_es_iters=500,
    smooth_es_iters=200,
    es=True,
    # man_thresh=("elbo", -2.0e10),
)

if plot_training_metrics:
    fig, axes = plt.subplots(
        len(agibnn_tracker.keys()),
        1,
        figsize=(8, len(agibnn_tracker.keys()) * 4),
        dpi=100,
        sharex=True,
    )

    for ax, (key, vals) in zip(axes, agibnn_tracker.items()):
        ax.plot(vals)
        ax.set_ylabel(key)
        ax.grid()

    plt.show()

In [None]:
num_test = 4
test_ratio = 0.5
test_paths = img_file_paths[num_train : num_train + num_test]

test_metadataset = []
for file_path in test_paths:
    img = transform(Image.open(file_path))
    mask = utils.dataset_utils.random_mask(test_ratio, img)[1]
    test_metadataset.append((img, mask))

In [None]:
plot_std = True

for dataset in test_metadataset:
    img, mask = dataset
    x_c, y_c, x_t, y_t = utils.dataset_utils.img_for_reg(img, mask)
    ctx_img = utils.dataset_utils.vis_ctxt_img(mask, img)
    x_test = utils.dataset_utils.test_grid(img.shape[-2:])

    preds = agibnn(x_c, y_c, x_test=x_test, num_samples=20)[-1]
    pred_dists = agibnn.likelihood(preds)
    pred_probs = pred_dists.mean.detach()
    pred_img = pred_probs.mean(0).reshape((64, 64, 3))
    pred_std = pred_probs.std(0).reshape((64, 64, 3))

    linear_interpolation = utils.dataset_utils.linearly_interpolate(img, mask)

    num_plots = 3 + (1 + int(plot_std))
    fig, axes = plt.subplots(1, num_plots)
    axes[0].imshow(img.permute(1, 2, 0).numpy())
    axes[0].axis(False)
    axes[1].imshow(ctx_img)
    axes[1].axis(False)
    axes[2].imshow(pred_img)
    axes[2].axis(False)
    if plot_std:
        axes[3].imshow(pred_std.mean(-1), cmap="viridis")
        axes[3].axis(False)

    axes[-1].imshow(linear_interpolation.numpy())
    axes[-1].axis(False)

    plt.show()

### Evaluate models on test datasets

In [None]:
num_test = 4
test_ratio = 0.5
test_paths = img_file_paths[num_train : num_train + num_test]

test_metadataset = []
for file_path in test_paths:
    img = transform(Image.open(file_path))
    mask = utils.dataset_utils.random_mask(test_ratio, img)[1]
    test_metadataset.append((img, mask))

In [None]:
plot_std = True

for dataset in test_metadataset:
    img, mask = dataset
    x_c, y_c, x_t, y_t = utils.dataset_utils.img_for_reg(img, mask)
    ctx_img = utils.dataset_utils.vis_ctxt_img(mask, img)
    x_test = utils.dataset_utils.test_grid(img.shape[-2:])

    convcnp_preds = convcnp(img, mask)
    convcnp_pred_img = convcnp_preds.loc.detach().permute(1, 2, 0).numpy()
    convcnp_pred_std = convcnp_preds.scale.detach().permute(1, 2, 0).numpy()

    preds = agibnn(x_c, y_c, x_test=x_test, num_samples=20)[-1]
    pred_dists = agibnn.likelihood(preds)
    pred_probs = pred_dists.mean.detach()
    pred_img = pred_probs.mean(0).reshape((64, 64, 3))
    pred_std = pred_probs.std(0).reshape((64, 64, 3))

    linear_interpolation = utils.dataset_utils.linearly_interpolate(img, mask)

    num_plots = 3 + 2 * (1 + int(plot_std))
    fig, axes = plt.subplots(1, num_plots)
    axes[0].imshow(img.permute(1, 2, 0).numpy(), cmap="gray")
    axes[0].axis(False)
    axes[1].imshow(ctx_img)
    axes[1].axis(False)
    axes[2].imshow(pred_img)
    axes[2].axis(False)
    if plot_std:
        axes[3].imshow(pred_std.mean(-1), cmap="viridis")
        axes[3].axis(False)
        axes[4].imshow(convcnp_pred_img)
        axes[4].axis(False)
        axes[5].imshow(convcnp_pred_std.mean(-1), cmap="viridis")
        axes[5].axis(False)
    else:
        axes[3].imshow(convcnp_pred_img)
        axes[3].axis(False)

    axes[-1].imshow(linear_interpolation.numpy())
    axes[-1].axis(False)

    plt.show()