In [1]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from scipy.sparse.linalg import eigs as sparse_eigs

import sys
import os
sys.path.append('..')
from adjoint_esn.esn import ESN
from adjoint_esn.jax_esn import JAXESN

In [2]:
scale = np.array([[2.0,3.0,4.0],[8.0,5.0,3.0]])
input_bias = 0.5

In [3]:
my_ESN = ESN(reservoir_size = 1000, dimension = 3, parameter_dimension = 1, reservoir_connectivity = 3, 
                input_scaling = 4.00943288, spectral_radius = 0.13390513,
                leak_factor = 1.0, input_bias = input_bias, input_normalization = scale, parameter_normalization=[np.array([5.17175175]),np.array(5.51759002)],
                input_seeds=[0,1,2],reservoir_seeds=[3,4],)

Input normalization is changed, training must be done again.
Parameter normalization is changed, training must be done again.
Input scaling is set to 1, set it separately if necessary.
Input weights are rescaled with the new input scaling.
Spectral radius is set to 1, set it separately if necessary.
Reservoir weights are rescaled with the new spectral radius.


In [4]:
my_JAXESN = JAXESN(reservoir_size = 1000, dimension = 3, parameter_dimension = 1, reservoir_connectivity = 3, 
                input_scaling = 4.00943288, spectral_radius = 0.13390513,
                leak_factor = 1.0, input_bias = input_bias, input_normalization = scale, parameter_normalization=jnp.array([[5.17175175],[5.51759002]]),
                input_seeds=[0,1,2],reservoir_seeds=[3,4],)

Input normalization is changed, training must be done again.
Parameter normalization is changed, training must be done again.
Input scaling is set to 1, set it separately if necessary.
Input weight generation 0.2573356628417969
Input weights are rescaled with the new input scaling.
Spectral radius is set to 1, set it separately if necessary.
Reservoir weight generation 0.3530547618865967
Reservoir weights are rescaled with the new spectral radius.


In [5]:
x_prev = np.ones(1000)
u = 0.1*np.ones(3)
p = 3.0
print("ESN sparse scipy step")
%timeit my_ESN.step(x_prev, u, p)
# regular np dot is much slower

x_prev = jnp.ones(1000)
u = 0.1*jnp.ones(3)
p = 3.0
print("JAX ESN step")
%timeit my_JAXESN.step(x_prev, u, p).block_until_ready()


x_prev = jnp.ones(1000)
u = 0.1*jnp.ones(3)
p = 3.0
step_jit = jit(my_JAXESN.step)
step_jit(x_prev, u, p)
print("Jitted JAX ESN step")
%timeit step_jit(x_prev,u,p).block_until_ready()


ESN sparse scipy step
67 µs ± 3.83 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
JAX ESN step
45.2 ms ± 4.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Jitted JAX ESN step
487 µs ± 39.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
x0 = np.ones(1000)
U = 0.1*np.ones((100,3))
P = 3.0*np.ones((100,1))
print("ESN sparse scipy open-loop")
%timeit my_ESN.open_loop(x0, U, P)

x0 = jnp.ones(1000)
U = 0.1*jnp.ones((100,3))
P = 3.0*jnp.ones((100,1))
print("Jitted JAX ESN open-loop")
my_JAXESN.open_loop(x0, U, P)
%timeit my_JAXESN.open_loop(x0, U, P).block_until_ready()

ESN sparse scipy open-loop
5.93 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Jitted JAX ESN open-loop
381 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# @TODO: jnp.einsum, oe.contract instead of double dot
def run_esn_grad_adj(my_ESN, X_pred_grad, N_t):
    # calculate gradient for a timeseries, adjoint method
    # time averaged objective
    X_pred_aug = jnp.hstack((X_pred_grad[N_t - 1, :], my_ESN.b_out))
    v_prev = (
        (1 / N_t)
        * 1
        / 2
        * jnp.dot(
            jnp.dot(X_pred_aug, my_ESN.W_out), my_ESN.W_out[: my_ESN.N_reservoir, :].T
        ).T
    )
    dJ_dp_adj = jnp.zeros(my_ESN.N_param_dim)
    for i in jnp.arange(N_t - 1, 0, -1):
        dJ_dp_adj += jnp.dot(my_ESN.drdp(X_pred_grad[i, :]).T, v_prev)
        X_pred_aug = jnp.hstack((X_pred_grad[i - 1, :], my_ESN.b_out))
        dJ_dr = (
            (1 / N_t)
            * 1
            / 2
            * jnp.dot(
                jnp.dot(X_pred_aug, my_ESN.W_out),
                my_ESN.W_out[: my_ESN.N_reservoir, :].T,
            ).T
        )
        v = jnp.dot(my_ESN.jac(X_pred_grad[i, :]).T, v_prev) + dJ_dr
        v_prev = v
    return dJ_dp_adj