The easiest way to compare the performance between original RCNP and our implementation is to replace the train function and its corresponding loglik objective with our own methods, while remaining other functions the same.

# 1D regression task

We start by importing the necessary dependencies. This implementation is based on PyTorch.

In [1]:
import numpy as np
import numpy.random as npr
import torch
import torch.nn as nn
import torch.optim as optim
import collections
import matplotlib.pyplot as plt
import gc
import datetime
import os
import sys
import time
import warnings
from functools import partial
import experiment as exp
import lab as B
import wbml.out as out
from matrix.util import ToDenseWarning
from wbml.experiment import WorkingDirectory
import neuralprocesses.torch as nps
from neuralprocesses.numdata import num_data

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

USE_CUDA = torch.cuda.is_available()
print(USE_CUDA)
device = torch.device("cuda") if USE_CUDA else torch.device("cpu")

state = B.create_random_state(torch.float32, seed=0)
B.set_global_device(device)

True


In [8]:
config = {
        "default": {
            "epochs": None,
            "rate": None,
            "also_ar": False,
        },
        "epsilon": 1e-8,
        "epsilon_start": 1e-2,
        "cholesky_retry_factor": 1e6,
        "fix_noise": None,
        "fix_noise_epochs": 3,
        "width": 256,
        "dim_embedding": 256,
        "relational_width": 64,
        "dim_relational_embeddings": 128,
        "enc_same": False,
        "num_heads": 8,
        "num_layers": 6,
        "num_relational_layers": 3,
        "unet_channels": (64,) * 6,
        "unet_strides": (1,) + (2,) * 5,
        "conv_channels": 64,
        "encoder_scales": None,
        "fullconvgnp_kernel_factor": 2,
        "mean_diff": 0,
        # Performance of the ConvGNP is sensitive to this parameter. Moreover, it
        # doesn't make sense to set it to a value higher of the last hidden layer of
        # the CNN architecture. We therefore set it to 64.
        "num_basis_functions": 64,
        "dim_x": 1
    }

args = {"dim_x": 1,
        "dim_y": 1,
        "data": 'eq',
        "batch_size": 16,
        "epochs": 100,
        "rate": 3e-4,
        "objective": "loglik",
        "num_samples": 20,
        "unnormalised": False,
        "evaluate_num_samples": 512,
        "evaluate_batch_size": 8,
        "train_fast": False,
        "evaluate_fast": True,
      
        
       }
class mydict(dict):
    def __getattribute__(self, key):
        if key in self:
            return self[key]
        else:
            return super().__getattribute__(key)    
        
args = mydict(args)

We define below some global variables for the notebook.

In [9]:
gen_train, gen_cv, gens_eval = exp.data[args.data]["setup"](
        args,
        config,
        num_tasks_train=2**6 if args.train_fast else 2**14,
        num_tasks_cv=2**6 if args.train_fast else 2**12,
        num_tasks_eval=2**6 if args.evaluate_fast else 2**12,
        device=device,
    )

# NP package train loop

In [10]:
objective = partial(
            nps.loglik,
            num_samples=args.num_samples,
            normalise=not args.unnormalised,
        )
objective_cv = partial(
            nps.loglik,
            num_samples=args.num_samples,
            normalise=not args.unnormalised,
        )
objectives_eval = [
            (
                "Loglik",
                partial(
                    nps.loglik,
                    num_samples=args.evaluate_num_samples,
                    batch_size=args.evaluate_batch_size,
                    normalise=not args.unnormalised,
                ),
            )
        ]

In [11]:
def train(state, model, opt, objective, gen, *, fix_noise):
    """Train for an epoch."""
    vals = []
    for batch in gen.epoch():
        state, obj = objective(
            state,
            model,
            batch["contexts"],
            batch["xt"],
            batch["yt"],
            fix_noise=fix_noise,
        )
        vals.append(B.to_numpy(obj))
        # Be sure to negate the output of `objective`.
        val = -B.mean(obj)
        opt.zero_grad(set_to_none=True)
        val.backward()
        opt.step()

    vals = B.concat(*vals)
    out.kv("Loglik (T)", exp.with_err(vals, and_lower=True))
    return state, B.mean(vals) - 1.96 * B.std(vals) / B.sqrt(len(vals))

