In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import copy

import sys
sys.path.append("../")

from utils.data import DIR_DATA, GPDataset
from utils.data.helpers import DatasetMerger

from sklearn.gaussian_process.kernels import (
    RBF,
    ConstantKernel,
    DotProduct,
    ExpSineSquared,
    Matern,
    WhiteKernel,
)

def get_all_gp_datasets(**kwargs):
    """Return train / tets / valid sets for all GP experiments."""
    datasets, test_datasets, valid_datasets = dict(), dict(), dict()

    for f in [
        get_datasets_single_gp,
        get_datasets_variable_hyp_gp,
        get_datasets_variable_kernel_gp,
    ]:
        _datasets, _test_datasets, _valid_datasets = f(**kwargs)
        datasets.update(_datasets)
        test_datasets.update(_test_datasets)
        valid_datasets.update(_valid_datasets)

    return datasets, test_datasets, valid_datasets


def get_datasets_single_gp(**kwargs):
    """Return train / tets / valid sets for 'Samples from a single GP'."""
    kernels = dict()

    kernels["RBF_Kernel"] = RBF(length_scale=(0.2))

    kernels["Periodic_Kernel"] = ExpSineSquared(length_scale=0.5, periodicity=0.5)

    # kernels["Matern_Kernel"] = Matern(length_scale=0.2, nu=1.5)

    kernels["Noisy_Matern_Kernel"] = WhiteKernel(noise_level=0.1) + Matern(
        length_scale=0.2, nu=1.5
    )

    return get_gp_datasets(
        kernels,
        is_vary_kernel_hyp=False,  # use a single hyperparameter per kernel
        n_samples=10_000,  # number of different context-target sets
        n_points=128,  # size of target U context set for each sample
        is_reuse_across_epochs=False,  # never see the same example twice
        **kwargs,
    )


def get_datasets_variable_hyp_gp(**kwargs):
    """Return train / tets / valid sets for 'Samples from GPs with varying Kernel hyperparameters'."""
    kernels = dict()

    kernels["Variable_Matern_Kernel"] = Matern(length_scale_bounds=(0.01, 0.3), nu=1.5)

    return get_gp_datasets(
        kernels,
        is_vary_kernel_hyp=True,  # use a different hyp for each samples
        n_samples=50000,  # number of different context-target sets
        n_points=128,  # size of target U context set for each sample
        is_reuse_across_epochs=False,  # never see the same example twice
        **kwargs,
    )


def get_datasets_variable_kernel_gp(**kwargs):
    """Return train / tets / valid sets for 'Samples from GPs with varying Kernels'."""

    datasets, test_datasets, valid_datasets = get_datasets_single_gp(**kwargs)
    return (
        dict(All_Kernels=DatasetMerger(datasets.values())),
        dict(All_Kernels=DatasetMerger(test_datasets.values())),
        dict(All_Kernels=DatasetMerger(valid_datasets.values())),
    )


def sample_gp_dataset_like(dataset, **kwargs):
    """Wrap the output of `get_samples` in a gp dataset."""
    new_dataset = copy.deepcopy(dataset)
    new_dataset.set_samples_(*dataset.get_samples(**kwargs))
    return new_dataset


def get_gp_datasets(
    kernels, save_file=f"{os.path.join(DIR_DATA, 'gp_dataset.hdf5')}", **kwargs
):
    """
    Return a train, test and validation set for all the given kernels (dict).
    """
    datasets = dict()

    def get_save_file(name, save_file=save_file):
        if save_file is not None:
            save_file = (save_file, name)
        return save_file

    for name, kernel in kernels.items():
        datasets[name] = GPDataset(
            kernel=kernel, save_file=get_save_file(name), **kwargs
        )

    datasets_test = {
        k: sample_gp_dataset_like(
            dataset, save_file=get_save_file(k), idx_chunk=-1, n_samples=10000
        )
        for k, dataset in datasets.items()
    }

    datasets_valid = {
        k: sample_gp_dataset_like(
            dataset,
            save_file=get_save_file(k),
            idx_chunk=-2,
            n_samples=dataset.n_samples // 10,
        )
        for k, dataset in datasets.items()
    }

    return datasets, datasets_test, datasets_valid

In [None]:
gp_datasets, gp_test_datasets, gp_valid_datasets = get_datasets_single_gp()

In [None]:
from npf.utils.datasplit import CntxtTrgtGetter, GetRandomIndcs, get_all_indcs
from utils.data import cntxt_trgt_collate

get_cntxt_trgt_1d = cntxt_trgt_collate(
    CntxtTrgtGetter(
        contexts_getter=GetRandomIndcs(a=0.0, b=50), targets_getter=get_all_indcs,
    )
)

In [None]:
gp_datasets["RBF_Kernel"]

# Build model.

In [None]:
import numpy as np
import torch
from npf import ICCNP, CNP
from npf.architectures import MLP, merge_flat_input

