In [1]:
import numpy as np
import numpy.random as npr
import torch
import torch.nn as nn
import torch.optim as optim
import collections
import matplotlib.pyplot as plt
import gc
import datetime
import os
import sys
import time
import warnings
from functools import partial
import experiment as exp
import lab as B
import wbml.out as out
from matrix.util import ToDenseWarning
from wbml.experiment import WorkingDirectory
import neuralprocesses.torch as nps
from neuralprocesses.numdata import num_data

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

In [2]:
class mydict(dict):
    def __getattribute__(self, key):
        if key in self:
            return self[key]
        else:
            return super().__getattribute__(key) 
        
        
def get_dir(args, suffix, observe, seed, model_name, dim_x, dim_y):
    wd = WorkingDirectory(
        *args.root,
        *(args.subdir or ()),
        args.data,
        *((f"x{dim_x}_y{dim_y}",) if hasattr(args, "dim_x") else ()),
        model_name,
        *((args.arch,) if hasattr(args, "arch") else ()),
        args.objective,
        str(seed),
        log=f"log{suffix}.txt",
        diff=f"diff{suffix}.txt",
        observe=observe,
    )
    return wd

        
def construct_model(name):
    if name == "cnp":
        model = nps.construct_gnp(
            dim_x=config["dim_x"],
            dim_yc=(1,) * config["dim_y"],
            dim_yt=config["dim_y"],
            dim_embedding=config["dim_embedding"],
            enc_same=config["enc_same"],
            num_dec_layers=config["num_layers"],
            width=config["width"],
            likelihood="het",
            transform=config["transform"],
        )
    elif name == "rcnp":
        model = nps.construct_rnp(
            dim_x=config["dim_x"],
            dim_yc=(1,) * config["dim_y"],
            dim_yt=config["dim_y"],
            dim_embedding=config["dim_embedding"],
            enc_same=config["enc_same"],
            num_dec_layers=config["num_layers"],
            width=config["width"],
            relational_width=config['relational_width'],
            num_relational_enc_layers=config['num_relational_layers'],
            likelihood="het",
            transform=config["transform"],
        )
    elif name == "rgnp":
        model = nps.construct_rnp(
            dim_x=config["dim_x"],
            dim_yc=(1,) * config["dim_y"],
            dim_yt=config["dim_y"],
            dim_embedding=config["dim_embedding"],
            enc_same=config["enc_same"],
            num_dec_layers=config["num_layers"],
            width=config["width"],
            relational_width=config['relational_width'],
            num_relational_enc_layers=config['num_relational_layers'],
            likelihood="lowrank",
            transform=config["transform"],
        )
    elif name == "gnp":
        model = nps.construct_gnp(
            dim_x=config["dim_x"],
            dim_yc=(1,) * config["dim_y"],
            dim_yt=config["dim_y"],
            dim_embedding=config["dim_embedding"],
            enc_same=config["enc_same"],
            num_dec_layers=config["num_layers"],
            width=config["width"],
            likelihood="lowrank",
            num_basis_functions=config["num_basis_functions"],
            transform=config["transform"],
        )
        
    return model



In [3]:
config = {
        "default": {
            "epochs": None,
            "rate": None,
            "also_ar": False,
        },
        "epsilon": 1e-8,
        "epsilon_start": 1e-2,
        "cholesky_retry_factor": 1e6,
        "fix_noise": None,
        "fix_noise_epochs": 3,
        "width": 256,
        "dim_embedding": 256,
        "relational_width": 64,
        "dim_relational_embeddings": 128,
        "enc_same": False,
        "num_heads": 8,
        "num_layers": 6,
        "num_relational_layers": 3,
        "unet_channels": (64,) * 6,
        "unet_strides": (1,) + (2,) * 5,
        "conv_channels": 64,
        "encoder_scales": None,
        "fullconvgnp_kernel_factor": 2,
        "mean_diff": 0,
        "num_basis_functions": 64,
    }

args = {"dim_x": 1,
        "dim_y": 1,
        "data": 'eq',
        "batch_size": 16,
        "epochs": 100,
        "rate": 3e-4,
        "objective": "loglik",
        "num_samples": 20,
        "unnormalised": False,
        "evaluate_num_samples": 512,
        "evaluate_batch_size": 8,
        "evaluate_num_plots": 5,
        "train_fast": True,
        "evaluate_fast": False,
        "seed": 1,
        "root": ["_experiments"],
        "subdir": None,
       }

