In [None]:
import os
import jax
jax.config.update('jax_enable_x64',True)
from pathlib import Path

from jsindy.sindy_model import JSINDyModel
from jsindy.trajectory_model import DataAdaptedRKHSInterpolant,CholDataAdaptedRKHSInterpolant
from jsindy.dynamics_model import FeatureLinearModel
from jsindy.optim import AlternatingActiveSetLMSolver, LMSettings
from jsindy.optim.solvers.alt_active_set_lm_solver import pySindySparsifier
from pysindy import STLSQ,SSR,MIOSR
from exp.expdata import ExpData, LorenzExp
from exp.metrics import coeff_metrics, data_metrics
import pickle
import jax.numpy as jnp
import time
import matplotlib.pyplot as plt
from pysindy import EnsembleOptimizer
plt.style.use('ggplot')
import matplotlib as mpl
# mpl.rcParams.update({"text.usetex":True})


In [None]:
exp_data = LorenzExp
initial_state = jnp.array([ -5, 0., 5])
initial_state = jnp.array([ -8,8,27.])
dt = 0.025
noise_var = 2.
# sigma^2 - var
true_sigma2 = noise_var
t0=0
t1=10.1
n_train = len(jnp.arange(t0,t1,dt))

n_colloc = 505
expdata = exp_data(
    initial_state=initial_state,
    t0=t0,
    t1=t1,
    dt = 0.01,
    dt_train=dt,
    noise= jnp.sqrt(true_sigma2),
    seed=29,
    n_colloc=n_colloc,
    one_rkey=True,
    feature_names=['x','y','z']
)

In [None]:
t_grid = jnp.linspace(t0,t1,1000)
true_states = jax.vmap(expdata.system_sol.evaluate)(t_grid)
for dim in range(3):
    plt.figure(figsize=(12,4))
    plt.scatter(expdata.t_train,expdata.x_train[:,dim],label = 'observations',s = 9)
    plt.plot(t_grid,true_states[:,dim],label = 'truth',c='black')
    plt.legend()
    plt.show()


In [None]:
trajectory_model = CholDataAdaptedRKHSInterpolant()
dynamics_model = FeatureLinearModel(reg_scaling = 1.)
optsettings = LMSettings(
    max_iter = 1000,
    show_progress=True,
    no_tqdm=False,
    min_alpha = 1e-16,
    init_alpha = 5.,
    print_every = 100,
)
data_weight =  1.
colloc_weight = 1e5
sparsifier = pySindySparsifier(
    STLSQ(threshold = 0.5,alpha = 0.01)
    )
# sparsifier = pySindySparsifier(EnsembleOptimizer(STLSQ(threshold = 0.5,alpha = 0.1),bagging=True,n_models = 100))

# sparsifier = pySindySparsifier(
#     MIOSR(target_sparsity = 7,alpha = 0.1)
#     )


optimizer = AlternatingActiveSetLMSolver(
        beta_reg=1e-3,
        solver_settings=optsettings,
        fixed_colloc_weight=colloc_weight,
        fixed_data_weight=data_weight,
        sparsifier = sparsifier
        )

model = JSINDyModel(
    trajectory_model=trajectory_model,
    dynamics_model=dynamics_model,
    optimizer=optimizer,
    feature_names=expdata.feature_names
)

# nodes,weights = legendre_nodes_weights(500,t0,t1)

model.fit(
    expdata.t_train,
    expdata.x_train,
    t_colloc = expdata.t_colloc,
    w_colloc = expdata.w_colloc
)

metrics = {}

metrics["coeff_mets"]  = coeff_metrics(
    coeff_est = model.theta.T,
    coeff_true = expdata.true_coeff
)

metrics["data_mets"] = data_metrics(
    pred_sim = model.predict(expdata.x_true),
    true = expdata.x_dot
)
metrics['model_params'] = model.params


In [None]:
model.print()

In [None]:
model.optimizer.solver_settings

In [None]:
print(model)

In [None]:
expdata.print()

In [None]:
model.print()

In [None]:
t_grid = jnp.linspace(0,10,500)
state_preds = model.predict_state(t_grid)
true_states = jax.vmap(expdata.system_sol.evaluate)(t_grid)
plt.figure(figsize=(12,8))
for dim in range(3):
    plt.subplot(3,1,dim+1)
    plt.scatter(expdata.t_train,expdata.x_train[:,dim],label = 'Observations',s = 9)
    plt.plot(t_grid,true_states[:,dim],label = 'True Trajectory',c='black')
    plt.plot(t_grid,state_preds[:,dim],label = 'State Estimates',c ='blue',alpha = 0.3,lw = 5)
    if dim ==0:
        plt.legend()
    plt.ylabel(f"$x_{dim+1}(t)$")
plt.savefig("figures/lorenz_prelim.pdf")

In [None]:
import numpy as np
import pysindy as ps

from pysindy import SmoothedFiniteDifference
from pysindy import STLSQ
from pysindy import EnsembleOptimizer


t_train = np.array(expdata.t_train)
x_train = np.array(expdata.x_train)
feature_library = ps.PolynomialLibrary(degree =2)

optimizer = EnsembleOptimizer(STLSQ(threshold = 0.5,alpha = 0.01),bagging=True,n_models = 200)
optimizer = STLSQ(threshold = 0.5,alpha = 0.01)
ps_model = ps.SINDy(
    differentiation_method=SmoothedFiniteDifference(),
    feature_library=feature_library,
    optimizer=optimizer,
    feature_names=["x", "y","z"],
)
ps_model.fit(x_train,t_train)
ps_model.print()
print()
expdata.print()

In [None]:
ps_model

In [None]:
print(ps_model.feature_library._repr_html_())

In [None]:
# Same library terms as before
library_functions = [lambda x: x, lambda x: x * x, lambda x, y: x * y]
library_function_names = [lambda x: x, lambda x: x + x, lambda x, y: x + y]


In [None]:
ode_lib = ps.WeakPDELibrary(
    function_library=feature_library,
    spatiotemporal_grid=t_train,
    is_uniform=True,
    K=500,
)
optimizer = EnsembleOptimizer(
    STLSQ(threshold = 0.3,alpha = 0.005),
    bagging=True,
    n_models = 200)

# sr3_optimizer = ps.SR3(
#     reg_weight_lam=0.5,
#     regularizer="l0",
#     max_iter=100000,
#     normalize_columns=True,
#     tol=1e-1
# )
ps_model = ps.SINDy(feature_library=ode_lib, optimizer=optimizer)
ps_model.fit(x_train,t_train)
ps_model.print()


In [None]:
model.coefficients()