r_dim = 64
ic_r_dim = 64

model = ICCNP(x_dim=1, y_dim=1, r_dim=r_dim, ic_r_dim=ic_r_dim)
# model = CNP(x_dim=1, y_dim=1, r_dim=r_dim)

# Train model

In [None]:
def ic_batch_generator(gp_datasets=gp_datasets, batch_size=16, n=128, max_n_cntxt=60, max_n_dc=2):
    # Ranomly select kernel.
    kernel_idx = np.random.randint(len(gp_datasets))
    # kernel_idx = 0
    kernel = list(gp_datasets.keys())[kernel_idx]
    dataset_gen = gp_datasets[kernel]
    
    # We need (1 + n_dc) * batch_size datasets.
    n_dc = np.random.randint(max_n_dc)
    x, y = dataset_gen.get_samples(n_samples=(1 + n_dc) * batch_size, n_points=n)

    # Keep first batch_size to split into target / context sets.
    xt, yt = x[:batch_size], y[:batch_size]
    n_cntxt = np.random.randint(max_n_cntxt)
    idx_cntxt = torch.randperm(n)[:n_cntxt]
    xc, yc = x[:batch_size, idx_cntxt], y[:batch_size, idx_cntxt]

    # Create D_cntxt of shape (batch_size, n_dc, n, x_dim).
    x_dc, y_dc = x[batch_size:], y[batch_size:] # (batch_size * n_dc, n, x_dim).
    x_dc = x_dc.reshape(batch_size, n_dc, n, 1)
    y_dc = y_dc.reshape(batch_size, n_dc, n, 1)
    dc = (x_dc, y_dc)

    return xc, yc, xt, yt, dc

In [None]:
from npf import CNPFLoss
from tqdm.auto import tqdm

batch_size = 3
lr = 1e-3
decay_lr = 10
max_epochs = 10
iters = 50000 # batches per epoch.
loss = CNPFLoss()

gp_datasets, gp_test_datasets, gp_valid_datasets = get_datasets_single_gp()
opt = torch.optim.Adam(model.parameters(), lr=lr)
gamma = (1 / decay_lr) ** (1 / max_epochs)
schedule = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=gamma)

for epoch in range(max_epochs):
    print(f"Epoch: {epoch}.")

    tqdm_iter = tqdm(range(iters), desc="Iters")
    for _ in tqdm_iter:
        # Generate batch of data.
        xc, yc, xt, yt, dc = ic_batch_generator(gp_datasets, batch_size)

        nll = loss.get_loss(*model(xc, yc, xt, yt, dc), yt)
        # nll = loss.get_loss(*model(xc, yc, xt, yt), yt)}
        nll = nll.mean()
        opt.zero_grad()
        nll.backward()
        opt.step()

        tqdm_iter.set_postfix({"nll": nll.item()})

    # Reduce learning rate.
    schedule.step()

In [None]:
import wbml.plot
import matplotlib.pyplot as plt

def visualise_1d(mean, scale, xc, yc, xt, yt):

    plt.figure(figsize=(8, 6))
    
    # Plot context and target.
    plt.scatter(xc, yc, label="Context", style="train", s=20)
    plt.scatter(xt, yt, label="Target", style="test", s=20)

    # Plot prediction.
    err = 1.96 * scale
    plt.plot(xt, mean, label="Prediction", style="pred")
    plt.fill_between(xt, mean - err, mean + err, style="pred")

    plt.xlim(xt.min(), xt.max())
    wbml.plot.tweak()
    
    plt.show()


def ic_posterior_samples(model, n_samples=10, gp_datasets=gp_datasets, n=200, max_n_cntxt=60, max_n_dc=10):
    for _ in range(n_samples):
        xc, yc, xt, yt, dc = ic_batch_generator(gp_datasets, batch_size=1, n=n, max_n_cntxt=max_n_cntxt, max_n_dc=max_n_dc)

        p_yCc, *_ = model(xc, yc, xt, yt, dc)
        mean = p_yCc.base_dist.loc.detach().numpy()
        scale = p_yCc.base_dist.scale.detach().numpy()

        visualise_1d(mean, scale, xc, yc, xt, yt)

def cnp_posterior_samples(model, n_samples=10, gp_datasets=gp_datasets, n=200, max_n_cntxt=60, max_n_dc=10):
    for _ in range(n_samples):
        xc, yc, xt, yt, dc = ic_batch_generator(gp_datasets, batch_size=1, n=n, max_n_cntxt=max_n_cntxt, max_n_dc=max_n_dc)

        p_yCc, *_ = model(xc, yc, xt, yt)
        mean = p_yCc.base_dist.loc.detach().numpy()
        scale = p_yCc.base_dist.scale.detach().numpy()

        visualise_1d(mean, scale, xc, yc, xt, yt)

In [None]:
ic_posterior_samples(model)
# cnp_posterior_samples(model)