In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from time import time
import matplotlib.pyplot as plt
from apply_dilations_shifts import apply_both_delays_3d_cyclic, apply_dilations_shifts_3d_jax
from lbfgsb_loss_and_grad import loss

# apply_delay function without for loops or np.argmin

In [3]:
def apply_dilations_shifts_3d_no_argmin(
    S, dilations, shifts, max_dilation=1., max_shift=0., shift_before_dilation=True, n_concat=1,
):
    m, p, n_total = S.shape
    n = n_total // n_concat
    S_4d = jnp.moveaxis(jnp.array(jnp.split(S, n_concat, axis=-1)), source=0, destination=2)
    max_delay_time = (1 + max_shift) * max_dilation - 1
    max_delay_samples = np.ceil(max_delay_time * n).astype("int")
    t_extended = jnp.arange(n+2*max_delay_samples) - max_delay_samples
    t = jnp.arange(n)
    T = jnp.array([t] * m * p * n_concat).reshape(S_4d.shape)
    dilations_newaxis = dilations[:, :, jnp.newaxis, jnp.newaxis]
    shifts_newaxis = shifts[:, :, jnp.newaxis, jnp.newaxis] * n
    if shift_before_dilation:
        T_ds = (T - shifts_newaxis) * dilations_newaxis
    else:
        T_ds = T * dilations_newaxis - shifts_newaxis
    # ind = jnp.rint(T_ds).astype(int) + max_delay_samples
    # ind = jnp.clip(ind, 1, n+2*max_delay_samples-2)
    # T_ref = t_extended[ind]
    T_ref = jnp.clip(jnp.rint(T_ds).astype(int), -max_delay_samples+1, n+max_delay_samples-2)
    ind = jnp.clip(T_ref + max_delay_samples, 1, n+2*max_delay_samples-2)
    T_ds_clipped = jnp.clip(T_ds, -max_delay_samples+1, n+max_delay_samples-2)
    # jnp.copysign allows to avoid the case sign==0
    signs = jnp.copysign(1, jnp.sign(T_ds_clipped - T_ref)).astype(int)
    S_extended = jnp.concatenate([S_4d[:, :, :, n-max_delay_samples:], S_4d, S_4d[:, :, :, :max_delay_samples]], axis=-1)
    S_extended_ind = jnp.take_along_axis(S_extended, ind, axis=-1)
    t_extended_ind = t_extended[ind]
    slopes = (jnp.take_along_axis(S_extended, ind+signs, axis=-1) - S_extended_ind) / (
        t_extended[ind + signs] - t_extended_ind)
    intercepts = S_extended_ind - slopes * t_extended_ind
    S_ds = slopes * T_ds + intercepts
    S_ds = S_ds.reshape((m, p, -1))
    return S_ds

We use 3 different apply_delay functions:
- apply_both_delays_3d_cyclic which uses for loops
- apply_dilations_shifts_3d_jax which vectorizes the for loops but needs to call np.argmin
- apply_dilations_shifts_3d_no_argmin which vectorizes the for loops but doesn't need to call np.argmin

# parameters

In [4]:
m = 10
p = 5
n_concat = 3
n = 600
max_dilation = 1.15
max_shift = 0.05
shift_before_dilation = False
random_state = 12
rng = np.random.RandomState(random_state)
functions = [apply_both_delays_3d_cyclic, apply_dilations_shifts_3d_jax, apply_dilations_shifts_3d_no_argmin]
messages = ["For loops : ", "Vectorize : ", "Vectorize without np.argmin : "]

# apply_delay time without jit

In [5]:
def run_expe(
    function, m=m, p=p, n_concat=n_concat, n=n, max_dilation=max_dilation, max_shift=max_shift,
    shift_before_dilation=shift_before_dilation, rng=rng, message=None,
):
    S_list_3d = rng.randn(m, p, n_concat * n)
    dilations = rng.uniform(low=1/max_dilation, high=max_dilation, size=(m, p))
    shifts = rng.uniform(low=-max_shift, high=max_shift, size=(m, p))
    if message is not None:
        print(message)
    start = time()
    _ = function(S_list_3d, dilations, shifts, max_dilation, max_shift, shift_before_dilation, n_concat)
    return time() - start

In [6]:
# first call to functions
for i, function in enumerate(functions):
    print(run_expe(function, message=messages[i]))

For loops : 
1.2799110412597656
Vectorize : 
0.6003360748291016
Vectorize without np.argmin : 
1.024425983428955


In [7]:
# second call to functions (faster)
for i, function in enumerate(functions):
    print(run_expe(function, message=messages[i]))

For loops : 
0.3252229690551758
Vectorize : 
0.03591418266296387
Vectorize without np.argmin : 
0.03495621681213379


# jit

In [8]:
apply_both_delays_3d_cyclic_jit = jax.jit(apply_both_delays_3d_cyclic, static_argnums=(3, 4, 5, 6))
print(run_expe(apply_both_delays_3d_cyclic_jit, message="Jit..."))

