In [46]:
import jax
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4"
jax.config.update('jax_enable_x64',True)
# jax.config.update('jax_default_device',jax.devices()[4])

from jax.random import key
from scipy.integrate import solve_ivp
from tqdm.auto import tqdm
from exp.expdata import LorenzExp
import jax.numpy as jnp
import matplotlib.pyplot as plt
from exp.metrics import coeff_metrics, data_metrics
plt.style.use("ggplot")

from jsindy.sindy_model import JSINDyModel
from jsindy.util import get_collocation_points_weights
from jsindy.trajectory_model import DataAdaptedRKHSInterpolant,CholDataAdaptedRKHSInterpolant
from jsindy.dynamics_model import FeatureLinearModel, PolyLib
from jsindy.optim import AlternatingActiveSetLMSolver, LMSettings
from jsindy.optim.solvers.alt_active_set_lm_solver import pySindySparsifier
from pysindy import STLSQ,SSR,MIOSR
from jsindy.kernels import ConstantKernel, ScalarMaternKernel
import pickle
from pathlib import Path
import pysindy as ps

In [47]:
x0 = jnp.array([-8, 8, 27.])
dt = 0.01
t0=0
t1=10.1
n_colloc = 505

expdata = LorenzExp(
    dt = dt,
    initial_state=x0,
    feature_names=['x','y','z'],
    t0=t0,
    t1=t1,
    n_colloc=n_colloc
)

tEndL = jnp.arange(4.0, 11.0, 1.0)
epsL = jnp.arange(0.025, 0.401, 0.025)

t_true = expdata.t_true
X_true = expdata.x_true

cutoff = 1
signal_power = jnp.std(X_true)
n_colloc = 500


