In this notebook we implement the unoptimized structured transform; i.e. we build a hadamard and orthonormal matrix function.

In [1]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from reservoirtaming.layers.activation import leaky_erf
import numpy as np
from typing import Tuple, Callable

from flax.core import unfreeze
from flax.traverse_util import flatten_dict
from flax.linen.initializers import zeros, normal

In [2]:
# fake test data
key = random.PRNGKey(42)
X = random.normal(key, (1, 4096))

# Diagonal matric

In [3]:
class Diagonal(nn.Module):
    @nn.compact
    def __call__(self, X):
        D = self.param('kernel', random.rademacher, (1, X.shape[1]))
        return D * X

In [4]:
model = Diagonal()
key = random.PRNGKey(42)
params = model.init(key, X)

In [5]:
params

FrozenDict({
    params: {
        kernel: DeviceArray([[-1,  1, -1, ...,  1, -1,  1]], dtype=int32),
    },
})

In [6]:
model.apply(params, X)

DeviceArray([[ 0.82862306, -1.8391167 ,  0.23136322, ...,  2.1180675 ,
              -1.9316142 , -0.5232188 ]], dtype=float32)

# Hadamard 

Let's start with a hadamard initializer

In [7]:
def hadamard(normalized=True, dtype=jnp.float32):
    """ We need the numpy to use it as initializer"""
    def init(key, shape, dtype=dtype):
        n = shape[0]
        # Input validation
        if n < 1:
            lg2 = 0
        else:
            lg2 = np.log2(n)
        assert 2 ** lg2 == n, "shape must be a positive integer and a power of 2."
    
        # Logic
        H = jnp.ones((1, ), dtype=dtype)
        for i in np.arange(lg2):
            H = jnp.vstack([jnp.hstack([H, H]), jnp.hstack([H, -H])])
        
        if normalized:
            H = 2**(-lg2 / 2) * H
        return H
    return init

In [8]:
key = random.PRNGKey(42)
hadamard()(key, (8, ))

