In [1]:
import jax
import jax.numpy as jnp
import jax.random as random
from jax.numpy.fft import fft, ifft, fftn, ifftn
import matplotlib.pyplot as plt
import numpy as np
import time

In [2]:
@jax.jit
def circ_mult(w,x):
    return jnp.real(fft(ifft(x.T)*fft(w)).T)

circ_vmap = jax.vmap(lambda col, ind: jnp.roll(col, ind), in_axes=(None,0), out_axes=1)

@jax.jit
def reg_mult(w,x):
    C = circ_vmap(w.reshape(1,-1), jnp.arange(D_Z))[0]
    return C @ x


In [3]:
w = jnp.array([1,2,3,4,5])
print(w.shape)
C = circ_vmap(w, jnp.arange(5))
C

(5,)


Array([[1, 5, 4, 3, 2],
       [2, 1, 5, 4, 3],
       [3, 2, 1, 5, 4],
       [4, 3, 2, 1, 5],
       [5, 4, 3, 2, 1]], dtype=int32)

In [4]:
D_Z = 500
N = 1
w_warm = jax.random.normal(random.PRNGKey(0), (D_Z,))
x_warm = jax.random.normal(random.PRNGKey(1), (N,D_Z))
w = jax.random.normal(random.PRNGKey(0), (D_Z,))
x = jax.random.normal(random.PRNGKey(1), (N,D_Z))
C_warm = circ_vmap(w_warm.reshape(1,-1), jnp.arange(D_Z))[0]
C_warm @ x_warm.T
C_warm @ x_warm.T
C_warm @ x_warm.T
time_start = time.time()
C = circ_vmap(w.reshape(1,-1), jnp.arange(D_Z))[0]
jax.block_until_ready(jnp.matrix_transpose(C @ x.T)) # to get x @ C'
time_end = time.time()
print(f"Time: {time_end - time_start:.4f}")
# Z = (N, D_Z)
# W = (D_Z, D_Z)
# WZ = (Z^T @ W^T)^T


Time: 0.0121


In [5]:
jnp.real(circ_mult(w_warm, x_warm.T))
jnp.real(circ_mult(w_warm, x_warm.T))
jnp.real(circ_mult(w_warm, x_warm.T))
time_start = time.time()
jnp.real(jax.block_until_ready(circ_mult(w, x.T)))
time_end = time.time()
print(f"Time: {time_end - time_start:.4f}")

Time: 0.0002


In [6]:
circ_grad = jax.jacobian(circ_mult)
circ_grad(w_warm, x_warm.T)
circ_grad(w_warm, x_warm.T)
circ_grad(w_warm, x_warm.T)
time_start = time.time()
res = circ_grad(w, x.T)
time_end = time.time()
print(f"Time: {time_end - time_start:.4f}. Res shape: {res.shape}")

Time: 0.0041. Res shape: (500, 1, 500)


In [7]:
mult_grad = jax.jacobian(reg_mult)
mult_grad(w_warm, x_warm.T)
mult_grad(w_warm, x_warm.T)
mult_grad(w_warm, x_warm.T)
time_start = time.time()
res = mult_grad(w, x.T)
time_end = time.time()
print(f"Time: {time_end - time_start:.4f}. Res shape: {res.shape}")

Time: 0.5250. Res shape: (500, 1, 500)


In [8]:
reg_mult(w, x.T)

