In [2]:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
from jax import jit,grad,hessian,jacfwd,jacrev
import numpy as np
import matplotlib.pyplot as plt
import jax
from tqdm.auto import tqdm
plt.style.use("ggplot")

from importlib import reload
import KernelTools
reload(KernelTools)
from KernelTools import *
from EquationModel import InducedRKHS,OperatorModel,CholOperatorModel
from parabolic_data_utils import (
    build_burgers_data,build_tx_grid,
    build_tx_grid_chebyshev,setup_problem_data
)
from plotting import plot_input_data,plot_compare_error
from evaluation_metrics import compute_results    
from data_utils import MinMaxScaler
from evaluation_metrics import get_nrmse

from Kernels import log1pexp,inv_log1pexp
from Kernels import (
    get_centered_scaled_poly_kernel,
    get_anisotropic_gaussianRBF,
    fit_kernel_params
)
from EquationModel import CholInducedRKHS, CholOperatorModel, OperatorPDEModel
from functools import partial

import LM_Solve
import importlib
importlib.reload(LM_Solve)
from LM_Solve import LevenbergMarquadtMinimize,adaptive_refine_solution

In [3]:
# import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from scipy.interpolate import griddata
from scipy.spatial import distance
from matplotlib import cm
import time
from mpl_toolkits.mplot3d import Axes3D
from pyDOE import lhs
#    import sobol_seq
import os
# from sklearn.preprocessing import MinMaxScaler

In [4]:
def sparsifyDynamics(Theta,dXdt,lamb,n):
    # Initial guess: Least-squares
    Xi = np.linalg.lstsq(Theta,dXdt,rcond=None)[0]

    for k in range(10):
        smallinds = np.abs(Xi) < lamb # Find small coeffs.
        Xi[smallinds]=0 # and threshold
        for ind in range(n): # n is state dimension
            biginds = smallinds[:,ind] == 0
            # Regress onto remaining terms to find sparse Xi
            Xi[biginds,ind] = np.linalg.lstsq(Theta[:,
                biginds],dXdt[:,ind],rcond=None)[0]
        return Xi

We will compare `1.5 step` , `PINNSR` and  `SINDy` methods by using:

- Collocation point set: Randomly selected.
- Data point set: Randomly selected from Collocation point set.

by measuring the error of the recovered $u$ on the full given data.

We call the two functions `get_error_1_5` and `get_error_SINDy` in this notebook and later we upload the data resulting from `PINNSR` since they run on different Python environments.

In [5]:
def get_error_1_5(n_obs, n_coll, run):
    # Generate data

    data = scipy.io.loadmat('/home/juanfelipe/Desktop/research/keql/examples/burgers/data/burgers.mat')
    # t
    t = jnp.real(data['t'].flatten()[:,None])
    # Scale t
    scaler_t = MinMaxScaler()
    t = scaler_t.fit_transform(t)
    # x
    x = np.real(data['x'].flatten()[:,None])
    # Scale x
    scaler_x = MinMaxScaler()
    x = scaler_x.fit_transform(x)
    # u true values
    Exact = np.real(data['usol'])

    # Fine meshgrid
    T, X = np.meshgrid(t,x)

    # Fine pairs (t,x)
    X_star = np.hstack((T.flatten()[:,None], X.flatten()[:,None]))
    # Fine u values
    u_star = Exact.flatten()[:,None]
    
    # Triples at collocation point set
    N_all = n_coll
    triplets_fine = np.hstack([X_star,u_star])
    triplets_all = jax.random.choice(key = jax.random.PRNGKey(0), a = triplets_fine, shape = (N_all,), replace=False)
    
    # Collocation point set
    tx_all = triplets_all[:,:2]
    print(tx_all.shape)


    N_obs = n_obs
    triplets_obs = jax.random.choice(key = jax.random.PRNGKey(run), a = triplets_all, shape = (N_obs,), replace=False)
    # triplets_obs = triplets_all[idx_obs,:] # Choose data point set from collocation point set
    # Data point set
    tx_obs = triplets_obs[:,:2]
    u_obs = triplets_obs[:,-1]

    # Run 1_5 step method

    u_operators = (eval_k,)
    feature_operators = (eval_k,dx_k,dxx_k)

    # Choose u kernel
    def param_ani_gaussian_RBF(x,y,params):
        lengthscales = log1pexp(params)
        return get_anisotropic_gaussianRBF(1.,jnp.diag(lengthscales))(x,y)
    
    fitted_params,ml_value = fit_kernel_params(param_ani_gaussian_RBF,tx_obs,u_obs,jnp.zeros(2))
    ML_lengthscales = log1pexp(fitted_params)
    print(1/(jnp.sqrt(ML_lengthscales)))
    k_u = get_anisotropic_gaussianRBF(1.,jnp.diag(log1pexp(fitted_params)))

    # RKHS class for u
    u_model = CholInducedRKHS(
        tx_all,
        u_operators,
        k_u,
        nugget_size = 1e-8
        )
    u_params_init = u_model.get_fitted_params(tx_obs,u_obs)

    grid_features_init = (
        (u_model.evaluate_operators(feature_operators,tx_all,u_params_init))
        .reshape(
                len(tx_all),
                len(feature_operators),
                order = 'F'
            )
    )

    # Choose kernel for P
    k_P_u_part = get_centered_scaled_poly_kernel(2,grid_features_init,c=1.,scaling = 'diagonal')
    @vectorize_kfunc
    def k_P(x,y):
        return k_P_u_part(x[2:],y[2:])
    P_model = OperatorModel(k_P)    
        
    # Equation model that has u and P object
    EqnModel = OperatorPDEModel(
        P_model,
        (u_model,),
        (tx_obs,),
        (u_obs,),
        (tx_all,),
        feature_operators,
        rhs_operator=dt_k,
        datafit_weight = 10
    )

    # Optimize - LM
    params_init = jnp.hstack([u_params_init,jnp.zeros(len(grid_features_init))])
    params,convergence_data = LevenbergMarquadtMinimize(
        params_init.copy(),
        EqnModel,
        beta = 1e-11,
        max_iter = 501,
        init_alpha=0.1,
        line_search_increase_ratio=1.4,
        print_every = 100
    )
    p_adjusted,refine_convergence_data = adaptive_refine_solution(params,EqnModel,1e-3,500)
    # u_params
    u_sol = p_adjusted[:u_model.num_params]
    # u_true 
    u_true = u_star.flatten()
    # get error
    error_u_field = get_nrmse(u_true, u_model.point_evaluate(X_star,u_sol))
    
    return error_u_field

