In [None]:
# read command line arguments
import os
import argparse
parser = argparse.ArgumentParser(description='Experiment')
parser.add_argument('--ns', type=int, default=4, help='number of virtual points')

args = parser.parse_args(os.environ['NB_ARGS'].split())
# args = parser.parse_args("")
ns = args.ns
print("ns = ", ns)

In [None]:
import cpuinfo

In [None]:
from jax import config
import jax.numpy as jnp
import jax.random as jr
import optax as ox
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.stats import qmc
from numpy.random import Generator, PCG64

import numpy as np
import torch
import scipy
import copy
import timeit
import os

from constrained_gp import *

config.update("jax_enable_x64", True)

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

key = jr.key(123)

cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

print_title = False
save_fig = True
save_samples = True
file_name = "2d_1_"+str(ns)
# create dir file_name 
if not os.path.exists(file_name):
    os.makedirs(file_name)
info_strings = []
fig_size = (1.5, 1.5)

In [None]:
# %% create and plot data
n = 16# number of measurements
d = 2
noise = 1e-3

key, subkey = jr.split(key)

xlim = (-5, 5)
ylim = (-5, 5)

# x = jr.uniform(key=key, minval=-5, maxval=5, shape=(2*n,)).reshape(-1, 2)
rng = Generator(PCG64(1)) #5
lh_sampler = qmc.LatinHypercube(d=d, rng=rng)
t = xlim[0] + lh_sampler.random(n=n)*(xlim[1]-xlim[0])
# t1 = lh_sampler.random(n=int(0.5*n))*2-5
# t2 = lh_sampler.random(n=int(0.5*n))*2+3
# t = jnp.vstack([t1, t2])
t = jnp.array(t)

# t = jnp.array([[-4.8], [-4.0], [0.0], [2.0]])

t_dim = len(t)

# ground truth function
def f(t):
    # return 4/(1+jnp.exp(-t/2+4))
    # return (t[:,1])*(t[:,1]) #jnp.exp(0.01*(t[:,0]))*(t[:,1])*(t[:,1])
    return jnp.sin(t[:,1])

f0t = f(t)
f0t = f0t + jr.normal(subkey, shape=f0t.shape) * noise
f0t = f0t.reshape(-1, 1)  # Reshape y to be (N, 1) instead of (N,)

xx = jnp.linspace(xlim[0], xlim[1], 50)
yy = jnp.linspace(ylim[0], ylim[1], 50)
xv, yv = jnp.meshgrid(xx, yy)
xv = xv.reshape(-1,1)
yv = yv.reshape(-1,1)
u = jnp.hstack([xv, yv])

f0u_exact = f(u)
vmin = f0u_exact.min()
vmax = f0u_exact.max()

fig, ax = plot_2d_function_with_points(xx, yy, f0u_exact.reshape(50, 50), pts=t, figsize=fig_size)
if save_fig:
    plt.savefig(file_name + "/groundtruth.pdf", bbox_inches="tight")

## Build unconstrained GP model

In [None]:
# data
D = gpx.Dataset(X=t, y=f0t)
# prior
kernel = gpx.kernels.RBF(lengthscale=1.0, variance=1.0)#1.3
meanf = gpx.mean_functions.Constant(0.0)
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)

# posterior
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=noise) #1e-3
posterior = prior * likelihood

opt_posterior, history = gpx.fit(
    model=posterior,
    objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
    optim=ox.sgd(0.01),
    train_data=D,
    num_iters=20000
)

print("optimized lengthscale:", opt_posterior.prior.kernel.lengthscale.value)
print("optimized variance:", opt_posterior.prior.kernel.variance.value)
print("optimized mean:", opt_posterior.prior.mean_function.constant.value)
print("optimized noise std:", opt_posterior.likelihood.obs_stddev.value)

# shorter names for optimized parameters
l = opt_posterior.prior.kernel.lengthscale.value
sigma = jnp.sqrt(opt_posterior.prior.kernel.variance.value)
sigma_square = opt_posterior.prior.kernel.variance.value
fun_noise_var = opt_posterior.likelihood.obs_stddev.value**2
fun_noise_matrix = jnp.eye(len(t)) * fun_noise_var
fun_noise_matrix = np.asarray(fun_noise_matrix)

In [None]:
unconstrained_gp = opt_posterior.predict(u, train_data=D)

fig, ax = plot_2d_function_with_points(xx, yy, unconstrained_gp.mean.reshape(50, 50), pts=t, vmin=vmin, vmax=vmax, figsize=fig_size)
if save_fig:
    plt.savefig(file_name + "/unconstrained_mean.pdf", bbox_inches="tight")