Jit...
18.057190895080566


In [9]:
apply_dilations_shifts_3d_jax_jit = jax.jit(apply_dilations_shifts_3d_jax, static_argnums=(3, 4, 5, 6))
print(run_expe(apply_dilations_shifts_3d_jax_jit, message="Jit..."))

Jit...
0.2545020580291748


In [10]:
apply_dilations_shifts_3d_no_argmin_jit = jax.jit(apply_dilations_shifts_3d_no_argmin, static_argnums=(3, 4, 5, 6))
print(run_expe(apply_dilations_shifts_3d_no_argmin_jit, message="Jit..."))

Jit...
0.23122787475585938


# apply_delay time after jit

In [11]:
functions_jit = [apply_both_delays_3d_cyclic_jit, apply_dilations_shifts_3d_jax_jit, apply_dilations_shifts_3d_no_argmin_jit]
for i, function in enumerate(functions_jit):
    print(run_expe(function, message=messages[i]))

For loops : 
0.006395101547241211
Vectorize : 
0.0033071041107177734
Vectorize without np.argmin : 
0.002466917037963867


# time of the loss function

In [12]:
dilation_scale = rng.uniform(low=30, high=40, size=(m, p))
shift_scale = 60.
W_init = jnp.array(rng.randn(m, p, p))
A_init = jnp.ones((m, p)) * dilation_scale
B_init = jnp.zeros((m, p))
W_A_B_init = jnp.concatenate([jnp.ravel(W_init), jnp.ravel(A_init), jnp.ravel(B_init)])
X_list = rng.randn(m, p, n_concat * n)
noise_model = 1
number_of_filters_envelop = 1
filter_length_envelop = 10
number_of_filters_squarenorm_f = 0
filter_length_squarenorm_f = 3
use_envelop_term = True

args = (
    W_A_B_init,
    X_list,
    dilation_scale,
    shift_scale,
    max_shift,
    max_dilation,
    noise_model,
    number_of_filters_envelop,
    filter_length_envelop,
    number_of_filters_squarenorm_f,
    filter_length_squarenorm_f,
    use_envelop_term,
    n_concat,
)

In [13]:
def loss_1(*args):
    return loss(apply_delays_function=apply_both_delays_3d_cyclic, *args)

def loss_2(*args):
    return loss(apply_delays_function=apply_dilations_shifts_3d_jax, *args)

def loss_3(*args):
    return loss(apply_delays_function=apply_dilations_shifts_3d_no_argmin, *args)

In [14]:
print(loss_1(*args))
print(loss_2(*args))
print(loss_3(*args))

60.265755
60.265755
95.05128


The loss is different for apply_dilations_shifts_3d_no_argmin. Indeed, when we plot sources delayed by apply_dilations_shifts_3d_jax (or apply_both_delays_3d_cyclic) and apply_dilations_shifts_3d_no_argmin, we can see very little differences. But this is not important. 

In [15]:
val_and_grad_1 = jax.jit(jax.value_and_grad(loss_1), static_argnums=tuple(np.arange(3, 13)))
val_and_grad_2 = jax.jit(jax.value_and_grad(loss_2), static_argnums=tuple(np.arange(3, 13)))
val_and_grad_3 = jax.jit(jax.value_and_grad(loss_3), static_argnums=tuple(np.arange(3, 13)))

In [16]:
print("Jit...")
_ = val_and_grad_1(*args)

Jit...


In [17]:
print("Jit...")
_ = val_and_grad_2(*args)

Jit...


In [18]:
print("Jit...")
_ = val_and_grad_3(*args)

Jit...


In [39]:
print("For loops : ")
start = time()
_ = val_and_grad_1(*args)
print(time() - start)

print("Vectorize : ")
start = time()
_ = val_and_grad_2(*args)
print(time() - start)

print("Vectorize without np.argmin : ")
start = time()
_ = val_and_grad_3(*args)
print(time() - start)

For loops : 
0.02010512351989746
Vectorize : 
0.07315492630004883
Vectorize without np.argmin : 
0.023099899291992188


In [40]:
n_it = 100
times_1 = []
times_2 = []
times_3 = []
for i in range(n_it):
    start = time()
    _ = val_and_grad_1(*args)
    times_1.append(time() - start)
    
    start = time()
    _ = val_and_grad_2(*args)
    times_2.append(time() - start)
    
    start = time()
    _ = val_and_grad_3(*args)
    times_3.append(time() - start)

In [41]:
print(f"Time for loops : {np.mean(times_1)}")
print(f"Time vectorize : {np.mean(times_2)}")
print(f"Time vectorize without np.argmin : {np.mean(times_3)}")

Time for loops : 0.018464410305023195
Time vectorize : 0.07389827489852906
Time vectorize without np.argmin : 0.024706676006317138
