In [3]:
import jax.numpy as jnp
import jax
import os
jax.config.update('jax_enable_x64',True)

from jsindy.trajectory_model import DataAdaptedRKHSInterpolant
from jsindy.sindy_model import JSINDyModel
from jsindy.dynamics_model import FeatureLinearModel
from jsindy.optim import LMSolver, AlternatingActiveSetLMSolver
from jsindy.optim import AlternatingActiveSetLMSolver, LMSettings


from equinox import tree_pprint
import matplotlib.pyplot as plt

from exp.expdata import ExpData, LorenzExp
from exp.metrics import coeff_metrics, data_metrics
import pickle
import jax.numpy as jnp

import time
from equinox import tree_pprint
from jsindy.kernels import softplus_inverse

In [4]:
noise_var = 16.
dt = 0.08

exp_data = LorenzExp

In [5]:
initial_state = jnp.array([ 0.37719066, -0.39631459, 16.92126795])
# sigma^2 - var
true_sigma2 = noise_var
t0=0
t1=10.1
n_train = len(jnp.arange(t0,t1,dt))

n_colloc = 600
expdata = exp_data(
    initial_state=initial_state,
    t0=t0,
    t1=t1,
    dt = 0.01,
    dt_train=dt,
    noise= jnp.sqrt(true_sigma2),
    seed=32,
    n_colloc=n_colloc,
    one_rkey=True,
    feature_names=['x','y','z']
)

trajectory_model = DataAdaptedRKHSInterpolant()
dynamics_model = FeatureLinearModel()
optsettings = LMSettings(
    max_iter = 2000,
    min_alpha = 1e-15,
    max_alpha = 1e8,
    init_alpha=100.,
)
optimizer = AlternatingActiveSetLMSolver(beta_reg=0.001,solver_settings=optsettings)

model = JSINDyModel(
    trajectory_model=trajectory_model,
    dynamics_model=dynamics_model,
    optimizer=optimizer,
    feature_names=expdata.feature_names
)


In [6]:
model.fit(
    expdata.t_train,
    expdata.x_train,
    expdata.t_colloc
)

metrics = {}

metrics["coeff_mets"]  = coeff_metrics(
    coeff_est = model.theta.T,
    coeff_true = expdata.true_coeff
)

metrics["data_mets"] = data_metrics(
    pred_sim = model.predict(expdata.x_true),
    true = expdata.x_dot
)

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

{'show_progress': True, 'sigma2_est': Array(10.65914167, dtype=float64), 'data_weight': Array(0.09372825, dtype=float64), 'colloc_weight': Array(9.37282521, dtype=float64)}
Warm Start


  0%|          | 0/2000 [00:00<?, ?it/s]

Iteration 0, loss = 2.092e+07, gradnorm = 8.691e+09, alpha = 100.0, improvement_ratio = 0.7961
Iteration 1, loss = 1.4e+07, gradnorm = 1.925e+10, alpha = 100.0, improvement_ratio = 0.3305
Iteration 2, loss = 6.65e+06, gradnorm = 5.021e+10, alpha = 100.0, improvement_ratio = 0.5251
Iteration 3, loss = 8.849e+05, gradnorm = 3.975e+10, alpha = 83.33, improvement_ratio = 0.8669
Iteration 4, loss = 5.456e+05, gradnorm = 1.147e+10, alpha = 632.8, improvement_ratio = 0.3836
Iteration 5, loss = 1.629e+04, gradnorm = 5.185e+09, alpha = 527.3, improvement_ratio = 0.9706
Iteration 200, loss = 37.98, gradnorm = 7.634e+05, alpha = 254.3, improvement_ratio = 0.7596
Iteration 400, loss = 29.53, gradnorm = 5.215e+05, alpha = 19.81, improvement_ratio = 0.7816
Iteration 600, loss = 24.55, gradnorm = 5.329e+05, alpha = 16.51, improvement_ratio = 0.7485
Iteration 800, loss = 22.92, gradnorm = 5.41e+05, alpha = 5.528, improvement_ratio = 0.7575
Iteration 1000, loss = 22.34, gradnorm = 1.796e+05, alpha = 3.

In [7]:
model.print()

(x)' = 4.948 1 + -26.554 x + 26.019 y
(y)' = -0.769 1 + 31.817 x + -1.304 x z
(z)' = 77.606 1 + 1.141 y + -7.159 z + 13.706 x^2


In [8]:
metrics

{'coeff_mets': {'precision': 0.5,
  'recall': 0.7142857142857143,
  'f1': 0.5882352941176471,
  'coeff_rel_l2': 2.6162234036721284,
  'coeff_rmse': 15.060232028327443,
  'coeff_mae': 4.711856375524261},
 'data_mets': {'mse': Array(457623.76148525, dtype=float64),
  'rmse': np.float64(676.4789438594881),
  'mae': np.float64(291.29082313613156),
  'max_abs_error': np.float64(4056.3058809207914),
  'normalized_mse': Array(117.6306395, dtype=float64),
  'relative_l2_error': np.float64(10.844036832737183)}}