std_colorbar_min = 4*unconstrained_gp.stddev().min()
std_colorbar_max = 4*unconstrained_gp.stddev().max()

fig, ax = plot_2d_function_with_points(xx, yy, 4*unconstrained_gp.stddev().reshape(50, 50), pts=t, vmin=std_colorbar_min, vmax=std_colorbar_max, figsize=fig_size)
info_string = "%.2e %.2e" % (mean_squared_error(f0u_exact, unconstrained_gp.mean), 4*unconstrained_gp.stddev().mean())
info_strings.append(info_string)
print(info_string)
if print_title:
    ax.set_title(info_string)
if save_fig:
    plt.savefig(file_name + "/unconstrained_ci.pdf", bbox_inches="tight")

## Virtual points

In [None]:
s_dim = ns

# # LHC
# s_dim = 4
# rng = Generator(PCG64(2025))
# lh_sampler = qmc.LatinHypercube(d=d, rng=rng)
# s = xlim[0] + lh_sampler.random(n=s_dim)*(xlim[1]-xlim[0])
# # sort s
# # s = np.sort(s, axis=0)
# s = jnp.array(s)
# # s = jnp.array([-2.5, 2.5]).reshape(-1, 1)

# Sobol
sobol_sampler = qmc.Sobol(d=2, scramble=True, seed=0)
s = sobol_sampler.random_base2(m=int(np.log2(ns)))
s[:,0] = xlim[0] + s[:,0]*(xlim[1]-xlim[0])
s[:,1] = ylim[0] + s[:,1]*(ylim[1]-ylim[0])
s = jnp.array(s)
# s = jnp.array([-2.5, 2.5]).reshape(-1, 1)

grad_noise_var = 5e-3
grad_noise_matrix = jnp.eye(len(s)) * grad_noise_var
grad_noise_matrix = np.asarray(grad_noise_matrix)

In [None]:
# fig, ax = plot_2d_function_with_points(pts=t, vpts=s, xlim=xlim, ylim=ylim, figsize=fig_size)
fig, ax = plot_2d_function_with_points(xx, yy, f0u_exact.reshape(50, 50),pts=t, vpts=s, xlim=xlim, ylim=ylim, figsize=fig_size)
# ax.set_title("Ground truth")
if save_fig:
    plt.savefig(file_name + "/virtual_points.pdf", bbox_inches="tight")

## Prepare matrices

In [None]:
kernel_xy = RBF(lengthscale=l, variance=sigma_square)
kernel_dx1y = RBFDX1Y(lengthscale=l, variance=sigma_square)
kernel_xdy1 = RBFXDY1(lengthscale=l, variance=sigma_square)
kernel_dx1dy1 = RBFDX1DY1(lengthscale=l, variance=sigma_square)

K00tt = kernel_xy.gram(t).to_dense()
K10st = kernel_dx1y.cross_covariance(s, t)
K01ts = kernel_xdy1.cross_covariance(t, s)
K11ss = kernel_dx1dy1.gram(s).to_dense()
f0t_prior_mean = opt_posterior.prior.mean_function(t)
f1s_prior_mean = jnp.zeros(s_dim)
f0u_prior_mean = opt_posterior.prior.mean_function(u).flatten()

K00tt = np.asarray(K00tt)
K10st = np.asarray(K10st)
K01ts = np.asarray(K01ts)
K11ss = np.asarray(K11ss)

K00tu = np.asarray(kernel_xy.cross_covariance(t, u))
K10su = np.asarray(kernel_dx1y.cross_covariance(s, u))
K00uu = np.asarray(kernel_xy.gram(u).to_dense())

In [None]:
no_of_warmups = 1000
no_of_samples = 50000

## Truncated Gibbs

In [None]:
np.random.seed(0)
start_time = timeit.default_timer()
samples = gibbs_truncated_sample(K01ts, K11ss, K10st, K00tt, fun_noise_matrix, grad_noise_matrix, f0t_prior_mean, f1s_prior_mean, f0t, no_of_warmups, no_of_samples)
end_time = timeit.default_timer()
time_elapsed = end_time - start_time

samples.plot_ci()
if save_samples:
    np.savez(file_name + "/truncated_gibbs_samples.npz", samples=samples.samples, ess=samples.compute_ess(), time=time_elapsed)

In [None]:
f0u_samples = draw_samples_with_derivative_enhanced_gp(K00tt, K01ts, K10st, K11ss, K00tu, K10su, K00uu, fun_noise_var, grad_noise_var, f0t_prior_mean, f1s_prior_mean, f0u_prior_mean, f0t, samples.samples)

f0u_samples.geometry = cuqi.geometry.Image2D((50, 50))

