In [1]:
from reservoirtaming.layers.reservoirs import StructuredTransform, FastStructuredTransform
from reservoirtaming.models.generic import GenericEchoState
from reservoirtaming.layers.output import Residual
from reservoirtaming.training.training import ridge


from reservoirtaming.data.KS import KS
from jax import random
from functools import partial
import numpy as np

import jax.numpy as jnp
from jax.scipy.linalg import cho_factor, cho_solve
from flax.core import unfreeze, freeze
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

In [2]:
# Setting up our dataset; similar to jonathans
L = 22 / (2 * np.pi)  # length
N = 100  # space discretization step
dt = 0.25  # time discretization step
N_train = 10000
N_test = 2000
N_init = 1000  # remove the initial points
tend = (N_train + N_test) * dt + N_init

np.random.seed(1)
dns = KS(L=L, N=N, dt=dt, tend=tend)
dns.simulate()

In [3]:
# Prepping train and test matrices
u = np.expand_dims(dns.uu, axis=1)
_, u_train, u_test, _ = np.split(u / np.sqrt(N), 
                                     [N_init, 
                                     (N_init + N_train), 
                                     (N_init + N_train + N_test)], axis=0)

In [4]:
# Setting up model
n_reservoir = 3996
n_out = u_train.shape[-1]

norm_factor = 1.1 * jnp.sqrt(n_out / n_reservoir)
model_fast = GenericEchoState(n_reservoir, FastStructuredTransform, (n_out, ),
                  n_out, Residual, (norm_factor, ))

model_slow = GenericEchoState(n_reservoir, StructuredTransform, (n_out, ),
                  n_out, Residual, (norm_factor, ))

key = random.PRNGKey(42)
state = model_fast.initialize_state(key, n_reservoir)
params = model_fast.init(key, state, u_train[0])

In [5]:
%%time
# Running the reservoir
new_state, intermediate_states_fast = model_fast.apply(params, state, u_train, method=model_fast.run_reservoir)

CPU times: user 2.41 s, sys: 647 ms, total: 3.05 s
Wall time: 1.78 s


In [6]:
%%time
# Running the reservoir
new_state, intermediate_states_slow = model_slow.apply(params, state, u_train, method=model_slow.run_reservoir)

CPU times: user 9.98 s, sys: 520 ms, total: 10.5 s
Wall time: 11.8 s


In [7]:
jnp.allclose(intermediate_states_fast, intermediate_states_slow)

DeviceArray(True, dtype=bool)