def eval(state, model, objective, gen):
    """Perform evaluation."""
    with torch.no_grad():
        vals, kls, kls_diag = [], [], []
        for batch in gen.epoch():
            state, obj = objective(
                state,
                model,
                batch["contexts"],
                batch["xt"],
                batch["yt"],
            )

            # Save numbers.
            n = nps.num_data(batch["xt"], batch["yt"])
            vals.append(B.to_numpy(obj))
            if "pred_logpdf" in batch:
                kls.append(B.to_numpy(batch["pred_logpdf"] / n - obj))
            if "pred_logpdf_diag" in batch:
                kls_diag.append(B.to_numpy(batch["pred_logpdf_diag"] / n - obj))

        # Report numbers.
        vals = B.concat(*vals)
        out.kv("Loglik (V)", exp.with_err(vals, and_lower=True))
        if kls:
            out.kv("KL (full)", exp.with_err(B.concat(*kls), and_upper=True))
        if kls_diag:
            out.kv("KL (diag)", exp.with_err(B.concat(*kls_diag), and_upper=True))
        
        # objective doesn't return pred_y, we can't plot the data

        return state, B.mean(vals) - 1.96 * B.std(vals) / B.sqrt(len(vals))

In [12]:
B.epsilon = config['epsilon']

model = nps.construct_rnp(
                dim_x=config["dim_x"],
                dim_yc=(1,) * config["dim_y"],
                dim_yt=config["dim_y"],
                dim_embedding=config["dim_embedding"],
                enc_same=config["enc_same"],
                num_dec_layers=config["num_layers"],
                width=config["width"],
                relational_width=config['relational_width'],
                num_relational_enc_layers=config['num_relational_layers'],
                likelihood="het",
                transform=config["transform"],
            )
model = model.to(device)

In [13]:
best_eval_lik = -np.inf

# Setup training loop.
opt = torch.optim.Adam(model.parameters(), args.rate)

# Set regularisation high for the first epochs.
original_epsilon = B.epsilon
B.epsilon = config["epsilon_start"]

for i in range(0, args.epochs):
    
    with out.Section(f"Epoch {i + 1}"):
        # Set regularisation to normal after the first epoch.
        if i > 0:
            B.epsilon = original_epsilon

        # Perform an epoch.
        if config["fix_noise"] and i < config["fix_noise_epochs"]:
            fix_noise = 1e-4
        else:
            fix_noise = None
        state, _ = train(
            state,
            model,
            opt,
            objective,
            gen_train,
            fix_noise=fix_noise,
        )

        # The epoch is done. Now evaluate.
        state, val = eval(state, model, objective_cv, gen_cv())
        torch.cuda.empty_cache()
        gc.collect()

Epoch 1:
    Loglik (T):   -1.22006 +-    0.00481 (  -1.22488)
    Loglik (V):   -0.99503 +-    0.00976 (  -1.00479)
    KL (full):     0.68121 +-    0.00833 (   0.68954)
    KL (diag):     0.32053 +-    0.00776 (   0.32829)
Epoch 2:
    Loglik (T):   -0.90953 +-    0.00489 (  -0.91442)
    Loglik (V):   -0.82819 +-    0.01027 (  -0.83846)
    KL (full):     0.59246 +-    0.00811 (   0.60057)
    KL (diag):     0.20899 +-    0.00603 (   0.21502)
Epoch 3:
    Loglik (T):   -0.80329 +-    0.00502 (  -0.80831)
    Loglik (V):   -0.75287 +-    0.01044 (  -0.76331)
    KL (full):     0.51714 +-    0.00784 (   0.52498)
    KL (diag):     0.13367 +-    0.00437 (   0.13804)
Epoch 4:
    Loglik (T):   -0.78931 +-    0.00525 (  -0.79456)
    Loglik (V):   -0.74758 +-    0.01012 (  -0.75771)
    KL (full):     0.51185 +-    0.00751 (   0.51936)
    KL (diag):     0.12839 +-    0.00412 (   0.13251)
Epoch 5:
    Loglik (T):   -0.75829 +-    0.00516 (  -0.76346)
    Loglik (V):   -0.75133 +-    0.01

    Loglik (T):   -0.72756 +-    0.00536 (  -0.73292)
    Loglik (V):   -0.70797 +-    0.01045 (  -0.71841)
    KL (full):     0.47223 +-    0.00769 (   0.47992)
    KL (diag):     0.08877 +-    0.00333 (   0.09210)