fig, ax = plot_2d_function_with_points(xx, yy, f0u_samples.mean().reshape(50, 50), pts=t, vpts=s, vmin=vmin, vmax=vmax, figsize=fig_size)
if save_fig:
    plt.savefig(file_name + "/truncated_gibbs_mean.pdf", bbox_inches="tight")

fig, ax = plot_2d_function_with_points(xx, yy, (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).reshape(50, 50), pts=t, vpts=s, vmin=std_colorbar_min, vmax=std_colorbar_max, figsize=fig_size)
info_string = "%.2e %.2e %.2e %.2e" % (mean_squared_error(f0u_exact, f0u_samples.mean()), (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).mean(), no_of_samples/samples.compute_ess().min(), samples.compute_ess().min()/time_elapsed)
info_strings.append(info_string)
print(info_string)
if print_title:
    ax.set_title(info_string)
if save_fig:
    plt.savefig(file_name + "/truncated_gibbs_ci.pdf", bbox_inches="tight")

## Truncated NUTS

In [None]:
np.random.seed(0)
torch.manual_seed(0)
start_time = timeit.default_timer()
original_samples, samples = nuts_truncated_sample_torch(K01ts, K11ss, K10st, K00tt, fun_noise_matrix, grad_noise_matrix, f0t_prior_mean, f0t, no_of_warmups, no_of_samples)
end_time = timeit.default_timer()
time_elapsed = end_time - start_time

samples.plot_ci()

if save_samples:
    np.savez(file_name + "/truncated_nuts_samples.npz", samples=original_samples.samples, ess=original_samples.compute_ess(), time=time_elapsed)

In [None]:
f0u_samples = draw_samples_with_derivative_enhanced_gp(K00tt, K01ts, K10st, K11ss, K00tu, K10su, K00uu, fun_noise_var, grad_noise_var, f0t_prior_mean, f1s_prior_mean, f0u_prior_mean, f0t, samples.samples)

f0u_samples.geometry = cuqi.geometry.Image2D((50, 50))

fig, ax = plot_2d_function_with_points(xx, yy, f0u_samples.mean().reshape(50, 50), pts=t, vpts=s, vmin=vmin, vmax=vmax, figsize=fig_size)
if save_fig:
    plt.savefig(file_name + "/truncated_nuts_mean.pdf", bbox_inches="tight")

fig, ax = plot_2d_function_with_points(xx, yy, (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).reshape(50, 50), pts=t, vpts=s, vmin=std_colorbar_min, vmax=std_colorbar_max, figsize=fig_size)
info_string = "%.2e %.2e %.2e %.2e" % (mean_squared_error(f0u_exact, f0u_samples.mean()), (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).mean(), no_of_samples/original_samples.compute_ess().min(), samples.compute_ess().min()/time_elapsed)
info_strings.append(info_string)
print(info_string)
if print_title:
    ax.set_title(info_string)
if save_fig:
    plt.savefig(file_name + "/truncated_nuts_ci.pdf", bbox_inches="tight")

## Nonlinear Gibbs

In [None]:
np.random.seed(0)
start_time = timeit.default_timer()
original_samples, samples = gibbs_nonlinear_sample(K01ts, K11ss, K10st, K00tt, fun_noise_matrix, grad_noise_matrix, f0t_prior_mean, f1s_prior_mean, f0t, no_of_warmups, no_of_samples)
end_time = timeit.default_timer()
time_elapsed = end_time - start_time

samples.plot_ci()

if save_samples:
    np.savez(file_name + "/nonlinear_gibbs_samples.npz", samples=original_samples.samples, ess=original_samples.compute_ess(), time=time_elapsed)

In [None]:
f0u_samples = draw_samples_with_derivative_enhanced_gp(K00tt, K01ts, K10st, K11ss, K00tu, K10su, K00uu, fun_noise_var, grad_noise_var, f0t_prior_mean, f1s_prior_mean, f0u_prior_mean, f0t, samples.samples)

f0u_samples.geometry = cuqi.geometry.Image2D((50, 50))

fig, ax = plot_2d_function_with_points(xx, yy, f0u_samples.mean().reshape(50, 50), pts=t, vpts=s, vmin=vmin, vmax=vmax, figsize=fig_size)
if save_fig:
    plt.savefig(file_name + "/nonlinear_gibbs_mean.pdf", bbox_inches="tight")

fig, ax = plot_2d_function_with_points(xx, yy, (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).reshape(50, 50), pts=t, vpts=s, vmin=std_colorbar_min, vmax=std_colorbar_max, figsize=fig_size)
info_string = "%.2e %.2e %.2e %.2e" % (mean_squared_error(f0u_exact, f0u_samples.mean()), (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).mean(), no_of_samples/original_samples.compute_ess().min(), samples.compute_ess().min()/time_elapsed)
info_strings.append(info_string)
print(info_string)
if print_title:
    ax.set_title(info_string)
