In [2]:
import jax
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4"
jax.config.update('jax_enable_x64',True)
# jax.config.update('jax_default_device',jax.devices()[4])

from jax.random import key
from scipy.integrate import solve_ivp
from tqdm.auto import tqdm
from exp.expdata import LorenzExp
import jax.numpy as jnp
import matplotlib.pyplot as plt
from exp.metrics import coeff_metrics, data_metrics
plt.style.use("ggplot")

from jsindy.sindy_model import JSINDyModel
from jsindy.util import get_collocation_points_weights
from jsindy.trajectory_model import DataAdaptedRKHSInterpolant,CholDataAdaptedRKHSInterpolant
from jsindy.dynamics_model import FeatureLinearModel, PolyLib
from jsindy.optim import AlternatingActiveSetLMSolver, LMSettings
from jsindy.optim.solvers.alt_active_set_lm_solver import pySindySparsifier
from pysindy import STLSQ,SSR,MIOSR
from jsindy.kernels import ConstantKernel, ScalarMaternKernel
import pickle
from pathlib import Path


In [3]:
x0 = jnp.array([-8, 8, 27.])
dt = 0.01
t0=0
t1=10.1
n_colloc = 505

expdata = LorenzExp(
    dt = dt,
    initial_state=x0,
    feature_names=['x','y','z'],
    t0=t0,
    t1=t1,
    n_colloc=n_colloc
)

tEndL = jnp.arange(4.0, 11.0, 1.0)
epsL = jnp.arange(0.025, 0.401, 0.025)

t_true = expdata.t_true
X_true = expdata.x_true

cutoff = 1
signal_power = jnp.std(X_true)
n_colloc = 500


