In [1]:
from reservoirtaming.models.generic import GenericEchoState
from reservoirtaming.layers.reservoirs import RandomReservoir
from reservoirtaming.layers.activation import leaky_erf

from reservoirtaming.training.training import train
from reservoirtaming.data.KS import KS
from jax import random
from jax.lax import scan
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from flax.core import unfreeze
sns.set()

%load_ext autoreload
%autoreload 2

In [2]:
# Setting up our dataset; similar to jonathans
L = 22 / (2 * np.pi)  # length
N = 100  # space discretization step
dt = 0.25  # time discretization step
N_train = 10000
N_test = 2000
N_init = 1000  # remove the initial points
tend = (N_train + N_test) * dt + N_init

np.random.seed(1)
dns = KS(L=L, N=N, dt=dt, tend=tend)
dns.simulate()

In [3]:
# Prepping train and test matrices
# why the sqrt scaling?
_, u_train, u_test, _ = np.split(dns.uu / np.sqrt(N), 
                                     [N_init, 
                                     (N_init + N_train), 
                                     (N_init + N_train + N_test)], axis=0)

n_input = u_train.shape[1]

In [4]:
model = GenericEchoState(RandomReservoir, leaky_erf, n_reservoir=3996, reservoir_args=(0.4, 0.9, 0.1), act_fn_args=(1.0, ))
key = random.PRNGKey(42)

In [5]:
variables = model.init(key, u_train[0])
state, params = variables.pop('params')
del variables

In [7]:
%%time
state, reservoir_states = train(model, state, params, u_train)

CPU times: user 1.87 s, sys: 164 µs, total: 1.87 s
Wall time: 1.82 s
