In [17]:
import numpy as np
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from dataclasses import dataclass

import flax
import jax

print(f'Jax version: {jax.__version__}, Flax version: {flax.__version__}')

Jax version: 0.2.9, Flax version: 0.3.0


In [18]:
def hadamard(normalized=True, dtype=jnp.float32):
    """ We need the numpy to use it as initializer"""

    def init(key, shape, dtype=dtype):
        n = shape[0]
        # Input validation
        if n < 1:
            lg2 = 0
        else:
            lg2 = np.log2(n)
        assert 2 ** lg2 == n, "shape must be a positive integer and a power of 2."

        # Logic
        H = jnp.ones((1,), dtype=dtype)
        for i in np.arange(lg2):
            H = jnp.vstack([jnp.hstack([H, H]), jnp.hstack([H, -H])])

        if normalized:
            H = 2 ** (-lg2 / 2) * H
        return H

    return init

In [19]:
class HadamardTransformFlax(nn.Module):
    n_hadamard: int

    def setup(self):
        self.H = hadamard()(None, (self.n_hadamard,))
    
    def __call__(self, X):
        return jnp.dot(X, self.H)

In [20]:
@dataclass
class HadamardTransform():
    n_hadamard: int

    def __post_init__(self):
        self.H = hadamard()(None, (self.n_hadamard,))

    def __call__(self, X):
        return jnp.dot(X, self.H)

In [30]:
# Making test data
n = 2 ** 10
print(n)
key = random.PRNGKey(42)
X = random.normal(key, (1, n))

1024


In [31]:
H = hadamard(normalized=True)(key, (n, ))

In [47]:
forward = jax.jit(lambda X: jnp.dot(X, H))

In [48]:
%%timeit
X_transformed = forward(X).block_until_ready()

30.8 µs ± 85.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [49]:
model = HadamardTransformFlax(n)
params = model.init(key, X)

In [50]:
forward = jax.jit(lambda X: model.apply(params, X))

In [51]:
%%timeit
X_transformed = forward(X).block_until_ready()

48.8 µs ± 41.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [17]:
model = HadamardTransform(n)

In [18]:
%%timeit
X_transformed = model(X[None, :])

242 µs ± 215 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
class 