In [1]:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_device", jax.devices()[1])

import numpy as np
import matplotlib.pyplot as plt
import jax
from tqdm.auto import tqdm
plt.style.use("ggplot")

from importlib import reload
import KernelTools
reload(KernelTools)
from KernelTools import *
from EquationModel import OperatorModel,SplitOperatorPDEModel,OperatorPDEModel
from evaluation_metrics import compute_results    
from data_utils import MinMaxScaler
from evaluation_metrics import get_nrmse

from Kernels import log1pexp,inv_log1pexp
from Kernels import (
    get_centered_scaled_poly_kernel,
    get_anisotropic_gaussianRBF,
    fit_kernel_params
)
from EquationModel import CholInducedRKHS, CholOperatorModel, OperatorPDEModel
from functools import partial

import Optimizers
import importlib
importlib.reload(Optimizers)
from Optimizers import CholeskyLM,SVD_LM

# import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from scipy.interpolate import griddata
from scipy.spatial import distance
from matplotlib import cm
import time
from mpl_toolkits.mplot3d import Axes3D
# from pyDOE import lhs
# #    import sobol_seq
import os


In [9]:
from lineax import CG
import lineax

x,sol = jax.scipy.sparse.linalg.cg(lambda x:jnp.linspace(1,100,10)*x,jnp.ones(10),jnp.zeros(10),maxiter = 1)

: 

In [4]:
def get_data_rand_coll(n_coll, n_obs,seed, data_for_pinn = False): 
    '''
    n_coll (int) : Integer less than 101 x 256.
    n_obs (int)    : Integet less than 101 x 256.
    seed (int)     : Seed to choose data point set.
    '''     
    # Generate data
    data = scipy.io.loadmat('/home/juanfelipe/Desktop/research/keql/examples/burgers/data/burgers.mat')
    # t
    t = jnp.real(data['t'].flatten()[:,None])
    # # Scale t
    # scaler_t = MinMaxScaler()
    # t = scaler_t.fit_transform(t)
    # x
    x = np.real(data['x'].flatten()[:,None])
    # # Scale x
    # scaler_x = MinMaxScaler()
    # x = scaler_x.fit_transform(x)
    # u true values
    Exact = np.real(data['usol'])

    # Fine meshgrid
    T, X = np.meshgrid(t,x)

    # Fine pairs (t,x)
    X_star = np.hstack((T.flatten()[:,None], X.flatten()[:,None]))
    # Fine u values
    u_star = Exact.flatten()[:,None]
    
    # Triples at collocation point set
    N_all = n_coll
    triplets_fine = np.hstack([X_star,u_star])
    triplets_all = jax.random.choice(key = jax.random.PRNGKey(0), a = triplets_fine, shape = (N_all,), replace=False)
    
    # Collocation point set
    tx_all = triplets_all[:,:2]


    N_obs = n_obs
    triplets_obs = jax.random.choice(key = jax.random.PRNGKey(seed), a = triplets_fine, shape = (N_obs,), replace=False)
    # triplets_obs = triplets_all[idx_obs,:] # Choose data point set from full point set
    # Data point set
    tx_obs = triplets_obs[:,:2]
    u_obs = triplets_obs[:,-1]

    u_star = triplets_fine[:,-1]

    # Invert them to be ready for PINNSR
    if data_for_pinn:
        tx_train = tx_train.at[:,[1,0]].set(tx_train[:,[0,1]])

        tx_val = tx_val.at[:,[1,0]].set(tx_val[:,[0,1]])

        tx_all = tx_all.at[:,[1,0]].set(tx_all[:,[0,1]])

        X_star = X_star.at[:,[1,0]].set(X_star[:,[0,1]])

        triplets_fine = triplets_fine.at[:,[1,0]].set(triplets_fine[:,[0,1]])
    
    return tx_obs, u_obs, tx_all, u_star, X_star

In [78]:
n_coll_t = 30
n_coll_x=30
n_obs = 500
run = 100

tx_obs, u_obs, tx_all, u_star, X_star = (
    get_data_rand_coll(n_coll = 1000,n_obs = n_obs,seed=run)
)
tx_all = jnp.vstack([tx_all,jnp.vstack([jnp.zeros(30),jnp.linspace(-8,8,30)]).T])

# # if using val
# tx_obs = jnp.vstack([tx_train,tx_val])
# u_obs  = jnp.concatenate([u_train,u_val]).flatten()    

# Run 1_5 step method

u_operators = (eval_k,)
feature_operators = (eval_k,dx_k,dxx_k)

