In [1]:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_device", jax.devices()[1])

import numpy as np
import matplotlib.pyplot as plt
import jax
from tqdm.auto import tqdm
plt.style.use("ggplot")

from importlib import reload
import KernelTools
reload(KernelTools)
from KernelTools import *
from EquationModel import OperatorModel,SplitOperatorPDEModel,OperatorPDEModel,InducedOperatorModel
from evaluation_metrics import compute_results    
from data_utils import MinMaxScaler
from evaluation_metrics import get_nrmse

from Kernels import log1pexp,inv_log1pexp
from Kernels import (
    get_centered_scaled_poly_kernel,
    get_anisotropic_gaussianRBF,
    fit_kernel_params
)
from EquationModel import CholInducedRKHS, CholOperatorModel, OperatorPDEModel
from functools import partial

import Optimizers
import importlib
importlib.reload(Optimizers)
from Optimizers import CholeskyLM,SVD_LM

In [2]:
# import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from scipy.interpolate import griddata
from scipy.spatial import distance
from matplotlib import cm
import time
from mpl_toolkits.mplot3d import Axes3D
# from pyDOE import lhs
# #    import sobol_seq
import os


In [3]:
def get_data_rand_coll(n_coll, n_obs,seed, data_for_pinn = False): 
    '''
    n_coll (int) : Integer less than 101 x 256.
    n_obs (int)    : Integet less than n_coll.
    seed (int)     : Seed to choose data point set.
    '''     
    # Generate data
    data = scipy.io.loadmat('/home/juanfelipe/Desktop/research/keql/examples/burgers/data/burgers.mat')
    # t
    t = jnp.real(data['t'].flatten()[:,None])
    # # Scale t
    # scaler_t = MinMaxScaler()
    # t = scaler_t.fit_transform(t)
    # x
    x = np.real(data['x'].flatten()[:,None])
    # # Scale x
    # scaler_x = MinMaxScaler()
    # x = scaler_x.fit_transform(x)
    # u true values
    Exact = np.real(data['usol'])

    # Fine meshgrid
    T, X = np.meshgrid(t,x)

    # Fine pairs (t,x)
    X_star = np.hstack((T.flatten()[:,None], X.flatten()[:,None]))
    # Fine u values
    u_star = Exact.flatten()[:,None]
    
    # Triples at collocation point set
    N_all = n_coll
    triplets_fine = np.hstack([X_star,u_star])
    triplets_all = jax.random.choice(key = jax.random.PRNGKey(0), a = triplets_fine, shape = (N_all,), replace=False)
    
    # Collocation point set
    tx_all = triplets_all[:,:2]


    N_obs = n_obs
    triplets_obs = jax.random.choice(key = jax.random.PRNGKey(seed), a = triplets_fine, shape = (N_obs,), replace=False)
    # triplets_obs = triplets_all[idx_obs,:] # Choose data point set from collocation point set
    # Data point set
    tx_obs = triplets_obs[:,:2]
    u_obs = triplets_obs[:,-1]

    u_star = triplets_fine[:,-1]

    # Invert them to be ready for PINNSR
    if data_for_pinn:
        tx_train = tx_train.at[:,[1,0]].set(tx_train[:,[0,1]])

        tx_val = tx_val.at[:,[1,0]].set(tx_val[:,[0,1]])

        tx_all = tx_all.at[:,[1,0]].set(tx_all[:,[0,1]])

        X_star = X_star.at[:,[1,0]].set(X_star[:,[0,1]])

        triplets_fine = triplets_fine.at[:,[1,0]].set(triplets_fine[:,[0,1]])
    
    return tx_obs, u_obs, tx_all, u_star, X_star

In [4]:
# n_coll_t = 30
# n_coll_x=30
n_obs = 50
run = 100

n_coll = 2500

tx_obs, u_obs, tx_all, u_star, X_star = (
    get_data_rand_coll(n_coll = n_coll,n_obs = n_obs,seed=run)
)
tx_all = jnp.vstack([tx_all,jnp.vstack([jnp.zeros(30),jnp.linspace(-8,8,30)]).T])

# Run 1_5 step method

u_operators = (eval_k,)
feature_operators = (eval_k,dx_k,dxx_k)

# Choose u kernel
def param_ani_gaussian_RBF(x,y,params):
    lengthscales = log1pexp(params)
    return get_anisotropic_gaussianRBF(1.,jnp.diag(lengthscales))(x,y)

fitted_params,ml_value = fit_kernel_params(param_ani_gaussian_RBF,tx_obs,u_obs,jnp.zeros(2))
ML_lengthscales = log1pexp(fitted_params)
print(1/(jnp.sqrt(ML_lengthscales)))
k_u = get_anisotropic_gaussianRBF(1.,jnp.diag(jnp.array([1.,1.])))

