In this notebook we implement the unoptimized structured transform; i.e. we build a hadamard and orthonormal matrix function.

In [1]:
import jax
import jax.numpy as jnp
from jax import random

Let's start with a hadamard matrix:

In [2]:
def hadamard(n, normalized=True, dtype=int):
    if n < 1:
        lg2 = 0
    else:
        lg2 = jnp.log2(n)
    assert 2 ** lg2 == n, "n must be a positive integer and a power of 2."
    
    H = jnp.ones((1, ), dtype=dtype)
    for i in jnp.arange(lg2):
        H = jnp.vstack([jnp.hstack([H, H]), jnp.hstack([H, -H])])
        
    if normalized:
        H = 2**(-lg2 / 2) * H
        
    return H

In [3]:
hadamard(2)

DeviceArray([[ 0.7071067,  0.7071067],
             [ 0.7071067, -0.7071067]], dtype=float32)

In [4]:
def init_structured_transform(key, n):
    diagonal_keys = random.split(key, 3)
    D1, D2, D3 = [random.rademacher(key, (n, 1)) for key in diagonal_keys]
    H = hadamard(n)
    return H, D1, D2, D3

In [5]:
key = random.PRNGKey(42)
n = 4096

H, D1, D2, D3 = init_structured_transform(key, n)

In [6]:
# Making some fake data
X = random.normal(key, (n, ))

In [7]:
%%timeit
z = jnp.dot(H, D3 * X)
z = jnp.dot(H, D2 * z)
z = jnp.dot(H, D1 * z)

The slowest run took 45.98 times longer than the fastest. This could mean that an intermediate result is being cached.
23.3 ms ± 16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%time
z = jax.jit(lambda X: jnp.linalg.multi_dot([H, D1 * H , D2 * H, D3 * X]))

CPU times: user 71 µs, sys: 28 µs, total: 99 µs
Wall time: 110 µs


In [9]:
%%time
z(X).block_until_ready()

CPU times: user 125 ms, sys: 125 ms, total: 250 ms
Wall time: 278 ms


DeviceArray([[-0.29267862, -0.64960045, -0.08171995, ...,  0.74812335,
               0.68226695, -0.18480723],
             [ 0.02463144,  0.05466538,  0.00687742, ..., -0.06295658,
              -0.05741588,  0.01555277],
             [ 0.72903943,  1.6180907 ,  0.20355886, ..., -1.8635184 ,
              -1.6994754 ,  0.46033984],
             ...,
             [-1.0537311 , -2.338742  , -0.29421806, ...,  2.6934826 ,
               2.4563704 , -0.6653619 ],
             [ 1.5551366 ,  3.4515965 ,  0.43421656, ..., -3.9751225 ,
              -3.625196  ,  0.98196113],
             [-0.2790252 , -0.6192892 , -0.07790655, ...,  0.71321386,
               0.65043354, -0.17618288]], dtype=float32)

In [10]:
%%timeit
z(X)

37.3 ms ± 88 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
@jax.jit
def structured_update(f, H, D1, D2, D3, bias, input_scale, res_scale, leak_rate, reservoir_state, inputs):
    X = jnp.concatenate([res_scale * reservoir_state, input_scale * inputs])
    z = sqrt(p) / sigma * jnp.linalg.multi_dot([H, D1 * H , D2 * H, D3 * X]) + bias
    new_state = (1.0 - leak_rate) * reservoir_state + leak_rate * f(z) / jnp.sqrt(n_reservoir)
    
    save_state = jnp.concatenate([new_state, renorm_factor * inputs])