In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jsindy.trajectory_model import DataAdaptedRKHSInterpolant
from jsindy.sindy_model import JSINDyModel
from jsindy.dynamics_model import FeatureLinearModel
from jsindy.optim import AlternatingActiveSetLMSolver, LMSettings
from jsindy.util import get_collocation_points

from data.lorenz import solve_lorenz
from exp import LorenzExp
import numpy as np

In [2]:
expdata = LorenzExp(feature_names=['x','y','z'])

In [3]:
lorenz_sol = solve_lorenz(initial_state=jnp.array([ 0.37719066, -0.39631459, 16.92126795]))
dt = 0.05
t_train = jnp.arange(0,10.01,dt)

true_sigma2 = 0.0
x_vals = jax.vmap(lorenz_sol.evaluate)(t_train)
x_train = x_vals + jnp.sqrt(true_sigma2) * jax.random.normal(jax.random.PRNGKey(32),(len(t_train),3))

In [4]:
t_colloc = get_collocation_points(t_train, 600)

In [5]:
traj_model = DataAdaptedRKHSInterpolant()
dynamics_model = FeatureLinearModel()

optsettings = LMSettings(no_tqdm=True,show_progress=False)

In [6]:
model = JSINDyModel(
    traj_model,
    dynamics_model,
    optimizer=AlternatingActiveSetLMSolver(beta_reg=0.001,solver_settings=optsettings),
    feature_names = expdata.feature_names
    )

In [7]:
model.fit(t_train, x_train, t_colloc)

In [8]:
model.print()

(x)' = -10.000 x + 10.000 y
(y)' = 28.000 x + -1.000 y + -1.000 x z
(z)' = -2.667 z + 1.000 x y


In [9]:
model.print(theta=expdata.true_coeff.T)

(x)' = -10.000 x + 10.000 y
(y)' = 28.000 x + -1.000 y + -1.000 x z
(z)' = -2.667 z + 1.000 x y


In [11]:
model.theta.T

Array([[  0.        , -10.00000005,   9.99999972,   0.        ,
          0.        ,   0.        ,   0.        ,   0.        ,
          0.        ,   0.        ],
       [  0.        ,  27.99999912,  -0.99999948,   0.        ,
          0.        ,   0.        ,  -0.99999998,   0.        ,
          0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        ,  -2.66666674,
          0.        ,   0.99999999,   0.        ,   0.        ,
          0.        ,   0.        ]], dtype=float64)

In [8]:
expdata = LorenzExp()

In [16]:
np.linalg.norm((expdata.true_coeff - model.theta.T))/ np.linalg.norm(expdata.true_coeff)

np.float64(3.386755451713351e-08)

In [21]:
model.predict(expdata.x_true).shape

(1000, 3)

In [22]:
expdata.x_dot.shape

(1000, 3)