# Tutorial 4 : Attentive Neural Processes Variants - 1D GP Data

Last Update : 12 June 2019

**Aim**: 
- Investigating variants of [Attentive Neural Process] (ANP)
- Showing how to use the library to build more complicated models

**Nota Bene:**
- The majority of the work below does not follow any paper
- Much more details about the framework and dataset can be found in [Tutorial 1 - Conditional Neural Process].

[Attentive Neural Process]: https://arxiv.org/abs/1901.05761
[Tutorial 1 - Conditional Neural Process]: Tutorial%201%20-%20Conditional%20Neural%20Process.ipynb

**Environment Hypermarameters:**

In [1]:
N_THREADS = 8
# Nota Bene : notebooks don't deallocate GPU memory
IS_FORCE_CPU = False # can also be set in the trainer

## Environment

In [2]:
cd ..

/master


In [3]:
%autosave 600
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# CENTER PLOTS
from IPython.core.display import HTML
display(HTML(""" <style> .output_png {display: table-cell; text-align: center; margin:auto; }
.prompt display:none;}  </style>"""))

import os
if IS_FORCE_CPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = ""

import sys
sys.path.append("notebooks")

import numpy as np
import matplotlib.pyplot as plt
import torch
torch.set_num_threads(N_THREADS)

Autosaving every 600 seconds


## Dataset

In this notebook I only look at the periodic and mattern kernel as these seem the harder to learn. See [Tutorial 1 - Conditional Neural Process] for more details.

[Tutorial 1 - Conditional Neural Process]: Tutorial%201%20-%20Conditional%20Neural%20Process.ipynb

In [4]:
from sklearn.gaussian_process.kernels import RBF, Matern, ExpSineSquared, DotProduct, ConstantKernel
from ntbks_viz import plot_posterior_samples, plot_prior_samples, plot_dataset_samples

from ntbks_datasets import GPDataset

In [5]:
X_DIM = 1  # 1D spatial input
Y_DIM = 1  # 1D regression
N_POINTS = 128
N_SAMPLES = 100000 # this is a lot and can work with less

In [6]:
datasets = dict()
kwargs = dict(n_samples=N_SAMPLES, n_points=N_POINTS)
datasets["matern"] = GPDataset(kernel=1.0 * Matern(length_scale=1.0,
                                                   length_scale_bounds=(1e-1, 10.0),
                                                   nu=1.5),
                               **kwargs)
datasets["periodic"] = GPDataset(kernel=1.0 * ExpSineSquared(length_scale=1.0,
                                                             periodicity=3.0,
                                                             length_scale_bounds=(0.1, 10.0),
                                                             periodicity_bounds=(1.0, 10.0)),
                                 **kwargs)

## Model

The general model architecture is slightly different from the paper to make it modular and easy to extend, but it is easy to make them equivalent with the right parameters. Refer to [Tutorial 1 - Conditional Neural Process] for an overview main parameters or the docstrings of `NeuralProcess` for all parameters. 

[Tutorial 1 - Conditional Neural Process]: Tutorial%201%20-%20Conditional%20Neural%20Process.ipynb

In [7]:
from torch.distributions import Normal

from skssl.transformers import (NeuralProcessLoss, NeuralProcess, AttentiveNeuralProcess, 
                                SelfAttentionBlock, SinusoidalEncodings)
from skssl.predefined import MLP, merge_flat_input, get_uninitialized_mlp
from skssl.transformers.neuralproc.datasplit import context_target_split

In [8]:
R_DIM = 128
RANGE_CNTXT = (4, 50)  # context points will be sampled uniformly in this range
RANGE_EXTRA_TRGTS = (3, N_POINTS-RANGE_CNTXT[1])  # extra number of targtes points 

In [9]:
def get_cntxt_trgt(*args):
    return context_target_split(*args, range_cntxts=RANGE_CNTXT, range_extra_trgts=RANGE_EXTRA_TRGTS)

def get_attn_bloc(attention):
    return lambda *args: SelfAttentionBlock(*args, attention=attention)