# RKHS class for u
u_model = CholInducedRKHS(
    tx_all,
    u_operators,
    k_u,
    nugget_size = 1e-8
    )
u_params_init = u_model.get_fitted_params(tx_obs,u_obs)

grid_features_init = (
    (u_model.evaluate_operators(feature_operators,tx_all,u_params_init))
    .reshape(
            len(tx_all),
            len(feature_operators),
            order = 'F'
        )
)
grid_features_init = jnp.hstack([tx_all,grid_features_init])
num_P_inducing = 500
P_inducing_points = jax.random.choice(jax.random.PRNGKey(13),grid_features_init,(num_P_inducing,))


# Choose kernel for P
k_P_u_part = get_centered_scaled_poly_kernel(2,grid_features_init[:,2:],c=1.,scaling = 'diagonal')

def k_P(x,y):
    return k_P_u_part(x[2:],y[2:])
P_model = InducedOperatorModel(P_inducing_points,k_P)

# Equation model that has u and P object
EqnModel = SplitOperatorPDEModel(
    P_model,
    (u_model,),
    (tx_obs,),
    (u_obs,),
    (tx_all,),
    feature_operators,
    rhs_operator=dt_k,
    datafit_weight = 100,
    num_P_operator_params=num_P_inducing
)
ut_init = EqnModel.apply_rhs_op_single(u_model,u_params_init,EqnModel.collocation_points[0])
P_params_init = P_model.get_fitted_params(grid_features_init,ut_init,lam = 1e-3)
params_init = jnp.hstack([u_params_init,P_params_init])

[2.03724345 2.01946685]


In [33]:
sketch = jax.random.normal(jax.random.PRNGKey(304),shape = (300,EqnModel.residual_dimension))/jnp.sqrt(300)

#@partial(jax.jit,static_argnames = 'size')

@jax.jit
def sketch_jac(params):
    primals,F_vjp = jax.vjp(EqnModel.F,params)
    SJ = jax.vmap(F_vjp)(sketch)[0]
    return SJ

In [6]:
from jax.scipy.linalg import solve
jit_valgrad = jax.jit(jax.value_and_grad(EqnModel.loss))

params = params_init
num_steps = 500
fvals_SLM = []
for i in tqdm(range(num_steps)):
    val,g = jit_valgrad(params)
    SJ = sketch_jac(params)
    params = params - 0.2 * solve(SJ.T@SJ + 1e-2 * jnp.identity(len(params_init)),g,assume_a = 'pos')
    if i%100==0:
        print(f"Iteration {i}: ",val)
    fvals_SLM.append(val)
for i in tqdm(range(num_steps)):
    val,g = jit_valgrad(params)
    SJ = sketch_jac(params)
    params = params - 0.2 * solve(SJ.T@SJ + 1e-3 * jnp.identity(len(params_init)),g,assume_a = 'pos')
    if i%100==0:
        print(f"Iteration {i}: ",val)
    fvals_SLM.append(val)

fvals_SLM = jnp.array(fvals_SLM)


  0%|          | 0/500 [00:00<?, ?it/s]

Iteration 0:  0.0015574354215858828
Iteration 100:  6.826753958606303e-05
Iteration 200:  3.986134243932362e-05
Iteration 300:  2.784689354185507e-05
Iteration 400:  2.1993528982141498e-05


  0%|          | 0/500 [00:00<?, ?it/s]

Iteration 0:  1.8600368297105102e-05
Iteration 100:  9.388248166332159e-06
Iteration 200:  6.833101337822985e-06
Iteration 300:  5.3798324176885835e-06
Iteration 400:  4.429673614901269e-06


In [7]:
len(params_init)

3030

In [21]:
jax.jvp(EqnModel.F,(params_init,),(params_init,))

(Array([-2.24255836e-10,  1.13307128e-07,  3.94358958e-08, ...,
        -2.77598301e-06, -2.59843180e-06, -2.43500809e-06], dtype=float64),
 Array([-1.46655320e-04, -1.24436270e-01, -1.18475954e-01, ...,
        -2.66793995e-06, -2.43691044e-06, -2.35040780e-06], dtype=float64))

In [29]:
# jax.jacfwd(EqnModel.F)(params_init)

In [50]:
fval,Jfunc = jax.linearize(EqnModel.F,params_init)
J = jax.lax.map(Jfunc,jnp.identity(len(params_init)), batch_size=1000).T

In [None]:
jax.vmap(Jfunc)(jnp.identity(len(params_init))[:500])

In [28]:
# fval,Jfunc = jax.linearize(EqnModel.F,params_init)
# J = jax.vmap(Jfunc)(jnp.identity(len(params_init)))

In [52]:
import jax
import lineax as lx

