## Creating the training data set

`jax.numpy` has almost the same usage as the standard `numpy` package, with the caveat that `jax.ndarray` is an immutable type, meaning that no inplace changes can be made. For creating training data this should however not be an issue.

In [None]:
import jax.numpy as jnp
from jax import random

We will model a forth order polynomial:

In [None]:
def fun(x, noise=0.0, key = random.PRNGKey(0)):
    return (x[:,0]**2 + x[:,1] - 11)**2 / 800.0 + (x[:,0] + x[:,1]**2 -7)**2 / 800.0 + random.normal(key,(len(x),), dtype=jnp.float32)*noise

def grad(x, noise=0.0, key = random.PRNGKey(0)):
    dx1 = 4 * (x[:,0]**2 + x[:,1] - 11) * x[:,0] + 2 * (x[:,0] + x[:,1]**2 -7)
    dx2 = 2 * (x[:,0]**2 + x[:,1] - 11) + 4 * (x[:,0] + x[:,1]**2 -7) * x[:,1]
    return jnp.vstack((dx1, dx2)).T / 800.0 + random.normal(key,x.shape, dtype=jnp.float32)*noise

To define the training data we first need to define boundaries to choose the datapoints from. Then, random points are chosen in this interval. `random.split` creates a new subkey from the previous key to get a new sudo-random sample from the uniform distribution.

In [None]:
# Interval bounds from which to choose the data points
bounds = jnp.array([-5.0, 5.0])

# How many function and derivative observations should be chosen
num_f_vals = 1
num_d_vals = 50

# initial seed for the pseudo random key generation
seed = 0

# create new keys and randomly sample the above interval for training features
key, subkey = random.split(random.PRNGKey(seed))
x_func = random.uniform(subkey, (num_f_vals, 2), minval=bounds[0], maxval=bounds[1])
key, subkey = random.split(key)
x_der = random.uniform(subkey, (num_d_vals,2), minval=bounds[0], maxval=bounds[1])

# noise with which to sample the training labels
noise = 0.1
key, subkey = random.split(key)
y_func = fun(x_func,noise, subkey)
key, subkey = random.split(key)
y_der = grad(x_der, noise, subkey)

The GPR framework needs as input for training a tuple of arrays `X_split` of which contains a set of points where the function is sampled and a set of points where the gradient is sampled. Both array in `X_split` is of shape `(n_samples_i, N)`. `X_split` should be ordered as follows: the first array represents the datapoints for the function observations and the second array represents the gradient of the function. `Y_train` should just be an array of shape `(n_samples_function + n_samples_gradient,)`.

In [None]:
X_split = [x_func,x_der]

Y_train = (y_func.reshape(-1,1), y_der.reshape(-1,1)) # jnp.hstack((y_func, y_der.reshape(-1)))

## Before BayesOpt

In [None]:
from jaxgp.kernels import RBF

kernel = RBF()
# an RBF kernel has per default 2 parameters
init_kernel_params = jnp.array([2.0, 2.0])

In [None]:
from jaxgp.regression import ExactGPR
from jaxgp.utils import Logger

logger = Logger()
model = ExactGPR(kernel, init_kernel_params, noise, logger=logger)

In [None]:
model.train(X_split, jnp.vstack(Y_train).reshape(-1))

In [None]:
predict_grid = jnp.array(jnp.meshgrid(jnp.linspace(*bounds, 101), jnp.linspace(*bounds, 101))).T.reshape(101**2, 2)

# model.eval returns a mean prediction and a confidence interval around the mean prediction
means, stds = model.eval(predict_grid)

In [None]:
import matplotlib.pyplot as plt

means = means.reshape(101,101)
stds = stds.reshape(101,101)
true = fun(predict_grid).reshape(101,101)

fig, ax = plt.subplots(2, 2, figsize=(10,7))

mesh = jnp.meshgrid(jnp.linspace(*bounds, 101),jnp.linspace(*bounds, 101))

im1 = ax[0,0].contourf(*mesh, means, levels=12, vmin=-0.15, vmax=1.2)
im2 = ax[0,1].contourf(*mesh, true, levels=12, vmin=-0.15, vmax=1.2)
im3 = ax[1,0].contourf(*mesh, stds)
im4 = ax[1,1].contourf(*mesh, jnp.abs(means-true))
# im4 = ax[1,1].contourf(*mesh, jnp.greater(true, means-stds)*jnp.less(true, means+stds))

ax[1,1].scatter(X_split[1][:,0], X_split[1][:,1], c="pink", marker="x", label="der pos")
ax[1,0].scatter(X_split[1][:,0], X_split[1][:,1], c="pink", marker="x", label="der pos")

