In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

import argparse
import torch
import neuralprocesses.torch as nps
import icicl
import stheno
import numpy as np
import matplotlib.pyplot as plt
import wbml.plot

from tqdm.auto import tqdm

torch.set_default_dtype(torch.float32)

# Construct ICNP model.

In [None]:
x_dim = 1
y_dim = 1
r_dim = 32
encoder_num_layers = 3
encoder_width = 128
decoder_num_layers = 3
decoder_width = 128

likelihood = icicl.likelihoods.HeteroscedasticNormalLikelihood()

deepset_mlp = icicl.nn.MLP(
    in_dim=x_dim + y_dim,
    out_dim=r_dim,
    num_layers=encoder_num_layers,
    width=encoder_width,
)
deepset = icicl.deepset.DeepSet(phi=encoder_mlp)

dataset_deepset_mlp = icicl.nn.MLP(
    in_dim=x_dim + y_dim,
    out_dim=r_dim,
    num_layers=encoder_num_layers,
    width=encoder_width,
)
dataset_deepset = icicl.deepset.DatasetDeepSet(
    icicl.deepset.DeepSet(dataset_deepset_mlp)
)

encoder = icicl.deepset.ICDeepSet(deepset, dataset_deepset)

# Change this to be more like Wessels.
decoder_mlp = icicl.nn.MLP(
    in_dim=2 * r_dim + x_dim,
    out_dim=likelihood.out_dim_multiplier * y_dim,
    num_layers=decoder_num_layers,
    width=decoder_width,
)
decoder = icicl.cnp.CNPDecoder(decoder_mlp)

model = icicl.models.ICNP(encoder, decoder, likelihood)

In [None]:
def train_iccnp(model, opt, objective, batch_size, iters, mixture_kernels):
    vals = []
    iter_iter = tqdm(range(iters))
    for iter in iter_iter:
        batch_vals = []
        for batch in range(batch_size):
            # Randomly sample number of context point, number of context datasets, gp kernel.
            gp_kernel_idx = np.random.randint(low=0, high=len(mixture_kernels))
            kernel_name = list(mixture_kernels.keys())[gp_kernel_idx]
            kernel = mixture_kernels[kernel_name]

            x_c, y_c, x_t, y_t, d_c = icicl.utils.gp_ic_sampler(
                max_n_ic_datasets=10,
                max_n_context=30,
                kernel=kernel,
            )

            val = icicl.objectives.iccnp_objective(model, x_c, y_c, x_t, y_t, d_c)
            batch_vals.append(val)

        opt.zero_grad()
        batch_val = sum(batch_vals) / len(batch_vals)
        batch_val.backward()
        opt.step()
        vals.append(batch_val.item())

        iter_iter.set_postfix({"val": batch_val.item()})

    return vals


def train_cnp(model, opt, objective, batch_size, iters, mixture_kernels):
    vals = []
    iter_iter = tqdm(range(iters))
    for iter in iter_iter:
        batch_vals = []
        for batch in range(batch_size):
            # Randomly sample number of context point, number of context datasets, gp kernel.
            gp_kernel_idx = np.random.randint(low=0, high=len(mixture_kernels))
            kernel_name = list(mixture_kernels.keys())[gp_kernel_idx]
            kernel = mixture_kernels[kernel_name]

            x_c, y_c, x_t, y_t = icicl.utils.gp_sampler(
                max_n_context=30, kernel=kernel, dtype=torch.float32
            )
            val = icicl.objectives.cnp_objective(model, x_c, y_c, x_t, y_t)
            batch_vals.append(val)

        opt.zero_grad()
        batch_val = sum(batch_vals) / len(batch_vals)
        batch_val.backward()
        opt.step()
        vals.append(batch_val.item())

        iter_iter.set_postfix({"val": batch_val.item()})

    return vals

In [None]:
# Setup optimiser.
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

mixture_kernels = {
    "se": stheno.EQ().stretch(0.25),
    # "matern": stheno.Matern52().stretch(0.25),
    # "weakly-periodic": stheno.EQ().stretch(0.5) * stheno.EQ().periodic(0.5)
    "period": stheno.EQ().periodic(0.5),
}

train_vals = train_iccnp(
    model,
    opt,
    iccnp_objective,
    batch_size=5,
    iters=50_000,
    mixture_kernels=mixture_kernels,
)

In [None]:
plt.plot(train_vals)

