In [2]:
from jax import numpy as jnp
from jax import jit, grad, vmap
@jit
def rbf_kernel_single_x(x: float, y: float, params: list) -> float:
    """general RBF kernel k(x,y)"""
    l_x, sigma_f_sq = params[0], params[1]
    sqdist = jnp.sum(x-y)**2
    return sigma_f_sq * jnp.exp(-0.5 / l_x**2 * sqdist)
@jit
def rbf_kernel_single_t(t: float, s: float, l_t: float) -> float:
    """general RBF kernel. takes scalar inputs t,s and returns k(t,s)"""
    sqdist = jnp.sum(t-s)**2
    value = jnp.exp(-0.5 / l_t**2 * sqdist)
    return value


The next step is to implement the kernel for the heat equation. I will use the same kernel as in the paper, so just the RBF kernel. Our linear operator is given by:
$$\mathcal{L}_x ^\alpha =  \frac{\partial^2}{\partial x^2} - \frac{1}{c^2} * \frac{\partial^2}{\partial t}$$    
We need 4 different parts for the kernel:
1. The kernel without the linear operator $K(x,y,t,s)_{uu}$
2. The kernel with one $K_{uf}$ = $\mathcal{L}_{y}^\alpha$ $K_{uu}$
3. The kernel with one $K_{fu}$ = $\mathcal{L}_x^\alpha$ $K_{uu}$
4. And with both transformations $K_{ff}$ = $\mathcal{L}_x^\alpha$ $\mathcal{L}_y^\alpha$ $K_{uu}$
The first kernel is just the normal RBF kernel: 
\begin{align}
K(x,y,t,s)_{uu} = \sigma_f^2 e^{-\gamma_x(x-y)^2 - \gamma_t(t-s)^2}
\end{align}
For the second we use the linear operator once with respect to y and s.
\begin{align}
K_{uf} = \mathcal{L}_{y}^\alpha K_{uu} =  \frac{\partial^2}{\partial y^2} K_{uu} -\frac{1}{c^2} \frac{\partial^2}{\partial s^2} K_{uu}
\end{align}
Same for the third kernel, but with respect to x and t.
\begin{align}
K_{fu} = \mathcal{L}_{x}^\alpha K_{uu} =  \frac{\partial^2}{\partial x^2} K_{uu} -\frac{1}{c^2} \frac{\partial^2}{\partial t^2} K_{uu}
\end{align}
And for the last kernel we use the linear operator twice.
\begin{align}
K_{ff} = \mathcal{L}_{x}^\alpha \mathcal{L}_{y}^\alpha K_{uu} = \frac{\partial^2}{\partial x^2} \frac{\partial^2}{\partial y^2} K_{uu} +\frac{1}{c^4} \frac{\partial^2}{\partial t^2} \frac{\partial^2}{\partial s^2} K_{uu} -\frac{1}{c^2} \frac{\partial^2}{\partial t^2} \frac{\partial^2}{\partial y^2} K_{uu} -\frac{1}{c^2} \frac{\partial^2}{\partial x^2} \frac{\partial^2}{\partial s^2} K_{uu} = \frac{\partial^2}{\partial x^2} \frac{\partial^2}{\partial y^2} K_{uu} +\frac{1}{c^4} \frac{\partial^2}{\partial t^2} \frac{\partial^2}{\partial s^2} K_{uu} -\frac{2}{c^2} \frac{\partial^2}{\partial t^2} \frac{\partial^2}{\partial y^2} K_{uu} 
\end{align}
The implementation was very hard at the start, so I had to split the starting kernel into two parts, one for xy and one for ts. I just use the $\sigma_f$ parameter once in the x kernel. But I will multiply the two kernels anyways, so it should not make a difference.

\begin{align*}
\frac{\partial k}{\partial t^2\partial y^2} = 4{\gamma}_\text{t}{\gamma}_\text{x}\cdot\left(2{\gamma}_\text{x}\cdot\left(y-x\right)^2-1\right)\left(2{\gamma}_\text{t}\cdot\left(t-s\right)^2-1\right)\mathrm{e}^{-{\gamma}_\text{t}\cdot\left(t-s\right)^2-{\gamma}_\text{x}\cdot\left(y-x\right)^2}
\end{align*}
I computed both parts seperatly so I can use the vecotorisation with vmap. I did not use the autograd for this. I probably could do the whole thing like this but for the time beeing it is fine.(I dont know if it is possible to use vmap with the whole function)