# Choose u kernel
def param_ani_gaussian_RBF(x,y,params):
    lengthscales = log1pexp(params)
    return get_anisotropic_gaussianRBF(1.,jnp.diag(lengthscales))(x,y)

fitted_params,ml_value = fit_kernel_params(param_ani_gaussian_RBF,tx_obs,u_obs,jnp.zeros(2))
ML_lengthscales = log1pexp(fitted_params)
print(1/(jnp.sqrt(ML_lengthscales)))
k_u = get_anisotropic_gaussianRBF(1.,jnp.diag(jnp.array([1.,1.])))

# RKHS class for u
u_model = CholInducedRKHS(
    tx_all,
    u_operators,
    k_u,
    nugget_size = 1e-8
    )
u_params_init = u_model.get_fitted_params(tx_obs,u_obs)

grid_features_init = (
    (u_model.evaluate_operators(feature_operators,tx_all,u_params_init))
    .reshape(
            len(tx_all),
            len(feature_operators),
            order = 'F'
        )
)

# Choose kernel for P
k_P_u_part = get_centered_scaled_poly_kernel(2,grid_features_init,c=1.,scaling = 'diagonal')
@vectorize_kfunc
def k_P(x,y):
    return k_P_u_part(x[2:],y[2:])
P_model = OperatorModel(k_P)

[1.75730243 0.7683696 ]


In [79]:
from itertools import combinations_with_replacement

def get_monomial_feature_list(dim,degree):
    return sum([list(combinations_with_replacement(range(dim),r = d)) for d in range(1,degree+1)],[])

def get_polynomial_feature_map(monomial_feature_list):
    def poly_features(X):
        return jnp.array([jnp.ones(len(X))]+[jnp.prod(X[:,inds],axis=1) for inds in monomials]).T
    return poly_features

In [80]:
class PolyFeatures():
    def __init__(
        self,
        degree,
        processed_dimension,
        preprocessor = lambda x:x,
        data_to_scale = None,
    ):
        self.degree = degree
        self.input_dimension = processed_dimension
        self.preprocessor = preprocessor
        self.monomials = get_monomial_feature_list(processed_dimension,degree = degree)
        self.feature_map = get_polynomial_feature_map(self.monomials)
        self.feature_dim = len(self.monomials)+1

    def predict(self,features,params):
        return self.feature_map(self.preprocessor(features))@params
    
    def rkhs_mat(self,X):
        return jnp.identity(self.feature_dim)


class OperatorModel():
    def __init__(
        self,
        kernel,
        nugget_size = 1e-7
    ):
        self.kernel_function = kernel
        self.nugget_size = nugget_size

    def predict(self,input_data,params):
        K = self.kernel_function(input_data,input_data)
        return K@params
        
    def predict_new(self,X,anchors,params):
        return self.kernel_function(X,anchors)@params
    
    def fit_params(self,X,y,nugget = 1e-8):
        K = self.kernel_function(X,X)
        return jnp.linalg.solve(K + nugget * diagpart(K),y)
    
    def rkhs_mat(self,X):
        return self.kernel_function(X,X)


In [81]:
PolyModel = PolyFeatures(degree = 2,processed_dimension=3,preprocessor = lambda x:x[:,2:])

# Equation model that has u and P object
EqnModel = SplitOperatorPDEModel(
    PolyModel,
    (u_model,),
    (tx_obs,),
    (u_obs,),
    (tx_all,),
    feature_operators,
    rhs_operator=dt_k,
    datafit_weight = 100,
    num_P_operator_params=PolyModel.feature_dim
)

# Optimize - LM
params_init = jnp.hstack([u_params_init,jnp.zeros(PolyModel.feature_dim)])
params,convergence_data = CholeskyLM(
    params_init.copy(),
    EqnModel,
    beta = 1e-11,
    max_iter = 501,
    init_alpha=0.1,
    line_search_increase_ratio=1.4,
    print_every = 100
)
p_adjusted,refine_convergence_data = SVD_LM(params,EqnModel,1e-3,100,print_every = 10,overall_regularization=1e-13)
# u_params
u_sol = p_adjusted[:u_model.num_params]
# u_true 
u_true = u_star.flatten()
# get error
error_u_field = get_nrmse(u_true, u_model.point_evaluate(X_star,u_sol))

  0%|          | 0/501 [00:00<?, ?it/s]