In [None]:
def iccnp_visualise_1d(model, x_c, y_c, x_t, y_t, d_c, kernel=None, noise=0.05):

    x = torch.linspace(-2.0, 2.0, 200).unsqueeze(-1)
    with torch.no_grad():
        pred_y_t = model(x_c, y_c, x, d_c)
        mean, scale = pred_y_t.loc, pred_y_t.scale.pow(2)

    plt.figure(figsize=(8, 6))

    # Plot context and target.
    plt.scatter(x_c, y_c, label="Context", style="train", s=20)
    plt.scatter(x_t, y_t, label="Target", style="test", s=20)

    # Plot prediction.
    err = 1.96 * scale
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, mean - err, mean + err, style="pred")

    if kernel is not None:
        f = stheno.GP(kernel)
        f_post = f | (f(x_c, noise), y_c)
        mean, lower, upper = f_post(x).marginal_credible_bounds()

        plt.plot(x, mean, label="Truth", style="pred2")
        plt.plot(x, lower, style="pred2")
        plt.plot(x, upper, style="pred2")

    plt.xlim(x.min(), x.max())
    wbml.plot.tweak()

    plt.show()

In [None]:
gp_kernel_idx = np.random.randint(low=0, high=len(mixture_kernels))
kernel_name = list(mixture_kernels.keys())[gp_kernel_idx]
kernel = mixture_kernels[kernel_name]

n_context = np.random.randint(low=1, high=30)
n_context_datasets = np.random.randint(low=1, high=10)

x_c, y_c, x_t, y_t, d_c = gp_ic_sampler(
    n_context_datasets=n_context_datasets,
    n_context=n_context,
    kernel=kernel,
)

iccnp_visualise_1d(model, x_c, y_c, x_t, y_t, d_c, kernel=kernel)

# Construct CNP model.

In [None]:
x_dim = 1
y_dim = 1
r_dim = 16
encoder_num_layers = 3
encoder_width = 64
decoder_num_layers = 3
decoder_width = 64
likelihood = icicl.likelihoods.HeteroscedasticNormalLikelihood()

cnp_encoder_mlp = icicl.nn.MLP(
    in_dim=x_dim + y_dim,
    out_dim=r_dim,
    num_layers=encoder_num_layers,
    width=encoder_width,
)
cnp_encoder = icicl.deepset.DeepSet(phi=cnp_encoder_mlp)

# Can we chain together aggregations, rather than have seperate classes?
cnp_decoder_mlp = icicl.nn.MLP(
    in_dim=r_dim + x_dim,
    out_dim=likelihood.out_dim_multiplier * y_dim,
    num_layers=decoder_num_layers,
    width=decoder_width,
)
cnp_decoder = icicl.cnp.CNPDecoder(cnp_decoder_mlp)
cnp_model = icicl.models.NP(cnp_encoder, cnp_decoder, likelihood)

In [None]:
# Setup optimiser.
cnp_opt = torch.optim.Adam(cnp_model.parameters(), lr=1e-3)

mixture_kernels = {
    "se": stheno.EQ().stretch(0.25),
    "matern": stheno.Matern52().stretch(0.25),
    "weakly-periodic": stheno.EQ().stretch(0.5) * stheno.EQ().periodic(0.25),
}

train_vals = train_cnp(
    cnp_model,
    cnp_opt,
    cnp_objective,
    batch_size=10,
    iters=50_000,
    mixture_kernels=mixture_kernels,
)

In [None]:
plt.plot(train_vals)

In [None]:
def cnp_visualise_1d(model, x_c, y_c, x_t, y_t, kernel=None, noise=0.05):

    x = torch.linspace(-2.0, 2.0, 200).unsqueeze(-1)
    with torch.no_grad():
        pred_y_t = model(x_c, y_c, x_t=x)
        mean, scale = pred_y_t.loc, pred_y_t.scale.pow(2)

    plt.figure(figsize=(8, 6))

    # Plot context and target.
    plt.scatter(x_c, y_c, label="Context", style="train", s=20)
    plt.scatter(x_t, y_t, label="Target", style="test", s=20)

    # Plot prediction.
    err = 1.96 * scale
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, mean - err, mean + err, style="pred")

    if kernel is not None:
        f = stheno.GP(kernel)
        f_post = f | (f(x_c, noise), y_c)
        mean, lower, upper = f_post(x).marginal_credible_bounds()

        plt.plot(x, mean, label="Truth", style="pred2")
        plt.plot(x, lower, style="pred2")
        plt.plot(x, upper, style="pred2")

    plt.xlim(x.min(), x.max())
    wbml.plot.tweak()

    plt.show()

In [None]:
# Randomly sample number of context point, number of context datasets, gp kernel.
gp_kernel_idx = np.random.randint(low=0, high=len(mixture_kernels))
kernel_name = list(mixture_kernels.keys())[gp_kernel_idx]
kernel = mixture_kernels[kernel_name]

n_context = np.random.randint(low=1, high=30)

x_c, y_c, x_t, y_t = gp_sampler(
    n_context=n_context,
    kernel=kernel,
)

cnp_visualise_1d(cnp_model, x_c, y_c, x_t, y_t, kernel=kernel)