In [18]:
import jax
import jax.numpy as jnp
from jax import jacobian, grad, hessian, vmap

In [19]:
def get_rbf(gamma):
    def kernel(x, y):
        return jnp.exp(-gamma*jnp.sum((x-y)**2))
    return kernel

# For simpler autodiff tests
def poly():
    def f(x, y):
        return jnp.sum(x**3 + 2*y**3)
    def f1(x, y):
        return jnp.dot(x**4, y**4)
    return f1

def get_jac(f, argnums):
    def jacobian(x, y):
        return jacobian(f, argnums=argnums)(x, y)
    return jacobian

def get_lap(f, argnums):
    def laplacian(x, y):
        return jnp.trace(hessian(f, argnums=argnums)(x, y))
    return laplacian

In [20]:
k = get_rbf(gamma = 0.5)

Now define the laplacian and bilaplacian functions. To ensure these functions work on batches of points (We want to evaluate kernel *matrices*) we can use vmap to vectorize them. In practice, we want to evaluate $k(X_{I}, X_{B})$ where $X_I: (n, d)$ and $X_B: (m, d)$ are the d-dimensional interior and boundary points, arranged in rows. The following functions should take $X_I$ and $X_B$ and return a matrix $f(X_I, X_B)$ satisfying $f(X_I, X_B)_{ij} = f((X_I)_i, (X_B)_j)$. 


Use vmap twice. The inner vectorization makes sure f works on arguments of the form $(x, X_I)$, and the outer one allows the first argument to also be a matrix. For instance, using vmap on $f(X_I, x)$ with in_axes = (0, None) means evaluating $f((X_I)_{i}, x)$ for each row $i$ of $X_I$. 

In [22]:
lap_x = get_lap(k, argnums=0) # laplacian wrt to first input x
bilap = get_lap(lap_x, argnums=1) # laplacian wrt to second input y

# I want them to work on batches, but need all to go to all. So one to all, then all to one..?
vec_lapx = vmap(vmap(lap_x, in_axes = (None,0)), in_axes=(0, None)) 
vec_bilap = vmap(vmap(bilap, in_axes = (None,0)), in_axes = (0, None))

In [23]:
key = jax.random.key(1039)

I = 200
B = 100

X_I = jax.random.uniform(key, shape=(I, 2))
X_B = jax.random.uniform(key, shape=(B, 2))

print(vec_lapx(X_I, X_B).shape)

(200, 100)