Epoch 39:
    Loglik (T):   -0.71920 +-    0.00531 (  -0.72451)
    Loglik (V):   -0.69649 +-    0.01040 (  -0.70689)
    KL (full):     0.46075 +-    0.00757 (   0.46833)
    KL (diag):     0.07729 +-    0.00300 (   0.08029)
Epoch 40:
    Loglik (T):   -0.72155 +-    0.00535 (  -0.72689)
    Loglik (V):   -0.69557 +-    0.01050 (  -0.70607)
    KL (full):     0.45984 +-    0.00769 (   0.46752)
    KL (diag):     0.07637 +-    0.00304 (   0.07941)
Epoch 41:
    Loglik (T):   -0.71047 +-    0.00521 (  -0.71569)
    Loglik (V):   -0.69522 +-    0.01047 (  -0.70570)
    KL (full):     0.45949 +-    0.00765 (   0.46714)
    KL (diag):     0.07603 +-    0.00292 (   0.07895)
Epoch 42:
    Loglik (T):   -0.69956 +-    0.00527 (  -0.70483)
    Loglik (V):   -0.69263 +-    0.01059 (

    Loglik (T):   -0.69614 +-    0.00541 (  -0.70155)
    Loglik (V):   -0.68638 +-    0.01058 (  -0.69696)
    KL (full):     0.45064 +-    0.00772 (   0.45836)
    KL (diag):     0.06718 +-    0.00271 (   0.06989)
Epoch 76:
    Loglik (T):   -0.71482 +-    0.00541 (  -0.72023)
    Loglik (V):   -0.68723 +-    0.01059 (  -0.69782)
    KL (full):     0.45150 +-    0.00772 (   0.45922)
    KL (diag):     0.06803 +-    0.00282 (   0.07085)
Epoch 77:
    Loglik (T):   -0.68920 +-    0.00530 (  -0.69449)
    Loglik (V):   -0.68225 +-    0.01050 (  -0.69276)
    KL (full):     0.44652 +-    0.00763 (   0.45415)
    KL (diag):     0.06306 +-    0.00264 (   0.06570)
Epoch 78:
    Loglik (T):   -0.69099 +-    0.00536 (  -0.69635)
    Loglik (V):   -0.68031 +-    0.01053 (  -0.69084)
    KL (full):     0.44458 +-    0.00765 (   0.45223)
    KL (diag):     0.06111 +-    0.00256 (   0.06367)
Epoch 79:
    Loglik (T):   -0.70446 +-    0.00533 (  -0.70980)
    Loglik (V):   -0.68155 +-    0.01058 (

# NP package training loop + own class

### Encoder

In [8]:
class CNPDeterministicEncoder(nn.Module):
    def __init__(self, sizes):
        super(CNPDeterministicEncoder, self).__init__()
        self.linears = nn.ModuleList()
        for i in range(len(sizes) - 1):
            self.linears.append(nn.Linear(sizes[i], sizes[i + 1]))

    def forward(self, context_x, context_y):
        """
        Encode training set as one vector representation

        Args:
            context_x:  batch_size x set_size x feature_dim
            context_y:  batch_size x set_size x 1

        Returns:
            representation:
        """
        encoder_input = torch.cat((context_x, context_y), dim=-1)

        batch_size, set_size, filter_size = encoder_input.shape
        x = encoder_input.view(batch_size * set_size, -1)
        for i, linear in enumerate(self.linears[:-1]):
            x = torch.relu(linear(x))
        x = self.linears[-1](x)
        x = x.view(batch_size, set_size, -1)
        representation = x.mean(dim=1)
        # Add number of context points to the representation? (does it help?)
        if False:
            representation = torch.cat((representation, set_size*torch.ones(batch_size,1,device=device)),dim=-1)
        return representation

### Decoder

In [9]:
class CNPDeterministicDecoder(nn.Module):
    def __init__(self, sizes):
        super(CNPDeterministicDecoder, self).__init__()
        self.linears = nn.ModuleList()
        for i in range(len(sizes) - 1):
            self.linears.append(nn.Linear(sizes[i], sizes[i + 1]))

    def forward(self, representation, target_x):
        """
        Take representation representation of current training set, and a target input x,
        return the predictive distribution at x (Gaussian with mean mu and scale sigma)

        Args:
            representation: batch_size x representation_size
            target_x: batch_size x set_size x d
        """
        batch_size, set_size, d = target_x.shape
        representation = representation.unsqueeze(1).repeat([1, set_size, 1])
        input = torch.cat((representation, target_x), dim=-1)
        x = input.view(batch_size * set_size, -1)
        for i, linear in enumerate(self.linears[:-1]):
            x = torch.relu(linear(x))
        x = self.linears[-1](x)
        out = x.view(batch_size, set_size, -1)
        mu, log_sigma = torch.split(out, 1, dim=-1)
        sigma = 0.01 + 0.99 * torch.nn.functional.softplus(log_sigma)
        dist = torch.distributions.normal.Normal(loc=mu, scale=sigma)
        return dist, mu, sigma

### Relational Encoder

In [10]:
class RelationalEncoder(nn.Module):
    def __init__(self, sizes):
        super(RelationalEncoder, self).__init__()
        self.linears = nn.ModuleList()
        for i in range(len(sizes) - 1):
            self.linears.append(nn.Linear(sizes[i], sizes[i + 1]))

    def forward(self, context_x, context_y, target_x):
        """
        Encode target point as relational representation with the context set.

        Args:
            context_x:  batch_size x set_size x feature_dim
            context_y:  batch_size x set_size x 1
            target_x:   batch_size x target_set_size x feature_dim

        Returns:
            encoded_target_x: batch_size x target_set_size x relational_dim
        """

        out_dim = 1
        batch_size, set_size, feature_dim = context_x.shape
        _, target_set_size, _ = target_x.shape
        
        # Compute difference between target and context set 
        # (we also concatenate y_i to the context, and 0 for the target)
        context_xp = torch.cat((context_x, context_y), dim=-1).unsqueeze(1)

        target_xp = torch.cat((target_x, torch.zeros(batch_size,target_set_size,1,device=device)), dim=-1).unsqueeze(2)
        diff_x = (target_xp - context_xp).reshape(batch_size,-1,feature_dim + out_dim)

        batch_size, diff_size, filter_size = diff_x.shape
        x = diff_x.view(batch_size * diff_size, -1)

        for i, linear in enumerate(self.linears[:-1]):
            x = torch.relu(linear(x))
        x = self.linears[-1](x)
        x = x.view(batch_size, diff_size, -1)

        encoded_feature_dim = x.shape[-1]
        
        x = torch.reshape(x,(batch_size, target_set_size, set_size, encoded_feature_dim))
        encoded_target_x = x.mean(dim=2)
        
        return encoded_target_x

### RCNP Model

In [11]:
class RCNPDeterministicModel(nn.Module):
    def __init__(self, relational_sizes, encoder_sizes, decoder_sizes):
        super(RCNPDeterministicModel, self).__init__()
        self._relational_encoder = RelationalEncoder(relational_sizes)
        self._encoder = CNPDeterministicEncoder(encoder_sizes)
        self._decoder = CNPDeterministicDecoder(decoder_sizes)

    def forward(self, contexts, target_x, target_y=None):
        (context_x, context_y) = contexts[0]
        context_x = B.transpose(context_x)
        context_y = B.transpose(context_y)
        target_x = B.transpose(target_x)
        target_y = B.transpose(target_y)
        encoded_context_x = self._relational_encoder(context_x,context_y,context_x)
        
        representation = self._encoder(encoded_context_x, context_y)        
        encoded_target_x = self._relational_encoder(context_x,context_y,target_x)        
        dist, mu, sigma = self._decoder(representation, encoded_target_x)

        log_p = None if target_y is None else dist.log_prob(target_y)
        return log_p, mu, sigma

In [12]:
def train_rnp(state, model, opt, objective, gen, *, fix_noise):
    vals = []
    for batch in gen.epoch():
        log_prob, _, _ = model(batch['contexts'], batch['xt'], batch['yt'])
        log_prob = torch.sum(log_prob, dim=1)
        log_prob = B.logsumexp(log_prob.reshape(1, -1), axis=0) - B.log(1)
        obj = log_prob / B.cast(torch.float64, num_data(batch['xt'], batch['yt']))
        
        vals.append(B.to_numpy(obj))
        val = -B.mean(obj)
        opt.zero_grad(set_to_none=True)
        val.backward()
        opt.step()
        
    vals = B.concat(*vals)
    out.kv("Loglik (T)", exp.with_err(vals, and_lower=True))
    return state, B.mean(vals) - 1.96 * B.std(vals) / B.sqrt(len(vals))


def eval_rnp(state, model, objective, gen):
    """Perform evaluation."""
    with torch.no_grad():
        vals, kls, kls_diag = [], [], []
        for batch in gen.epoch():
            log_prob, pred_y, sigma = model(batch['contexts'], batch['xt'], batch['yt'])
            log_prob = torch.sum(log_prob, dim=1)
            log_prob = B.logsumexp(log_prob.reshape(1, -1), axis=0) - B.log(1)
            obj = log_prob / B.cast(torch.float64, num_data(batch['xt'], batch['yt']))

            # Save numbers.
            n = nps.num_data(batch["xt"], batch["yt"])
            vals.append(B.to_numpy(obj))
            if "pred_logpdf" in batch:
                kls.append(B.to_numpy(batch["pred_logpdf"] / n - obj))
            if "pred_logpdf_diag" in batch:
                kls_diag.append(B.to_numpy(batch["pred_logpdf_diag"] / n - obj))

        # Report numbers.
        vals = B.concat(*vals)
        out.kv("Loglik (V)", exp.with_err(vals, and_lower=True))
        if kls:
            out.kv("KL (full)", exp.with_err(B.concat(*kls), and_upper=True))
        if kls_diag:
            out.kv("KL (diag)", exp.with_err(B.concat(*kls_diag), and_upper=True))
        

        return state, B.mean(vals) - 1.96 * B.std(vals) / B.sqrt(len(vals))

In [13]:
torch.manual_seed(0)

# Sizes of the layers of the MLPs for the encoder and decoder
# The final output layer of the decoder outputs two values, one for the mean and
# one for the variance of the prediction at the target location
d_x, d_in, representation_size, relational_size, d_out = 1, 2, 128, 64, 2
relational_sizes = [d_in, 128, 128, relational_size]
encoder_sizes = [relational_size + 1, 128, 128, 128, representation_size]
decoder_sizes = [representation_size + relational_size, 128, 128, 2]

original_model = RCNPDeterministicModel(relational_sizes, encoder_sizes, decoder_sizes)

In [14]:
best_eval_lik = -np.inf

# Setup training loop.
opt = torch.optim.Adam(original_model.parameters(), args.rate)

# Set regularisation high for the first epochs.
original_epsilon = B.epsilon
B.epsilon = config["epsilon_start"]

for i in range(0, args.epochs):
    with out.Section(f"Epoch {i + 1}"):
        # Set regularisation to normal after the first epoch.
        if i > 0:
            B.epsilon = original_epsilon

        # Perform an epoch.
        if config["fix_noise"] and i < config["fix_noise_epochs"]:
            fix_noise = 1e-4
        else:
            fix_noise = None
        state, _ = train_rnp(
            state,
            original_model,
            opt,
            objective,
            gen_train,
            fix_noise=fix_noise,
        )

        # The epoch is done. Now evaluate.
        state, val = eval_rnp(state, original_model, objective_cv, gen_cv())

Epoch 1:
    Loglik (T):   -1.19819 +-    0.00480 (  -1.20300)
    Loglik (V):   -1.03212 +-    0.00920 (  -1.04133)
    KL (full):     0.71577 +-    0.00786 (   0.72363)
    KL (diag):     0.34565 +-    0.00736 (   0.35301)
Epoch 2:
    Loglik (T):   -0.95546 +-    0.00486 (  -0.96032)
    Loglik (V):   -0.86583 +-    0.00967 (  -0.87550)
    KL (full):     0.62746 +-    0.00767 (   0.63513)
    KL (diag):     0.23437 +-    0.00639 (   0.24075)
Epoch 3:
    Loglik (T):   -0.83715 +-    0.00497 (  -0.84212)
    Loglik (V):   -0.82150 +-    0.01017 (  -0.83166)
    KL (full):     0.58313 +-    0.00790 (   0.59103)
    KL (diag):     0.19004 +-    0.00527 (   0.19531)
Epoch 4:
    Loglik (T):   -0.80517 +-    0.00496 (  -0.81013)
    Loglik (V):   -0.78249 +-    0.00997 (  -0.79246)
    KL (full):     0.54412 +-    0.00748 (   0.55161)
    KL (diag):     0.15103 +-    0.00476 (   0.15579)
Epoch 5:
    Loglik (T):   -0.78505 +-    0.00509 (  -0.79014)
    Loglik (V):   -0.77427 +-    0.00