In [5]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
from kernel_oscillator import k_uu, k_ff_jax, k_uf_jax, rbf_kernel_single, k_fu_jax

In [33]:
@jit
def k_uf(x, y, params):
    m = 1
    b = params[2]
    k = params[3]
    gamma = 0.5 / params[0]**2
    #dt'^2
    k_yy = 2*gamma*(2*gamma * (x-y)**2 - 1) * rbf_kernel_single(x, y, params)
    #dt'
    k_y = 2*gamma*(x-y) * rbf_kernel_single(x, y, params)
    #no dev
    k_normal = rbf_kernel_single(x, y, params)
    return jnp.squeeze(m * k_yy + b * k_y + k * k_normal)
k_uf = vmap(vmap(k_uf, (None, 0, None)), (0, None, None))
#k_uf = jit(k_uf)

def k_fu(x,y,params):
    m = 1
    b = params[2]
    k = params[3]
    gamma = 0.5 / params[0]**2
    #dt^2
    k_xx = 2*gamma*(2*gamma * (x-y)**2 - 1) * rbf_kernel_single(x, y, params)
    #dt'
    k_x = -2*gamma*(x-y) * rbf_kernel_single(x, y, params)
    #no dev
    k_normal = rbf_kernel_single(x, y, params)
    return jnp.squeeze(m * k_xx + b * k_x + k * k_normal)
k_fu = vmap(vmap(k_fu, (None, 0, None)), (0, None, None))
#k_fu = jit(k_fu)

In [42]:
def k_ff(x,y,params):
    m = 1
    b = params[2]
    k = params[3]
    gamma = 0.5 / params[0]**2
    #dt^2 dt'^2
    dif = (x-y)
    k_xxyy = (16*gamma**4* dif**4 - 48*gamma**3*dif**2 + 12*gamma**2) * rbf_kernel_single(x,y,params)
    #dt'^2
    k_yy = 2*gamma*(2*gamma * (x-y)**2 - 1) * rbf_kernel_single(x, y, params)
    #dtdt'
    k_xy = (2*gamma - 4*gamma**2*dif**2) * rbf_kernel_single(x, y, params)
    #no div
    k_normal = rbf_kernel_single(x, y, params)
    return m**2 * k_xxyy + 2*m*k*k_yy + b**2 * k_xy + k**2 * k_normal
k_ff = vmap(vmap(k_ff, (None, 0, None)), (0, None, None))
k_ff = jit(k_ff)

In [44]:
x = np.linspace(0, 10, 100).reshape(-1, 1)
y = np.linspace(-30, 34, 100).reshape(-1, 1)
params = np.array([1, 1, 1, 1])
print(np.allclose(k_uf_jax(x, y, params), k_uf(x, y, params)))
print(np.allclose(k_fu_jax(x, y, params), k_fu(x, y, params)))
print(np.allclose(k_ff_jax(x, y, params), k_ff(x, y, params)))

True
True
False


In [54]:
test1 = vmap(rbf_kernel_single, (None, 0, None))
test2 = vmap(rbf_kernel_single, (0, None, None))
test3 = vmap(vmap(rbf_kernel_single, (None, 0, None)), (0, None, None))
test4 = vmap(vmap(rbf_kernel_single, (0, None, None)), (None, 0, None))
X = np.array([1,2,3,2]).reshape(-1,1)
Y = np.array([1.5,2.5,1.5,2.5]).reshape(-1,1)
params = np.array([1,1,1,1])
print(test1(X,Y,params))
print(test2(X,Y,params))
print(test3(X,Y,params))
print("--------")
print(test4(X,Y,params))
d = np.zeros((4,4))
for i in range(4):
    for j in range(4):
        d[i,j] = rbf_kernel_single(X[i],Y[j],params)
print(d)

[0.22313017 0.22313017 0.22313017 0.22313017]
[0.082085   0.60653067 0.082085   0.60653067]
[[0.8824969  0.32465246 0.8824969  0.32465246]
 [0.8824969  0.8824969  0.8824969  0.8824969 ]
 [0.32465246 0.8824969  0.32465246 0.8824969 ]
 [0.8824969  0.8824969  0.8824969  0.8824969 ]]
--------
[[0.8824969  0.8824969  0.32465246 0.8824969 ]
 [0.32465246 0.8824969  0.8824969  0.8824969 ]
 [0.8824969  0.8824969  0.32465246 0.8824969 ]
 [0.32465246 0.8824969  0.8824969  0.8824969 ]]
[[0.88249689 0.32465246 0.88249689 0.32465246]
 [0.88249689 0.88249689 0.88249689 0.88249689]
 [0.32465246 0.88249689 0.32465246 0.88249689]
 [0.88249689 0.88249689 0.88249689 0.88249689]]
