In [1]:
import jax
import jax.numpy as jnp

### Force CPU usage

Making the algorithm work on GPU is left for future work.

In [2]:
%env CUDA_VISIBLE_DEVICES=""

env: CUDA_VISIBLE_DEVICES=""


Jax runs in `float32` by default, which is not enough precision for our purposes. We force it to `float64` instead.

In [3]:
jax.config.update("jax_enable_x64", True)

In [4]:
from variables import VariableBlock, Variable1D, VariablePartition
from function_basis import HatFunctions
from kernels import MultivariateKernel, GaussianKernel
from constraints import GPConstraints, NoConstraints, GaussianProcess
from additive_gp import MaxModeAdditiveGaussianProcess

### Example in 2D

We rely on the illustration p2 of the paper [1].

[1] López-Lopera, A., Bachoc, F. and Roustant, O., 2022.
    *High-dimensional additive Gaussian processes under monotonicity constraints.*
    Advances in Neural Information Processing Systems, 35, pp.8041-8053.

In [5]:
from variables import isotropic_block  # helper function to create a block of variables

block = isotropic_block('xy', (0, 1), 5)  # hypercube in [0, 1]^2 with 5 subdivisions in each direction.
variable_partition = VariablePartition([block])
hat_functions = HatFunctions(max_value=1.0)
gaussian_kernel = GaussianKernel(length_scale=1.)
no_constraints = NoConstraints()
monoblock_gp = MaxModeAdditiveGaussianProcess(variable_partition, hat_functions, gaussian_kernel, no_constraints)

2023-07-21 17:53:30.723504: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [13]:
x_train = jnp.array([[0.5, 0.], [0.5, 0.5], [0.5, 1.], [0., 0.5], [1., 0.5]])

def y_ground_truth(x):
  return 4 * (x[:,0] - 0.5)**2 + 2 * x[:,1]

y_train = y_ground_truth(x_train)
y_train

Array([0., 1., 2., 2., 2.], dtype=float64)

In [7]:
additive_fun = monoblock_gp.fit(x_train, y_train)

In [18]:
structured_pred = additive_fun.predict(x_train)
structured_pred

StructuredPrediction(y_pred=Array([8.90676110e-04, 1.04011452e+00, 1.99582057e+00, 1.97832722e+00,
       1.97832722e+00], dtype=float64), y_pred_per_block=[Array([8.90676110e-04, 1.04011452e+00, 1.99582057e+00, 1.97832722e+00,
       1.97832722e+00], dtype=float64)])

In [19]:
y_pred = structured_pred.y_pred

In [23]:
# mean squared error
abs_err = jnp.abs(y_pred - y_train)
rel_err = jnp.where(y_train == 0, abs_err, abs_err / jnp.abs(y_train))
print(f"Mean absolute error: {jnp.mean(abs_err):.2f}")
print(f"Mean relative error: {jnp.mean(rel_err):.2f}%")

Mean absolute error: 0.02
Mean relative error: 0.01%


In [24]:
import pandas as pd
df = pd.DataFrame({'abs': abs_err, 'rel': rel_err})
df

Unnamed: 0,abs,rel
0,0.000891,0.000891
1,0.040115,0.040115
2,0.004179,0.00209
3,0.021673,0.010836
4,0.021673,0.010836
