In [1]:
import jax
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4"
jax.config.update('jax_enable_x64',True)
# jax.config.update('jax_default_device',jax.devices()[4])

from jax.random import key
from scipy.integrate import solve_ivp
from tqdm.auto import tqdm
from exp.expdata import LorenzExp
import jax.numpy as jnp
import matplotlib.pyplot as plt
from exp.metrics import coeff_metrics, data_metrics
plt.style.use("ggplot")

from jsindy.sindy_model import JSINDyModel
from jsindy.util import get_collocation_points_weights
from jsindy.trajectory_model import DataAdaptedRKHSInterpolant,CholDataAdaptedRKHSInterpolant
from jsindy.dynamics_model import FeatureLinearModel, PolyLib
from jsindy.optim import AlternatingActiveSetLMSolver, LMSettings
from jsindy.optim.solvers.alt_active_set_lm_solver import pySindySparsifier
from pysindy import STLSQ,SSR,MIOSR
from jsindy.kernels import ConstantKernel, ScalarMaternKernel
import pickle
from pathlib import Path


In [2]:
x0 = jnp.array([-8, 8, 27.])
dt = 0.01
t0=0
t1=10.1
n_colloc = 505

expdata = LorenzExp(
    dt = dt,
    initial_state=x0,
    feature_names=['x','y','z'],
    t0=t0,
    t1=t1,
    n_colloc=n_colloc
)

tEndL = jnp.arange(4.0, 11.0, 1.0)
epsL = jnp.arange(0.025, 0.401, 0.025)

t_true = expdata.t_true
X_true = expdata.x_true

cutoff = 1
signal_power = jnp.std(X_true)
n_colloc = 500


In [3]:
tend = 5
noise_ratio = 0.1
rkey = jax.random.key(12038)
t_end_idx = int(tend // dt)
X_train = X_true[:t_end_idx]
t_train = t_true[:t_end_idx]

t_colloc, w_colloc = get_collocation_points_weights(t_train,n_colloc)

eps = noise_ratio*signal_power


noise = eps*jax.random.normal(rkey, X_train.shape)

X_train = X_train + noise

kernel = (
	ConstantKernel(variance = 5.)
	+ScalarMaternKernel(p = 5,variance = 10., lengthscale=3,min_lengthscale=0.05)
)   
trajectory_model = CholDataAdaptedRKHSInterpolant(kernel=kernel)
dynamics_model = FeatureLinearModel(
	reg_scaling = 1.,
	feature_map=PolyLib(degree=2,include_bias=False)
)
optsettings = LMSettings(
	max_iter = 1000,
	no_tqdm=True,
	min_alpha = 1e-16,
	init_alpha = 5.,
	print_every = 100,
	show_progress = True,
)
data_weight =  1.
colloc_weight = 1e5

pysindy_opt = STLSQ(threshold = 0.2,alpha = 0.05)
sparsifier = pySindySparsifier(
	pysindy_opt
	)


optimizer = AlternatingActiveSetLMSolver(
		beta_reg=1e-3,
		solver_settings=optsettings,
		fixed_colloc_weight=colloc_weight,
		fixed_data_weight=data_weight,
		sparsifier = sparsifier
		)

model = JSINDyModel(
	trajectory_model=trajectory_model,
	dynamics_model=dynamics_model,
	optimizer=optimizer,
	feature_names=expdata.feature_names
)

model.fit(t_train, X_train,t_colloc=t_colloc)

{'show_progress': True, 'sigma2_est': Array(1.60652631, dtype=float64), 'data_weight': 1.0, 'colloc_weight': 100000.0}
Warm Start
Iteration 0, loss = 8.306e+04, gradnorm = 4.266e+07, alpha = 4.167, improvement_ratio = 0.9939
Iteration 1, loss = 1.356e+04, gradnorm = 2.44e+07, alpha = 3.472, improvement_ratio = 0.849
Iteration 2, loss = 1.23e+03, gradnorm = 1.691e+07, alpha = 2.894, improvement_ratio = 0.9968
Iteration 3, loss = 1.195e+03, gradnorm = 5.717e+05, alpha = 2.894, improvement_ratio = 0.776
Iteration 4, loss = 1.185e+03, gradnorm = 1.898e+05, alpha = 2.411, improvement_ratio = 0.8191
Iteration 5, loss = 1.183e+03, gradnorm = 5.526e+04, alpha = 2.411, improvement_ratio = 0.6373
Line Search Failed!
Final Iteration Results
Iteration 76, loss = 1.179e+03, gradnorm = 0.004093, alpha = 3.325e+05, improvement_ratio = -inf
Model after smooth warm start
(x)' = -9.735 x + 10.557 y + -0.163 z + -0.008 x^2 + 0.017 x y + -0.003 x z + -0.015 y^2 + -0.023 y z + 0.005 z^2
(y)' = 27.693 x + -

In [7]:
metrics = {}

metrics["coeff_mets"] = coeff_metrics(
	coeff_est=model.theta,
	coeff_true=expdata.true_coeff.T[1:]
)
metrics["theta"] = model.theta
metrics['noise_ratio'] = noise_ratio
metrics['t_end'] = tend

In [10]:
metrics

{'coeff_mets': {'precision': 1.0,
  'recall': 1.0,
  'f1': 1.0,
  'coeff_rel_l2': 0.004344839773282527,
  'coeff_rmse': 0.02636388127885087,
  'coeff_mae': 0.008701504177452404},
 'theta': Array([[-9.92647428, 27.88702444,  0.        ],
        [10.00982559, -1.01935785,  0.        ],
        [ 0.        ,  0.        , -2.65983303],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  1.0052542 ],
        [ 0.        , -0.99283195,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]], dtype=float64),
 'noise_ratio': 0.1,
 't_end': 5}