Let's implement the fast walsh hadamrd transform in jax.

In [11]:
import jax.numpy as jnp
from jax import random
from reservoirtaming.layers.utils import HadamardTransform, hadamard
import jax
from jax.ops import index_update
import numpy as np
from functools import reduce

In [2]:
# Making test data
n = 2 ** 12
print(n)
key = random.PRNGKey(42)
X = random.normal(key, (1, n))

4096


In [3]:
H = hadamard(normalized=False)(key, (n, ))

In [4]:
forward_baseline = jax.jit(lambda X: jnp.dot(X, H))

In [5]:
# Triggeri jit
forward_baseline(X)

DeviceArray([[-56.67905 , -47.176163, -45.646194, ...,  49.567734,
              -37.2996  , 133.21384 ]], dtype=float32)

In [6]:
%%timeit
forward_baseline(X).block_until_ready()

168 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]:
X_baseline = forward_baseline(X)

In [8]:
X_baseline

Buffer([[-56.67905 , -47.176163, -45.646194, ...,  49.567734, -37.2996  ,
         133.21384 ]], dtype=float32)

# V1: simple for loops 

In [44]:
def fasthadamardtransform(X):
    h = 1
    n = X.shape[-1]
    while h < n:
        for i in jnp.arange(n, step=h * 2):
            for j in jnp.arange(i, i + h):
                x = X[j]
                y = X[j + h]
                X = index_update(X, j, x + y)
                X = index_update(X, j + h, x - y)
        h *= 2
    return X

In [45]:
X_work = X.copy().squeeze()

In [46]:
%%time
X_new = fasthadamardtransform(X_work)

CPU times: user 48 s, sys: 8.8 s, total: 56.8 s
Wall time: 37 s


In [47]:
jnp.allclose(X_new, X_baseline)

DeviceArray(False, dtype=bool)

In [48]:
jnp.max(jnp.abs(X_new - X_baseline))

DeviceArray(2.2888184e-05, dtype=float32)

# V2; vmapping update

In [49]:
@jax.jit
def single_update(X, h, i, j):
    x = X[j]
    y = X[j + h]
    X = index_update(X, j, x + y)
    X = index_update(X, j + h, x - y)
    return X

In [50]:
def fasthadamardtransform(X):
    h = 1
    n = X.shape[-1]
    while h < n:
        for i in jnp.arange(n, step=h * 2):
            for j in jnp.arange(i, i + h):
                X = single_update(X, h, i ,j)
        h *= 2
    return X

In [51]:
X_work = X.copy().squeeze()

In [52]:
X_new = fasthadamardtransform(X_work)

In [53]:
%%time
X_new = fasthadamardtransform(X_work)

CPU times: user 2.38 s, sys: 681 ms, total: 3.06 s
Wall time: 1.29 s


In [54]:
jnp.max(jnp.abs(X_new - X_baseline))

DeviceArray(2.2888184e-05, dtype=float32)

# V3; for i loops

In [55]:
@jax.jit
def single_update(X, h, i, j):
    x = X[j]
    y = X[j + h]
    X = index_update(X, j, x + y)
    X = index_update(X, j + h, x - y)
    return X

In [56]:
def fasthadamardtransform(X):
    h = 1
    n = X.shape[-1]
    while h < n:
        for i in np.arange(n, step=h * 2):
            X = jax.lax.fori_loop(i, i+h, lambda idx, x: single_update(x, h, i, idx), X)
        h *= 2
    return X

In [57]:
X_new = fasthadamardtransform(X_work)

In [None]:
%%time
X_new = fasthadamardtransform(X_work)

In [23]:
jnp.max(jnp.abs(X_new - X_baseline))

DeviceArray(3.8146973e-05, dtype=float32)

DeviceArray([-2.5212405, -2.4014864,  1.2293223, ...,  2.1180675,
              1.9316142, -0.5232188], dtype=float32)

In [29]:
h = 1
i = 0
X = jax.lax.fori_loop(i, i+h, lambda idx, x: single_update(x, h, i, idx), X)

In [31]:
%%timeit
h = 1
i = 0
jax.lax.fori_loop(i, i+h, lambda idx, x: single_update(x, h, i, idx), X)

30.8 ms ± 894 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# V4; preclaculting h and doing it the other way