In [7]:
@jit
def k_x_welle(x:float, y:float, params):
    """computes the k_x part of the derivative k_ttyy. The parts are seperated to make vmap work"""
    gamma_x = 0.5 / params[0]**2
    polynom = (2*gamma_x*(x-y)**2 - 1) 
    return polynom * rbf_kernel_single_x(x,y,params) * 2*gamma_x
@jit
def k_t_welle(t:float, s:float, params):
    """computes the k_t part of the derivative k_ttyy. The parts are seperated to make vmap work"""
    gamma_t = 0.5 / params[2]**2
    polynom = (2*gamma_t*(t-s)**2 - 1) 
    return polynom * rbf_kernel_single_t(t,s,params[2]) * 2*gamma_t
@jit
def k_ff(X, Y, T, S, params):
    """computes k_ff part of the block matrix K. It corresponds to the part with double the operator L: k_ff = L k_uu L'^T
    #k_ff =  d^2/dx^2 d^2/dy^2 K_uu - 2/c^2 d^2/dt^2 d^2/dy^2 K_uu + 1/c^4 d^2/dt^2 d^2/ds^2 K_uu
       params = [l_x, sigma_f_sq, l_t, c]
    """
    c = params[-1]
    l_t = params[2]
    params = params[:-1]
    # flatten the data so that it can be used in the grad function (only 1d arrays are allowed in grad)
    X,Y,T,S = X.flatten(), Y.flatten(), T.flatten(), S.flatten() 
    #vectorizazion of both the kernel functions 
    rbf_kernel_x = vmap(vmap(rbf_kernel_single_x, (None, 0, None)), (0, None, None))
    rbf_kernel_t = vmap(vmap(rbf_kernel_single_t, (None, 0, None)), (0, None, None))
    #compute the derivatives seperately and then multiply by the other kernel.
    # d^2/dx^2 d^2/dy^2 K_uu
    dk_dydy = grad(grad(rbf_kernel_single_x, argnums = 1), argnums = 1) #second derivative of  k with respect y
    dk_dxdxdydy = grad(grad(dk_dydy, argnums = 0), argnums = 0) # second derivative with respect to x of dk_dydy
    vectorized_dxdxdydy = vmap(vmap(dk_dxdxdydy, (None, 0, None)), (0, None, None))
    k_xxyy = vectorized_dxdxdydy(X,Y,params) * rbf_kernel_t(T,S,l_t)

    #d^2/dt^2 d^2/dy^2 K_uu
    vec_k_x_welle = vmap(vmap(k_x_welle, (None, 0, None)), (0, None, None))(X,Y,params)
    vec_k_t_welle = vmap(vmap(k_t_welle, (None, 0, None)), (0, None, None))(T,S,params)
    k_ttyy = vec_k_x_welle * vec_k_t_welle

    #d^2/dt^2 d^2/ds^2 K_uu
    dk_ss = grad(grad(rbf_kernel_single_t, argnums = 1), argnums = 1) #second derivative of  k with respect s
    dk_dtdtdsds = grad(grad(dk_ss, argnums = 0), argnums = 0) # second derivative with respect to t of dk_ss
    vectorized_dtdtdsds = vmap(vmap(dk_dtdtdsds, (None, 0, None)), (0, None, None))
    k_ttdsds = vectorized_dtdtdsds(T,S,params[2]) * rbf_kernel_x(X,Y,params)

    return k_xxyy - 2/c**2 * k_ttyy + 1/c**4 * k_ttdsds


@jit
def k_uf(X,Y,T,S,params):
    c = params[-1]
    l_t = params[2]
    params = params[:-1]
    X,Y,T,S = X.flatten(), Y.flatten(), T.flatten(), S.flatten() 
    rbf_kernel_x = vmap(vmap(rbf_kernel_single_x, (None, 0, None)), (0, None, None))
    rbf_kernel_t = vmap(vmap(rbf_kernel_single_t, (None, 0, None)), (0, None, None))

    #d^2/dy^2 K_uu
    dk_dydy = grad(grad(rbf_kernel_single_x, argnums = 1), argnums = 1) #second derivative of  k with respect y
    vectorized_dydy = vmap(vmap(dk_dydy, (None, 0, None)), (0, None, None))
    k_yy = vectorized_dydy(X,Y,params) * rbf_kernel_t(T,S,l_t)
    #d^2/ds^2 K_uu
    dk_ss = grad(grad(rbf_kernel_single_t, argnums = 1), argnums = 1) #second derivative of  k with respect s
    vectorized_dss = vmap(vmap(dk_ss, (None, 0, None)), (0, None, None))
    k_ss = vectorized_dss(T,S,params[2]) * rbf_kernel_x(X,Y,params)

    return k_yy - 1/c**2 * k_ss
    