def init_models():
    return {"Conditional_ANP":AttentiveNeuralProcess(X_DIM, Y_DIM,
                                  r_dim=R_DIM,
                                  get_cntxt_trgt=get_cntxt_trgt,
                                  encoded_path="deterministic"),
     
     "Multihead_ANP":AttentiveNeuralProcess(X_DIM, Y_DIM,
                                  r_dim=R_DIM,
                                  get_cntxt_trgt=get_cntxt_trgt,
                                  attention="multihead", 
                                  encoded_path="both") ,
     
     "Sinusoidal_ANP":AttentiveNeuralProcess(X_DIM, Y_DIM,
                                  XEncoder=SinusoidalEncodings, # sinusoidal encodings
                                  r_dim=R_DIM,
                                  get_cntxt_trgt=get_cntxt_trgt,
                                  encoded_path="both") ,
     
     "SelfAttn_Enc_ANP":AttentiveNeuralProcess(X_DIM, Y_DIM,
                                  XYEncoder=SelfAttentionBlock,  # self attention encoder 
                                  r_dim=R_DIM,
                                  get_cntxt_trgt=get_cntxt_trgt,
                                  encoded_path="both") ,
     
     "SelfAttn_Dec_ANP":AttentiveNeuralProcess(X_DIM, Y_DIM,
                                  Decoder=SelfAttentionBlock,  # self attention decoder 
                                  r_dim=R_DIM,
                                  get_cntxt_trgt=get_cntxt_trgt,
                                  encoded_path="both"),
     
     "r128_Conditional_Transformer_Process":AttentiveNeuralProcess(X_DIM, Y_DIM,
                                  XEncoder=SinusoidalEncodings, # sinusoidal encodings
                                  XYEncoder=get_attn_bloc("transformer"),  # self attention encoder 
                                  Decoder=get_attn_bloc("transformer"),  # self attention decoder 
                                  r_dim=128,
                                  get_cntxt_trgt=get_cntxt_trgt,
                                  attention="transformer", 
                                  encoded_path="deterministic"),
     
     "r64_Conditional_Transformer_Process":AttentiveNeuralProcess(X_DIM, Y_DIM,
                                  XEncoder=SinusoidalEncodings, # sinusoidal encodings
                                  XYEncoder=get_attn_bloc("transformer"),  # self attention encoder 
                                  Decoder=get_attn_bloc("transformer"),  # self attention decoder 
                                  r_dim=64,
                                  get_cntxt_trgt=get_cntxt_trgt,
                                  attention="transformer", 
                                  encoded_path="deterministic")}

# initialize all model for each dataset
data_models = {name: (init_models(), data) 
               for name, data in datasets.items()}



### N Param

Number of parameters (note that I did not play around with this much, this depends a lot on the representation size):

In [10]:
from utils.helpers import count_parameters

In [11]:
for k, (models, dataset) in data_models.items():
    for name, model in models.items():
        print("N Param for {}:".format(name), count_parameters(model))
    break

N Param for Conditional_ANP: 98498
N Param for Multihead_ANP: 246594
N Param for Sinusoidal_ANP: 176642
N Param for SelfAttn_Enc_ANP: 296770
N Param for SelfAttn_Dec_ANP: 306178
N Param for r128_Conditional_Transformer_Process: 417986
N Param for r64_Conditional_Transformer_Process: 106626


## Training

In [12]:
from skorch.callbacks import ProgressBar, Checkpoint

from skssl.training import NeuralNetTransformer
from skssl.training.helpers import make_Xy_input

In [13]:
N_EPOCHS = 50
BATCH_SIZE = 64 
is_RETRAIN = False # if false load precomputed

In [None]:
for k,(models, dataset) in data_models.items():
    for name, neural_proc in models.items():
        
        print()
        print("--- {} {} ---".format("Training" if is_RETRAIN else "Loading", k + " " + name))
        print()

        chckpt = Checkpoint(dirname="results/notebooks/neural_process/variants/{}_{}".format(name, k), 
                            monitor='train_loss_best') # train would be same as validation as always resamples

        model = NeuralNetTransformer(neural_proc, NeuralProcessLoss,
                                     max_epochs=N_EPOCHS,
                                     batch_size=BATCH_SIZE,
                                     train_split=None,  # don't use cross validation dev set
                                     lr=1e-3, 
                                     callbacks=[ProgressBar(), 
                                                chckpt], # checkpoint best model
                                    ) 

        if is_RETRAIN:
            _=model.fit(*make_Xy_input(dataset))

        model.initialize()
        model.load_params(checkpoint=chckpt)


