In [None]:
from exp.expdata import RosslerExp
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['font.family'] = 'serif'
import numpy as np
from scipy.fft import fft, fftfreq
from scipy.signal import find_peaks
from scipy.integrate import solve_ivp
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64',True)
savefigs =False

In [None]:
noise  = 0.0
dt_train=0.3

t1=200.
expdata = RosslerExp(t0=0.0,t1=t1,noise=noise,dt_train=dt_train,n_colloc=500,feature_names=['x','y','z'])

In [None]:
expdata.print()

In [None]:
x_true = expdata.x_true
t_true = expdata.t_true
x_train = expdata.x_train
t_train = expdata.t_train
t_colloc = expdata.t_colloc

cutoff = 150

x_train = jnp.concat([x_train[:cutoff],x_train[-cutoff:]])
t_train = jnp.concat([t_train[:cutoff],t_train[-cutoff:]])

fig, axs = plt.subplots(3,1,figsize=(12,5),sharex=True)

axs = axs.flatten()

y_labels = ["x","y","z"]
for idx, ax in enumerate(axs):
    # ax.set_ylim(-20,20)
    ax.plot(t_true, x_true[:,idx],c='black')
    ax.scatter(t_train,x_train[:,idx],zorder=2,facecolors='black',edgecolors='red',marker='.',s=150,lw=1)
    ax.set_ylabel(y_labels[idx])
    if idx == 2:
        ax.set_xlabel("t")

plt.suptitle("Rossler System")
plt.tight_layout()

# Model Learning

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import jax
jax.config.update('jax_enable_x64',True)

from jsindy.sindy_model import JSINDyModel
from jsindy.trajectory_model import CholDataAdaptedRKHSInterpolant
from jsindy.dynamics_model import FeatureLinearModel, PolyLib
from jsindy.optim import AlternatingActiveSetLMSolver, LMSettings, AnnealedAlternatingActiveSetLMSolver
from jsindy.optim.solvers.alt_active_set_lm_solver import pySindySparsifier
from pysindy import STLSQ
from jsindy.kernels import ConstantKernel, ScalarMaternKernel

In [None]:

kernel = (
    ConstantKernel(variance = 5.)
    +ScalarMaternKernel(p = 5,variance = 10., lengthscale=3,min_lengthscale=0.05)
)   
kernel=None

trajectory_model = CholDataAdaptedRKHSInterpolant(kernel=kernel)
dynamics_model = FeatureLinearModel(
    reg_scaling = 1.,
    feature_map=PolyLib(degree=2)
    
)
optsettings = LMSettings(
    max_iter = 2000,
    atol_gradnorm=1e-8,
    show_progress=True,
    no_tqdm=False,
    min_alpha = 1e-16,
    init_alpha = 5.,
)
data_weight = 1.
colloc_weight = 1e4
thresh = 0.05
alpha = 0.01
sparsifier = pySindySparsifier(STLSQ(threshold = thresh,alpha = alpha))
optimizer = AlternatingActiveSetLMSolver(
        beta_reg=1e-1,
        solver_settings=optsettings,
        fixed_colloc_weight=colloc_weight,
        fixed_data_weight=data_weight,
        sparsifier = sparsifier
        )

model = JSINDyModel(
    trajectory_model=trajectory_model,
    dynamics_model=dynamics_model,
    optimizer=optimizer,
    feature_names=['x','y','z']
)

In [None]:
model.fit(t_train,x_train,t_colloc)
print("Learned Model")
model.print()
print("\nTrue model")
expdata.print()

In [None]:
model.print()
print()
expdata.print()

In [None]:
x_pred= model.predict_state(t_true)

In [None]:
x_pred.shape

In [None]:
fig, axs = plt.subplots(3,1,figsize=(12,5),sharex=True)

axs = axs.flatten()

y_labels = ["x","y","z"]
for idx, ax in enumerate(axs):
    # ax.set_ylim(-20,20)
    ax.plot(t_true, x_true[:,idx],c='black')
    ax.scatter(t_train,x_train[:,idx],zorder=2,facecolors='red',edgecolors='black',marker='.',s=150,lw=2)
    ax.set_ylabel(y_labels[idx])
    if idx == 2:
        ax.set_xlabel("t")

    ax.plot(t_true,x_pred[:,idx])

plt.suptitle("Rossler System")
plt.tight_layout()

## Simulated missing middle path

In [None]:
import diffrax

jit_ross_pred = jax.jit(model.predict)

def model_ross_system(t,x,args):
    return jit_ross_pred(x)

def simulate_sol(y0,system, t0=expdata.t0,t1=expdata.t1,dt=expdata.dt,t_eval = expdata.t_true, args = None):
    term = diffrax.ODETerm(system)
    solver = diffrax.Tsit5()

    save_at = diffrax.SaveAt(dense=True)
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0 = t0,
        t1=t1,
        dt0=dt,
        y0=y0,
        args = args,
        saveat=save_at,
        max_steps=int(10*(t1-t0)/dt)
    )

    return jax.vmap(sol.evaluate)(t_eval)

In [None]:
x_train[20-1]
t_train[20]
t_eval = jnp.linspace(t_train[cutoff-1],t_train[cutoff],101)

In [None]:
out = simulate_sol(x_train[cutoff-1],system=model_ross_system,t0=t_train[cutoff-1],t1=t_train[cutoff],t_eval=t_eval)

In [None]:
out.shape

In [None]:
fig, axs = plt.subplots(3,1,figsize=(12,5),sharex=True)

axs = axs.flatten()

y_labels = ["x","y","z"]
for idx, ax in enumerate(axs):
    # ax.set_ylim(-20,20)
    ax.plot(t_true, x_true[:,idx],c='black')
    ax.scatter(t_train,x_train[:,idx],zorder=2,facecolors='red',edgecolors='black',marker='.',s=150,lw=2)
    ax.set_ylabel(y_labels[idx])
    if idx == 2:
        ax.set_xlabel("t")

    ax.plot(t_true,x_pred[:,idx],label="simulated")
    ax.plot(t_eval,out[:,idx])
    ax.legend()
plt.suptitle("Rossler System")
plt.tight_layout()