In [4]:
import os
import sys
import time
import warnings
from functools import partial

import experiment as exp
import lab as B
import neuralprocesses.torch as nps
import numpy as np
import torch
import wbml.out as out
from matrix.util import ToDenseWarning
from wbml.experiment import WorkingDirectory

device = "cpu"

state = B.create_random_state(torch.float32, seed=0)

In [12]:
config = {
        "default": {
            "epochs": None,
            "rate": None,
            "also_ar": False,
        },
        "epsilon": 1e-8,
        "epsilon_start": 1e-2,
        "cholesky_retry_factor": 1e6,
        "fix_noise": None,
        "fix_noise_epochs": 3,
        "width": 256,
        "dim_embedding": 256,
        "relational_width": 64,
        "dim_relational_embeddings": 128,
        "enc_same": False,
        "num_heads": 8,
        "num_layers": 6,
        "num_relational_layers": 3,
        "unet_channels": (64,) * 6,
        "unet_strides": (1,) + (2,) * 5,
        "conv_channels": 64,
        "encoder_scales": None,
        "fullconvgnp_kernel_factor": 2,
        "mean_diff": 0,
        # Performance of the ConvGNP is sensitive to this parameter. Moreover, it
        # doesn't make sense to set it to a value higher of the last hidden layer of
        # the CNN architecture. We therefore set it to 64.
        "num_basis_functions": 64,
        "dim_x": 1
    }

args = {"dim_x": 1,
        "dim_y": 1,
        "data": 'eq',
        "batch_size": 16,
        "epochs": 100,
        "rate": 3e-4,
        "objective": "loglik",
        "num_samples": 20,
        "unnormalised": False,
        "evaluate_num_samples": 512,
        "evaluate_batch_size": 8,
        "train_fast": False,
        "evaluate_fast": True,
      
        
       }
class mydict(dict):
    def __getattribute__(self, key):
        if key in self:
            return self[key]
        else:
            return super().__getattribute__(key)    
        
args = mydict(args)

In [22]:
gen_train, gen_cv, gens_eval = exp.data[args.data]["setup"](
        args,
        config,
        num_tasks_train=2**6, #if args.train_fast else 2**14,
        num_tasks_cv=2**6, #if args.train_fast else 2**12,
        num_tasks_eval=2**6, #if args.evaluate_fast else 2**12,
        device=device,
    )

In [15]:
B.epsilon = config['epsilon']

# model = nps.construct_gnp(
#                 dim_x=config["dim_x"],
#                 dim_yc=(1,) * config["dim_y"],
#                 dim_yt=config["dim_y"],
#                 dim_embedding=config["dim_embedding"],
#                 enc_same=config["enc_same"],
#                 num_dec_layers=config["num_layers"],
#                 width=config["width"],
#                 likelihood="het",
#                 transform=config["transform"],
#             )

model = nps.construct_rnp(
                dim_x=config["dim_x"],
                dim_yc=(1,) * config["dim_y"],
                dim_yt=config["dim_y"],
                dim_embedding=config["dim_embedding"],
                enc_same=config["enc_same"],
                num_dec_layers=config["num_layers"],
                width=config["width"],
                relational_width=config['relational_width'],
                num_relational_enc_layers=config['num_relational_layers'],
                likelihood="het",
                transform=config["transform"],
            )

In [16]:
objective = partial(
            nps.loglik,
            num_samples=args.num_samples,
            normalise=not args.unnormalised,
        )
objective_cv = partial(
            nps.loglik,
            num_samples=args.num_samples,
            normalise=not args.unnormalised,
        )
objectives_eval = [
            (
                "Loglik",
                partial(
                    nps.loglik,
                    num_samples=args.evaluate_num_samples,
                    batch_size=args.evaluate_batch_size,
                    normalise=not args.unnormalised,
                ),
            )
        ]

In [17]:
def train(state, model, opt, objective, gen, *, fix_noise):
    """Train for an epoch."""
    vals = []
    for batch in gen.epoch():
        state, obj = objective(
            state,
            model,
            batch["contexts"],
            batch["xt"],
            batch["yt"],
            fix_noise=fix_noise,
        )
        vals.append(B.to_numpy(obj))
        # Be sure to negate the output of `objective`.
        val = -B.mean(obj)
        opt.zero_grad(set_to_none=True)
        val.backward()
        opt.step()

    vals = B.concat(*vals)
    out.kv("Loglik (T)", exp.with_err(vals, and_lower=True))
    return state, B.mean(vals) - 1.96 * B.std(vals) / B.sqrt(len(vals))

