In [1]:
import sys
sys.path.insert(0, "..")
from gp3.inference import Vanilla
from gp3.utils import data as sim
from gp3.utils.structure import kron_list
from gp3.kernels import RBF
from gp3.utils.transforms import softplus
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objs as go
from plotly import tools
from IPython.display import display
init_notebook_mode(connected=True)
import warnings
warnings.filterwarnings('ignore')
from tqdm import trange
from ipywidgets import IntProgress
import numpy as np

In [2]:
X = sim.sim_X_equispaced(D = 2, N_dim = 30, lower=0, upper=100)
f = sim.sim_f(X, RBF(20., 1., 0.0), mu = 0.)
y = f + np.random.normal(size = len(f))

In [3]:
trace_func = go.Scatter3d(x = X[:,0], y = X[:,1], z=f, mode = 'markers', marker=dict(size = 2,))
trace_draws = go.Scatter3d(x = X[:,0], y = X[:,1], z=y, mode = 'markers', marker=dict(size = 2,))
fig = tools.make_subplots(rows=1, cols=2, specs=[[{'is_3d': True}, {'is_3d': True}]])
fig.append_trace(trace_func, 1, 1)
fig.append_trace(trace_draws, 1, 2)
iplot(fig)

This is the format of your plot grid:
[ (1,1) scene1 ]  [ (1,2) scene2 ]



In [None]:
import scipy
import collections
from tqdm import trange
from gp3.utils.structure import kron_list


for ls in [10., 20., 30., 40., 50., 60., 70.]:
    
    kern = RBF(ls, 1.)
    X_dims = [np.expand_dims(np.unique(X[:, i]), 1) for i in range(X.shape[1])]
    K = kron_list([kern.eval(kern.params, X_d) for X_d in X_dims])
    K_root = np.real(scipy.linalg.sqrtm(K))

    res_ls = collections.defaultdict(list)

    for p in [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.]:
        print(p)
        for i in trange(100):

            f = np.dot(K_root, np.random.normal(size = X.shape[0]))
            y = f + np.random.normal(size = len(f))

            X_part, y_part = sim.rand_partial_grid(X, y, p)
            X_full, y_full, obs_idx, imag_idx = sim.fill_grid(X_part, y_part)

            inf = Vanilla(X, y_part, kernel = RBF(20., 1.), obs_idx = obs_idx, noise = 1.)
            f_pred = inf.predict_mean()
            var_exact = inf.variance_exact()
            var_est = inf.variance_slow(100)[0]

            err = np.abs(f_pred - f)
            err_bound = 1.96*np.sqrt(var_exact)
            var_err = np.linalg.norm(var_exact - var_est)
            check = np.sum(err < err_bound)/X.shape[0]

            res)[p].append((np.sum(err), np.sum(var_exact), check, var_err))

  0%|          | 0/100 [00:00<?, ?it/s]

0.2


100%|██████████| 100/100 [00:18<00:00,  5.50it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

0.3


100%|██████████| 100/100 [00:23<00:00,  4.26it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

0.4


100%|██████████| 100/100 [04:09<00:00,  2.50s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

0.5


100%|██████████| 100/100 [07:03<00:00,  4.23s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

0.6


100%|██████████| 100/100 [07:54<00:00,  4.75s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

0.7


100%|██████████| 100/100 [01:11<00:00,  1.39it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

0.8


 97%|█████████▋| 97/100 [00:53<00:01,  1.82it/s]

In [67]:
for p, v in res.items():
    errs, var_norms, checks, var_err = list(zip(*v))
    print(f"""\n {p}
          f error: {np.mean(errs)}
          variance norm: {np.mean(var_norms)}
          interval check: {np.mean(checks)}
          relative variance error: {np.mean(var_err)/np.mean(var_norms)}""")


 0.2
          f error: 11.551899526833573
          variance norm: 105.08058610765225
          interval check: 0.0
          relative variance error: 0.029106962141967653

 0.3
          f error: 8.716517626359568
          variance norm: 76.28333645736065
          interval check: 0.0
          relative variance error: 0.037130758899664135

 0.4
          f error: 7.817668329251535
          variance norm: 60.930759352030776
          interval check: 0.0
          relative variance error: 0.04764088678185614

 0.5
          f error: 7.135309463110272
          variance norm: 51.049288899370076
          interval check: 0.0
          relative variance error: 0.056500926507082376

 0.6
          f error: 6.623481676334311
          variance norm: 44.13118098204127
          interval check: 0.0
          relative variance error: 0.06487778615140252

 0.7
          f error: 6.12780187314863
          variance norm: 38.986458050081694
          interval check: 0.0
          relative var