In [86]:
from jax import grad
import jax.numpy as jnp
import numpy as np
from sklearn.metrics.pairwise import rbf_kernel as rbf_kernel_sklearn

In [147]:
def rbf_kernel_jax(X1, X2, gamma=1.0, sigma_f=1.0):

    sqdist = jnp.sum(X1**2, 1).reshape(-1, 1) + jnp.sum(X2**2, 1) - 2 * jnp.dot(X1, X2.T)
    return sigma_f**2 * jnp.exp(-gamma * sqdist)

def rbf_kernel_single(x1, x2, gamma):
    return jnp.sum(jnp.exp(-gamma * jnp.sum((x1 - x2)**2)))
x = np.array([1,2,3]).reshape(-1,1)
y = np.array([1,2,3]).reshape(-1,1)


now I try to implement the derivatives of the kernel    
I want to do this for the helmholtz equation:  
\begin{align}   
 \mathcal{L}_x ^\omega = \frac{\partial^2}{\partial x^2} - \nu^2 
\end{align}     
The first step would be to use differential operator onto the kernel function twice. So that:
\begin{align}
K_{new}(x,x') = \mathcal{L}_x ^\nu \mathcal{L}_{x'} ^\nu K(x,x')
\end{align}   

\begin{align}   
    \mathcal{L}_x ^\nu \mathcal{L}_{x'} ^\nu k(x, x') = \left(\frac{\partial^2}{\partial x^2} - \nu^2\right)\left(\frac{\partial^2}{\partial x'^2} - \nu^2\right)k(x, x') 
\end{align} 
this then results in:
\begin{align*}
\mathcal{L}\mathcal{L}k(x, x') = & \frac{\partial^4}{\partial x^2 \partial x'^2} k(x, x') - \nu^2 \frac{\partial^2}{\partial x^2} k(x, x') - \nu^2 \frac{\partial^2}{\partial x'^2} k(x, x') + \nu^4 k(x, x')
\end{align*}
Because of the way the rbf works, the two parts with the second derivative can be combined to:
\begin{align*}  
\mathcal{L}\mathcal{L}k(x, x') = & \frac{\partial^4}{\partial x^2 \partial x'^2} k(x, x') - 2\nu^2 \frac{\partial^2}{\partial x^2} k(x, x') + \nu^4 k(x, x')
\end{align*}
To make it more clear, I will call the first part of the result $A$ and the second part $B$, the third part does not need further calculations.

In the following I will try to implement this in code.  
For A we get:
\begin{align*}
A_{ij} = (16{\gamma}^4\cdot\left(x_i-y_j\right)^4-48{\gamma}^3\cdot\left(x_i-y_j\right)^2+12{\gamma}^2)\mathrm{e}^{-{\gamma}\cdot\left(x_i-y_j\right)^2}
\end{align*}
For B we get:
\begin{align*}
B_{ij} = (4{\gamma}^2\cdot\left(x_i-y_j\right)^2-2{\gamma})\mathrm{e}^{-{\gamma}\cdot\left(x_i-y_j\right)^2}
\end{align*}


A has to computed elementwise just like the kernel itself. I will now try to do this once with the expressions and once with autograd.

In [205]:
def A_scratch(x,x_bar, hyperparameters):
    gamma, sigma_f = hyperparameters[0], hyperparameters[1]
    
    kernel_values = rbf_kernel_jax(x, x_bar, gamma,sigma_f)
    #kernel_values = rbf_kernel_jax(x, x_bar, l= 1/np.sqrt(2*gamma), sigma_f_sq = sigma_f**2)

    n, m = x.shape[0], x_bar.shape[0]
    dk_ff = np.zeros((n, m))
    for i in range(n):
        for j in range(m):
            dist_sq = (x[i] - x_bar[j])**2
            factor = 16*gamma**4*dist_sq**2 - 48*gamma**3*dist_sq + 12*gamma**2
            dk_ff[i, j] = factor
    return dk_ff * kernel_values
def B_scratch(x,x_bar, hyperparameters):
    gamma, sigma_f = hyperparameters[0], hyperparameters[1]
    nu = hyperparameters[2]
    kernel_values = sigma_f**2*rbf_kernel_sklearn(x, x_bar, gamma)

    n, m = x.shape[0], x_bar.shape[0]
    dk_ff = np.zeros((n, m))
    for i in range(n):
        for j in range(m):
            dist_sq = (x[i] - x_bar[j])**2
            factor = -2*nu**2*(4*gamma**2 * dist_sq - 2*gamma)
            dk_ff[i, j] = factor*kernel_values[i,j]
    return dk_ff
def C_scratch(x,x_bar, hyperparameters):
    gamma, sigma_f = hyperparameters[0], hyperparameters[1]
    nu = hyperparameters[2]
    kernel_values = sigma_f**2*rbf_kernel_sklearn(x, x_bar, gamma)

    return nu**4*kernel_values

In [236]:
from jax import vmap, jit
from functools import partial
x = np.array([1.0,2.0,3.0])
y = np.array([1.0,2.0,3.0])
gamma = 1
@jit
def rbf_kernel_single(x1, x2, params):
    gamma, sigma_f = params[0], params[1]
    return sigma_f**2*jnp.exp(-gamma * jnp.sum((x1 - x2)**2))

#done with help of https://jejjohnson.github.io/research_notebook/content/notes/kernels/kernel_derivatives.html

def A_autograd(x,x_bar, hyperparameters):
    """only works for 1D arrays atm """

    # Vectorize the kernel function, in_axes specifies which argument is vectorized. I could also use the lambda function for the hyperparameter argument, but this is more readable.
    params_rbf = hyperparameters[:2]
    # first_vmap = vmap(rbf_kernel_single, in_axes=(None, 0, None))
    # vectorized_rbf_kernel = vmap(first_vmap, in_axes=(0, None, None))

    # # Now rbf_kernel should accept 2D arrays for x1 and x2, however atm it only works for 1D arrays
    # K =  vectorized_rbf_kernel(x, x_bar, params_rbf)

    # Now compute the derivatives
    second_derivative_x = grad(grad(rbf_kernel_single, argnums=0), argnums=0)
    fourth_derivative = grad(grad(second_derivative_x, argnums=1), argnums=1)

    # Vectorize the derivative function 
    fourth_derivative_vectorized = vmap(vmap(fourth_derivative, in_axes=(None, 0, None)), in_axes=(0, None, None))

    # Now fourth_derivative_vectorized will accept 2D arrays for x1 and x2
    K_4th_derivative = fourth_derivative_vectorized(x, x_bar, params_rbf)
    return K_4th_derivative 


def A_autograd_2(x,x_bar, hyperparameters):
    """ to check if the autograd function works, however this is very slow in comparison to the other function """
    params_rbf = hyperparameters[:2]
    
    n, m = x.shape[0], x_bar.shape[0]
    dk_ff = np.zeros((n, m))
    for i in range(n):
        for j in range(m):
            
            dk_ff[i, j] = grad(grad(grad(grad(rbf_kernel_single, argnums=1), argnums=1), argnums=0), argnums=0)(x[i], x_bar[j], params_rbf)
    return dk_ff


In [237]:
hyperparameters = [2,2,4]
x = np.linspace(0,1,4).reshape(-1,1)
y = np.linspace(0,1,4).reshape(-1,1)
print(A_scratch(x,y,hyperparameters))

x_1 = np.linspace(0,1,4)
y_1 = np.linspace(0,1,4)
#print(A_autograd_2(x,y,hyperparameters))
print(A_autograd(x_1,y_1,hyperparameters))
print(np.allclose(A_autograd(x_1,y_1,hyperparameters),A_scratch(x,y,hyperparameters)))


[[ 192.         27.2053   -118.562744  -43.30729 ]
 [  27.2053    192.         27.2053   -118.56274 ]
 [-118.562744   27.2053    192.00003    27.205303]
 [ -43.30729  -118.56274    27.205303  192.      ]]
[[ 192.         27.205296 -118.56275   -43.307293]
 [  27.205296  192.         27.205296 -118.56276 ]
 [-118.56275    27.205296  192.         27.205307]
 [ -43.307293 -118.56276    27.205307  192.      ]]
True


Perfect! Both are the same.


In [238]:
def B_autograd(x,x_bar, hyperparameters):
    params_rbf = hyperparameters[:2]
    nu = hyperparameters[2]
    first_vmap = vmap(rbf_kernel_single, in_axes=(None, 0, None))
    vectorized_rbf_kernel = vmap(first_vmap, in_axes=(0, None, None))

    # here we only need the second derivative once with respect to x2
    second_derivative_x2 = grad(grad(rbf_kernel_single, argnums=1), argnums=1)
    second_derivative_x2_vectorized = vmap(vmap(second_derivative_x2, in_axes=(None, 0, None)), in_axes=(0, None, None))
    
    return -2*nu**2*second_derivative_x2_vectorized(x, x_bar, params_rbf)
def C_vectorized(x,x_bar,hyperparameters):
    params_rbf = hyperparameters[:2]
    nu = hyperparameters[2]
    first_vmap = vmap(rbf_kernel_single, in_axes=(None, 0, None))
    vectorized_rbf_kernel = vmap(first_vmap, in_axes=(0, None, None))
    return nu**4*vectorized_rbf_kernel(x, x_bar, params_rbf)
#compare with the scratch function
hyperparameters = [1,3,4]
print("B - comparison")
print(B_autograd(x_1,y_1,hyperparameters))
print(B_scratch(x,y,hyperparameters))
print(np.allclose(B_autograd(x_1,y_1,hyperparameters),B_scratch(x,y,hyperparameters)), "B","\n")
print("C - comparison")
print(C_vectorized(x_1,y_1,hyperparameters))
print(C_scratch(x,y,hyperparameters))
print(np.allclose(C_vectorized(x_1,y_1,hyperparameters),C_scratch(x,y,hyperparameters)),"C")
    

B - comparison
[[ 576.        400.888      41.035553 -211.89856 ]
 [ 400.888     576.        400.888      41.035583]
 [  41.035553  400.888     576.        400.88806 ]
 [-211.89856    41.035583  400.88806   576.      ]]
[[ 576.          400.88801393   41.03554486 -211.89855811]
 [ 400.88801393  576.          400.88801393   41.03554486]
 [  41.03554486  400.88801393  576.          400.88801393]
 [-211.89855811   41.03554486  400.88801393  576.        ]]
True B 

C - comparison
[[2304.      2061.7097  1477.2797   847.59424]
 [2061.7097  2304.      2061.7097  1477.2797 ]
 [1477.2797  2061.7097  2304.      2061.71   ]
 [ 847.59424 1477.2797  2061.71    2304.     ]]
[[2304.         2061.70978594 1477.27961494  847.59423246]
 [2061.70978594 2304.         2061.70978594 1477.27961494]
 [1477.27961494 2061.70978594 2304.         2061.70978594]
 [ 847.59423246 1477.27961494 2061.70978594 2304.        ]]
True C


Again the same result. So the dimplementation shold be correct.     
Now lets try to implement the kernel with the derivatives.  
At first I want to compute the marginal log likelihood to optimize the hyperparameters $\gamma, \sigma_f$ and the infered hyperparameter from the Helmholtzequation $\nu$.      
The marginal log likelihood is given by:
\begin{align*}
\log p(\mathbf{y} \mid \mathbf{X}, \gamma, \sigma_f, \nu) = & -\frac{1}{2}\mathbf{y}^T\left(K_{new} + \sigma_n^2\mathbf{I}\right)^{-1}\mathbf{y} - \frac{1}{2}\log\left|K_{new} + \sigma_n^2\mathbf{I}\right| - \frac{n}{2}\log 2\pi
\end{align*}

In [239]:
def create_derivative_matrix_scratch(x_train, x_test,noise, hyperparameters):
    """ create the derivative matrix with the three functions above """
    A = A_scratch(x_train, x_test, hyperparameters)
    B = B_scratch(x_train, x_test, hyperparameters)
    C = C_scratch(x_train, x_test, hyperparameters)
    return A + B + C + noise * np.eye(len(x_train))

def create_derivative_matrix_jax(x_train, x_test,noise, hyperparameters):
    """ create the derivative matrix with the three functions above """
    A = A_autograd(x_train, x_test, hyperparameters)
    B = B_autograd(x_train, x_test, hyperparameters)
    C = C_vectorized(x_train, x_test, hyperparameters)
    return A + B + C + noise * np.eye(len(x_train))

In [241]:
size = 10
x_1, y_1 = np.linspace(0,1,size), np.linspace(0,1,size)
hyperparameters = [1,3,4]
x,y = np.linspace(0,1,size).reshape(-1,1), np.linspace(0,1,size).reshape(-1,1)

%timeit create_derivative_matrix_jax(x_1,y_1,0.001,hyperparameters)
#%timeit create_derivative_matrix_scratch(x,y,0.001, hyperparameters)

31.3 ms ± 2.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