ax[1,1].scatter(x_der[:,0], x_der[:,1], c="r", marker="x", label="der pos")
ax[1,1].scatter(x_func[:,0], x_func[:,1], c="orange", marker="+", label="fun pos")

ax[1,0].scatter(x_der[:,0], x_der[:,1], c="r", marker="x", label="der pos")
ax[1,0].scatter(x_func[:,0], x_func[:,1], c="orange", marker="+", label="fun pos")

plt.colorbar(im1, ax=ax[0,0])
plt.colorbar(im2, ax=ax[0,1])
plt.colorbar(im3, ax=ax[1,0])
plt.colorbar(im4, ax=ax[1,1])

ax[0,0].set_title("prediction")
ax[0,1].set_title("true function")
ax[1,0].set_title("std")
ax[1,1].set_title("abs dif")

In [None]:
mae_before = jnp.mean(jnp.abs(means-true))
mse_before = jnp.mean(jnp.abs(means-true)**2)

print(f"{mae_before=}, {mse_before=}")

## BayesOpt

In [None]:
from jaxgp.bayesopt import *

rand = 3
bayesopt_bounds = jnp.array([[-5.0,-5.0],[5.0,5.0]])
# eval_func = lambda x: grad(x, noise, key).reshape(-1,1)
eval_func = lambda x: grad(x).reshape(-1,1)
explore_param = 2
grid = jnp.array(jnp.meshgrid(jnp.linspace(*bounds, 101), jnp.linspace(*bounds, 101))).T.reshape(101**2, 2)

acqui_fun = UpperConfidenceBound(grid, explore_param)
# acqui_fun = MaximumVariance(grid)

bayesopt = ExactBayesOpt(X_split, Y_train, kernel, acquisition_func=acqui_fun, eval_func=eval_func)

In [None]:
bayesopt(50)

In [None]:
X_split, Y_train = bayesopt.X_split, bayesopt.Y_train

## After BayesOpt

In [None]:
from jaxgp.regression import ExactGPR
from jaxgp.utils import Logger

logger = Logger()
model = ExactGPR(kernel, init_kernel_params, noise, logger=logger)

In [None]:
model.train(X_split, jnp.vstack(Y_train).reshape(-1))

In [None]:
predict_grid = jnp.array(jnp.meshgrid(jnp.linspace(*bounds, 101), jnp.linspace(*bounds, 101))).T.reshape(101**2, 2)

# model.eval returns a mean prediction and a confidence interval around the mean prediction
means, stds = model.eval(predict_grid)

In [None]:
import matplotlib.pyplot as plt

means = means.reshape(101,101)
stds = stds.reshape(101,101)
true = fun(predict_grid).reshape(101,101)

fig, ax = plt.subplots(2, 2, figsize=(10,7))

mesh = jnp.meshgrid(jnp.linspace(*bounds, 101),jnp.linspace(*bounds, 101))

im1 = ax[0,0].contourf(*mesh, means, levels=12, vmin=-0.15, vmax=1.2)
im2 = ax[0,1].contourf(*mesh, true, levels=12, vmin=-0.15, vmax=1.2)
im3 = ax[1,0].contourf(*mesh, stds)
im4 = ax[1,1].contourf(*mesh, jnp.abs(means-true))
# im4 = ax[1,1].contourf(*mesh, jnp.greater(true, means-stds)*jnp.less(true, means+stds))

ax[1,1].scatter(X_split[1][:,0], X_split[1][:,1], c="pink", marker="x", label="der pos")
ax[1,0].scatter(X_split[1][:,0], X_split[1][:,1], c="pink", marker="x", label="der pos")

ax[1,1].scatter(x_der[:,0], x_der[:,1], c="r", marker="x", label="der pos")
ax[1,1].scatter(x_func[:,0], x_func[:,1], c="orange", marker="+", label="fun pos")

ax[1,0].scatter(x_der[:,0], x_der[:,1], c="r", marker="x", label="der pos")
ax[1,0].scatter(x_func[:,0], x_func[:,1], c="orange", marker="+", label="fun pos")

plt.colorbar(im1, ax=ax[0,0])
plt.colorbar(im2, ax=ax[0,1])
plt.colorbar(im3, ax=ax[1,0])
plt.colorbar(im4, ax=ax[1,1])

ax[0,0].set_title("prediction")
ax[0,1].set_title("true function")
ax[1,0].set_title("std")
ax[1,1].set_title("abs dif")

In [None]:
mae_after = jnp.mean(jnp.abs(means-true))
mse_after = jnp.mean(jnp.abs(means-true)**2)

print(f"{mae_after=}, {mse_after=}")