# Recovering a 2d function from its gradient with a Gaussian Process Regression model

This notebook shows examples on how to use the present gaussian process regression framework to formally integrate functions from derivative observations.

This example will show how to predict a 2D function using gradient observations with both a full GPR and a sparse GPR. The sparse GPR framework works by projecting the training data into a lower dimensional feature space in order to reduce the computational cost (mainly matrix inversions).

## Creating the training data set

`jax.numpy` has almost the same usage as the standard `numpy` package, with the caveat that `jax.ndarray` is an immutable type, meaning that no inplace changes can be made. For creating training data this should however not be an issue.

In [1]:
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import jax.scipy as jsp
from jax.lax import Precision

We will model a forth order polynomial:

In [2]:
def fun(x, noise=0.0, key = random.PRNGKey(0)):
    return (x[:,0]**2 + x[:,1] - 11)**2 / 800.0 + (x[:,0] + x[:,1]**2 -7)**2 / 800.0 + random.normal(key,(len(x),), dtype=jnp.float32)*noise

def grad(x, noise=0.0, key = random.PRNGKey(0)):
    dx1 = 4 * (x[:,0]**2 + x[:,1] - 11) * x[:,0] + 2 * (x[:,0] + x[:,1]**2 -7)
    dx2 = 2 * (x[:,0]**2 + x[:,1] - 11) + 4 * (x[:,0] + x[:,1]**2 -7) * x[:,1]
    return jnp.vstack((dx1, dx2)).T / 800.0 + random.normal(key,x.shape, dtype=jnp.float32)*noise



To define the training data we first need to define boundaries to choose the datapoints from. Then, random points are chosen in this interval. `random.split` creates a new subkey from the previous key to get a new sudo-random sample from the uniform distribution.

In [3]:
# Interval bounds from which to choose the data points
bounds = jnp.array([-5.0, 5.0])

# How many function and derivative observations should be chosen
num_f_vals = 1
num_d_vals = 100

# initial seed for the pseudo random key generation
seed = 0

# create new keys and randomly sample the above interval for training features
key, subkey = random.split(random.PRNGKey(seed))
x_func = random.uniform(subkey, (num_f_vals, 2), minval=bounds[0], maxval=bounds[1])
key, subkey = random.split(key)
x_der = random.uniform(subkey, (num_d_vals,2), minval=bounds[0], maxval=bounds[1])

# noise with which to sample the training labels
noise = 0.02
key, subkey = random.split(key)
y_func = fun(x_func,noise, subkey)
key, subkey = random.split(key)
y_der = grad(x_der, noise, subkey)

The GPR framework needs as input for training a tuple of arrays `X_split` of which contains a set of points where the function is sampled and a set of points where the gradient is sampled. Both array in `X_split` is of shape `(n_samples_i, N)`. `X_split` should be ordered as follows: the first array represents the datapoints for the function observations and the second array represents the gradient of the function. `Y_train` should just be an array of shape `(n_samples_function + n_samples_gradient,)`.

In [4]:
X_split = [x_func,x_der]

Y_train = jnp.hstack((y_func, y_der.reshape(-1)))

### Defining the Kernel and its initial parameters

The kernels can be found in `jaxgp.kernels`. Currently implemented are `RBF`, `Linear`, and `Periodic` kernels. When in doubt what kernel to use, go with an `RBF` kernel.

In [5]:
from jaxgp.kernels import RBF

kernel = RBF()
# an RBF kernel has per default 2 parameters
init_kernel_params = jnp.array([2.0, 2.0])

In [6]:
from jaxgp.regression import ExactGPR
from jaxgp.utils import Logger

logger = Logger()

## The sparse GPR model

### Training the sparse GPR model

The `sparseGPR` model can be found in `jaxgp.regression`. The idea of a sparse model is to project your training data into a space with smaller dimension in order to save in computational cost. This is done by projecting your full training set onto a set of reference points via the kernel. 

There are typically 2 methods to choose reference points:
 - choosing a subset of size $m<n$ from the existing datapoints
 - creating an even grid inside the bounds on which the model should be evaluated

Note that below the evenly spread reference grid has fewer points than the subset grid. This was done in order to get a nice even grid which was achieved by taking the largerst perfect square smaller than the number of reference points.

Furthermore we also created a larger set of gradient observations since the predictive power of the sparse model is lower than of the full model. However, this is not a problem computation wise as is seen further below.

In [13]:
from jaxgp.regression import SparseGPR

# How many function and derivative observations should be chosen
num_f_vals = 1
num_d_vals = 1000

# initial seed for the pseudo random key generation
seed = 0

# create new keys and randomly sample the above interval for training features
key, subkey = random.split(random.PRNGKey(seed))
x_func = random.uniform(subkey, (num_f_vals, 2), minval=bounds[0], maxval=bounds[1])
key, subkey = random.split(key)
x_der = random.uniform(subkey, (num_d_vals,2), minval=bounds[0], maxval=bounds[1])

# noise with which to sample the training labels
noise = 0.02
key, subkey = random.split(key)
y_func = fun(x_func,noise, subkey)
key, subkey = random.split(key)
y_der = grad(x_der, noise, subkey)

logger = Logger()

num_ref_points = (num_d_vals + num_f_vals) // 20
key, subkey = random.split(key)
X_ref_rand = random.permutation(subkey, jnp.vstack((x_der,x_func)))[:num_ref_points]
# this grid has fewer points. It has N points, where N is the largest perfect square smaller than num_ref_points
X_ref_even = jnp.array(jnp.meshgrid(jnp.linspace(*bounds, round(jnp.sqrt(num_ref_points))),jnp.linspace(*bounds, round(jnp.sqrt(num_ref_points))))).reshape(2,-1).T

model_rand = SparseGPR(kernel, init_kernel_params, noise, X_ref_rand, logger=None)
model_even = SparseGPR(kernel, init_kernel_params, noise, X_ref_even, logger=None)

In [14]:
X_split = [x_func,x_der]

Y_train = jnp.hstack((y_func, y_der.reshape(-1)))

Looking at the time needed to train the sparse models, even with 10 times as many datapoints the computation is still twice as fast, compare to the full model. This is because the computational effort is roughly 100 times smaller: 
 - The full model needs $\mathcal{O}(N^3)$ flops to train and fit the model
 - The sparse model needs $\mathcal{O}(M^2N + M^3)$ flops to train and fit the model

which comes out to roughly a factor 100 times faster.

In [15]:
model_rand.train(X_split, Y_train)
model_even.train(X_split, Y_train)

OptStep(params=DeviceArray([2., 2.], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(nan, dtype=float32, weak_type=True), success=False, status=2, iter_num=0))
OptStep(params=DeviceArray([0.00379112, 0.33245406], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-1057.3768, dtype=float32, weak_type=True), success=True, status=0, iter_num=24))


In [20]:
import jax
from jaxgp.likelyhood import sparse_kernelNegativeLogLikelyhood

param_bounds = jnp.array([1e-3,10.0])
num_params = 20
params = jnp.array(jnp.meshgrid(jnp.linspace(*param_bounds, num_params),jnp.linspace(*param_bounds, num_params))).reshape(2,-1).T

mapping = jax.vmap(lambda x: jax.jit(sparse_kernelNegativeLogLikelyhood)(x, X_split, Y_train, X_ref_even, noise, kernel), in_axes=0)

In [None]:
nlle =