In [1]:
import jax.numpy as jnp
from jax import jit, vmap, grad, random
import numpy as np
import matplotlib.pyplot as plt

In [30]:
X = np.array([[1,2],
              [3,4],
              [5,6]])

x_bar = X * 0.3
T = np.array([1,2,3]).reshape(-1,1)
t_bar = T 
full = np.hstack((X,T))
full2 = np.hstack((x_bar,t_bar))
def kernel_(X1, X2,T1,T2, l=1.0, sigma_f=1.0,l_t = 1):
    sqdist = np.sum(X1**2, 1).reshape(-1, 1) + np.sum(X2**2, 1) - 2 * np.dot(X1, X2.T)
    sqdist_t = np.sum(T1**2, 1).reshape(-1, 1) + np.sum(T2**2, 1) - 2 * np.dot(T1, T2.T)
    return sigma_f**2 * np.exp(-0.5 / l**2 * sqdist - 0.5 / l_t**2 * sqdist_t)


In [31]:
def single_rbf_kernel_space(x1, x2, l):
    return jnp.exp(-jnp.linalg.norm(x1 - x2)**2 / (2 * l**2))
def single_rbf_kernel_time(t1, t2, l):
    return jnp.exp(-jnp.linalg.norm(t1 - t2)**2 / (2 * l**2))


In [34]:
rbf_space = vmap(vmap(single_rbf_kernel_space, (None, 0, None)), (0, None, None))
rbf_time = vmap(vmap(single_rbf_kernel_time, (None, 0, None)), (0, None, None))
def kernel(XT,XT_bar,params):
    assert XT.shape[1] == 3, "XT must be a 3d array"
    """computes the kernel for a problem with two spatial dimensions and one time dimension. 
       ARD is used for the spatial and time dimension. The two spatial dimensions have one length each(isotropic), the time dimension has a seperate length scale.

    Args:
        XT (array): batched input data with spatial and time dimension
        YS (array): batched input data with spatial and time dimension
        params (list): list of the parameters [l_x, sigma_f_sq, l_t]

    Returns:
        _type_: _description_
    """
    X,T = XT[:,0:2], XT[:,2]
    X_bar,T_bar = XT_bar[:,0:2], XT_bar[:,2]
    return params[1] * rbf_space(X,X_bar,params[0]) * rbf_time(T,T_bar,params[2])

In [35]:
print(kernel(full,full2,[1,1,1]))
print(kernel_(X,x_bar,T,t_bar,1,1,1))

[[2.9375774e-01 4.3823501e-01 1.1706804e-01]
 [4.8935901e-05 2.1874933e-03 1.7509703e-02]
 [1.0060374e-12 1.3475183e-09 3.2319863e-07]]
[[2.93757700e-01 4.38234992e-01 1.17068037e-01]
 [4.89358647e-05 2.18749112e-03 1.75097047e-02]
 [1.00603928e-12 1.34751996e-09 3.23198226e-07]]


In [13]:
#test arrays
X = np.random.rand(10,2)
X_ = np.random.rand(10,2)
@jit
def k(x1, x2):
        x = x1[:, 0][:, jnp.newaxis]
        t = x1[:, 1][:, jnp.newaxis]
        x_ = x2[:, 0][jnp.newaxis, :]
        t_ = x2[:, 1][jnp.newaxis, :]
        denom = 1 + (2 * (t + t_))
        return jnp.exp(-(x - x_) ** 2 / (2 * denom)) / jnp.sqrt(denom)
@jit
def k_2(x1, x2):
    x, t = x1[0], x1[1]
    x_, t_ = x2[0], x2[1]
    denom = 1 + (2 * (t + t_))
    return jnp.exp(-(x - x_) ** 2 / (2 * denom)) / jnp.sqrt(denom)
k_2 = vmap(vmap(k_2, (None, 0)), (0, None))
k_2 = jit(k_2)
print(np.allclose(k(X, X), k_2(X, X)))

True


In [14]:
%timeit k(X,X)
%timeit k_2(X,X)

7.77 µs ± 30.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
7.71 µs ± 37 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