In [4]:
args = mydict(args)

B.epsilon = config['epsilon']

# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
B.set_global_device(device)
print(device)



suffix = "_evaluate"
observe = True

cpu


In [5]:
# data generation
def get_gen_eval(args, config):
    _, _, gens_eval = exp.data[args.data]["setup"](
            args,
            config,
            num_tasks_train=2**6 if args.train_fast else 2**14,
            num_tasks_cv=2**6 if args.train_fast else 2**12,
            num_tasks_eval=2**6 if args.evaluate_fast else 2**12,
            device=device,
        )
    return gens_eval


# define objective
objective_eval = partial(nps.loglik,
                         num_samples=args.evaluate_num_samples,
                         batch_size=args.evaluate_batch_size,
                         normalise=not args.unnormalised)

In [121]:
def eval(state, model, objective, gen):
    """Perform evaluation."""
    with torch.no_grad():
        vals, kls, kls_diag = [], [], []
        for batch in gen.epoch():
            state, obj = objective(
                state,
                model,
                batch["contexts"],
                batch["xt"],
                batch["yt"],
            )

            # Save numbers.
            n = nps.num_data(batch["xt"], batch["yt"])
            
            vals.append(B.to_numpy(obj))
            if "pred_logpdf" in batch:
                kls.append(B.to_numpy(batch["pred_logpdf"] / n - obj))
            if "pred_logpdf_diag" in batch:
                kls_diag.append(B.to_numpy(batch["pred_logpdf_diag"] / n - obj))
        print(vals)
        # Report numbers.
        vals = B.concat(*vals)
        kl = B.concat(*kls)
        kl_diag = B.concat(*kls_diag)

        out.kv("Loglik (V)", B.mean(vals))
        if kls:
            out.kv("KL (full)", B.mean(kl))
        if kls_diag:
            out.kv("KL (diag)", B.mean(kl_diag))

        return state, B.mean(vals), B.mean(kl), B.mean(kl_diag)

# Trivial

In [9]:
def eval_trivial(state, gen):
    """Perform evaluation."""
    with torch.no_grad():
        vals, kls, kls_diag = [], [], []
        for batch in gen.epoch():
            if batch['contexts'][0][1].shape[2] <= 5:
                continue
            m = torch.mean(batch['contexts'][0][1], dim=-1).squeeze()
            std = torch.std(batch['contexts'][0][1], dim=-1).squeeze()

            batch_size = len(m)
            
            n = nps.num_data(batch["xt"], batch["yt"])
            
            val = torch.zeros(batch_size)
            kl = torch.zeros(batch_size)
            kl_diag = torch.zeros(batch_size)
            for i in range(batch_size):
                dist = torch.distributions.normal.Normal(m[i], std[i])
                obj = torch.mean(dist.log_prob(batch['yt'][i].reshape(-1)))
                val[i] = obj
                
                kl[i] = batch['pred_logpdf'][i] / n[i] - obj
#                 print(obj)
                kl_diag[i] = batch['pred_logpdf_diag'][i] / n[i] - obj
                
            
            vals.append(B.to_numpy(val))
            if "pred_logpdf" in batch:
                kls.append(B.to_numpy(kl))
            if "pred_logpdf_diag" in batch:
                kls_diag.append(B.to_numpy(kl_diag))

        # Report numbers.
        vals = B.concat(*vals)
        kl = B.concat(*kls)
        kl_diag = B.concat(*kls_diag)

        out.kv("Loglik (V)", B.mean(vals))
        if kls:
            out.kv("KL (full)", B.mean(kl))
        if kls_diag:
            out.kv("KL (diag)", B.mean(kl_diag))

        return state, B.mean(vals), B.mean(kl), B.mean(kl_diag)

In [14]:
num_replicates = 1
args['dim_x'] = 5
gens_eval = get_gen_eval(args, config)

with out.Section(f"Evaluate on Trivial"):
    # 3 datasets, 3 metrics, num_replicates data
    results = np.zeros((3, 3, num_replicates))
    for seed in range(1, 1+num_replicates):
        with out.Section(f"Seed = {seed}"):
            state = B.create_random_state(torch.float32, seed=seed)

            for i, (gen_name, gen) in enumerate(gens_eval()):
                with out.Section(gen_name.capitalize()):
                    state, results[i, 0, seed-1], results[i, 1, seed-1], results[i, 2, seed-1] = eval_trivial(state, gen)

    with out.Section(f"Statistics over {num_replicates} replicates"):
        for i, (gen_name, _) in enumerate(gens_eval()):
            with out.Section(gen_name.capitalize()):
                out.kv("Loglik (V)", exp.with_err(results[i, 0]))
                out.kv("KL (full)", exp.with_err(results[i, 1]))
                out.kv("KL (diag)", exp.with_err(results[i, 2]))