In [53]:
from jax.scipy.linalg import solve
jit_valgrad = jax.jit(jax.value_and_grad(EqnModel.loss))

params = params_init
val,g = jit_valgrad(params)
SJ = sketch_jac(params)
# step = 
    # params = params - solve(SJ.T@SJ + 1e-2 * jnp.identity(len(params_init)),g,assume_a = 'pos')
    # if i%100==0:
    #     print(f"Iteration {i}: ",val)
    # fvals_SLM.append(val)


In [35]:
alpha = 1e-2
U,sigma,Vt = jnp.linalg.svd(SJ,full_matrices = False)
D = sigma ** 2
precon = lambda x:Vt.T@(D/(1+D/alpha))*(Vt@x)
direct = jnp.linalg.inv(SJ.T@SJ + jnp.identity(len(params))*alpha)

beta = 1.
sv_rescale = (beta * D/(beta+D)).reshape(-1,1)
M = 1/beta * jnp.identity(len(params)) - (1/(beta**2)) * Vt.T@(sv_rescale*Vt)

In [47]:
J.shape

(3030, 2580)

In [76]:
alpha = 1e-4
precon = lambda x:Vt.T@((sigma/(sigma**2+alpha))*(Vt@x))
direct = jnp.linalg.inv(J.T@J + jnp.identity(len(params))*alpha)

#precon_op = lx.FunctionLinearOperator(precon,params_init,tags = lx.positive_semidefinite_tag)
precon_op = lx.MatrixLinearOperator(direct,tags = lx.positive_semidefinite_tag)

J_op = lx.JacobianLinearOperator(lambda x,args:EqnModel.F(x),params)
LM_hess = lx.TaggedLinearOperator(J_op.T@J_op  + alpha * lx.IdentityLinearOperator(params),lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-10, atol=1e-11,stabilise_every = 1)

In [77]:
out = lx.linear_solve(LM_hess, g, solver,options = {'preconditioner':precon_op,'y0':direct@g})

E1022 15:32:54.854709 1968624 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/alexh/miniconda3/envs/keql/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/alexh/miniconda3/envs/keql/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/alexh/miniconda3/envs/keql/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/alexh/miniconda3/envs/keql/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
  File "/home/alexh/miniconda3/envs/keql/lib/python3.12/asyncio/base_events.py", line 641, in run_forever
  File "/home/alexh/miniconda3/envs/keql/lib/python3.12/asyncio/base_events.py", line 1987, in _run_once
  File "/home/alexh/m

EquinoxRuntimeError: Above is the stack outside of JIT. Below is the stack inside of JIT:
  File "/home/alexh/miniconda3/envs/keql/lib/python3.12/site-packages/equinox/internal/_primitive.py", line 148, in _wrapper
    out = rule(*args)
          ^^^^^^^^^^^
  File "/home/alexh/miniconda3/envs/keql/lib/python3.12/site-packages/lineax/_solve.py", line 103, in _linear_solve_impl
    solution, result, stats = result.error_if(
                              ^^^^^^^^^^^^^^^^
equinox.EquinoxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

-------------------

An error occurred during the runtime of your JAX program.

1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful
way to debug such errors. This can be interacted with using most of the usual commands
for the Python debugger: `u` and `d` to move up and down frames, the name of a variable
to print its value, etc.

2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can
(mostly) inspect the state of your program as if it was normal Python.

3) See `https://docs.kidger.site/equinox/api/debug/` for more suggestions.


In [72]:
out.stats

{'max_steps': None, 'num_steps': Array(2, dtype=int64)}

In [63]:
jit_valgrad(params - out.value)[0]

Array(0.00032402, dtype=float64)

In [145]:
jit_valgrad(params)[0]

Array(0.00155744, dtype=float64)

In [146]:
jit_valgrad(params - precon_op.mv(g))

(Array(0.0015487, dtype=float64),
 Array([ 1.71995741e-04,  9.62767443e-04,  2.40964628e-05, ...,
         1.22681015e-09, -8.04242069e-10,  7.89538012e-10], dtype=float64))

In [140]:
jit_valgrad(params - precon_op.mv(g))

(Array(0.0015487, dtype=float64),
 Array([ 1.71995741e-04,  9.62767443e-04,  2.40964628e-05, ...,
         1.22681015e-09, -8.04242069e-10,  7.89538012e-10], dtype=float64))

In [65]:
jit_valgrad(params - g)

(Array(0.00152983, dtype=float64),
 Array([ 3.43140916e-04, -3.55518935e-03, -1.83517427e-05, ...,
         4.40956583e-09, -3.01421386e-09,  1.88361300e-09], dtype=float64))

In [None]:
import jax
import lineax as lx

key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))
D = jnp.linspace(1,10000,10)
def quadratic_fn(y, args):
  return jax.numpy.sum(D*(y - 1)**2)

gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value