## Creating the training data set

`jax.numpy` has almost the same usage as the standard `numpy` package, with the caveat that `jax.ndarray` is an immutable type, meaning that no inplace changes can be made. For creating training data this should however not be an issue.

In [None]:
import jax.numpy as jnp
from jax import random

We will model a simple sin function:

In [None]:
# true function is a noisy lennard jones potential
def sin(x, noise=0.0, key = random.PRNGKey(0)):
    return jnp.sin(x) + random.normal(key,x.shape, dtype=jnp.float32)*noise

def cos(x, noise=0.0, key = random.PRNGKey(0)):
    return jnp.cos(x) + random.normal(key,x.shape, dtype=jnp.float32)*noise

To define the training data we first need to define boundaries to choose the datapoints from. Then, random points are chosen in this interval. `random.split` creates a new subkey from the previous key to get a new sudo-random sample from the uniform distribution.

In [None]:
# Interval bounds from which to choose the data points
bounds = jnp.array([0.0, 2*jnp.pi])

# How many function and derivative observations should be chosen
num_f_vals = (1,)
num_d_vals = (3,)

# initial seed for the pseudo random key generation
seed = 0

# create new keys and randomly sample the above interval for training features
key, subkey = random.split(random.PRNGKey(seed))
x_func = random.uniform(subkey, num_f_vals, minval=bounds[0], maxval=bounds[1])
key, subkey = random.split(key)
x_der = random.uniform(subkey, num_d_vals, minval=bounds[0], maxval=bounds[1])

# noise with which to sample the training labels
noise = 0.1
key, subkey = random.split(key)
y_func = sin(x_func,noise, subkey)
key, subkey = random.split(key)
y_der = cos(x_der, noise, subkey)

The GPR framework needs as input for training a tuple of arrays `X_split` of which contains a set of points where the function is sampled and a set of points where the gradient is sampled. Both array in `X_split` is of shape `(n_samples_i, N)`. `X_split` should be ordered as follows: the first array represents the datapoints for the function observations and the second array represents the gradient of the function. `Y_train` should just be an array of shape `(n_samples_function + n_samples_gradient,)`.

In [None]:
# reshaping needs to be done the get the arrays in the form (n_samples_i, N)
X_split = [x_func.reshape(-1,1),x_der.reshape(-1,1)]

Y_train = (y_func.reshape(-1,1), y_der.reshape(-1,1)) # jnp.hstack((y_func, y_der))

### Defining the Kernel and its initial parameters

The kernels can be found in `jaxgp.kernels`. Currently implemented are `RBF`, `Linear`, and `Periodic` kernels. When in doubt what kernel to use, go with an `RBF` kernel.

In [None]:
from jaxgp.kernels import RBF

kernel = RBF()
# an RBF kernel has per default 2 parameters
init_kernel_params = jnp.array([2.0, 2.0])

In [None]:
from jaxgp.regression import ExactGPR

model = ExactGPR(kernel, init_kernel_params, noise, "L-BFGS-B")
model.train(X_split, jnp.vstack(Y_train).reshape(-1))

predict_grid = jnp.linspace(*bounds, 200)
means, stds = model.eval(predict_grid.reshape(-1,1))

In [None]:
import matplotlib.pyplot as plt

means = means.reshape(-1)
stds = stds.reshape(-1)

plt.plot(predict_grid, means, label="prediction")
plt.fill_between(predict_grid, means-stds, means+stds, alpha=0.5)

plt.plot(predict_grid, sin(predict_grid), c="gray", ls="--",label="true function")

plt.scatter(x_func, y_func, c="r", label="function eval")
for i,x in enumerate(X_split[1]): 
    if i == 0:
        plt.axvline(x, c="r", lw=0.8, ls="--", label="deriv positions")
    else:
        plt.axvline(x, c="r", lw=0.8, ls="--")

plt.grid()
plt.legend()

In [None]:
from jaxgp.bayesopt import *

rand = 3
bounds = jnp.array([[0.0],[2*jnp.pi]])
eval_func = lambda x: cos(x, noise, key)
explore_param = 5
grid = jnp.linspace(bounds[0], bounds[1], 200)

acqui_fun = UpperConfidenceBound(grid, explore_param)
# acqui_fun = MaximumVariance(bounds, 100)

bayesopt = ExactBayesOpt(X_split, Y_train, kernel, acquisition_func=acqui_fun, eval_func=eval_func)

In [None]:
bayesopt(num_iters=5)

In [None]:
X_split = bayesopt.X_split
Y_train = jnp.vstack(bayesopt.Y_train).reshape(-1)

In [None]:
from jaxgp.regression import ExactGPR

model = ExactGPR(kernel)
model.train(X_split, Y_train)

predict_grid = jnp.linspace(*bounds, 200)
means, stds = model.eval(predict_grid)

In [None]:
import matplotlib.pyplot as plt

means, stds = means.reshape(-1), stds.reshape(-1)

plt.plot(predict_grid, means, label="prediction")
plt.fill_between(predict_grid.reshape(-1), means-stds, means+stds, alpha=0.5)

plt.plot(predict_grid, sin(predict_grid), c="gray", ls="--",label="true function")

plt.scatter(x_func, y_func, c="r", label="function eval")
for i,x in enumerate(X_split[1]): 
    if i == 0:
        plt.axvline(x, c="r", lw=0.8, ls="--", label="deriv positions")
    else:
        plt.axvline(x, c="r", lw=0.8, ls="--")

plt.grid()
plt.legend()