if save_fig:
    plt.savefig(file_name + "/nonlinear_gibbs_ci.pdf", bbox_inches="tight")

## Nonlinear NUTS

In [None]:
np.random.seed(0)
torch.manual_seed(0)
start_time = timeit.default_timer()
original_samples, samples = nuts_nonlinear_sample_torch(K01ts, K11ss, K10st, K00tt, fun_noise_matrix, grad_noise_matrix, f0t_prior_mean, f0t, no_of_warmups, no_of_samples)
end_time = timeit.default_timer()
time_elapsed = end_time - start_time

samples.plot_ci()

if save_samples:
    np.savez(file_name + "/nonlinear_nuts_samples.npz", samples=original_samples.samples, ess=original_samples.compute_ess(), time=time_elapsed)

In [None]:
f0u_samples = draw_samples_with_derivative_enhanced_gp(K00tt, K01ts, K10st, K11ss, K00tu, K10su, K00uu, fun_noise_var, grad_noise_var, f0t_prior_mean, f1s_prior_mean, f0u_prior_mean, f0t, samples.samples)

f0u_samples.geometry = cuqi.geometry.Image2D((50, 50))

fig, ax = plot_2d_function_with_points(xx, yy, f0u_samples.mean().reshape(50, 50), pts=t, vpts=s, vmin=vmin, vmax=vmax, figsize=fig_size)
if save_fig:
    plt.savefig(file_name + "/nonlinear_nuts_mean.pdf", bbox_inches="tight")

fig, ax = plot_2d_function_with_points(xx, yy, (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).reshape(50, 50), pts=t, vpts=s, vmin=std_colorbar_min, vmax=std_colorbar_max, figsize=fig_size)
info_string = "%.2e %.2e %.2e %.2e" % (mean_squared_error(f0u_exact, f0u_samples.mean()), (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).mean(), no_of_samples/original_samples.compute_ess().min(), samples.compute_ess().min()/time_elapsed)
info_strings.append(info_string)
print(info_string)
if print_title:
    ax.set_title(info_string)
if save_fig:
    plt.savefig(file_name + "/nonlinear_nuts_ci.pdf", bbox_inches="tight")

## RTO

In [None]:
np.random.seed(0)
start_time = timeit.default_timer()
samples = rto_sample(K01ts, K11ss, K10st, K00tt, fun_noise_matrix, grad_noise_matrix, f0t_prior_mean, f0t, no_of_warmups, no_of_samples)
end_time = timeit.default_timer()
time_elapsed = end_time - start_time

samples.plot_ci()

if save_samples:
    np.savez(file_name + "/rto_samples.npz", samples=samples.samples, ess=samples.compute_ess(), time=time_elapsed)

In [None]:
f0u_samples = draw_samples_with_derivative_enhanced_gp(K00tt, K01ts, K10st, K11ss, K00tu, K10su, K00uu, fun_noise_var, grad_noise_var, f0t_prior_mean, f1s_prior_mean, f0u_prior_mean, f0t, samples.samples)

f0u_samples.geometry = cuqi.geometry.Image2D((50, 50))

fig, ax = plot_2d_function_with_points(xx, yy, f0u_samples.mean().reshape(50, 50), pts=t, vpts=s, vmin=vmin, vmax=vmax, figsize=fig_size)
if save_fig:
    plt.savefig(file_name + "/rto_mean.pdf", bbox_inches="tight")

fig, ax = plot_2d_function_with_points(xx, yy, (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).reshape(50, 50), pts=t, vpts=s, vmin=std_colorbar_min, vmax=std_colorbar_max, figsize=fig_size)
info_string = "%.2e %.2e %.2e %.2e" % (mean_squared_error(f0u_exact, f0u_samples.mean()), (f0u_samples.compute_ci()[1]-f0u_samples.compute_ci()[0]).mean(), no_of_samples/samples.compute_ess().min(), samples.compute_ess().min()/time_elapsed)
info_strings.append(info_string)
print(info_string)
if print_title:
    ax.set_title(info_string)
if save_fig:
    plt.savefig(file_name + "/rto_ci.pdf", bbox_inches="tight")

## Summary

In [None]:
for i in range(len(info_strings)):
    print(info_strings[i])

In [None]:
# save info_strings to file
with open(file_name + "/info.txt", "w") as f:
    f.write(cpuinfo.get_cpu_info()["brand_raw"] + "\n")
    f.write("# MSE CI IAC ESS/S\n")
    for i in range(len(info_strings)):
        f.write(info_strings[i] + "\n")