In [2]:
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import optax


import pickle as pkl

import itertools
import more_itertools as mit

import os

from alive_progress import alive_bar
import gc

parallel_scan = jax.lax.associative_scan

In [3]:
def binary_operator_diag(element_i, element_j):
    a_i, bu_i = element_i
    a_j, bu_j = element_j

    return a_j * a_i, a_j * bu_i + bu_j


def init_lru_parameters(N, H, r_min = 0.0, r_max = 1, max_phase = 6.28):
    # N: state dimension, H: model dimension
    # Initialization of Lambda is complex valued distributed uniformly on ring
    # between r_min and r_max, with phase in [0, max_phase].

    u1 = np.random.uniform(size = (N,))
    u2 = np.random.uniform(size = (N,))

    nu_log = np.log(-0.5*np.log(u1*(r_max**2-r_min**2) + r_min**2))
    theta_log = np.log(max_phase*u2)

    # Glorot initialized Input/Output projection matrices
    B_re = np.random.normal(size=(N,H))/np.sqrt(2*H)
    B_im = np.random.normal(size=(N,H))/np.sqrt(2*H)
    C_re = np.random.normal(size=(H,N))/np.sqrt(N)
    C_im = np.random.normal(size=(H,N))/np.sqrt(N)
    D = np.random.normal(size=(H,))

    # Normalization
    diag_lambda = np.exp(-np.exp(nu_log) + 1j*np.exp(theta_log))
    gamma_log = np.log(np.sqrt(1-np.abs(diag_lambda)**2))

    return nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log


def forward_LRU(lru_parameters, input_sequence):
    # Unpack the LRU parameters
    nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log = lru_parameters

    # Initialize the hidden state
    Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
    B_norm = (B_re + 1j*B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1)
    #print(B_norm.shape)
    C = C_re + 1j*C_im

    Lambda_elements = jnp.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0)

    Bu_elements = jax.vmap(lambda u: B_norm @ u)(input_sequence)
    elements = (Lambda_elements, Bu_elements)
    _, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k
    y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence)


    return y

In [4]:
LRU = init_lru_parameters(10, 10)

input_sequence = np.random.normal(size=(100, 10))

forward_LRU(LRU, input_sequence)

def loss_fn(lru_parameters, input_sequence, target_sequence):
    prediction = forward_LRU(lru_parameters, input_sequence)
    return jnp.mean((prediction - target_sequence)**2)


grad_fn = jit(grad(loss_fn, argnums=0))

grads = grad_fn(LRU, input_sequence, input_sequence)


In [5]:
# Loop through tuple, and print type and shape of each element
def print_shapes(tup):
    for i, x in enumerate(tup):
        print(f"Element {i} has type {type(x)} and shape {x.shape}")

print_shapes(grads)

jax.tree.map(lambda x: x*0.01, grads)

Element 0 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10,)
Element 1 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10,)
Element 2 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10, 10)
Element 3 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10, 10)
Element 4 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10, 10)
Element 5 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10, 10)
Element 6 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10,)
Element 7 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10,)


(Array([-0.00096473, -0.00146302, -0.00045664, -0.00051238, -0.00127714,
        -0.0003096 , -0.00059101, -0.00053974, -0.00029481, -0.00014714],      dtype=float32),
 Array([ 1.1198167e-03, -4.3983446e-04,  3.5606287e-04,  5.0764397e-04,
        -1.9497373e-03, -9.0920884e-04, -4.0615429e-04,  7.5074313e-05,
         1.1198573e-03, -3.3929609e-03], dtype=float32),
 Array([[ 1.3677563e-03,  3.5716846e-04, -1.0316291e-05,  5.7374791e-04,
          3.2095436e-04, -1.0929062e-03, -3.6343175e-04,  3.1418170e-04,
          2.0988332e-03,  5.4513825e-06],
        [ 1.4669453e-04, -3.6194659e-04, -8.3397620e-04, -1.0086268e-03,
          1.3167992e-03, -5.4091995e-04, -9.7825693e-04,  5.5544800e-04,
         -1.8531587e-03,  2.3180561e-05],
        [-5.9466000e-04, -1.9458104e-03,  1.2403500e-04, -9.4228907e-04,
         -5.6273997e-04, -2.1868628e-04, -2.0675105e-03,  6.1665237e-04,
          1.3954288e-04, -9.8187440e-05],
        [-7.6583168e-04,  1.5413455e-03,  9.4369985e-04,  8.6728123

In [9]:
# Define the optax optimiser
optimzer = optax.adam(1e-3)
opt_state = optimzer.init(LRU)

updates, new_opt_state = optimzer.update(grads, opt_state) 

# Apply the updates to the parameters
new_LRU = optax.apply_updates(LRU, updates)

In [None]:
print_shapes(new_LRU)



Element 0 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10,)
Element 1 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10,)
Element 2 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10, 10)
Element 3 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10, 10)
Element 4 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10, 10)
Element 5 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10, 10)
Element 6 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10,)
Element 7 has type <class 'jaxlib.xla_extension.ArrayImpl'> and shape (10,)


In [11]:
# Setup the training loop
for i in range(1000):
    grads = grad_fn(LRU, input_sequence, input_sequence)
    updates, opt_state = optimzer.update(grads, opt_state)
    new_LRU = optax.apply_updates(LRU, updates)
    LRU = new_LRU

    if i % 100 == 0:
        print(f"Loss at iteration {i} is {loss_fn(LRU, input_sequence, input_sequence)}")

Loss at iteration 0 is 2.6947569847106934
Loss at iteration 100 is 1.0703318119049072
Loss at iteration 200 is 0.41919204592704773
Loss at iteration 300 is 0.16260117292404175
Loss at iteration 400 is 0.07829469442367554
Loss at iteration 500 is 0.04505010321736336
Loss at iteration 600 is 0.029060449451208115
Loss at iteration 700 is 0.020274262875318527
Loss at iteration 800 is 0.014927276410162449
Loss at iteration 900 is 0.011423378251492977