In [5]:
tend = 5
noise_ratio = 0.4
rkey = jax.random.key(12038)
t_end_idx = int(tend // dt)
X_train = X_true[:t_end_idx]
t_train = t_true[:t_end_idx]

t_colloc, w_colloc = get_collocation_points_weights(t_train,n_colloc)

eps = noise_ratio*signal_power


noise = eps*jax.random.normal(rkey, X_train.shape)

X_train = X_train + noise

kernel = (
	ConstantKernel(variance = 5.)
	+ScalarMaternKernel(p = 5,variance = 10., lengthscale=3,min_lengthscale=0.05)
)   
trajectory_model = CholDataAdaptedRKHSInterpolant(kernel=kernel)
dynamics_model = FeatureLinearModel(
	reg_scaling = 1.,
	feature_map=PolyLib(degree=2)
)
optsettings = LMSettings(
	max_iter = 1000,
	no_tqdm=True,
	min_alpha = 1e-16,
	init_alpha = 5.,
	print_every = 100,
	show_progress = True,
)
data_weight =  100#1/(eps**2)
colloc_weight = 1e5

pysindy_opt = STLSQ(threshold = 0.4,alpha = 0.2)
sparsifier = pySindySparsifier(
	pysindy_opt
	)


optimizer = AlternatingActiveSetLMSolver(
		beta_reg=1e-4,
		solver_settings=optsettings,
		fixed_colloc_weight=colloc_weight,
		fixed_data_weight=data_weight,
		sparsifier = sparsifier
		)

model = JSINDyModel(
	trajectory_model=trajectory_model,
	dynamics_model=dynamics_model,
	optimizer=optimizer,
	feature_names=expdata.feature_names
)

model.fit(t_train, X_train,t_colloc=t_colloc)

{'show_progress': True, 'sigma2_est': Array(25.69798702, dtype=float64), 'data_weight': 100, 'colloc_weight': 100000.0}
Warm Start
Iteration 0, loss = 3.348e+07, gradnorm = 7.64e+07, alpha = 5.0, improvement_ratio = 0.5046
Iteration 1, loss = 3.727e+06, gradnorm = 1.849e+08, alpha = 4.167, improvement_ratio = 0.9416
Iteration 2, loss = 2.06e+06, gradnorm = 6.491e+07, alpha = 3.472, improvement_ratio = 0.905
Iteration 3, loss = 2.033e+06, gradnorm = 4.585e+07, alpha = 1.216e+03, improvement_ratio = 0.1536
Iteration 4, loss = 1.945e+06, gradnorm = 3.604e+07, alpha = 100.0, improvement_ratio = 0.5906
Iteration 5, loss = 1.929e+06, gradnorm = 1.153e+07, alpha = 2.563e+03, improvement_ratio = 0.2567
Iteration 100, loss = 1.883e+06, gradnorm = 1.334e+05, alpha = 405.0, improvement_ratio = 0.09176
Iteration 200, loss = 1.882e+06, gradnorm = 2.428e+05, alpha = 5.307, improvement_ratio = 0.4787
Line Search Failed!
Final Iteration Results
Iteration 270, loss = 1.879e+06, gradnorm = 0.1385, alpha

In [6]:
model.print()

(x)' = -4.002 1 + -9.906 x + 10.387 y
(y)' = 2.113 1 + 27.158 x + -1.125 y + -0.960 x z
(z)' = -1.472 1 + -2.581 z + 1.018 x y


In [95]:
n = 500
t = jnp.linspace(0,5,500)
xdot = model.traj_model.derivative(t,model.z)
X = model.predict_state(t,model.z)


A = model.dynamics_model.feature_map(X)
normalizers = jnp.linalg.norm(A,axis=0)
A = A/normalizers

In [117]:
import pysindy as ps
opt = ps.STLSQ(threshold=100.,alpha = 0.)
opt.fit(A,xdot)

theta = opt.coef_.T/normalizers[:,None]
# theta = opt.coef_.T
model.print(theta)

(x)' = -9.905 x + 10.233 y
(y)' = 27.084 x + -1.026 y + -0.958 x z
(z)' = -2.636 z + 1.018 x y


In [125]:
def compute_bic(A, B, mask):
	"""
	Compute the Bayesian Information Criterion (BIC) for the least squares fit |A x - B|^2,
	using only the columns of A selected by the boolean mask for each target.

	Parameters
	----------
	A : array-like, shape (n_samples, n_features)
		The design matrix.
	B : array-like, shape (n_samples, n_targets)
		The target matrix.
	mask : array-like, shape (n_features, n_targets)
		Boolean mask indicating which columns of A to use for each target.

	Returns
	-------
	bic : float
		The BIC value for the masked model.
	"""
	n, n_features = A.shape
	_, n_targets = B.shape
	bic = 0.0
	for i in range(n_targets):
		mask_i = mask[:, i]
		A_masked = A[:, mask_i]
		B_i = B[:, i]
		# Solve least squares for this target
		X_hat, resid, rank, s = jnp.linalg.lstsq(A_masked, B_i, rcond=None)
		if resid.size == 0:
			RSS = jnp.sum((A_masked @ X_hat - B_i) ** 2)
		else:
			RSS = jnp.sum(resid)
		k = A_masked.shape[1]
		# BIC for this target
		bic += n * jnp.log(RSS / n) + k * jnp.log(n)
	return bic

In [148]:
mask = (theta!=0)
initial_bic = compute_bic(A,xdot,mask)

active_entries = jnp.where(mask!=0)
removal_bic_vals = []
for i,j in zip(*active_entries):
	mod_mask = mask.at[i,j].set(False)
	removal_bic_vals.append(compute_bic(A,xdot,mod_mask))

inactive_entries = jnp.where(mask==0)
addition_bic_vals = []
for i,j in zip(*inactive_entries):
	mod_mask = mask.at[i,j].set(True)
	addition_bic_vals.append(compute_bic(A,xdot,mod_mask))

In [149]:
initial_bic

Array(1398.06082333, dtype=float64)

In [150]:
addition_bic_vals

[Array(-2536.09207331, dtype=float64),
 Array(-1313.91870381, dtype=float64),
 Array(193.86202266, dtype=float64),
 Array(1403.60529016, dtype=float64),
 Array(1394.74464929, dtype=float64),
 Array(378.21688324, dtype=float64),
 Array(442.10720563, dtype=float64),
 Array(1175.62109848, dtype=float64),
 Array(1156.99174836, dtype=float64),
 Array(1042.58219352, dtype=float64),
 Array(1234.19416599, dtype=float64),
 Array(1206.19122905, dtype=float64),
 Array(1403.11019332, dtype=float64),
 Array(1395.46178158, dtype=float64),
 Array(1222.01165383, dtype=float64),
 Array(1189.5745105, dtype=float64),
 Array(1220.57835236, dtype=float64),
 Array(1404.07021544, dtype=float64),
 Array(1402.6272446, dtype=float64),
 Array(1404.21244596, dtype=float64),
 Array(875.4626167, dtype=float64),
 Array(904.76570516, dtype=float64),
 Array(496.36047039, dtype=float64)]

In [147]:
mask

Array([[False, False, False],
       [ True,  True, False],
       [ True,  True, False],
       [False, False,  True],
       [False, False, False],
       [False, False,  True],
       [False,  True, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

In [66]:
normalizers[:,None]*opt.coef_.T

Array([[-2.00038488e+03,  1.12066567e+03, -6.63157518e+02],
       [-3.22109141e+05,  8.83205530e+05, -5.14116645e+02],
       [ 4.32292400e+05, -4.62839424e+04, -7.36372294e+02],
       [-1.74593980e+02, -1.75516043e+03, -8.21794129e+05],
       [ 4.52420677e+03,  7.40463816e+03, -1.51603333e+03],
       [-6.42732345e+02, -2.84350932e+03,  5.64678247e+06],
       [-6.56864375e+03, -3.05323977e+07,  2.05271232e+04],
       [-2.17009786e+02,  4.35982568e+03, -0.00000000e+00],
       [ 0.00000000e+00, -1.73121578e+04,  1.46347930e+04],
       [-1.68820627e+04, -4.63694400e+04,  5.66447113e+04]],      dtype=float64)

In [29]:
x,resid,_,_ = jnp.linalg.lstsq(A[:,1:],A[:,0])

In [35]:
jnp.sqrt(resid)/jnp.linalg.norm(A[:,0])

Array([0.09067546], dtype=float64)

In [30]:
metrics = {}

metrics["coeff_mets"] = coeff_metrics(
	coeff_est=model.theta,
	coeff_true=expdata.true_coeff.T
)
metrics["theta"] = model.theta
metrics['noise_ratio'] = noise_ratio
metrics['t_end'] = tend

In [31]:
metrics

{'coeff_mets': {'precision': 0.7,
  'recall': 1.0,
  'f1': 0.8235294117647058,
  'coeff_rel_l2': 0.15387476680622142,
  'coeff_rmse': 0.8857766841141267,
  'coeff_mae': 0.30595960121342514},
 'theta': Array([[-4.00192846,  2.11280467, -1.47178993],
        [-9.90613514, 27.15824938,  0.        ],
        [10.38740817, -1.12531337,  0.        ],
        [ 0.        ,  0.        , -2.58052028],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  1.01803583],
        [ 0.        , -0.96025426,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]], dtype=float64),
 'noise_ratio': 0.4,
 't_end': 5}