In [None]:
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
import jax

from jaxgp.covar import sparse_covariance_matrix
from jaxgp.utils import *
from jaxgp.kernels import RBF

In [None]:
def fun(x, noise=0.0, key = random.PRNGKey(0)):
    return (x[:,0]**2 + x[:,1] - 11)**2 / 800.0 + (x[:,0] + x[:,1]**2 -7)**2 / 800.0 + random.normal(key,(len(x),), dtype=jnp.float32)*noise

def grad(x, noise=0.0, key = random.PRNGKey(0)):
    dx1 = 4 * (x[:,0]**2 + x[:,1] - 11) * x[:,0] + 2 * (x[:,0] + x[:,1]**2 -7)
    dx2 = 2 * (x[:,0]**2 + x[:,1] - 11) + 4 * (x[:,0] + x[:,1]**2 -7) * x[:,1]

    return jnp.vstack((dx1, dx2)).T / 800.0 + random.normal(key,x.shape, dtype=jnp.float32)*noise

In [None]:
def sparse_kernelNegativeLogLikelyhood(kernel_params, X_split, Y_data, X_ref, noise, kernel) -> float:
    '''Negative log Likelyhood for sparse GPR (PPA). Y_data ~ N(0,[id*s**2 + K_MN.T@K_MM**(-1)@K_MN]) which is the same as for Nystrom approximation.
    kernel_params are the first arguments in order to minimize this function w.r.t. those variables.

    Parameters
    ----------
    kernel_params : ndarray
        kernel parameters. Function can be optimized w.r.t to these parameters
    X_split : list[ndarray]
        List of ndarrays: [function_evals(n_samples_f, n_features), dx1_evals(n_samples_dx1, n_features), ..., dxn_featrues_evals(n_samples_dxn_features, n_features)]
    Y_data : ndarray
        ndarray of shape (n_samples,) s.t. n_samples = sum(n_samples_i) in X_split. Corresponding labels to the samples in X_split
    X_ref : ndarray
        ndarray of shape (n_referencepoints, n_features). Reference points onto which the whole input dataset is projected.
    noise : Union[ndarray, float]
        either scalar or ndarray of shape (len(X_split),). If scalar, the same value is added along the diagonal. 
        Else each value is added to the corresponding diagonal block coming from X_split
        ndarray is not supported yet!!!
    kernel : derived class from BaseKernel
        Kernel that describes the covariance between input points.

    Returns
    -------
    float
        Negative Log Likelyhood estimate for PPA
    '''
    covar_module = sparse_covariance_matrix(X_split, Y_data, X_ref, noise, kernel, kernel_params)

    # Logdet calculations
    K_ref_diag = jnp.diag(covar_module.k_ref)
    logdet_K_ref = 2*jnp.sum(jnp.log(K_ref_diag))
    K_inv_diag = jnp.diag(covar_module.k_inv)
    logdet_K_inv = 2*jnp.sum(jnp.log(K_inv_diag))
    logdet_fitc = jnp.sum(jnp.log(covar_module.diag))

    # Fit calculation
    fit_1 = Y_data@(Y_data / covar_module.diag)
    fit_2 = covar_module.proj_labs@jsp.linalg.cho_solve((covar_module.k_inv, False), covar_module.proj_labs)
    fit = fit_1 - fit_2

    nlle = 0.5*(logdet_fitc + logdet_K_inv + logdet_K_ref + fit + len(Y_data)*jnp.log(2*jnp.pi))
    
    return logdet_fitc, logdet_K_ref, logdet_K_inv, fit_1, fit_2, covar_module

In [None]:
# Interval bounds from which to choose the data points
bounds = jnp.array([-5.0, 5.0])

# How many function and derivative observations should be chosen
num_f_vals = 1
num_d_vals = 1000

# initial seed for the pseudo random key generation
seed = 0

# create new keys and randomly sample the above interval for training features
key, subkey = random.split(random.PRNGKey(seed))
x_func = random.uniform(subkey, (num_f_vals, 2), minval=bounds[0], maxval=bounds[1])
key, subkey = random.split(key)
x_der = random.uniform(subkey, (num_d_vals,2), minval=bounds[0], maxval=bounds[1])

# noise with which to sample the training labels
noise = 0.02
key, subkey = random.split(key)
y_func = fun(x_func,noise, subkey)
key, subkey = random.split(key)
y_der = grad(x_der, noise, subkey)

num_ref_points = (num_d_vals + num_f_vals) // 10
key, subkey = random.split(key)
X_ref_rand = random.permutation(subkey, jnp.vstack((x_der,x_func)))[:num_ref_points]
# this grid has fewer points. It has N points, where N is the largest perfect square smaller than num_ref_points
X_ref_even = jnp.array(jnp.meshgrid(jnp.linspace(*bounds, round(jnp.sqrt(num_ref_points))),jnp.linspace(*bounds, round(jnp.sqrt(num_ref_points))))).reshape(2,-1).T

X_split = [x_func,x_der]
Y_train = jnp.hstack((y_func, y_der.reshape(-1)))

kernel = RBF()
# an RBF kernel has per default 2 parameters
# init_kernel_params = jnp.ones(2)*jnp.log(2)
init_kernel_params = jnp.array([1.0, 1.0])

In [None]:
# X_ref = X_ref_rand
X_ref = X_ref_even

lml = lambda x: jax.jit(sparse_kernelNegativeLogLikelyhood)(x, X_split, Y_train, X_ref, noise, kernel)

In [None]:
ls = jnp.linspace(0.95, 0.96, 100)

for scale in ls:
    print(f"ls={scale:.3f} -> lml={lml(jnp.array([1.0, scale]))[-2]:.03f}")

In [None]:
res = lml(init_kernel_params)

In [None]:
covar = res[-1]

In [None]:
for elem in covar:
    print(jnp.mean(jnp.isnan(elem)))