In [104]:
@jax.jit
def single_update(X, h, i, j):
    x = X[j]
    y = X[j + h]
    X = index_update(X, j, x + y)
    X = index_update(X, j + h, x - y)
    return X

In [105]:
def fasthadamardtransform(X):
    n = X.shape[-1]
    for h in h_range:
        for i in jnp.arange(n, step=h * 2):
            for j in jnp.arange(i, i + h):
                X = single_update(X, h, i ,j)
    return X

In [106]:
X_work = X.copy().squeeze()

In [107]:
X_new = fasthadamardtransform(X_work)

In [108]:
%%time
X_new = fasthadamardtransform(X_work)

CPU times: user 1.97 s, sys: 750 ms, total: 2.72 s
Wall time: 1.12 s


In [110]:
jnp.max(jnp.abs(X_new - X_baseline))

DeviceArray(210.97363, dtype=float32)

In [77]:
h = 1
n = 2048
while h < n:
    print(h)
    h*=2

1
2
4
8
16
32
64
128
256
512
1024


In [85]:
2 ** jnp.arange(jnp.log2(n), 1)

DeviceArray([], dtype=float32)

In [94]:
jnp.arange(jnp.log2(n)-1, 0, -1)

DeviceArray([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.], dtype=float32)

In [95]:
2 ** 10

1024

In [112]:
X.shape

(2048,)

In [113]:
x, y = jnp.split(X, 2)
jnp.concatenate([x+y, x-y], axis=-1)

In [114]:
x.shape

(1024,)

In [115]:
y.shape

(1024,)

DeviceArray([ 0.25474286,  2.594915  ,  1.4464111 , ..., -2.2168927 ,
              1.0952634 ,  1.3360806 ], dtype=float32)

# Vmapping i and j

In [None]:
@jax.jit
def single_update(X, h, i, j):
    x = X[j]
    y = X[j + h]
    X = index_update(X, j, x + y)
    X = index_update(X, j + h, x - y)
    return X

In [10]:
X = random.normal(key, (1024, ))

In [25]:
h = 1
n = X.shape[-1]
idx = []
while h < n:
    for i in jnp.arange(n, step=h * 2):
        for j in jnp.arange(i, i + h):
            idx.append([j, j+h])
    h *= 2

32

# Cutting in two:

In [27]:
X

DeviceArray([-0.02862089,  1.5240539 , -1.0556508 , ..., -2.4457607 ,
             -0.2306908 , -0.1957571 ], dtype=float32)

In [29]:
x, y = jnp.split(X, 2)

In [31]:
.shape

(1024,)

In [46]:
def hadamard_update(X):
    x, y = jnp.split(X, 2)
    return jnp.concatenate([x + y, x - y], axis=-1)

In [35]:
jnp.split(X, 1)

[DeviceArray([-0.02862089,  1.5240539 , -1.0556508 , ..., -2.4457607 ,
              -0.2306908 , -0.1957571 ], dtype=float32)]

In [41]:
h_range = 2 ** jnp.arange(jnp.log2(X.shape[-1]))
print(h_range)

[  1.   2.   4.   8.  16.  32.  64. 128. 256. 512.]


In [None]:
for h in h_range:
    

In [44]:
jnp.stack(jnp.split(X, 2), axis=0).shape

(2, 512)

In [47]:
hadamard_update(X)

DeviceArray([-0.29711854,  0.16825843, -0.32572258, ...,  2.7445283 ,
              0.25652558,  0.9220723 ], dtype=float32)

DeviceArray([[ 0.2505168 ,  1.2786218 , -0.4245987 , ..., -1.5220637 ,
               1.5918701 , -0.12556547],
             [-0.87318814, -2.6005206 ,  2.8852375 , ...,  2.0790575 ,
               1.7609035 ,  0.8252154 ]], dtype=float32)

In [None]:
for h in h_range:
    jax.vmap(hadamard_update, in_axes=0)(jnp.stack(jnp.split(X, 2), axis=0))

# Actually making it work

In [None]:
def hadamard_transform_torch(u, normalize=False):
    """Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n.
    n must be a power of 2.
    Parameters:
        u: Tensor of shape (..., n)
        normalize: if True, divide the result by 2^{m/2} where m = log_2(n).
    Returns:
        product: Tensor of shape (..., n)
    """
    batch_size, n = u.shape
    m = int(np.log2(n))
    assert n == 1 << m, 'n must be a power of 2'
    x = u[..., np.newaxis]
    for d in range(m)[::-1]:
        x = torch.cat((x[..., ::2, :] + x[..., 1::2, :], x[..., ::2, :] - x[..., 1::2, :]), dim=-1)
    return x.squeeze(-2) / 2**(m / 2) if normalize else x.squeeze(-2)

