In [277]:
import sys
sys.path.append('../')
import gp

from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as np
import jax.scipy as scp
from jax import jit
from scipy.special import gamma, kv

from gp.metric import euclid_distance
from gp.utils import pairwise

@jit
def K_0p5(x1,x2,l,nu):
    dists = pairwise(euclid_distance, square=False)
    return np.exp(-dists(x1,x2)/l)

@jit
def K_1p5(x1,x2,l,nu):
    dists = pairwise(euclid_distance, square=False)
    K = dists(x1,x2)/l * np.sqrt(3)
    return (1. + K) * np.exp(-K)

@jit
def K_2p5(x1,x2,l,nu):
    dists = pairwise(euclid_distance, square=False)
    K = dists(x1,x2)/l * np.sqrt(5)
    return (1. + K + K ** 2 / 3.0) * np.exp(-K)

@jit
def K_inf(x1,x2,l,nu):
    dists = pairwise(euclid_distance, square=True)
    return np.exp(-dists(x1,x2) / 2.0 /l**2)

def K_other(x1,x2,l,nu):
    dists = pairwise(euclid_distance, square=False)
    dists_matrix = dists(x1,x2)/l
    dists_matrix = np.where(dists_matrix==0, np.finfo(float).eps, dists_matrix)
    tmp = (np.sqrt(2 * nu) * dists_matrix)
    val = (2 ** (1. - nu)) / np.exp(scp.special.gammaln(nu))
    return val * tmp**nu * kv(nu,tmp)

def matern(x,y, l=1., nu=1.5):
    if nu == 0.5:
        return K_0p5(x,y,l,nu)
    elif nu == 1.5:
        return K_1p5(x,y,l,nu)
    
    elif nu == 2.5:
        return K_2p5(x,y,l,nu)
    
    elif nu == np.inf:
        return K_inf(x,y,l,nu)
    else:
        return K_other(x,y,l,nu)

In [278]:
from sklearn.gaussian_process import kernels

In [279]:
a = np.arange(0,15,1).reshape(-1,3)
b = np.arange(0,15,0.5).reshape(-1,3)
nu = 5.5
l=3.
KM = kernels.Matern(length_scale=l,nu=nu)

In [280]:
KM(a,b)/matern(a,b,nu=nu,l=l)

DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float64)

In [281]:
%%timeit
KM(a,b)

600 µs ± 28.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [282]:
%%timeit
matern(a,b,nu=nu)

38.9 ms ± 224 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