Evaluate on Trivial:
    Seed = 1:
        Interpolation in training range:
            Loglik (V): -1.472
            KL (full):  0.0858
            KL (diag):  0.05125
        Interpolation beyond training range:
            Loglik (V): -1.472
            KL (full):  0.0858
            KL (diag):  0.05125
        Extrapolation beyond training range:
            Loglik (V): -1.476
            KL (full):  0.06804
            KL (diag):  0.03124
    Statistics over 1 replicates:
        Interpolation in training range:
            Loglik (V):   -1.47219 +-    0.00000
            KL (full):     0.08580 +-    0.00000
            KL (diag):     0.05125 +-    0.00000
        Interpolation beyond training range:
            Loglik (V):   -1.47219 +-    0.00000
            KL (full):     0.08580 +-    0.00000
            KL (diag):     0.05125 +-    0.00000
        Extrapolation beyond training range:
            Loglik (V):   -1.47550 +-    0.00000
            KL (full):     0.06804 +-    0.

# dim_x = 2

In [None]:
model_list = ["cnp", "rcnp", "gnp", "rgnp"]  
num_replicates = 5
args['dim_x'] = 2
gens_eval = get_gen_eval(args, config)

for model_name in model_list:
    with out.Section(f"Evaluate on {model_name}"):
        # 3 datasets, 3 metrics, num_replicates data
        results = np.zeros((3, 3, num_replicates))
        for seed in range(1, 1+num_replicates):
            with out.Section(f"Seed = {seed}"):
                state = B.create_random_state(torch.float32, seed=seed)
                
                wd = get_dir(args, suffix, observe, seed, model_name, args.dim_x, args.dim_y)
                
                model = construct_model(model_name).to(device)
                name = "model-best.torch"
                model.load_state_dict(
                    torch.load(wd.file(name), map_location=device)["weights"]
                )

                for i, (gen_name, gen) in enumerate(gens_eval()):
                    with out.Section(gen_name.capitalize()):
                        state, results[i, 0, seed-1], results[i, 1, seed-1], results[i, 2, seed-1] = eval(state, model, objective_eval, gen)

        with out.Section(f"Statistics over {num_replicates} replicates"):
            for i, (gen_name, _) in enumerate(gens_eval()):
                with out.Section(gen_name.capitalize()):
                    out.kv("Loglik (V)", exp.with_err(results[i, 0]))
                    out.kv("KL (full)", exp.with_err(results[i, 1]))
                    out.kv("KL (diag)", exp.with_err(results[i, 2]))

# dim_x = 3

In [22]:
model_list = ["cnp", "rcnp", "gnp", "rgnp"]  
num_replicates = 5
args['dim_x'] = 3
gens_eval = get_gen_eval(args, config)

for model_name in model_list:
    with out.Section(f"Evaluate on {model_name}"):
        # 3 datasets, 3 metrics, num_replicates data
        results = np.zeros((3, 3, num_replicates))
        for seed in range(1, 1+num_replicates):
            with out.Section(f"Seed = {seed}"):
                state = B.create_random_state(torch.float32, seed=seed)
                
                wd = get_dir(args, suffix, observe, seed, model_name, args.dim_x, args.dim_y)
                
                model = construct_model(model_name).to(device)
                name = "model-best.torch"
                model.load_state_dict(
                    torch.load(wd.file(name), map_location=device)["weights"]
                )

                for i, (gen_name, gen) in enumerate(gens_eval()):
                    with out.Section(gen_name.capitalize()):
                        state, results[i, 0, seed-1], results[i, 1, seed-1], results[i, 2, seed-1] = eval(state, model, objective_eval, gen)

        with out.Section(f"Statistics over {num_replicates} replicates"):
            for i, (gen_name, _) in enumerate(gens_eval()):
                with out.Section(gen_name.capitalize()):
                    out.kv("Loglik (V)", exp.with_err(results[i, 0]))
                    out.kv("KL (full)", exp.with_err(results[i, 1]))
                    out.kv("KL (diag)", exp.with_err(results[i, 2]))