DeviceArray([[ 0.35355338,  0.35355338,  0.35355338,  0.35355338,
               0.35355338,  0.35355338,  0.35355338,  0.35355338],
             [ 0.35355338, -0.35355338,  0.35355338, -0.35355338,
               0.35355338, -0.35355338,  0.35355338, -0.35355338],
             [ 0.35355338,  0.35355338, -0.35355338, -0.35355338,
               0.35355338,  0.35355338, -0.35355338, -0.35355338],
             [ 0.35355338, -0.35355338, -0.35355338,  0.35355338,
               0.35355338, -0.35355338, -0.35355338,  0.35355338],
             [ 0.35355338,  0.35355338,  0.35355338,  0.35355338,
              -0.35355338, -0.35355338, -0.35355338, -0.35355338],
             [ 0.35355338, -0.35355338,  0.35355338, -0.35355338,
              -0.35355338,  0.35355338, -0.35355338,  0.35355338],
             [ 0.35355338,  0.35355338, -0.35355338, -0.35355338,
              -0.35355338, -0.35355338,  0.35355338,  0.35355338],
             [ 0.35355338, -0.35355338, -0.35355338,  0.35355338,
   

In [9]:
class HadamardTransform(nn.Module):
    n_hadamard: int
        
    @nn.compact
    def __call__(self, X):
        z = nn.Dense(self.n_hadamard, kernel_init=hadamard(), use_bias=False)(X)
        return z

In [10]:
key = random.PRNGKey(42)
X = random.normal(key, (1, 4096))

In [11]:
model = HadamardTransform(4096)
key = random.PRNGKey(42)
params = model.init(key, X)

In [12]:
print(params)

FrozenDict({
    params: {
        Dense_0: {
            kernel: DeviceArray([[ 0.015625,  0.015625,  0.015625, ...,  0.015625,  0.015625,
                           0.015625],
                         [ 0.015625, -0.015625,  0.015625, ..., -0.015625,  0.015625,
                          -0.015625],
                         [ 0.015625,  0.015625, -0.015625, ...,  0.015625, -0.015625,
                          -0.015625],
                         ...,
                         [ 0.015625, -0.015625,  0.015625, ..., -0.015625,  0.015625,
                          -0.015625],
                         [ 0.015625,  0.015625, -0.015625, ...,  0.015625, -0.015625,
                          -0.015625],
                         [ 0.015625, -0.015625, -0.015625, ..., -0.015625, -0.015625,
                           0.015625]], dtype=float32),
        },
    },
})


In [13]:
model.apply(params, X)

DeviceArray([[-0.8856098, -0.7371275, -0.7132219, ...,  0.7744956,
              -0.5828059,  2.081466 ]], dtype=float32)

# Slow structured transform

In [14]:
class StructuredTransform(nn.Module):
    n_reservoir: int
    n_input: int
    n_layers: int = 3
        
    input_scale: float = 0.4
    res_scale: float = 0.9
    bias_scale: float = 0.1
        
    activation_fn: Callable = leaky_erf
    activation_fn_args: Tuple = (1.0, )
        
    def setup(self):
        #Padding
        self.n_hadamard = int(2 ** jnp.ceil(jnp.log2(self.n_input + self.n_reservoir))) # finding next power of 2
        self.n_padding = int(self.n_hadamard - self.n_reservoir - self.n_input)
        self.padding = jnp.zeros((1, self.n_padding)) 
        
        # Layers
        self.diagonal_layers= [Diagonal() for _ in jnp.arange(self.n_layers)]
        self.hadamard = HadamardTransform(self.n_hadamard)
        self.bias = self.param('bias', normal(stddev=self.bias_scale), (self.n_reservoir, ))
        
    
    def __call__(self, state, inputs): 
        X = jnp.concatenate([self.res_scale * state, self.input_scale * inputs, self.padding], axis=1)
        for diagonal in self.diagonal_layers:
            X = self.hadamard(diagonal(X))
        
        # TODO: check if self.n_hadamard is correct; comes from code from paper
        z = X[:, :self.n_reservoir] / self.n_hadamard + self.bias
        z = self.activation_fn(z, state, *self.activation_fn_args)
        return z
    
    
    @staticmethod
    def initialize_state(rng, n_reservoir, init_fn=zeros):
        return init_fn(rng, (1, n_reservoir))

In [15]:
key = random.PRNGKey(42)
X = random.normal(key, (1, 100))

In [16]:
n_reservoir = 3700
n_input=  X.shape[-1]

model = StructuredTransform(n_reservoir, n_input)
key = random.PRNGKey(42)
state = model.initialize_state(key, n_reservoir)
params = model.init(key, state, X)

In [17]:
params

FrozenDict({
    params: {
        bias: DeviceArray([ 0.15158992, -0.19948767,  0.09550976, ..., -0.04397626,
                      0.01238424,  0.00847415], dtype=float32),
        diagonal_layers_0: {
            kernel: Buffer([[ 1,  1,  1, ..., -1, -1,  1]], dtype=int32),
        },
        hadamard: {
            Dense_0: {
                kernel: DeviceArray([[ 0.015625,  0.015625,  0.015625, ...,  0.015625,  0.015625,
                               0.015625],
                             [ 0.015625, -0.015625,  0.015625, ..., -0.015625,  0.015625,
                              -0.015625],
                             [ 0.015625,  0.015625, -0.015625, ...,  0.015625, -0.015625,
                              -0.015625],
                             ...,
                             [ 0.015625, -0.015625,  0.015625, ..., -0.015625,  0.015625,
                              -0.015625],
                             [ 0.015625,  0.015625, -0.015625, ...,  0.015625, -0.015625,
        

In [44]:
model.apply(params, state, X).shape

(1, 3700)