In [2]:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
from jax import jit,grad,hessian,jacfwd,jacrev
import numpy as np
import matplotlib.pyplot as plt
import jax
from tqdm.auto import tqdm
plt.style.use("ggplot")

from importlib import reload
import KernelTools
reload(KernelTools)
from KernelTools import *
from EquationModel import InducedRKHS,OperatorModel,CholOperatorModel
from parabolic_data_utils import (
    build_burgers_data,build_tx_grid,
    build_tx_grid_chebyshev,setup_problem_data
)
from plotting import plot_input_data,plot_compare_error
from evaluation_metrics import compute_results    


from Kernels import (
    get_centered_scaled_poly_kernel,
    get_anisotropic_gaussianRBF,
    fit_kernel_params,
    setup_matern,log1pexp,inv_log1pexp,get_gaussianRBF
)

In [3]:
base_matern_family = setup_matern(20)

base_kernel_family = setup_matern(3)#get_gaussianRBF
def param_kernel(x,y,params):
    lengthscales = log1pexp(params)
    rootD = jnp.diag(jnp.sqrt(lengthscales))
    return base_kernel_family(1.)(rootD@x,rootD@y)



fitted_params = fit_kernel_params(param_kernel,tx_obs,u_obs,jnp.zeros(2))
ML_lengthscales = log1pexp(fitted_params)
print(1/(jnp.sqrt(ML_lengthscales)))

KeyboardInterrupt: 

In [None]:
from jax import value_and_grad



def fit_kernel_params(parametrized_kernel,X,y,init_params,nugget = 1e-7):
    
    @jit
    @value_and_grad
    def marginal_like(params):
        vmapped_kfunc = vectorize_kfunc(partial(parametrized_kernel,params = params))
        K = vmapped_kfunc(X,X)
        K = K + nugget * diagpart(K)
        return (1/2) * y.T@jnp.linalg.inv(K)@y + (1/2) * jnp.linalg.slogdet(K).logabsdet
    solver = GradientDescent(marginal_like,value_and_grad=True,jit = True,tol = 1e-5)
    result = solver.run(init_params)
    optimized_params = result.params
    return optimized_params#,partial(parametrized_kernel,params = optimized_params)


In [None]:
def parametrized_kernel(x,y,params):
    lengthscales = log1pexp(params)
    D = jnp.diag(lengthscales)
    return base_kernel_family(1.)(D@x,D@y)

nugget = 1e-8
X = tx_obs
y = u_obs


from jax import value_and_grad
@jit
@value_and_grad
def marginal_like(params):
        vmapped_kfunc = vectorize_kfunc(partial(parametrized_kernel,params = params))
        K = vmapped_kfunc(X,X)
        K = K + nugget * diagpart(K)
        return (1/2) * y.T@jnp.linalg.inv(K)@y + (1/2) * jnp.linalg.slogdet(K).logabsdet

In [None]:
params = jnp.array([0.,0])
func_vals = []
for i in range(1000):
    val,gradval = marginal_like(params)
    func_vals.append(val)
    params = params - 0.0001 * gradval

In [None]:
params = jnp.array([5.,5])
val,gradval = marginal_like(params)

In [None]:
example_X = jnp.array([[0.,0]])

def get_mat(params):
    vmapped_kfunc = vectorize_kfunc(partial(parametrized_kernel,params = params))
    K = vmapped_kfunc(X,X)
    return jnp.sum(K)

In [None]:
k = partial(parametrized_kernel,params = params)

In [None]:
jax.grad(k,argnums = 0)(jnp.array([0.,0]),jnp.array([0.,1]))

In [None]:
example_X = jnp.array([[0.,0]])


def param_func(params):
    new_function = partial(parametrized_kernel,params = params)
    return new_function(example_X[0],example_X[0]+1e-100)

jax.grad(param_func)(params)

In [None]:
from functools import partial
vectorize_kfunc(partial(param_kernel,params = 3 * jnp.ones(2)))(tx_obs,tx_obs)