In [23]:
def hadamard_transform(u):
    """Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n.
    n must be a power of 2.
    Parameters:
        u: Tensor of shape (..., n)
        normalize: if True, divide the result by 2^{m/2} where m = log_2(n).
    Returns:
        product: Tensor of shape (..., n)
    """
    batch_size, n = u.shape
    m = int(jnp.log2(n))
    assert n == 1 << m, 'n must be a power of 2'
    x = u[..., jnp.newaxis]
    for d in jnp.arange(m)[::-1]:
        print(x.shape)
        x = jnp.concatenate((x[..., ::2, :] + x[..., 1::2, :], x[..., ::2, :] - x[..., 1::2, :]), axis=-1)
    return x.squeeze(-2)

In [31]:
%%time
X_new = hadamard_transform(X)

(1, 4096, 1)
(1, 2048, 2)
(1, 1024, 4)
(1, 512, 8)
(1, 256, 16)
(1, 128, 32)
(1, 64, 64)
(1, 32, 128)
(1, 16, 256)
(1, 8, 512)
(1, 4, 1024)
(1, 2, 2048)
CPU times: user 43.9 ms, sys: 28.1 ms, total: 72 ms
Wall time: 56.4 ms


In [51]:
jnp.max(jnp.abs(X_new - X_baseline))

DeviceArray(3.8146973e-05, dtype=float32)

In [54]:
m_max = jnp.log2(X.shape[-1])
m_range = (2 ** jnp.arange(m_max)).astype(int)
z = X

In [55]:
for m in m_range:
    z = z.reshape(1, -1, m)
    z = jnp.concatenate((z[:, ::2, :] + z[:, 1::2, :], z[:, ::2, :] - z[:, 1::2, :]), axis=-1)
    z = z.reshape(1, -1)

In [56]:
z 

DeviceArray([[-56.67905 , -47.176163, -45.6462  , ...,  49.56773 ,
              -37.299606, 133.21384 ]], dtype=float32)

In [58]:
jnp.max(jnp.abs(z - X_baseline))

DeviceArray(3.8146973e-05, dtype=float32)

In [171]:
def hadamard_update(X, m):
    power = np.power(2, m).astype(int)
    X = X.reshape(1, -1, power)
    X = jnp.concatenate((X[:, ::2, :] + X[:, 1::2, :], X[:, ::2, :] - X[:, 1::2, :]), axis=-1)
    return X.reshape(1, -1)

In [143]:
@jax.jit
def hadamard_transform(X):
    m_max = np.log2(X.shape[-1])
    for m in np.arange(m_max):
        X = hadamard_update(X, m)
    return X

In [144]:
%%timeit
hadamard_transform(X)

57.5 µs ± 70.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


Okay not bad... Can we jit it?

In [187]:
update = jax.jit(hadamard_update, static_argnums=(1, ))

In [188]:
# triggering compilation
update(X, 3)

DeviceArray([[-0.5576862 , -3.1554017 , -0.22818078, ..., -2.748053  ,
              -1.6223387 ,  0.9490731 ]], dtype=float32)

In [189]:
%%timeit
update(X, 3)

31.3 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [102]:
m_max = jnp.log2(X.shape[-1])

In [None]:
jax.lax.fori_loop(0, int(m_max), lambda m, X: hadamard_update(X, m), X)

In [72]:
m_range

DeviceArray([   1,    2,    4,    8,   16,   32,   64,  128,  256,  512,
             1024, 2048], dtype=int32)

In [117]:
from functools import reduce

In [190]:
# once to run all jit updates
reduce(lambda X, idx: update(X, idx), np.arange(m_max), X)

DeviceArray([[-56.67905 , -47.176163, -45.6462  , ...,  49.56773 ,
              -37.299606, 133.21384 ]], dtype=float32)

In [191]:
%%timeit
reduce(lambda X, idx: update(X, idx), np.arange(m_max), X)

1 ms ± 2.77 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [124]:
z = jnp.expand_dims(X, -1)

