In [2]:
import jax.numpy as jnp
from jax import random
import jax.scipy as jsp
import jax

from jaxgp.utils import CovMatrixDD, CovMatrixFD, CovMatrixFF, _build_xT_Ainv_x
from jaxgp.kernels import RBF
from jaxgp import covar

import matplotlib.pyplot as plt

In [3]:
def fun(x, noise=0.0, key = random.PRNGKey(0)):
    return (x[:,0]**2 + x[:,1] - 11)**2 / 800.0 + (x[:,0] + x[:,1]**2 -7)**2 / 800.0 + random.normal(key,(len(x),), dtype=jnp.float32)*noise

def grad(x, noise=0.0, key = random.PRNGKey(0)):
    dx1 = 4 * (x[:,0]**2 + x[:,1] - 11) * x[:,0] + 2 * (x[:,0] + x[:,1]**2 -7)
    dx2 = 2 * (x[:,0]**2 + x[:,1] - 11) + 4 * (x[:,0] + x[:,1]**2 -7) * x[:,1]
    return jnp.vstack((dx1, dx2)).T / 800.0 + random.normal(key,x.shape, dtype=jnp.float32)*noise



In [4]:
# Interval bounds from which to choose the data points
bounds = jnp.array([-5.0, 5.0])

# How many function and derivative observations should be chosen
num_f_vals = 10
num_d_vals = 100
num_ref_points = (num_d_vals + num_f_vals) // 5
dim = 2

# initial seed for the pseudo random key generation
seed = 0

# create new keys and randomly sample the above interval for training features
key, subkey = random.split(random.PRNGKey(seed))
x_func = random.uniform(subkey, (num_f_vals, dim), minval=bounds[0], maxval=bounds[1])
key, subkey = random.split(key)
x_der = random.uniform(subkey, (num_d_vals, dim), minval=bounds[0], maxval=bounds[1])
key, subkey = random.split(key)
X_ref = random.uniform(subkey, (num_ref_points, dim), minval=bounds[0], maxval=bounds[1])

# noise with which to sample the training labels
noise = 0.1
key, subkey = random.split(key)
y_func = fun(x_func,noise, subkey)
key, subkey = random.split(key)
y_der = grad(x_der, noise, subkey)


kernel = RBF()
# an RBF kernel has per default 2 parameters
init_kernel_params = jnp.array([1.0, 1.0])

In [5]:
X_split = (x_func, x_der)

In [6]:
def sparse_covariance_matrix(X_split, X_ref, kernel, params):
    KF = CovMatrixFF(X_ref, X_split[0], kernel, params)
    KD = CovMatrixFD(X_ref, X_split[1], kernel, params)
    
    K_MN = jnp.hstack((KF,KD))

    K_ref = CovMatrixFF(X_ref, X_ref, kernel, params)

    return K_MN.T@jsp.linalg.cho_solve(jsp.linalg.cho_factor(K_ref), K_MN)

def full_covariance_matrix(X_split, kernel, params):
    KF = CovMatrixFF(X_split[0], X_split[0], kernel, params)
    KD = CovMatrixFD(X_split[0], X_split[1], kernel, params)
    KDD = CovMatrixDD(X_split[1], X_split[1], kernel, params)

    K_NN = jnp.vstack((jnp.hstack((KF,KD)), 
                       jnp.hstack((KD.T,KDD))))

    return K_NN

In [7]:
def fitc_diag(X_split, X_ref, kernel, params):
    KF = CovMatrixFF(X_ref, X_split[0], kernel, params)
    KD = CovMatrixFD(X_ref, X_split[1], kernel, params)
    
    K_MN = jnp.hstack((KF,KD))

    K_ref = CovMatrixFF(X_ref, X_ref, kernel, params)
    func = jax.vmap(lambda v: kernel.eval(v, v, params), in_axes=(0))(X_split[0])
    der = jnp.ravel(jax.vmap(lambda v: jnp.diag(kernel.jac(v, v, params)), in_axes=(0))(X_split[1]))
    full_diag = jnp.hstack((func, der))
    sparse_diag = _build_xT_Ainv_x(K_ref, K_MN.T)
    L = (full_diag - sparse_diag)

    return L

In [8]:
full = full_covariance_matrix(X_split, kernel, init_kernel_params)
sparse = sparse_covariance_matrix(X_split, X_ref, kernel, init_kernel_params)

fitc = fitc_diag(X_split, X_ref, kernel, init_kernel_params)

In [9]:
print(full.shape, sparse.shape, fitc.shape)

(210, 210) (210, 210) (210,)


In [10]:
print(jnp.allclose(jnp.diag(full-sparse), fitc))

True


In [26]:
_test = lambda A,x: jsp.linalg.cho_solve((A,False), x)
test = jax.jit(_test)    

In [27]:
key = jax.random.PRNGKey(0)
_A = jax.random.uniform(key, shape=(10, 10))
A = (_A + _A.T) / 2 + jnp.eye(*_A.shape)
A_cho,_ = jsp.linalg.cho_factor(A)
x = jax.random.uniform(key, shape=(10,))

In [28]:
test(A_cho,x)

DeviceArray([-0.7015892 ,  1.0962088 , -0.67026514, -0.07697594,
              0.38710043, -0.18255392, -0.00859932,  0.11759639,
              1.5120746 , -0.62241197], dtype=float32)

In [25]:
type(A_cho)

tuple