Evaluate on cnp:
    Seed = 1:
        Interpolation in training range:
            Loglik (V): -1.387
            KL (full):  0.2612
            KL (diag):  0.07966
        Interpolation beyond training range:
            Loglik (V): -1.496
            KL (full):  0.3706
            KL (diag):  0.1891
        Extrapolation beyond training range:
            Loglik (V): -1.502
            KL (full):  0.2696
            KL (diag):  0.05809
    Seed = 2:
        Interpolation in training range:
            Loglik (V): -1.386
            KL (full):  0.2598
            KL (diag):  0.07831
        Interpolation beyond training range:
            Loglik (V): -1.47
            KL (full):  0.3439
            KL (diag):  0.1625
        Extrapolation beyond training range:
            Loglik (V): -1.473
            KL (full):  0.2407
            KL (diag):  0.02915
    Seed = 3:
        Interpolation in training range:
            Loglik (V): -1.387
            KL (full):  0.2611
            KL 

            Loglik (V): -1.279
            KL (full):  0.1532
            KL (diag):  -0.02828
        Interpolation beyond training range:
            Loglik (V): -1.279
            KL (full):  0.1532
            KL (diag):  -0.02828
        Extrapolation beyond training range:
            Loglik (V): -1.444
            KL (full):  0.2124
            KL (diag):  8.616e-04
    Seed = 2:
        Interpolation in training range:
            Loglik (V): -1.285
            KL (full):  0.1591
            KL (diag):  -0.02244
        Interpolation beyond training range:
            Loglik (V): -1.285
            KL (full):  0.1591
            KL (diag):  -0.02244
        Extrapolation beyond training range:
            Loglik (V): -1.442
            KL (full):  0.2103
            KL (diag):  -1.306e-03
    Seed = 3:
        Interpolation in training range:
            Loglik (V): -1.289
            KL (full):  0.1631
            KL (diag):  -0.01842
        Interpolation beyond training rang

# dim_x = 4

In [36]:
model_list = ["cnp", "rcnp", "gnp", "rgnp"]  
num_replicates = 5
args['dim_x'] = 4
gens_eval = get_gen_eval(args, config)

for model_name in model_list:
    with out.Section(f"Evaluate on {model_name}"):
        # 3 datasets, 3 metrics, num_replicates data
        results = np.zeros((3, 3, num_replicates))
        for seed in range(1, 1+num_replicates):
            with out.Section(f"Seed = {seed}"):
                state = B.create_random_state(torch.float32, seed=seed)
                
                wd = get_dir(args, suffix, observe, seed, model_name, args.dim_x, args.dim_y)
                
                model = construct_model(model_name).to(device)
                name = "model-best.torch"
                model.load_state_dict(
                    torch.load(wd.file(name), map_location=device)["weights"]
                )

                for i, (gen_name, gen) in enumerate(gens_eval()):
                    with out.Section(gen_name.capitalize()):
                        state, results[i, 0, seed-1], results[i, 1, seed-1], results[i, 2, seed-1] = eval(state, model, objective_eval, gen)

        with out.Section(f"Statistics over {num_replicates} replicates"):
            for i, (gen_name, _) in enumerate(gens_eval()):
                with out.Section(gen_name.capitalize()):
                    out.kv("Loglik (V)", exp.with_err(results[i, 0]))
                    out.kv("KL (full)", exp.with_err(results[i, 1]))
                    out.kv("KL (diag)", exp.with_err(results[i, 2]))

Evaluate on cnp:
    Seed = 1:
        Interpolation in training range:
            Loglik (V): -1.443
            KL (full):  0.1295
            KL (diag):  0.05153
        Interpolation beyond training range:
            Loglik (V): -1.445
            KL (full):  0.131
            KL (diag):  0.05304
        Extrapolation beyond training range:
            Loglik (V): -1.445
            KL (full):  0.08619
            KL (diag):  9.860e-04
    Seed = 2:
        Interpolation in training range:
            Loglik (V): -1.443
            KL (full):  0.1295
            KL (diag):  0.05153
        Interpolation beyond training range:
            Loglik (V): -1.446
            KL (full):  0.1318
            KL (diag):  0.05387
        Extrapolation beyond training range:
            Loglik (V): -1.445
            KL (full):  0.08607
            KL (diag):  8.623e-04
    Seed = 3:
        Interpolation in training range:
            Loglik (V): -1.443
            KL (full):  0.1295
       

            Loglik (V): -1.392
            KL (full):  0.07832
            KL (diag):  3.455e-04
        Interpolation beyond training range:
            Loglik (V): -1.392
            KL (full):  0.07832
            KL (diag):  3.455e-04
        Extrapolation beyond training range:
            Loglik (V): -1.443
            KL (full):  0.08361
            KL (diag):  -1.594e-03
    Seed = 2:
        Interpolation in training range:
            Loglik (V): -1.393
            KL (full):  0.07865
            KL (diag):  6.749e-04
        Interpolation beyond training range:
            Loglik (V): -1.393
            KL (full):  0.07865
            KL (diag):  6.749e-04
        Extrapolation beyond training range:
            Loglik (V): -1.442
            KL (full):  0.08283
            KL (diag):  -2.377e-03
    Seed = 3:
        Interpolation in training range:
            Loglik (V): -1.393
            KL (full):  0.07868
            KL (diag):  7.060e-04
        Interpolation beyond 

