This notebook shows how to run and compare different reservoirs.

**Timings are currently off due to not having jitted everything yet.**

In [10]:
from jacho.layers.reservoirs import RandomReservoir, StructuredTransform, FastStructuredTransform, SparseReservoir
from jacho.models.generic import GenericEchoState
from jacho.layers.output import Residual
from jacho.training.training import ridge
from jacho.data.KS import KS

from jax import random
import numpy as np
import jax.numpy as jnp
from jax import jit
from flax import linen as nn
from jax import vmap
from functools import partial

import matplotlib.pyplot as plt

key = random.PRNGKey(42)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Making data

In [2]:
# Setting up our dataset; similar to jonathans
L = 22 / (2 * np.pi)  # length
N = 100  # space discretization step
dt = 0.25  # time discretization step
N_train = 10000
N_test = 2000
N_init = 1000  # remove the initial points
tend = (N_train + N_test) * dt + N_init

np.random.seed(1)
dns = KS(L=L, N=N, dt=dt, tend=tend)
dns.simulate()

In [3]:
# Prepping train and test matrices
# inputs need to be size [time_steps, samples, spatial_points]
# i.e. here they are [10000, 1, 100]
u = np.expand_dims(dns.uu, axis=1)
_, u_train, u_test, _ = np.split(u / np.sqrt(N), 
                                     [N_init, 
                                     (N_init + N_train), 
                                     (N_init + N_train + N_test)], axis=0)

# Random reservoir

First one is a fully connected random reservoir:

In [20]:
# Setting up random model
n_reservoir = 100
n_models = 10
reservoir_args = (0.4, 0.9, 0.4) #input_scale, reservoir_scale, bias_scale

n_out = u_train.shape[-1]
norm_factor = 1.1 * jnp.sqrt(n_out / n_reservoir)
output_layer_args = (norm_factor, )

model = GenericEchoState(n_reservoir, RandomReservoir, reservoir_args,
                  n_out, Residual, output_layer_args)

Simply set parallel_reservoirs to the number you want:

In [21]:
state = model.initialize_state(key, n_reservoir, parallel_reservoirs=n_models)

First axis is reservoir axis:

In [22]:
state.shape

(10, 1, 100)

To run the models in parallel, we initialize the model using vmap:

In [23]:
params = vmap(model.init, in_axes=(None, 0, None), out_axes=None)(key, state, u_train[0])

Note that the all reservoirs use the same kernel:

In [24]:
params["params"]["reservoir"]["Dense_0"]["kernel"].shape

(100, 100)

If you want to use different weights to truly train multiple models we also parallelize over the keys for different weights and remove out_axes=None

In [26]:
params = vmap(model.init, in_axes=(0, 0, None))(random.split(key, n_models), state, u_train[0])

In [27]:
params["params"]["reservoir"]["Dense_0"]["kernel"].shape

(10, 100, 100)

See are all different:

In [28]:
params["params"]["reservoir"]["Dense_0"]["kernel"][:, 0, 0]

DeviceArray([-0.53143334,  0.45951787, -0.17088753, -0.2543605 ,
             -0.66390693,  0.5153455 , -0.1227669 , -0.15841977,
              0.2611812 , -0.02392638], dtype=float32)

To run the reservoir we again use vmap. For example same data but multiple reservoirs:

In [29]:
_, intermediate_states = vmap(partial(model.apply, method=model.run_reservoir), in_axes=(0, 0, None))(params, state, u_train);

In [30]:
intermediate_states.shape

(10, 10000, 1, 100)

We can also run different data but same reservoir. let's first duplicate the dataL

In [36]:
u_train_multiple_data = jnp.tile(u_train, (n_models, 1, 1, 1))
u_train_multiple_data.shape

params = vmap(model.init, in_axes=(None, 0, None), out_axes=None)(key, state, u_train[0]) # correct params

(10, 10000, 1, 100)

Now simply vmap over the state and u_train axis:

In [43]:
_, intermediate_states = vmap(partial(model.apply, method=model.run_reservoir), in_axes=(None, 0, 0))(params, state, u_train_multiple_data)