In [2]:
import jax
import jax.numpy as jnp

### Force CPU usage

Making the algorithm work on GPU is left for future work.

In [3]:
%env CUDA_VISIBLE_DEVICES=""

env: CUDA_VISIBLE_DEVICES=""


Jax runs in `float32` by default, which is not enough precision for our purposes. We force it to `float64` instead.

In [4]:
jax.config.update("jax_enable_x64", True)

In [5]:
from variables import VariableBlock, Variable1D, VariablePartition
from function_basis import HatFunctions
from kernels import MultivariateKernel, GaussianKernel
from constraints import GPConstraints, NoConstraints, GaussianProcess
from additive_gp import MaxModeAdditiveGaussianProcess

### Example in 2D

We rely on the illustration p2 of the paper [1].

[1] López-Lopera, A., Bachoc, F. and Roustant, O., 2022.
    *High-dimensional additive Gaussian processes under monotonicity constraints.*
    Advances in Neural Information Processing Systems, 35, pp.8041-8053.

### Step by step

1) We start with a discretization of the domain.

In [6]:
from variables import isotropic_block  # helper function to create a block of variables

block = isotropic_block('xy', (0, 1), 5)  # hypercube in [0, 1]^2 with 5 subdivisions in each direction.

2023-08-21 11:23:21.022288: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


2) We build a trivial partition that consists of a single block.

In [7]:
partition = VariablePartition([block])

3) For the basis of functions, we rely on Hat functions thanks to their simplicity and their ability to be combined to form more complex functions.

In [8]:
basis = HatFunctions(max_value=1.0)

4. We use the so-called Gaussian kernel, or RBF (that stands for "Radial Basis Functions") defined as:
$$k(x,y)=\exp{\left(-\frac{\|x-y\|^2_2}{l^2}\right)}$$
Note that this multivariate kernel is indeed a product of univariate kernels:
$$k(x,y)=\prod_{i=1}^d k(x_i,y_i)$$
where $k(x_i,y_i)=\exp{\left(-\frac{(x_i-y_i)^2}{l^2}\right)}$ is the univariate kernel for the $i$-th dimension.

In [9]:
kernel = GaussianKernel(length_scale=1.)


5. We consider a very simple Gaussian process without any constraints on the learned functions.

In [10]:
constraints = NoConstraints()

6. We are now ready to go! The model is fully defined.

In [11]:
monoblock_gp = MaxModeAdditiveGaussianProcess(partition, basis, kernel, constraints)

We illustrate how to train it on a simple example below.

In [12]:
x_train = jnp.array([[0.5, 0.], [0.5, 0.5], [0.5, 1.], [0., 0.5], [1., 0.5]])

def y_ground_truth(x):
  return 4 * (x[:,0] - 0.5)**2 + 2 * x[:,1]

y_train = y_ground_truth(x_train)
y_train

Array([0., 1., 2., 2., 2.], dtype=float64)

We fit the model:

In [13]:
additive_fun = monoblock_gp.fit(x_train, y_train)

We used the additive GP to predict labels on new entries.

In [14]:
structured_pred = additive_fun.predict(x_train)
structured_pred

StructuredPrediction(y_pred=Array([8.90676110e-04, 1.04011452e+00, 1.99582057e+00, 1.97832722e+00,
       1.97832722e+00], dtype=float64), y_pred_per_block=[Array([8.90676110e-04, 1.04011452e+00, 1.99582057e+00, 1.97832722e+00,
       1.97832722e+00], dtype=float64)])

We are interested on the final prediction to comptue the metrics. The `y_pred_per_block` fields is here for debugging purpose.

In [15]:
y_pred = structured_pred.y_pred

Let's compute some metrics!

In [16]:
# mean squared error
abs_err = jnp.abs(y_pred - y_train)
rel_err = jnp.where(y_train == 0, abs_err, abs_err / jnp.abs(y_train))
evs = 1 - jnp.var(y_train - y_pred) / jnp.var(y_train)
print(f"Mean absolute error: {jnp.mean(abs_err):.2f}")
print(f"Mean relative error: {jnp.mean(rel_err):.2f}%")
print(f"Explained variance score: {evs*100:.2f}%")

Mean absolute error: 0.02
Mean relative error: 0.01%
Explained variance score: 99.92%


Element-wise predictions:

In [17]:
import pandas as pd
df = pd.DataFrame({'absolute_error': abs_err, 'relative_error': rel_err}).T
df

Unnamed: 0,0,1,2,3,4
absolute_error,0.000891,0.040115,0.004179,0.021673,0.021673
relative_error,0.000891,0.040115,0.00209,0.010836,0.010836