# dim_x = 5

In [37]:
model_list = ["cnp", "rcnp", "gnp", "rgnp"]  
num_replicates = 5
args['dim_x'] = 5
gens_eval = get_gen_eval(args, config)

for model_name in model_list:
    with out.Section(f"Evaluate on {model_name}"):
        # 3 datasets, 3 metrics, num_replicates data
        results = np.zeros((3, 3, num_replicates))
        for seed in range(1, 1+num_replicates):
            with out.Section(f"Seed = {seed}"):
                state = B.create_random_state(torch.float32, seed=seed)
                
                wd = get_dir(args, suffix, observe, seed, model_name, args.dim_x, args.dim_y)
                
                model = construct_model(model_name).to(device)
                name = "model-best.torch"
                model.load_state_dict(
                    torch.load(wd.file(name), map_location=device)["weights"]
                )

                for i, (gen_name, gen) in enumerate(gens_eval()):
                    with out.Section(gen_name.capitalize()):
                        state, results[i, 0, seed-1], results[i, 1, seed-1], results[i, 2, seed-1] = eval(state, model, objective_eval, gen)

        with out.Section(f"Statistics over {num_replicates} replicates"):
            for i, (gen_name, _) in enumerate(gens_eval()):
                with out.Section(gen_name.capitalize()):
                    out.kv("Loglik (V)", exp.with_err(results[i, 0]))
                    out.kv("KL (full)", exp.with_err(results[i, 1]))
                    out.kv("KL (diag)", exp.with_err(results[i, 2]))

Evaluate on cnp:
    Seed = 1:
        Interpolation in training range:
            Loglik (V): -1.444
            KL (full):  0.05697
            KL (diag):  0.02219
        Interpolation beyond training range:
            Loglik (V): -1.445
            KL (full):  0.05711
            KL (diag):  0.02233
        Extrapolation beyond training range:
            Loglik (V): -1.445
            KL (full):  0.03722
            KL (diag):  3.132e-04
    Seed = 2:
        Interpolation in training range:
            Loglik (V): -1.444
            KL (full):  0.05696
            KL (diag):  0.02218
        Interpolation beyond training range:
            Loglik (V): -1.445
            KL (full):  0.05721
            KL (diag):  0.02243
        Extrapolation beyond training range:
            Loglik (V): -1.445
            KL (full):  0.03713
            KL (diag):  2.242e-04
    Seed = 3:
        Interpolation in training range:
            Loglik (V): -1.444
            KL (full):  0.05697
 

            Loglik (V): -1.423
            KL (full):  0.03523
            KL (diag):  4.432e-04
        Interpolation beyond training range:
            Loglik (V): -1.423
            KL (full):  0.03523
            KL (diag):  4.432e-04
        Extrapolation beyond training range:
            Loglik (V): -1.444
            KL (full):  0.03657
            KL (diag):  -3.356e-04
    Seed = 2:
        Interpolation in training range:
            Loglik (V): -1.423
            KL (full):  0.03527
            KL (diag):  4.918e-04
        Interpolation beyond training range:
            Loglik (V): -1.423
            KL (full):  0.03527
            KL (diag):  4.918e-04
        Extrapolation beyond training range:
            Loglik (V): -1.444
            KL (full):  0.03657
            KL (diag):  -3.425e-04
    Seed = 3:
        Interpolation in training range:
            Loglik (V): -1.423
            KL (full):  0.03522
            KL (diag):  4.388e-04
        Interpolation beyond 