In [4]:
import jax.numpy as jnp
import jax
import os

jax.config.update('jax_default_device',jax.devices()[1])

jax.default_device = jax.devices()[1]
jax.config.update('jax_enable_x64',True)

from jsindy.trajectory_model import DataAdaptedRKHSInterpolant
from jsindy.sindy_model import JSINDyModel
from jsindy.dynamics_model import FeatureLinearModel
from jsindy.optim import LMSolver, AlternatingActiveSetLMSolver
from jsindy.optim import AlternatingActiveSetLMSolver, LMSettings


from equinox import tree_pprint
import matplotlib.pyplot as plt

from exp.expdata import ExpData, LorenzExp
from exp.metrics import coeff_metrics, data_metrics
import pickle
import jax.numpy as jnp

import time
from equinox import tree_pprint
from jsindy.kernels import softplus_inverse

In [5]:
noise_var = 20.
dt = 0.08

exp_data = LorenzExp

In [6]:
initial_state = jnp.array([ 0.37719066, -0.39631459, 16.92126795])
# sigma^2 - var
true_sigma2 = noise_var
t0=0
t1=10.1
n_train = len(jnp.arange(t0,t1,dt))

n_colloc = 600
expdata = exp_data(
    initial_state=initial_state,
    t0=t0,
    t1=t1,
    dt = 0.01,
    dt_train=dt,
    noise= jnp.sqrt(true_sigma2),
    seed=32,
    n_colloc=n_colloc,
    one_rkey=True,
    feature_names=['x','y','z']
)

trajectory_model = DataAdaptedRKHSInterpolant()
dynamics_model = FeatureLinearModel()
optsettings = LMSettings(
    max_iter = 2000,
    min_alpha = 1e-15,
    max_alpha = 1e8,
    init_alpha=100.,
)
optimizer = AlternatingActiveSetLMSolver(beta_reg=0.001,solver_settings=optsettings)

model = JSINDyModel(
    trajectory_model=trajectory_model,
    dynamics_model=dynamics_model,
    optimizer=optimizer,
    feature_names=expdata.feature_names
)


In [7]:
model.fit(
    expdata.t_train,
    expdata.x_train,
    expdata.t_colloc
)

metrics = {}

metrics["coeff_mets"]  = coeff_metrics(
    coeff_est = model.theta.T,
    coeff_true = expdata.true_coeff
)

metrics["data_mets"] = data_metrics(
    pred_sim = model.predict(expdata.x_true),
    true = expdata.x_dot
)

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

{'show_progress': True, 'sigma2_est': Array(13.39557038, dtype=float64), 'data_weight': Array(0.07459586, dtype=float64), 'colloc_weight': Array(7.45958562, dtype=float64)}
Warm Start


  0%|          | 0/2000 [00:00<?, ?it/s]

Iteration 0, loss = 1.672e+07, gradnorm = 5.925e+09, alpha = 100.0, improvement_ratio = 0.7703
Iteration 1, loss = 1.323e+07, gradnorm = 1.452e+10, alpha = 100.0, improvement_ratio = 0.2086
Iteration 2, loss = 2.474e+06, gradnorm = 2.343e+10, alpha = 83.33, improvement_ratio = 0.813
Iteration 3, loss = 2.268e+06, gradnorm = 1.569e+10, alpha = 506.3, improvement_ratio = 0.08302
Iteration 4, loss = 3.44e+05, gradnorm = 1.87e+10, alpha = 421.9, improvement_ratio = 0.8485
Iteration 5, loss = 3.167e+04, gradnorm = 5.156e+09, alpha = 351.6, improvement_ratio = 0.9084
Iteration 200, loss = 32.21, gradnorm = 7.66e+05, alpha = 211.9, improvement_ratio = 0.7643
Iteration 400, loss = 24.37, gradnorm = 4.542e+05, alpha = 85.17, improvement_ratio = 0.7805
Iteration 600, loss = 20.4, gradnorm = 7.22e+05, alpha = 11.46, improvement_ratio = 0.5529
Iteration 800, loss = 18.34, gradnorm = 1.739e+05, alpha = 7.96, improvement_ratio = 0.7409
Iteration 1000, loss = 17.84, gradnorm = 1.401e+05, alpha = 2.22

In [8]:
model.print()

(x)' = 5.346 1 + -11.808 x + 11.434 y
(y)' = 24.703 x + -0.922 x z
(z)' = 3.561 1 + 0.541 x + -2.882 z + 0.906 x y


In [9]:
metrics

{'coeff_mets': {'precision': 0.6666666666666666,
  'recall': 0.8571428571428571,
  'f1': 0.75,
  'coeff_rel_l2': 0.24322808953656205,
  'coeff_rmse': 1.4001371056791037,
  'coeff_mae': 0.5791753443036523},
 'data_mets': {'mse': Array(62.75643304, dtype=float64),
  'rmse': np.float64(7.921895798605845),
  'mae': np.float64(5.035284033792181),
  'max_abs_error': np.float64(50.92665078943514),
  'normalized_mse': Array(0.01613133, dtype=float64),
  'relative_l2_error': np.float64(0.12698891902691825)}}