In [6]:
def get_error_SINDy(n_obs, n_coll, run):
    # Generate data

    data = scipy.io.loadmat('/home/juanfelipe/Desktop/research/keql/examples/burgers/data/burgers.mat')
    # t
    t = jnp.real(data['t'].flatten()[:,None])
    # Scale t
    scaler_t = MinMaxScaler()
    t = scaler_t.fit_transform(t)
    # x
    x = np.real(data['x'].flatten()[:,None])
    # Scale x
    scaler_x = MinMaxScaler()
    x = scaler_x.fit_transform(x)
    # u true values
    Exact = np.real(data['usol'])

    # Fine meshgrid
    T, X = np.meshgrid(t,x)

    # Fine pairs (t,x)
    X_star = np.hstack((T.flatten()[:,None], X.flatten()[:,None]))
    # Fine u values
    u_star = Exact.flatten()[:,None]
    
    # Triples at collocation point set
    N_all = n_coll
    triplets_fine = np.hstack([X_star,u_star])
    triplets_all = jax.random.choice(key = jax.random.PRNGKey(0), a = triplets_fine, shape = (N_all,), replace=False)
    
    # Collocation point set
    tx_all = triplets_all[:,:2]
    print(tx_all.shape)


    N_obs = n_obs
    triplets_obs = jax.random.choice(key = jax.random.PRNGKey(run), a = triplets_all, shape = (N_obs,), replace=False)
    # triplets_obs = triplets_all[idx_obs,:] # Choose data point set from collocation point set
    # Data point set
    tx_obs = triplets_obs[:,:2]
    u_obs = triplets_obs[:,-1]

    # Run SINDy method

    u_operators = (eval_k,)
    feature_operators = (eval_k,dt_k,dx_k,dxx_k)

    def param_ani_gaussian_RBF(x,y,params):
        lengthscales = log1pexp(params)
        return get_anisotropic_gaussianRBF(1.,jnp.diag(lengthscales))(x,y)

    fitted_params,ml_value = fit_kernel_params(param_ani_gaussian_RBF,tx_obs,u_obs,jnp.zeros(2))
    ML_lengthscales = log1pexp(fitted_params)

    k_u = get_anisotropic_gaussianRBF(1.,jnp.diag(log1pexp(fitted_params)))

    # print(1/(jnp.sqrt(ML_lengthscales)))

    u_model = CholInducedRKHS(
        tx_obs,
        u_operators,
        k_u,
        nugget_size = 1e-8
        )

    u_params = u_model.get_fitted_params(tx_obs,u_obs)

    # S = (
    #     (u_model.evaluate_operators(feature_operators,tx_all,u_params))
    #     .reshape(
    #             len(tx_all), 
    #             len(feature_operators),
    #             order = 'F'
    #         )
    # )

    # U_t = S[:,1]
    # UU_x = jnp.multiply(S[:,0],S[:,2])
    # U_xx = S[:,3]

    # Theta = jnp.vstack([UU_x, U_xx]).T
    
    # res = sparsifyDynamics(Theta,U_t.reshape(-1,1),lamb = 1e-8, n = 1)
    #print(f'SINDy recovered equation: u_t = {round(res[0][0],4)}uu_x + {round(res[1][0],4)}u_xx')
    
    # u_true 
    u_true = u_star.flatten()
    # get error
    error_u_field = get_nrmse(u_true, u_model.point_evaluate(X_star,u_params))
    
    return error_u_field

In [8]:
res = get_error_SINDy(100,800,1)

(800, 2)


RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

In [None]:
res