In [2]:
import jax.numpy as jnp
from jax import jit, vmap, grad, random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process.kernels import RBF
from sklearn.metrics.pairwise import rbf_kernel

In [3]:
X = np.array([[1.,2.],
              [3.,4.],
              [5.,6.]])

x_bar = X * 0.3
T = np.array([1,2,3]).reshape(-1,1)
t_bar = T 
full = np.hstack((X,T))
full2 = np.hstack((x_bar,t_bar))
def kernel_(X1, X2,T1,T2, l=1.0, sigma_f=1.0,l_t = 1):
    sqdist = np.sum(X1**2, 1).reshape(-1, 1) + np.sum(X2**2, 1) - 2 * np.dot(X1, X2.T)
    sqdist_t = np.sum(T1**2, 1).reshape(-1, 1) + np.sum(T2**2, 1) - 2 * np.dot(T1, T2.T)
    return sigma_f**2 * np.exp(-0.5 / l**2 * sqdist - 0.5 / l_t**2 * sqdist_t)


In [4]:
def single_rbf_kernel_space(x1, x2, l):
    return jnp.exp(-jnp.linalg.norm(x1 - x2)**2 / (2 * l**2))
def single_rbf_kernel_time(t1, t2, l):
    return jnp.exp(-jnp.linalg.norm(t1 - t2)**2 / (2 * l**2))


In [5]:
rbf_space = vmap(vmap(single_rbf_kernel_space, (None, 0, None)), (0, None, None))
rbf_time = vmap(vmap(single_rbf_kernel_time, (None, 0, None)), (0, None, None))
def kernel(XT,XT_bar,params):
    assert XT.shape[1] == 3, "XT must be a 3d array"
    """computes the kernel for a problem with two spatial dimensions and one time dimension. 
       ARD is used for the spatial and time dimension. The two spatial dimensions have one length each(isotropic), the time dimension has a seperate length scale.

    Args:
        XT (array): batched input data with spatial and time dimension
        YS (array): batched input data with spatial and time dimension
        params (list): list of the parameters [l_x, sigma_f_sq, l_t]

    Returns:
        _type_: _description_
    """
    X,T = XT[:,0:2], XT[:,2]
    print(X.shape)
    X_bar,T_bar = XT_bar[:,0:2], XT_bar[:,2]
    return params[1] * rbf_space(X,X_bar,params[0]) * rbf_time(T,T_bar,params[2])

In [6]:
print(kernel(full,full2,[1,1,1]))
print(kernel_(X,x_bar,T,t_bar,1,1,1))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(3, 2)
[[2.9375774e-01 4.3823501e-01 1.1706804e-01]
 [4.8935901e-05 2.1874933e-03 1.7509703e-02]
 [1.0060374e-12 1.3475183e-09 3.2319863e-07]]
[[2.93757700e-01 4.38234992e-01 1.17068037e-01]
 [4.89358647e-05 2.18749112e-03 1.75097047e-02]
 [1.00603928e-12 1.34751996e-09 3.23198226e-07]]


In [7]:
def single_rbf(x, x_bar, sigma):
    x,y = x[0], x[1]
    x_bar, y_bar = x_bar[0], x_bar[1]
    return jnp.exp( -(((x-x_bar)**2+ (y-y_bar)**2))/ (2 * sigma**2)) #-jnp.abs(x**2+y**2-2*x*x_bar - 2*y*y_bar + x_bar**2 + y_bar**2)
single = jit(vmap(vmap(single_rbf, (None, 0, None)), (0, None, None)))
single2 = jit(vmap(vmap(single_rbf_kernel_space, (None, 0, None)), (0, None, None)))
print(single(X,x_bar,1))
print(single2(X,x_bar,1))
gamma = 0.5
print(rbf_kernel(X,x_bar,gamma))


[[2.9375774e-01 7.2252738e-01 8.6502230e-01]
 [8.0681588e-05 2.1874921e-03 2.8868619e-02]
 [7.4336804e-12 2.2216864e-09 3.2319863e-07]]
[[2.9375774e-01 7.2252738e-01 8.6502230e-01]
 [8.0681661e-05 2.1874933e-03 2.8868619e-02]
 [7.4336665e-12 2.2216822e-09 3.2319863e-07]]
[[2.93757700e-01 7.22527354e-01 8.65022293e-01]
 [8.06816011e-05 2.18749112e-03 2.88686225e-02]
 [7.43368067e-12 2.22168482e-09 3.23198226e-07]]


In [8]:
%timeit single
%timeit vmap(vmap(jit(single_rbf),(None,0,None)), (0,None,None))

18.1 ns ± 0.196 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
57.9 µs ± 252 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [13]:
#test arrays
X = np.random.rand(10,2)
X_ = np.random.rand(10,2)
@jit
def k(x1, x2):
        x = x1[:, 0][:, jnp.newaxis]
        t = x1[:, 1][:, jnp.newaxis]
        x_ = x2[:, 0][jnp.newaxis, :]
        t_ = x2[:, 1][jnp.newaxis, :]
        denom = 1 + (2 * (t + t_))
        return jnp.exp(-(x - x_) ** 2 / (2 * denom)) / jnp.sqrt(denom)
@jit
def k_2(x1, x2):
    x, t = x1[0], x1[1]
    x_, t_ = x2[0], x2[1]
    denom = 1 + (2 * (t + t_))
    return jnp.exp(-(x - x_) ** 2 / (2 * denom)) / jnp.sqrt(denom)
k_2 = vmap(vmap(k_2, (None, 0)), (0, None))
k_2 = jit(k_2)
print(np.allclose(k(X, X), k_2(X, X)))

True


In [12]:
def k_uf(x, x_bar, params):
    """ Kernel function for the mixed covaricen function L_x' k_uu = k_uf U x F --> R
        the derivatives were calcualted with exp(-gamma ...) --> gamma = 1/(2*sigma^2)"""
    k_uu_data = single_rbf(x, x_bar, 1)
    x,y = x[0], x[1]
    x_bar, y_bar = x_bar[0], x_bar[1]
    gamma = 1/(2*params[0]**2)
    prefactor = 2*gamma *(2*gamma*((x-x_bar)**2 + (y-y_bar)**2)-2)
    return prefactor* k_uu_data

k_uf = jit(vmap(vmap(k_uf,(None,0,None)), (0,None,None)))
print(k_uf(X, x_bar, [1,1]))

[[ 1.3219093e-01 -9.7541207e-01 -1.4791882e+00]
 [ 1.3594847e-03  2.2421792e-02  1.4694127e-01]
 [ 3.6610875e-10  8.4090829e-08  9.0140093e-06]]


In [14]:
%timeit k(X,X)
%timeit k_2(X,X)

7.77 µs ± 30.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
7.71 µs ± 37 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