In [18]:
def eval(state, model, objective, gen):
    """Perform evaluation."""
    with torch.no_grad():
        vals, kls, kls_diag = [], [], []
        for batch in gen.epoch():
            state, obj = objective(
                state,
                model,
                batch["contexts"],
                batch["xt"],
                batch["yt"],
            )

            # Save numbers.
            n = nps.num_data(batch["xt"], batch["yt"])
            vals.append(B.to_numpy(obj))
            if "pred_logpdf" in batch:
                kls.append(B.to_numpy(batch["pred_logpdf"] / n - obj))
            if "pred_logpdf_diag" in batch:
                kls_diag.append(B.to_numpy(batch["pred_logpdf_diag"] / n - obj))

        # Report numbers.
        vals = B.concat(*vals)
        out.kv("Loglik (V)", exp.with_err(vals, and_lower=True))
        if kls:
            out.kv("KL (full)", exp.with_err(B.concat(*kls), and_upper=True))
        if kls_diag:
            out.kv("KL (diag)", exp.with_err(B.concat(*kls_diag), and_upper=True))

        return state, B.mean(vals) - 1.96 * B.std(vals) / B.sqrt(len(vals))

In [23]:
start = 0

best_eval_lik = -np.inf

# Setup training loop.
opt = torch.optim.Adam(model.parameters(), args.rate)

# Set regularisation high for the first epochs.
original_epsilon = B.epsilon
B.epsilon = config["epsilon_start"]

for i in range(start, args.epochs):
    with out.Section(f"Epoch {i + 1}"):
        # Set regularisation to normal after the first epoch.
        if i > 0:
            B.epsilon = original_epsilon

        # Perform an epoch.
        if config["fix_noise"] and i < config["fix_noise_epochs"]:
            fix_noise = 1e-4
        else:
            fix_noise = None
        state, _ = train(
            state,
            model,
            opt,
            objective,
            gen_train,
            fix_noise=fix_noise,
        )

        # The epoch is done. Now evaluate.
        state, val = eval(state, model, objective_cv, gen_cv())

        # Save current model.
#         torch.save(
#             {
#                 "weights": model.state_dict(),
#                 "objective": val,
#                 "epoch": i + 1,
#             },
#             wd.file(f"model-last.torch"),
#         )

        # Check if the model is the new best. If so, save it.
#         if val > best_eval_lik:
#             out.out("New best model!")
#             best_eval_lik = val
#             torch.save(
#                 {
#                     "weights": model.state_dict(),
#                     "objective": val,
#                     "epoch": i + 1,
#                 },
#                 wd.file(f"model-best.torch"),
#             )

        # Visualise a few predictions by the model.
#         gen = gen_cv()
#         for j in range(5):
#             path = (f"result/train-epoch-{i + 1:03d}-{j + 1}.pdf") 
#             exp.visualise(
#                 model,
#                 gen,
#                 path=path,
#                 config=config,
#             )

Epoch 1:
    Loglik (T):   -1.36545 +-    0.05478 (  -1.42023)
    Loglik (V):   -1.38369 +-    0.04967 (  -1.43336)
    KL (full):     1.05762 +-    0.04882 (   1.10644)
    KL (diag):     0.63488 +-    0.06960 (   0.70448)
Epoch 2:
    Loglik (T):   -1.43593 +-    0.06178 (  -1.49771)
    Loglik (V):   -1.38485 +-    0.04948 (  -1.43433)
    KL (full):     1.05878 +-    0.04881 (   1.10759)
    KL (diag):     0.63604 +-    0.06946 (   0.70550)
Epoch 3:
    Loglik (T):   -1.38601 +-    0.05773 (  -1.44375)
    Loglik (V):   -1.39722 +-    0.05429 (  -1.45152)
    KL (full):     1.07115 +-    0.05304 (   1.12419)
    KL (diag):     0.64841 +-    0.07175 (   0.72017)
Epoch 4:
    Loglik (T):   -1.38060 +-    0.05902 (  -1.43962)
    Loglik (V):   -1.39947 +-    0.05385 (  -1.45333)
    KL (full):     1.07340 +-    0.05272 (   1.12612)
    KL (diag):     0.65066 +-    0.07150 (   0.72217)
Epoch 5:
    Loglik (T):   -1.37597 +-    0.05067 (  -1.42663)
    Loglik (V):   -1.39989 +-    0.05

    Loglik (V):   -1.38331 +-    0.05267 (  -1.43598)
    KL (full):     1.05724 +-    0.05211 (   1.10935)
    KL (diag):     0.63450 +-    0.07062 (   0.70512)
Epoch 39:
    Loglik (T):   -1.34827 +-    0.05595 (  -1.40422)
    Loglik (V):   -1.37987 +-    0.04791 (  -1.42779)
    KL (full):     1.05380 +-    0.04793 (   1.10173)
    KL (diag):     0.63106 +-    0.06889 (   0.69995)
Epoch 40:


[E thread_pool.cpp:109] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:109] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 