In [13]:
from reservoirtaming.models.generic import GenericEchoState
from reservoirtaming.layers.reservoirs import RandomReservoir
from reservoirtaming.layers.activation import leaky_erf

from reservoirtaming.training.training import train
from reservoirtaming.data.KS import KS
from jax import random


from jax.scipy.linalg import cho_factor, cho_solve
from jax.lax import scan
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from flax.core import unfreeze
sns.set()

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Setting up our dataset; similar to jonathans
L = 22 / (2 * np.pi)  # length
N = 100  # space discretization step
dt = 0.25  # time discretization step
N_train = 10000
N_test = 2000
N_init = 1000  # remove the initial points
tend = (N_train + N_test) * dt + N_init

np.random.seed(1)
dns = KS(L=L, N=N, dt=dt, tend=tend)
dns.simulate()

In [3]:
# Prepping train and test matrices
# why the sqrt scaling?
_, u_train, u_test, _ = np.split(dns.uu / np.sqrt(N), 
                                     [N_init, 
                                     (N_init + N_train), 
                                     (N_init + N_train + N_test)], axis=0)

n_input = u_train.shape[1]

In [4]:
model = GenericEchoState(RandomReservoir, leaky_erf, n_reservoir=3996, reservoir_args=(0.4, 0.9, 0.1), act_fn_args=(1.0, ))
key = random.PRNGKey(42)

In [5]:
variables = model.init(key, u_train[0])
state, params = variables.pop('params')
del variables

In [6]:
%%time
state, reservoir_states = train(model, state, params, u_train)

CPU times: user 1.87 s, sys: 65.5 ms, total: 1.93 s
Wall time: 1.95 s


In [9]:
%%time
X = jnp.concatenate([reservoir_states[:-1], u_train[:-1]], axis=1)
y = u_train[1:]

CPU times: user 0 ns, sys: 19.2 ms, total: 19.2 ms
Wall time: 20.6 ms


In [8]:
X.shape

(9999, 4096)

In [11]:
alpha = 1e-2

In [16]:
%%time
c, low = cho_factor(jnp.dot(X.T, X) + alpha * jnp.eye(X.shape[1]))

CPU times: user 2.55 ms, sys: 0 ns, total: 2.55 ms
Wall time: 2.19 ms


In [21]:
%%time
X = jnp.concatenate([reservoir_states[:-1], u_train[:-1]], axis=1)
y = u_train[1:]

c, low = cho_factor(jnp.dot(X.T, X) + alpha * jnp.eye(X.shape[1]))
W_out = cho_solve((c, low), jnp.dot(X.T, y))

CPU times: user 4.97 ms, sys: 139 µs, total: 5.11 ms
Wall time: 3.97 ms


In [None]:
def 

In [None]:
%%time


X = output[:-1]
y = u_train[1:]

# we use cholesky decomp to solve the problem; fast, efficient and stable
# This seems much slower than what jonathan reports in the paper; we'll have to compare implementations
# It is *much* faster the second time we run it though; maybe an issue with jit?
c, low = cho_factor(jnp.dot(X.T, X) + alpha * jnp.eye(X.shape[1]))
W_out = cho_solve((c, low), jnp.dot(X.T, y))