In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jsindy.trajectory_model import DataAdaptedRKHSInterpolant
from jsindy.sindy_model import JSINDyModel
from jsindy.dynamics_model import FeatureLinearModel
from jsindy.optim import LMSolver, AlternatingActiveSetLMSolver
from jsindy.util import get_collocation_points

from data.lorenz import solve_lorenz

In [2]:
lorenz_sol = solve_lorenz(initial_state=jnp.array([ 0.37719066, -0.39631459, 16.92126795]))
dt = 0.05
t_train = jnp.arange(0,10.01,dt)

true_sigma2 = 0.5
x_vals = jax.vmap(lorenz_sol.evaluate)(t_train)
x_train = x_vals + jnp.sqrt(true_sigma2) * jax.random.normal(jax.random.PRNGKey(32),(len(t_train),3))

In [3]:
t_colloc = get_collocation_points(t_train, 600)

In [4]:
traj_model = DataAdaptedRKHSInterpolant()
dynamics_model = FeatureLinearModel()

In [5]:
model = JSINDyModel(
    traj_model,
    dynamics_model,
    optimizer=AlternatingActiveSetLMSolver(beta_reg=0.001)
    )

In [6]:
model.fit(t_train, x_train, t_colloc)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Warm Start


  0%|          | 0/501 [00:00<?, ?it/s]

Iteration 0, loss = 5.232e+04, gradnorm = 2.169e+09, alpha = 8.438, improvement_ratio = 0.9968
Iteration 1, loss = 2.522e+03, gradnorm = 2.673e+08, alpha = 10.55, improvement_ratio = 0.9652
Iteration 2, loss = 792.7, gradnorm = 2.155e+08, alpha = 8.789, improvement_ratio = 0.9501
Iteration 3, loss = 695.8, gradnorm = 5.31e+07, alpha = 10.99, improvement_ratio = 0.9556
Iteration 4, loss = 686.4, gradnorm = 5.613e+06, alpha = 13.73, improvement_ratio = 0.9189
Iteration 5, loss = 682.3, gradnorm = 1.566e+06, alpha = 17.17, improvement_ratio = 0.9552
Iteration 200, loss = 666.0, gradnorm = 19.21, alpha = 19.51, improvement_ratio = 1.0
Iteration 400, loss = 665.8, gradnorm = 5.372, alpha = 20.06, improvement_ratio = 1.0
Iteration 500, loss = 665.8, gradnorm = 3.485, alpha = 20.34, improvement_ratio = 1.0
Alternating Activeset Sparsifier
8 active coeffs changed
Active set stabilized


In [7]:
model.print()

(x0)' = 0.292 1 + -9.937 x0 + 9.944 x1
(x1)' = 28.026 x0 + -0.993 x1 + -1.002 x0 x2
(x2)' = -2.659 x2 + 0.997 x0 x1