In [127]:
f = lambda z, _: jnp.concatenate((z[:, ::2, :] + z[:, 1::2, :], z[:, ::2, :] - z[:, 1::2, :]), axis=-1)

In [148]:
%%time
reduce(f, np.arange(m_max), z)

CPU times: user 52.5 ms, sys: 600 µs, total: 53.1 ms
Wall time: 39.1 ms


DeviceArray([[[-56.67905 , -47.176163, -45.6462  , ...,  49.56773 ,
               -37.299606, 133.21384 ]]], dtype=float32)

In [192]:
@jax.jit
def hadamard_transform(X):
    m_max = np.log2(X.shape[-1])
    z = jnp.expand_dims(X, -1)
    f = lambda z, _: jnp.concatenate((z[:, ::2, :] + z[:, 1::2, :], z[:, ::2, :] - z[:, 1::2, :]), axis=-1)
    z = reduce(f, np.arange(m_max), z)
    return z.squeeze()

In [193]:
# triggering jit
hadamard_transform(X)

DeviceArray([-56.67905 , -47.176163, -45.6462  , ...,  49.56773 ,
             -37.299606, 133.21384 ], dtype=float32)

In [194]:
%%timeit
hadamard_transform(X).block_until_ready()

58.8 µs ± 94.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [195]:
from functools import partial

In [196]:
@jax.jit
def hadamard_transform(X):
    @partial(jax.jit, static_argnums=(1, ))
    def hadamard_update(X, m):
        power = np.power(2, m).astype(int)
        X = X.reshape(1, -1, power)
        X = jnp.concatenate((X[:, ::2, :] + X[:, 1::2, :], X[:, ::2, :] - X[:, 1::2, :]), axis=-1)
        return X.reshape(1, -1)
    
    m_max = np.log2(X.shape[-1])
    X = reduce(lambda X, idx: hadamard_update(X, idx), np.arange(m_max), X)
    return X

In [197]:
# triggering jit
hadamard_transform(X)

DeviceArray([[-56.67905 , -47.176163, -45.6462  , ...,  49.56773 ,
              -37.299606, 133.21384 ]], dtype=float32)

In [198]:
%%timeit
hadamard_transform(X)

57.8 µs ± 51.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# tweaking super fast solution

In [82]:
from functools import partial

In [58]:
@jax.jit
def hadamard_transform(X):
    m_max = np.log2(X.shape[-1])
    z = jnp.expand_dims(X, -1)
    f = lambda z, _: jnp.concatenate((z[:, ::2, :] + z[:, 1::2, :], z[:, ::2, :] - z[:, 1::2, :]), axis=-1)
    z = reduce(f, np.arange(m_max), z)
    return z.squeeze(-2)

In [87]:
# triggering jit
X_transformed = hadamard_transform(X)

In [66]:
jnp.allclose(X_transformed, X_baseline, atol=1e-4)

DeviceArray(True, dtype=bool)

In [67]:
%%timeit
hadamard_transform(X).block_until_ready()

59.4 µs ± 867 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [88]:
@jax.jit
def hadamard_transform(X):
    @partial(jax.jit, static_argnums=(1, ))
    def update(z, m):
        x = z[:, ::2, :]
        y = z[:, 1::2, :]
        return jnp.concatenate((x+y, x-y), axis=-1)

    m_max = np.log2(X.shape[-1])
    z = jnp.expand_dims(X, -1)
    z = reduce(update, np.arange(m_max), z)
    return z.squeeze(-2)

In [89]:
# triggering jit
X_transformed = hadamard_transform(X)

In [90]:
%%timeit
hadamard_transform(X).block_until_ready()

58.7 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [None]:
jnp.allclose(X_transformed, X_baseline, atol=1e-4)

If python is row major, its the last axis which is easiest to read 

In [106]:
@jax.jit
def hadamard_transform(X):
    def update(z, m):
        x = z[:, ::2, :]
        y = z[:, 1::2, :]
        return jnp.concatenate((x+y, x-y), axis=-1)

    m_max = np.log2(X.shape[-1])
    z = jnp.expand_dims(X, -1)
    z = reduce(update, np.arange(m_max), z)
    return z.squeeze(-2)

In [107]:
# triggering jit
X_transformed = hadamard_transform(X)

In [101]:
%%timeit
hadamard_transform(X).block_until_ready()

55.9 µs ± 98.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