Iteration 0, loss = 0.002082, Jres = 0.01379, alpha = 0.08333, improvement_ratio = 0.9981
Iteration 1, loss = 0.001558, Jres = 0.008866, alpha = 0.06944, improvement_ratio = 1.0
Iteration 2, loss = 0.001265, Jres = 0.005981, alpha = 0.05787, improvement_ratio = 1.001
Iteration 3, loss = 0.001095, Jres = 0.004096, alpha = 0.04823, improvement_ratio = 1.0
Iteration 4, loss = 0.0009908, Jres = 0.002872, alpha = 0.04019, improvement_ratio = 1.0
Iteration 5, loss = 0.0009213, Jres = 0.002075, alpha = 0.03349, improvement_ratio = 1.001
Iteration 100, loss = 3.245e-05, Jres = 1.964e-06, alpha = 8.333e-07, improvement_ratio = 1.003
Iteration 200, loss = 2.413e-05, Jres = 4.27e-07, alpha = 8.333e-07, improvement_ratio = 1.001
Iteration 300, loss = 2.096e-05, Jres = 1.969e-07, alpha = 8.333e-07, improvement_ratio = 1.0
Iteration 400, loss = 1.917e-05, Jres = 1.375e-07, alpha = 8.333e-07, improvement_ratio = 1.0
Iteration 500, loss = 1.799e-05, Jres = 1.101e-07, alpha = 8.333e-07, improvement_rat

  0%|          | 0/100 [00:00<?, ?it/s]

Iteration 0, loss = 1.79543293061942e-05
Iteration 10, loss = 1.795317258556761e-05
Iteration 20, loss = 1.791252347926775e-05
Iteration 30, loss = 1.6750273488753976e-05
Iteration 40, loss = 1.1170969972577893e-05
Iteration 50, loss = 7.243154683124377e-06
Iteration 60, loss = 5.786375468752498e-06
Iteration 60 Step Failed
Iteration 61 Step Failed
Iteration 62 Step Failed
Iteration 63 Step Failed
Iteration 65 Step Failed
Iteration 66 Step Failed
Iteration 68 Step Failed
Iteration 70, loss = 5.786352896709283e-06
Iteration 70 Step Failed
Iteration 71 Step Failed
Iteration 73 Step Failed
Iteration 74 Step Failed
Iteration 75 Step Failed
Iteration 76 Step Failed
Iteration 77 Step Failed
Iteration 78 Step Failed
Iteration 79 Step Failed
Iteration 80, loss = 5.786352766802076e-06
Iteration 80 Step Failed
Iteration 81 Step Failed
Iteration 84 Step Failed
Iteration 85 Step Failed
Iteration 86 Step Failed
Iteration 87 Step Failed
Iteration 88 Step Failed
Iteration 89 Step Failed
Iteration 90,

In [86]:
print(error_u_field)

0.016826812744712125


In [83]:
PolyModel.monomials

[(0,), (1,), (2,), (0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)]

In [87]:
P_coeffs = EqnModel.get_P_params(p_adjusted)

threshold = 3e-2
var_names = ['(u)','(ux)','(uxx)']
eqn = f"ut' = {P_coeffs[0]:.3f} "
for val,term in zip(P_coeffs[1:],PolyModel.monomials):
    if jnp.abs(val)>=threshold:
        equation_addition = f"+ {val:.3f} {''.join([var_names[i] for i in term])} "
        eqn = eqn + equation_addition
print(eqn)

ut' = -0.000 + 0.032 (ux) + 0.097 (uxx) + -1.096 (u)(ux) + -0.059 (ux)(uxx) 


In [97]:
from pysindy.optimizers.stlsq import STLSQ

In [125]:
final_features = EqnModel.single_eqn_features(u_model,u_sol,tx_all)
poly_final = PolyModel.feature_map(final_features)
target_final = EqnModel.apply_rhs_op_single(u_model,u_sol,tx_all)

stlsq = STLSQ(threshold = 1e-2,alpha = 0,max_iter = 100)
stlsq.fit(poly_final,target_final)

poly_coeffs = stlsq.coef_[0]


threshold = 1e-4
var_names = ['(u)','(ux)','(uxx)']
eqn = f"ut' = {poly_coeffs[0]:.3f} "
for val,term in zip(poly_coeffs[1:],PolyModel.monomials):
    if jnp.abs(val)>=threshold:
        equation_addition = f"+ {val:.3f} {''.join([var_names[i] for i in term])} "
        eqn = eqn + equation_addition
print(eqn)

ut' = 0.000 + 1.068 (uxx) + -0.101 (u)(uxx) + 0.214 (ux)(uxx) + -1.011 (uxx)(uxx) 


In [108]:
stlsq.coef_[1:]

array([], shape=(0, 10), dtype=float64)