In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

import jax
import jax.numpy as jnp
from jax import random

from gpx.kernels import Linear, Matern52, Sum, Prod
from gpx.parameters import Parameter
from gpx.priors import NormalPrior
from gpx.bijectors import Softplus

import matplotlib as mpl
import matplotlib.pyplot as plt

plt.style.use("dark_background")

%load_ext autoreload
%autoreload 2

  from .nlopt_optimize import NLoptWrapper


## Operations among kernels 

It is possible to perform sums and products (element-wise) between kernels, to obtain a new kernel. As an example, we propose this operation:

$$
k(\{\chi,\bar{\chi}\},\{\chi',\bar{\chi}'\}) = k_1 (\chi,\chi') + k_1 (\chi,\chi') \times k_2 (\bar{\chi},\bar{\chi}')
$$

Since we pass features as a single array, we may need to select which features must be included in each kernel. We show how to use the attribute `active_dims` to do that.

In [2]:
n_samples = 100
n_features = 10 # 4 + 6

X = random.normal(key=random.PRNGKey(2023), shape=(n_samples,n_features))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
X.shape

(100, 10)

In [4]:
K1 = Linear(active_dims=[0, 1, 2, 3])
K2 = Matern52(active_dims=[4, 5, 6, 7, 8, 9])

K = K1 + (K1 * K2)

In [5]:
parameters = {
    'kernel1': {},
    'kernel2': {
        'kernel1': {},
        'kernel2': {'lengthscale': Parameter(
            1.0,
            trainable=True,
            bijector=Softplus(),
            prior=NormalPrior(),
        )}
    }
}

In [6]:
K.k(X, X, parameters)

Array([[ 8.2005132 , -0.19720892,  1.51426235, ...,  0.94999283,
         0.22494487, -2.93347995],
       [-0.19720892,  2.6757051 ,  0.72331013, ..., -0.06176633,
         0.91016731, -1.62965559],
       [ 1.51426235,  0.72331013,  7.61130429, ...,  0.78260093,
         1.31966987, -2.95024783],
       ...,
       [ 0.94999283, -0.06176633,  0.78260093, ...,  5.53811681,
         0.8657378 ,  1.73220853],
       [ 0.22494487,  0.91016731,  1.31966987, ...,  0.8657378 ,
         1.7362445 , -0.73706293],
       [-2.93347995, -1.62965559, -2.95024783, ...,  1.73220853,
        -0.73706293, 14.31534414]], dtype=float64)

In [7]:
K.k(X, X, parameters).shape

(100, 100)

`Sum` and `Prod` are proper kernels. We can treat them as normal kernel classes

In [8]:
K1 = Linear(active_dims=[0, 1, 2, 3])
K2 = Linear(active_dims=[0, 1, 2, 3]) * Matern52(active_dims=[4, 5, 6, 7, 8, 9])

K = K1 + K2

In [9]:
K.k(X, X, parameters)

Array([[ 8.2005132 , -0.19720892,  1.51426235, ...,  0.94999283,
         0.22494487, -2.93347995],
       [-0.19720892,  2.6757051 ,  0.72331013, ..., -0.06176633,
         0.91016731, -1.62965559],
       [ 1.51426235,  0.72331013,  7.61130429, ...,  0.78260093,
         1.31966987, -2.95024783],
       ...,
       [ 0.94999283, -0.06176633,  0.78260093, ...,  5.53811681,
         0.8657378 ,  1.73220853],
       [ 0.22494487,  0.91016731,  1.31966987, ...,  0.8657378 ,
         1.7362445 , -0.73706293],
       [-2.93347995, -1.62965559, -2.95024783, ...,  1.73220853,
        -0.73706293, 14.31534414]], dtype=float64)

We can also pass active dims to the Sum or Prod kernel, which are then used to filter both components

In [10]:
K1 = Linear(active_dims=[4, 5, 6, 7, 8, 9])
K2 = Prod(kernel1=Linear(), kernel2=Matern52(), active_dims=[0, 1, 2, 3])

K = K1 + K2

In [11]:
K.k(X, X, parameters)

Array([[16.11183179, -5.59373677,  0.85675616, ..., -4.65580342,
        -4.81001637,  4.26845572],
       [-5.59373677,  6.21675112,  0.8275437 , ...,  4.72047191,
         4.0622014 , -1.06559384],
       [ 0.85675616,  0.8275437 ,  6.72746838, ...,  1.81414743,
         0.15439232, -2.02963621],
       ...,
       [-4.65580342,  4.72047191,  1.81414743, ..., 11.21352441,
         2.09075506,  0.40021425],
       [-4.81001637,  4.0622014 ,  0.15439232, ...,  2.09075506,
         6.45097892,  0.09222974],
       [ 4.26845572, -1.06559384, -2.02963621, ...,  0.40021425,
         0.09222974, 14.75973293]], dtype=float64)

which is equivalent to

In [12]:
K1 = Linear(active_dims=[4, 5, 6, 7, 8, 9])
K21 = Linear(active_dims=[0, 1, 2, 3])
K22 = Matern52(active_dims=[0, 1, 2, 3])

K = K1 + (K21 * K22)

In [13]:
K.k(X, X, parameters)

Array([[16.11183179, -5.59373677,  0.85675616, ..., -4.65580342,
        -4.81001637,  4.26845572],
       [-5.59373677,  6.21675112,  0.8275437 , ...,  4.72047191,
         4.0622014 , -1.06559384],
       [ 0.85675616,  0.8275437 ,  6.72746838, ...,  1.81414743,
         0.15439232, -2.02963621],
       ...,
       [-4.65580342,  4.72047191,  1.81414743, ..., 11.21352441,
         2.09075506,  0.40021425],
       [-4.81001637,  4.0622014 ,  0.15439232, ...,  2.09075506,
         6.45097892,  0.09222974],
       [ 4.26845572, -1.06559384, -2.02963621, ...,  0.40021425,
         0.09222974, 14.75973293]], dtype=float64)