In [1]:
import jax
import jax.numpy as jnp
import jax.random as random
from jax.numpy.fft import fft, ifft
import matplotlib.pyplot as plt
import numpy as np
import time

In [46]:
@jax.jit
def circ_mult(w,x):
    return jnp.real(fft(ifft(x.T)*fft(w)).T)

circ_vmap = jax.vmap(lambda col, ind: jnp.roll(col, ind), in_axes=(None,0), out_axes=1)

@jax.jit
def reg_mult(w,x):
    C = circ_vmap(w.reshape(1,-1), jnp.arange(D_Z))[0]
    return C @ x


In [63]:
D_Z = 500
N = 1
w_warm = jax.random.normal(random.PRNGKey(0), (D_Z,))
x_warm = jax.random.normal(random.PRNGKey(1), (N,D_Z))
w = jax.random.normal(random.PRNGKey(0), (D_Z,))
x = jax.random.normal(random.PRNGKey(1), (N,D_Z))
C_warm = circ_vmap(w_warm.reshape(1,-1), jnp.arange(D_Z))[0]
C_warm @ x_warm.T
C_warm @ x_warm.T
C_warm @ x_warm.T
time_start = time.time()
C = circ_vmap(w.reshape(1,-1), jnp.arange(D_Z))[0]
jnp.matrix_transpose(C @ x.T) # to get x @ C'
time_end = time.time()
print(f"Time: {time_end - time_start:.4f}")
# Z = (N, D_Z)
# W = (D_Z, D_Z)
# WZ = (Z^T @ W^T)^T


Time: 0.0134


In [64]:
jnp.real(circ_mult(w_warm, x_warm.T))
jnp.real(circ_mult(w_warm, x_warm.T))
jnp.real(circ_mult(w_warm, x_warm.T))
time_start = time.time()
jnp.real(circ_mult(w, x.T))
time_end = time.time()
print(f"Time: {time_end - time_start:.4f}")

Time: 0.0002


In [72]:
circ_grad = jax.jacobian(circ_mult)
circ_grad(w_warm, x_warm.T)
circ_grad(w_warm, x_warm.T)
circ_grad(w_warm, x_warm.T)
time_start = time.time()
res = circ_grad(w, x.T)
time_end = time.time()
print(f"Time: {time_end - time_start:.4f}. Res shape: {res.shape}")

Time: 0.0071. Res shape: (500, 1, 500)


In [70]:
mult_grad = jax.jacobian(reg_mult)
mult_grad(w_warm, x_warm.T)
mult_grad(w_warm, x_warm.T)
mult_grad(w_warm, x_warm.T)
time_start = time.time()
res = mult_grad(w, x.T)
time_end = time.time()
print(f"Time: {time_end - time_start:.4f}. Res shape: {res.shape}")

Time: 0.5629. Res shape: (500, 1, 500)


In [49]:
reg_mult(w, x.T).shape

(100, 100)

In [50]:
circ_mult(w, x.T).shape

(100, 100)

In [32]:
C.shape

(100, 100)