Array([[ 4.59632969e+00],
       [-1.41770697e+01],
       [-1.21102448e+01],
       [ 1.31331644e+01],
       [-3.50147400e+01],
       [-8.32195663e+00],
       [ 4.55397301e+01],
       [-2.67613316e+01],
       [ 1.38328876e+01],
       [-4.97577286e+01],
       [ 1.30531139e+01],
       [ 1.91941700e+01],
       [-4.72279320e+01],
       [ 6.16238642e+00],
       [ 1.93773136e+01],
       [-9.37050247e+00],
       [-2.81911125e+01],
       [ 8.06140900e+00],
       [ 8.72234631e+00],
       [-1.39772310e+01],
       [ 2.52091484e+01],
       [-3.29163589e+01],
       [ 5.11376572e+00],
       [ 4.49327707e+00],
       [-2.97822895e+01],
       [ 1.39635248e+01],
       [ 2.20934982e+01],
       [ 1.06957781e+00],
       [ 8.10378361e+00],
       [-3.91817398e+01],
       [-9.13887215e+00],
       [-4.08212051e+01],
       [ 1.02809870e+00],
       [ 2.31028538e+01],
       [ 1.64405537e+01],
       [ 1.43364449e+01],
       [-6.41580534e+00],
       [-1.76827488e+01],
       [ 5.0

In [9]:
circ_mult(w, x.T)

Array([[ 4.59633446e+00],
       [-1.41770744e+01],
       [-1.21102438e+01],
       [ 1.31331625e+01],
       [-3.50147285e+01],
       [-8.32195663e+00],
       [ 4.55397453e+01],
       [-2.67613525e+01],
       [ 1.38328876e+01],
       [-4.97577477e+01],
       [ 1.30531197e+01],
       [ 1.91941643e+01],
       [-4.72279396e+01],
       [ 6.16239262e+00],
       [ 1.93773117e+01],
       [-9.37049675e+00],
       [-2.81911316e+01],
       [ 8.06141663e+00],
       [ 8.72235203e+00],
       [-1.39772453e+01],
       [ 2.52091389e+01],
       [-3.29163513e+01],
       [ 5.11376476e+00],
       [ 4.49327087e+00],
       [-2.97822952e+01],
       [ 1.39635162e+01],
       [ 2.20935059e+01],
       [ 1.06958675e+00],
       [ 8.10378075e+00],
       [-3.91817398e+01],
       [-9.13887024e+00],
       [-4.08212051e+01],
       [ 1.02809334e+00],
       [ 2.31028652e+01],
       [ 1.64405594e+01],
       [ 1.43364525e+01],
       [-6.41579485e+00],
       [-1.76827469e+01],
       [ 5.0

In [10]:
C.shape

(500, 500)

In [11]:
ifft(fft(w) * fft(x.T))[0,0]

Array(-0.2505963+2.384186e-09j, dtype=complex64)

# Computation of $Cx$ using DFT
Cheng et al claims that
$$Cx=\mathcal{F}^{-1}(\mathcal{F}(r) \circ \mathcal{F}(x))$$
However, this does not work (see below)
I have found a solution that does work though (based on https://web.mit.edu/18.06/www/Spring17/Circulant-Matrices.pdf)
$$Cx=\mathcal{F}(\mathcal{F}(r) \circ \mathcal{F}^{-1}(x))$$

In [12]:
R = jnp.array([[1,3,2],[2,1,3],[3,2,1]])
r = jnp.array([1,2,3])
C = circ_vmap(r, jnp.arange(3))
assert jnp.allclose(C, R)
x = jnp.array([[4,5,6], [7,8,9], [10,11,12], [13,14,15]])
print(R)
print(r)
print(x)

[[1 3 2]
 [2 1 3]
 [3 2 1]]
[1 2 3]
[[ 4  5  6]
 [ 7  8  9]
 [10 11 12]
 [13 14 15]]


In [13]:
print("Reference", (x @ R), " of shapes ", x.shape, R.shape)
#print(fft(r)[:, None] * fftn(x, axes=(0,)).T)
#print("Cheng et al.", jnp.real(ifftn(fft(r)[:, None] * fftn(x, axes=(0,)).T, axes=(1,)).T))
print("Cheng et al.", jnp.real(ifft(fft(r) * fft(x))))
print("Reference, transposed", (R.T @ x.T))
print("My fix", jnp.real(fft(fft(r) * ifft(x))))

Reference [[32 29 29]
 [50 47 47]
 [68 65 65]
 [86 83 83]]  of shapes  (4, 3) (3, 3)
Cheng et al. [[31. 31. 28.]
 [49. 49. 46.]
 [67. 67. 64.]
 [85. 85. 82.]]
Reference, transposed [[32 50 68 86]
 [29 47 65 83]
 [29 47 65 83]]
My fix [[32. 29. 29.]
 [50. 47. 47.]
 [68. 65. 65.]
 [86. 83. 83.]]


In [14]:
jnp.real(fft(fft(w) * ifft(x[0]))[:3])

TypeError: mul got incompatible shapes for broadcasting: (500,), (3,).

In [None]:
def fft_mult(r, x):
    return jnp.real(ifft(fft(r) * fft(x)))
fft_jac = jax.jacobian(lambda r, x: fft_mult(r, x))

In [None]:
fft_grad(r.astype(jnp.float32),x.astype(jnp.float32)[:,0])

Array([4., 6., 5.], dtype=float32)

In [15]:
jax.grad(lambda x: jnp.real(ifft(x)[0]))(fft(r) * fft(x[0]))[0] @ ifft(x[0])[0]

ValueError: matmul input operand 0 must have ndim at least 1, but it has ndim 0

In [16]:
jax.make_jaxpr(fft_jac)(r.astype(jnp.float32),x.astype(jnp.float32)[:,0])

NameError: name 'fft_jac' is not defined

# Multiple Circ-Mult

When the network is expanding (for example, when going from 13 input channels to 50 hidden-layer channels), a (13, 50) weight matrix is needed. Since this is not square, this cannot be a complete circulant matrix.
However, it is possible to construct $\lceil 50 / 13 \rceil = 4\ \times$ (13,13) circulant matrices, stack the corresponding projections to a (13,13*4) = (13,52) matrix and then only use the (13,50)'th values.

In [51]:
# Example: 3 channels -> 14 channels
#key = random.PRNGKey(0)

num_circ = jnp.ceil(14/3).astype(int)
x = random.normal(random.PRNGKey(123), (10,3))
w_vectors = jnp.array([random.normal(random.PRNGKey(i), (3,)) for i in range(num_circ)])
w_circs = jnp.array([circ_vmap(w_vectors[i], jnp.arange(3)) for i in range(num_circ)])
zs = jnp.array([x @ w_circs[i] for i in range(num_circ)])
ground_truth = jnp.concat(zs, axis=1)[:, :14]

@jax.jit
def expand_circ_mult(w,x): # w has (num_circ, D_X), x has (N, D_X)
    x_fft = ifft(x)
    x_fft = jnp.repeat(x_fft[:, None, :], w.shape[0], axis=1)
    return jnp.real(fft(fft(w) * x_fft)).reshape(x.shape[0], -1) # (N, num_circ * D_X)

print("Ground truth", ground_truth[-6])
print("FFT-based", expand_circ_mult(w_vectors, x)[:, :14][-6])


Ground truth [ 0.19185932  0.8242394  -4.2175813   0.3819505  -0.03369207 -0.1433691
  1.321313    0.46177176 -2.6862946   3.1982338  -3.0110075  -0.66095805
 -0.07485346  2.107948  ]
FFT-based [ 0.19185925  0.8242395  -4.2175817   0.38195047 -0.03369207 -0.1433691
  1.3213131   0.46177173 -2.6862946   3.1982338  -3.0110073  -0.6609582
 -0.07485355  2.107948  ]