In [113]:
tend = 8
noise_ratio = 0.2
rkey = jax.random.key(12038)
t_end_idx = int(tend // dt)
X_train = X_true[:t_end_idx]
t_train = t_true[:t_end_idx]

t_colloc, w_colloc = get_collocation_points_weights(t_train,n_colloc)

eps = noise_ratio*signal_power


noise = eps*jax.random.normal(rkey, X_train.shape)

X_train = X_train + noise

kernel = (
	ConstantKernel(variance = 5.)
	+ScalarMaternKernel(p = 5,variance = 10., lengthscale=3,min_lengthscale=0.05)
)   
trajectory_model = CholDataAdaptedRKHSInterpolant(kernel=kernel)
dynamics_model = FeatureLinearModel(
	reg_scaling = 1.,
	feature_map=PolyLib(degree=2)
)
optsettings = LMSettings(
	max_iter = 1000,
	no_tqdm=True,
	min_alpha = 1e-16,
	init_alpha = 5.,
	print_every = 100,
	show_progress = True,
)
data_weight =  1
colloc_weight = 1e5

# pysindy_opt = STLSQ(threshold = 0.3,alpha = 1.)
pysindy_opt = ps.SR3(
    regularizer = 'L0',
    reg_weight_lam=3.,
    relax_coeff_nu=0.1,
    unbias=False,
    max_iter = 200
)
pysindy_opt = ps.SR3(
    regularizer = 'L1',
    reg_weight_lam=5.,
    relax_coeff_nu=0.05,
    unbias=False,
    max_iter = 1000,
)

sparsifier = pySindySparsifier(
	pysindy_opt
	)


optimizer = AlternatingActiveSetLMSolver(
		beta_reg=1e-4,
		solver_settings=optsettings,
		fixed_colloc_weight=colloc_weight,
		fixed_data_weight=data_weight,
		sparsifier = sparsifier
		)

model = JSINDyModel(
	trajectory_model=trajectory_model,
	dynamics_model=dynamics_model,
	optimizer=optimizer,
	feature_names=expdata.feature_names
)

model.fit(t_train, X_train,t_colloc=t_colloc)

{'show_progress': True, 'sigma2_est': Array(6.47258904, dtype=float64), 'data_weight': 1, 'colloc_weight': 100000.0}
Warm Start
Iteration 0, loss = 8.721e+05, gradnorm = 3.644e+07, alpha = 4.167, improvement_ratio = 0.9725
Iteration 1, loss = 9.542e+04, gradnorm = 7.592e+07, alpha = 3.472, improvement_ratio = 0.8987
Iteration 2, loss = 4.819e+04, gradnorm = 2.109e+07, alpha = 3.472, improvement_ratio = 0.5384
Iteration 3, loss = 8.49e+03, gradnorm = 3.383e+07, alpha = 2.894, improvement_ratio = 0.98
Iteration 4, loss = 7.822e+03, gradnorm = 1.559e+06, alpha = 2.411, improvement_ratio = 0.8095
Iteration 5, loss = 7.7e+03, gradnorm = 6.686e+05, alpha = 2.411, improvement_ratio = 0.735
Line Search Failed!
Final Iteration Results
Iteration 91, loss = 7.64e+03, gradnorm = 0.06084, alpha = 3.99e+05, improvement_ratio = -inf
Model after smooth warm start
(x)' = -11.947 1 + -11.670 x + 12.066 y + 1.282 z + 0.226 x^2 + -0.175 x y + 0.052 x z + 0.019 y^2 + -0.072 y z + -0.039 z^2
(y)' = 4.436 1 

In [114]:
model.print()

(x)' = -9.892 x + 10.027 y
(y)' = 27.992 x + -1.086 y + -0.994 x z
(z)' = -2.672 z + 1.014 x y


In [118]:
error = jnp.linalg.norm(model.theta - expdata.true_coeff.T)/jnp.linalg.norm(expdata.true_coeff.T)

In [119]:
jnp.log(error)

Array(-5.40351443, dtype=float64)

In [50]:
import pysindy as ps

n = 500
t = jnp.linspace(0,5,500)
xdot = model.traj_model.derivative(t,model.z)
X = model.predict_state(t,model.z)

features = model.dynamics_model.feature_map(X)

In [112]:
import pysindy as ps
pysindy_opt = ps.SR3(
    regularizer = 'L1',
    reg_weight_lam=5.,
    relax_coeff_nu=0.05,
    unbias=False,
    max_iter = 1000,
)
pysindy_opt.fit(features,xdot)
model.print(pysindy_opt.coef_.T)

(x)' = -9.536 x + 9.701 y
(y)' = 27.671 x + -0.848 y + -0.741 x z
(z)' = -2.464 z + 0.768 x y


In [None]:
import pysindy as ps

n = 500
t = jnp.linspace(0,5,500)
xdot = model.traj_model.derivative(t,model.z)
X = model.predict_state(t,model.z)

features = model.dynamics_model.feature_map(X)
normalizers = jnp.linalg.norm(features,axis=0)
A = features/normalizers

xdot_norm = jnp.linalg.norm(xdot,axis=0)
B = xdot/xdot_norm

In [191]:
opt = ps.STLSQ(threshold = 100,alpha =1e-5,normalize_columns=True)
opt.fit(features,xdot)
theta = opt.coef_.T
# theta = opt.coef_.T
model.print(theta)

(x)' = -9.905 x + 10.233 y
(y)' = 27.084 x + -1.026 y + -0.958 x z
(z)' = -2.636 z + 1.018 x y


In [None]:
opt = ps.STLSQ(threshold=0.1,alpha = 0.)
# opt = ps.SSR(alpha = 0.00001,criteria = 'model_residual')

opt.fit(A,B)
theta = xdot_norm*opt.coef_.T/normalizers[:,None]
# theta = opt.coef_.T
model.print(theta)

(x)' = -9.905 x + 10.233 y
(y)' = 24.624 x + -0.910 x z
(z)' = -2.636 z + 1.018 x y


In [157]:
jnp.linalg.norm(B[:,0])

Array(1., dtype=float64)

In [153]:
opt.coef_.T/normalizers[:,None]

Array([[ 0.        ,  0.        ,  0.        ],
       [-9.90542068, 27.08435191,  0.        ],
       [10.23281304, -1.026233  ,  0.        ],
       [ 0.        ,  0.        , -2.6356656 ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  1.01809876],
       [ 0.        , -0.95838057,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ]], dtype=float64)

In [29]:
x,resid,_,_ = jnp.linalg.lstsq(A[:,1:],A[:,0])

In [35]:
jnp.sqrt(resid)/jnp.linalg.norm(A[:,0])

Array([0.09067546], dtype=float64)

In [30]:
metrics = {}

metrics["coeff_mets"] = coeff_metrics(
	coeff_est=model.theta,
	coeff_true=expdata.true_coeff.T
)
metrics["theta"] = model.theta
metrics['noise_ratio'] = noise_ratio
metrics['t_end'] = tend

In [31]:
metrics

{'coeff_mets': {'precision': 0.7,
  'recall': 1.0,
  'f1': 0.8235294117647058,
  'coeff_rel_l2': 0.15387476680622142,
  'coeff_rmse': 0.8857766841141267,
  'coeff_mae': 0.30595960121342514},
 'theta': Array([[-4.00192846,  2.11280467, -1.47178993],
        [-9.90613514, 27.15824938,  0.        ],
        [10.38740817, -1.12531337,  0.        ],
        [ 0.        ,  0.        , -2.58052028],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  1.01803583],
        [ 0.        , -0.96025426,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]], dtype=float64),
 'noise_ratio': 0.4,
 't_end': 5}