In [2]:

import numpy as np
from jax import grad, jit, vmap
import jax.numpy as jnp



In [10]:
@jit
def rbf_kernel_single(x1, x2, params):
    x1, x2 = x1.flatten(), x2.flatten()
    l, sigma_f = params[0], params[1]
    return jnp.squeeze(sigma_f**2 * jnp.exp(-0.5/l**2 * jnp.sum((x1 - x2)**2)))

def k_ff_jax(x,y, params):
    x, y = jnp.squeeze(x), jnp.squeeze(y)
    m = params[2]
    b = params[3]
    k = params[4]
    #dt^2 dt'^2
    dk_yy = grad(grad(rbf_kernel_single, argnums=1), argnums=1)
    dk_xxyy = grad(grad(dk_yy, argnums=0), argnums=(0)) (x, y, params)
    #dt'^2
    dk_yy = grad(grad(rbf_kernel_single, argnums=1), argnums=1) (x, y, params)
    #dtdt'
    dk_xy = grad(grad(rbf_kernel_single, argnums=1), argnums=(0)) (x, y, params)
    #k^2
    k_normal = rbf_kernel_single(x, y, params)
    return m**2 * dk_xxyy + 2*m*k*dk_yy + b**2 * dk_xy + k**2 * k_normal

k_ff_jax = jit(vmap(vmap(k_ff_jax, (None, 0, None)), (0, None, None)))
x = np.linspace(0, 10, 3).reshape(-1, 1)
y = np.linspace(0, 10, 3).reshape(-1, 1)
params = np.array([1.0, 1.0, 1.0, 1.0, 1.0])
k_ff_jax(x, y, params)

(1,) (1,)


Array([[3.0000000e+00, 1.8745066e-03, 1.8328910e-18],
       [1.8745066e-03, 3.0000000e+00, 1.8745066e-03],
       [1.8328910e-18, 1.8745066e-03, 3.0000000e+00]], dtype=float32)

In [16]:
x = np.linspace(-5,5,15)
y = np.linspace(-2,3,15)
t = np.linspace(0,1,15)
s = np.linspace(0,2,15)
xt = np.hstack((x.reshape(-1,1),t.reshape(-1,1)))
ys = np.hstack((y.reshape(-1,1),s.reshape(-1,1)))
params = [1,1,1,1]

In [17]:
print(np.allclose(k_ff(xt,ys,params),k_ff_jax(xt,ys,params)))
print(np.allclose(k_uf(xt,ys,params),k_uf_jax(xt,ys,params)))
print(np.allclose(k_fu(xt,ys,params),k_fu_jax(xt,ys,params)))

Traced<ShapedArray(float32[2])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(float32[15,2])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0
True
Traced<ShapedArray(float32[])>with<BatchTrace(level=3/0)> with
  val = Traced<ShapedArray(float32[15])>with<BatchTrace(level=2/0)> with
    val = Traced<ShapedArray(float32[15,15])>with<DynamicJaxprTrace(level=1/0)>
    batch_dim = 0
  batch_dim = 0
Traced<ShapedArray(float32[15,15])>with<DynamicJaxprTrace(level=1/0)>
False
False


In [6]:
print(k_uf(xt,ys,params))
print(k_uf_jax(xt,ys,params))

[[-1.2096275e-02  5.1417255e-15]
 [-2.5165852e-20 -5.4946918e-02]]
Traced<ShapedArray(float32[2,2])>with<DynamicJaxprTrace(level=1/0)>
[[-9.9980973e-02 -1.0283450e-13]
 [-6.6662131e-10 -2.4625501e-01]]


In [7]:
print(k_ff(xt,ys,params),'\n')
print(k_ff_jax(xt,ys,params),'\n')

Traced<ShapedArray(float32[2])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(float32[2,2])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0
[[ 5.4434085e-01  5.7107426e-12]
 [ 2.9275787e-08 -5.7459497e-01]] 

[[ 5.4434079e-01  5.7107426e-12]
 [ 2.9275787e-08 -5.7459497e-01]] 