--- Training matern Conditional_ANP ---



HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m110.6675[0m     +  24.0356


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      2       [36m67.8055[0m     +  22.5172


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      3       [36m45.0732[0m     +  24.3130


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      4       [36m28.6640[0m     +  24.1081


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      5       [36m17.7251[0m     +  23.1909


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      6       [36m13.6086[0m     +  23.4674


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      7       [36m11.8871[0m     +  24.1483


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      8        [36m6.2535[0m     +  23.6491


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      9       [36m-1.4347[0m     +  24.3582


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     10        0.5787        23.3680


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     11       [36m-2.7678[0m     +  24.1192


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     12       [36m-5.3852[0m     +  24.2023


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     13       -2.2614        23.8901


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     14       -4.0328        24.0880


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     15       [36m-6.7724[0m     +  24.8889


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     16       [36m-7.7036[0m     +  23.7653


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     17       -5.2691        24.1965


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     18       -5.5044        24.2880


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     19      [36m-10.8397[0m     +  24.1802


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     20       -9.2933        24.2675


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     21       -3.5328        24.6167


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     22       -6.5789        24.4782


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     23       -9.1712        24.0591


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     24      [36m-10.9053[0m     +  24.2409


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     25      [36m-15.0206[0m     +  23.4037


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     26      -10.6347        24.0464


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     27      -12.2225        24.2900


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     28      -11.3536        24.2717


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     29      -10.4341        24.2028


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     30      [36m-15.7696[0m     +  23.9329


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     31      -14.7930        23.2641


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     32       -9.9557        24.1211


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     33      -12.3793        23.9821


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     34      -13.3565        23.7919


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     35      [36m-16.4905[0m     +  24.2321


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     36      -11.8722        24.2916


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     37      -16.0788        24.0755


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     38      -14.5133        24.2074


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     39      -14.3210        24.2630


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     40      [36m-21.0464[0m     +  24.5770


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     41       -8.6901        23.7543


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     42      -11.2082        23.9163


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     43      -20.6863        24.5299


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     44      -13.6770        23.8925


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     45      -15.1384        24.1006


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     46      -18.4053        24.6098


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     47      -13.4181        24.3150


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     48       -7.4181        20.7969


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     49      -21.0448        19.3864


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     50      -14.6531        19.9107
Re-initializing optimizer because the following parameters were re-set: .

--- Training matern Multihead_ANP ---



HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m116.7513[0m     +  34.9310


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      2       [36m57.0852[0m     +  35.4088


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      3       [36m41.4957[0m     +  37.3632


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      4       [36m21.6561[0m     +  34.1133


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      5       [36m11.5849[0m     +  36.0220


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      6        [36m3.5623[0m     +  36.1994


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      7        5.6875        33.9180


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      8        [36m2.4313[0m     +  36.0488


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      9       [36m-1.3942[0m     +  35.4657


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     10       [36m-2.4239[0m     +  34.6456


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     11       [36m-5.0751[0m     +  35.5672


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     12      [36m-16.1530[0m     +  35.3553


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     13      -10.9200        36.4944


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     14      -10.4637        35.1180


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     15      -16.0636        33.3517


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     16       -9.0118        33.9274


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     17      -14.3703        36.3440


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     18      -14.4601        35.2683


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     19      [36m-18.1125[0m     +  36.6442


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     20      [36m-18.6924[0m     +  34.5978


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     21      -16.6574        36.3270


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     22      -11.8826        35.3585


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     23       -3.3659        34.6500


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     24      -16.0386        33.6697


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     25      [36m-24.2249[0m     +  34.0245


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     26      -14.4530        34.5671


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     27      [36m-24.3320[0m     +  25.6568


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     28      -15.9492        24.1487


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     29      -20.5344        34.6021


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     30      -22.5489        28.1505


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     31      -23.0882        32.2769


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     32      [36m-26.1256[0m     +  34.4306


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     33      -18.9721        31.7942


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     34      -23.0192        36.9205


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     35      -25.6099        34.1246


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     36      -23.2904        33.8177


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     37      -22.7953        34.9656


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     38      [36m-27.7472[0m     +  34.5135


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     39      -25.1587        35.0312


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     40      [36m-28.7036[0m     +  33.7812


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     41      -19.3631        34.2507


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     42      -18.4283        33.4258


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     43      [36m-31.5342[0m     +  34.2192


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     44      -25.9154        33.9564


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     45      -24.6675        28.5437


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     46      -27.7185        33.7379


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     47      -28.7674        35.0210


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     48      -20.6380        35.5625


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     49       -8.2576        34.1881


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     50      -18.6083        29.7624
Re-initializing optimizer because the following parameters were re-set: .

--- Training matern Sinusoidal_ANP ---



HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m159.6312[0m     +  20.8209


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      2      [36m141.5471[0m     +  17.7596


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      3      [36m125.7075[0m     +  28.3719


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      4      [36m113.7779[0m     +  28.6217


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      5      [36m103.1448[0m     +  23.0803


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      6       [36m95.6320[0m     +  25.4044


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      7       [36m90.6016[0m     +  27.7850


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      8       [36m86.3776[0m     +  27.3399


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      9       [36m81.9410[0m     +  27.9722


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     10       [36m78.4914[0m     +  28.3914


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     11       [36m74.8854[0m     +  28.4315


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     12       [36m69.5434[0m     +  26.0980


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     13       [36m65.3606[0m     +  27.5306


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     14       [36m61.2513[0m     +  17.8255


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     15       [36m58.0892[0m     +  18.0234


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     16       [36m54.3557[0m     +  24.2821


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     17       [36m52.1971[0m     +  28.2026


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     18       [36m50.2494[0m     +  26.7825


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     19       [36m47.7437[0m     +  24.5181


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     20       [36m45.3963[0m     +  24.1727


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     21       [36m42.7992[0m     +  28.4502


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     22       [36m40.8427[0m     +  28.3170


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     23       [36m39.8969[0m     +  29.7770


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     24       [36m39.0086[0m     +  28.3975


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     25       [36m37.4931[0m     +  28.3108


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     26       37.9079        28.7457


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     27       [36m36.1982[0m     +  28.0726


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     28       36.4646        26.6071


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     29       [36m35.2775[0m     +  28.1904


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     30       [36m34.3941[0m     +  28.6070


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     31       [36m33.8003[0m     +  27.2202


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     32       [36m33.0037[0m     +  30.7669


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     33       33.9328        28.5124


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     34       [36m32.9218[0m     +  28.4133


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     35       [36m31.6325[0m     +  23.3765


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     36       31.9712        18.5412


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     37       31.8027        23.1600


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     38       [36m30.1686[0m     +  24.5353


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     39       30.4343        28.3280


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     40       [36m29.0730[0m     +  26.7610


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     41       30.3546        27.2349


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     42       30.1217        22.3447


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     43       [36m27.6129[0m     +  27.6986


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     44       [36m26.1166[0m     +  27.5410


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     45       [36m26.0092[0m     +  21.5735


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     46       [36m24.1986[0m     +  28.2383


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     47       [36m23.0425[0m     +  28.4666


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     48       24.3691        24.0909


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     49       [36m22.7693[0m     +  27.5560


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     50       22.8283        27.8129
Re-initializing optimizer because the following parameters were re-set: .

--- Training matern SelfAttn_Enc_ANP ---



HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m126.1375[0m     +  54.1309


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      2       [36m83.5206[0m     +  53.4671


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      3       [36m55.8969[0m     +  50.4148


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      4       [36m39.9782[0m     +  53.7710


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      5       [36m26.9713[0m     +  52.5069


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      6       [36m18.3674[0m     +  49.1612


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      7       20.9638        53.1991


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      8       [36m11.3401[0m     +  54.6601


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      9        [36m9.3760[0m     +  54.5127


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     10        [36m6.8282[0m     +  54.5633


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     11        [36m4.8698[0m     +  54.6556


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     12       [36m-6.2489[0m     +  50.8660


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     13       -1.8995        54.3236


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     14       -3.2529        54.8034


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     15       -4.2120        52.4130


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     16       -5.8096        51.7303


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     17       -1.4810        54.8312


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     18       [36m-9.1624[0m     +  54.7175


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     19      [36m-14.2252[0m     +  53.6867


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     20      -11.4981        54.8318


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     21      -11.7792        53.9984


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     22      -11.7841        54.8664


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     23      -11.7468        54.8406


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     24      -14.1804        55.1074


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     25      [36m-23.5963[0m     +  54.6781


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     26      -13.7994        53.2568


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     27      -21.1537        54.9991


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     28      -13.5951        54.5991


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     29      -19.1419        54.8400


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     30      -21.1143        46.7707


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     31      -21.2745        50.1792


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     32      [36m-25.5139[0m     +  55.0877


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     33      -19.3782        54.3282


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     34      -22.3855        54.6731


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     35      -23.8343        54.3201


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     36      -23.1322        52.3340


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     37      -23.4085        51.7946


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     38      [36m-27.4260[0m     +  54.9977


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     39      -25.8065        54.6240


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     40      [36m-30.6683[0m     +  54.4976


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     41      -18.1572        54.7181


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     42      -17.7217        54.5603


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     43      [36m-31.0472[0m     +  54.0961


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     44      -26.9853        53.4661


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     45      -22.5053        48.4335


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     46      -28.6147        51.2805


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     47      [36m-31.2293[0m     +  38.4637


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     48      -22.9943        45.4231


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     49      -27.7681        53.3605


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     50      -25.9808        48.1505
Re-initializing optimizer because the following parameters were re-set: .

--- Training matern SelfAttn_Dec_ANP ---



HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m112.0665[0m     +  31.5927


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      2       [36m75.2561[0m     +  38.9448


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      3       [36m55.6908[0m     +  35.7008


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      4       [36m41.1267[0m     +  31.9363


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      5       [36m28.3971[0m     +  38.9840


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      6       [36m17.3724[0m     +  38.0674


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      7       [36m13.7265[0m     +  36.7587


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      8        [36m7.6493[0m     +  38.5017


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      9        [36m2.0132[0m     +  39.1428


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     10       [36m-1.9543[0m     +  39.0485


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     11       [36m-4.9918[0m     +  39.0723


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     12      [36m-15.1739[0m     +  39.1251


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     13      -13.1794        39.1768


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     14      -12.5300        38.9954


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     15      [36m-19.4043[0m     +  39.0876


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     16      [36m-19.4061[0m     +  38.9931


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     17      [36m-20.0418[0m     +  39.0799


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     18      [36m-20.9780[0m     +  39.1711


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     19      [36m-23.4527[0m     +  39.1208


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     20      [36m-24.5090[0m     +  33.8209


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     21      -23.6944        36.3450


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     22      -24.4880        37.2015


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     23      [36m-24.7494[0m     +  34.2209


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     24      [36m-28.1479[0m     +  39.1205


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     25      [36m-34.6658[0m     +  39.0760


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     26      -25.7088        37.2834


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     27      -33.7701        36.0817


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     28      -25.2961        39.1538


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     29      -30.4895        39.1618


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     30      -33.5350        39.1250


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     31      -34.6239        33.8588


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     32      [36m-35.0710[0m     +  39.2153


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     33      -32.9921        39.1623


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     34      -32.1207        39.0810


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     35      [36m-37.6765[0m     +  39.1995


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     36      -36.1270        36.5237


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     37      -36.0736        31.1075


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     38      [36m-39.9610[0m     +  31.3029


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     39      -37.2539        34.8065


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     40      [36m-40.8630[0m     +  39.1343


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     41      -31.5821        39.1482


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     42      -32.6164        39.1136


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     43      [36m-43.5526[0m     +  39.2253


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     44      -38.0032        39.1675


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     45      -37.1709        39.2560


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     46      -41.2314        39.1270


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     47      -43.0156        39.1056


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     48      -34.3516        39.2681


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     49      -39.7178        36.6531


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     50      -36.5959        35.5118
Re-initializing optimizer because the following parameters were re-set: .

--- Training matern r128_Conditional_Transformer_Process ---



HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m104.1417[0m     +  39.0950


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      2       [36m63.4179[0m     +  39.2263


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      3       [36m44.5059[0m     +  39.3168


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      4       [36m33.2690[0m     +  39.2335


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      5       [36m25.6908[0m     +  39.3182


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      6       [36m19.2725[0m     +  39.3711


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      7       [36m19.2676[0m     +  35.2891


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      8       [36m13.9333[0m     +  34.6192


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

      9       [36m11.5319[0m     +  34.6275


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     10       [36m10.0681[0m     +  34.4844


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     11        [36m8.0120[0m     +  34.5728


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     12        [36m0.6782[0m     +  34.8975


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     13        1.9664        39.4241


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     14        1.6597        39.4994


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     15       [36m-2.0371[0m     +  39.4886


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     16       [36m-4.6561[0m     +  39.7708


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     17       [36m-5.6310[0m     +  39.6090


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     18       [36m-5.8391[0m     +  39.4785


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     19       [36m-6.3179[0m     +  39.5659


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     20       [36m-8.7213[0m     +  39.4738


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     21       -7.8637        39.4427


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     22       -6.3995        39.4223


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     23       [36m-9.3272[0m     +  39.5018


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     24      [36m-11.4131[0m     +  39.4632


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     25      [36m-16.8994[0m     +  39.6767


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     26       -9.3115        39.6558


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     27      -15.7456        39.5606


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     28      -10.5944        39.5375


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     29      -13.7146        39.4896


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     30      [36m-17.5151[0m     +  39.4909


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     31      -17.1822        39.5546


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     32      [36m-17.6151[0m     +  39.5619


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     33      -17.2857        39.5343


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     34      -17.0073        39.5968


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     35      [36m-22.7721[0m     +  39.5349


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     36      -19.8969        37.2178


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     37      -19.4741        39.5832


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     38      [36m-24.2324[0m     +  39.6303


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     39      -21.8106        39.5746


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     40      [36m-25.2398[0m     +  39.6703


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     41      -17.5479        39.0771


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     42      -16.2394        35.6350


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     43      [36m-25.7253[0m     +  34.7910


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     44      -23.6557        34.6575


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     45      -19.7343        34.6523


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     46      [36m-26.6150[0m     +  35.2731


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     47      -26.5755        34.6914


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     48      -20.2154        36.1948


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     49      -24.6033        34.8502


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

     50      -23.1066        39.7006
Re-initializing optimizer because the following parameters were re-set: .

--- Training matern r64_Conditional_Transformer_Process ---



HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

  epoch    train_loss    cp      dur
-------  ------------  ----  -------
      1      [36m120.2381[0m     +  38.2808


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

# Trained Prior

In [None]:
EXTRAP_DISTANCE = 2  # add 2 to the right for extrapolation
INTERPOLATION_RANGE = dataset.min_max
EXTRAPOLATION_RANGE = (dataset.min_max[0], dataset.min_max[1]+EXTRAP_DISTANCE )

In [None]:
for k,(neural_proc, dataset) in data_models.items():
    for name, neural_proc in models.items():
        plot_prior_samples(neural_proc, 
                           title="Trained Prior Samples : {} {}".format(name, k), 
                           test_min_max=EXTRAPOLATION_RANGE, 
                           train_min_max=INTERPOLATION_RANGE)

# Posterior

In [None]:
N_CNTXT = 10
for k,(neural_proc, dataset) in data_models.items():
    for name, neural_proc in models.items():
        plot_posterior_samples(dataset, neural_proc, 
                               n_cntxt=N_CNTXT, 
                               test_min_max=EXTRAPOLATION_RANGE, 
                               n_points=2*N_POINTS,
                               title="Posterior Samples Conditioned on {} Context Points : {} {}".format(N_CNTXT, name, k))

In [None]:
N_CNTXT = 2
for k,(neural_proc, dataset) in data_models.items():
    for name, neural_proc in models.items():
        plot_posterior_samples(dataset, neural_proc, 
                               n_cntxt=N_CNTXT, 
                               test_min_max=EXTRAPOLATION_RANGE, 
                               n_points=2*N_POINTS,
                               title="Posterior Samples Conditioned on {} Context Points : {} {}".format(N_CNTXT, name, k))

In [None]:
N_CNTXT = 20
for k,(neural_proc, dataset) in data_models.items():
    for name, neural_proc in models.items():
        plot_posterior_samples(dataset, neural_proc, 
                               n_cntxt=N_CNTXT, 
                               test_min_max=EXTRAPOLATION_RANGE, 
                               n_points=2*N_POINTS,
                               title="Posterior Samples Conditioned on {} Context Points : {} {}".format(N_CNTXT, name, k))