In [1]:
import jax.numpy as jnp
import jax

from jax.config import config; config.update("jax_enable_x64", True)

import numpy as np
from time import time

In [2]:
m, p, n = 5, 2, 600

shared_delays=False
rng = jax.random.PRNGKey(0)
X_list = jax.random.normal(rng, (m, p, n))
rng, _ = jax.random.split(rng)
W_list = jax.random.normal(rng, (m * p ** 2,))
noise = 1.



In [3]:
def _logcosh(X):
    Y = jnp.abs(X)
    return Y + jnp.log1p(jnp.exp(-2 * Y))


def loss_function(W_list, X_list, noise):
    m, p, n = X_list.shape
    W_list = W_list[:m*p**2].reshape((m, p, p))
    Y_list = jnp.array([jnp.dot(W, X) for W, X in zip(W_list, X_list)])
    Y_avg = jnp.mean(Y_list, axis=0)
    loss = jnp.mean(_logcosh(Y_avg)) * p
    for W, Y in zip(W_list, Y_list):
        loss -= jnp.linalg.slogdet(W)[1]
        loss += 1 / (2 * noise) * jnp.mean((Y - Y_avg) ** 2) * p
    return loss

In [4]:
loss_function(W_list, X_list, noise)

DeviceArray(11.76744027, dtype=float64)

In [5]:
val_and_grad = jax.jit(jax.value_and_grad(loss_function))

In [6]:
# First call: slow because of compil
t0 = time()
val_and_grad(W_list, X_list, noise)
print(f"time : {time() - t0}")

time : 2.534545421600342


In [7]:
# Secon call: fast thanks to jit
t0 = time()
val_and_grad(W_list, X_list, noise)
print(f"time : {time() - t0}")

time : 0.001519918441772461


In [8]:
def wrapper(W_list, X_list, noise):
    loss, grad = val_and_grad(W_list,  X_list, noise)
    return loss, np.array(grad)

In [9]:
from scipy.optimize import fmin_l_bfgs_b

In [10]:
%timeit fmin_l_bfgs_b(wrapper, W_list, args=(X_list, noise))

25.3 ms ± 587 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Add delays

In [11]:
rng, _ = jax.random.split(rng)
max_delay = 20
delays = jax.random.uniform(key=rng, shape=(m, p), minval=-max_delay, maxval=max_delay)

In [12]:
def apply_continuous_delays(
    S_list,
    tau_list,
    shared_delays=False,
):
    m, p, n = S_list.shape
    Y_list = jnp.zeros((m, p, n))
    for i in range(m):
        for j in range(p):
            fy = jnp.fft.fft(S_list[i, j])
            freqs = jnp.fft.fftfreq(fy.size)
            if shared_delays:
                delay = tau_list[i]
            else:
                delay = tau_list[i, j]
            fy *= jnp.exp(-2 * jnp.pi * 1j * delay * freqs)
            y = jnp.fft.ifft(fy)
            Y_list = Y_list.at[i, j].set(jnp.real(y))  # XXX
    return Y_list

S_list = jnp.array([jnp.dot(W, X) for W, X in zip(W_list, X_list)])
Y_list = apply_continuous_delays(S_list, -delays)

In [13]:
def loss_function_delays(W_delays, X_list, noise):
    m, p, n = X_list.shape
    W_list = W_delays[:m*p**2].reshape((m, p, p))
    delays = W_delays[m*p**2:].reshape((m, p))
    S_list = jnp.array([jnp.dot(W, X) for W, X in zip(W_list, X_list)])
    Y_list = apply_continuous_delays(S_list, -delays)
    Y_avg = jnp.mean(Y_list, axis=0)
    loss = jnp.mean(_logcosh(Y_avg)) * p
    for W, Y in zip(W_list, Y_list):
        loss -= jnp.linalg.slogdet(W)[1]
        loss += 1 / (2 * noise) * jnp.mean((Y - Y_avg) ** 2) * p
    return loss

In [14]:
W_delays = jnp.concatenate([jnp.ravel(W_list), jnp.ravel(delays)])
loss_function_delays(W_delays, X_list, noise)

DeviceArray(11.84313674, dtype=float64)

In [15]:
val_and_grad_delays = jax.jit(jax.value_and_grad(loss_function_delays))

In [16]:
# First call: slow because of compil
t0 = time()
val_and_grad_delays(W_delays, X_list, noise)
print(f"time : {time() - t0}")

time : 4.427131175994873


In [17]:
# Secon call: fast thanks to jit
t0 = time()
val_and_grad(W_list, X_list, noise)
print(f"time : {time() - t0}")

time : 0.0012564659118652344


In [18]:
def wrapper_delays(W_delays, X_list, noise):
    loss, grad = val_and_grad_delays(W_delays, X_list, noise)
    return loss, np.array(grad)

In [19]:
%timeit fmin_l_bfgs_b(wrapper_delays, W_delays, args=(X_list, noise))

175 ms ± 2.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
from multiviewica_delay import multiviewica_delay

In [26]:
%timeit multiviewica_delay(X=X_list, noise=noise, init=np.array(W_list).reshape(m, p, p), max_delay=max_delay)

158 ms ± 34.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
