In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from time import time
from apply_dilations_shifts import apply_both_delays_3d_cyclic, apply_dilations_shifts_3d_jax

# apply_delay function without for loops or np.argmin

In [2]:
def apply_dilations_shifts_3d_no_argmin(
    S, dilations, shifts, max_dilation=1., max_shift=0., shift_before_dilation=True, n_concat=1,
):
    m, p, n_total = S.shape
    n = n_total // n_concat
    S_4d = jnp.moveaxis(jnp.array(jnp.split(S, n_concat, axis=-1)), source=0, destination=2)
    max_delay_time = (1 + max_shift) * max_dilation - 1
    max_delay_samples = np.ceil(max_delay_time * n).astype("int")
    t_extended = jnp.arange(n+2*max_delay_samples) - max_delay_samples
    t = jnp.arange(n)
    T = jnp.array([t] * m * p * n_concat).reshape(S_4d.shape)
    dilations_newaxis = dilations[:, :, jnp.newaxis, jnp.newaxis]
    shifts_newaxis = shifts[:, :, jnp.newaxis, jnp.newaxis] * n
    if shift_before_dilation:
        T_ds = (T - shifts_newaxis) * dilations_newaxis
    else:
        T_ds = T * dilations_newaxis - shifts_newaxis
    # ind = jnp.rint(T_ds).astype(int) + max_delay_samples
    # ind = jnp.clip(ind, 1, n+2*max_delay_samples-2)
    # T_ref = t_extended[ind]
    T_ref = jnp.clip(jnp.rint(T_ds).astype(int), -max_delay_samples+1, n+max_delay_samples-2)
    ind = jnp.clip(T_ref + max_delay_samples, 1, n+2*max_delay_samples-2)
    T_ds_clipped = jnp.clip(T_ds, -max_delay_samples+1, n+max_delay_samples-2)
    signs = jnp.sign(T_ds_clipped - T_ref).astype(int)
    S_extended = jnp.concatenate([S_4d[:, :, :, n-max_delay_samples:], S_4d, S_4d[:, :, :, :max_delay_samples]], axis=-1)
    S_extended_ind = jnp.take_along_axis(S_extended, ind, axis=-1)
    t_extended_ind = t_extended[ind]
    slopes = (jnp.take_along_axis(S_extended, ind+signs, axis=-1) - S_extended_ind) / (
        t_extended[ind + signs] - t_extended_ind)
    intercepts = S_extended_ind - slopes * t_extended_ind
    return slopes * T_ds + intercepts

We use 3 different apply_delay functions:
- apply_both_delays_3d_cyclic which uses for loops
- apply_dilations_shifts_3d_jax which vectorizes the for loops but needs to call np.argmin
- apply_dilations_shifts_3d_no_argmin which vectorizes the for loops but doesn't need to call np.argmin

# parameters

In [3]:
m = 10
p = 5
n_concat = 3
n = 600
max_dilation = 1.15
max_shift = 0.05
shift_before_dilation = False
random_state = 12
rng = np.random.RandomState(random_state)
functions = [apply_both_delays_3d_cyclic, apply_dilations_shifts_3d_jax, apply_dilations_shifts_3d_no_argmin]
messages = ["For loops : ", "Vectorize : ", "Vectorize without np.argmin : "]

# apply_delay time without jit

In [4]:
def run_expe(
    function, m=m, p=p, n_concat=n_concat, n=n, max_dilation=max_dilation, max_shift=max_shift,
    shift_before_dilation=shift_before_dilation, rng=rng, message=None,
):
    S_list_3d = rng.randn(m, p, n_concat * n)
    dilations = rng.uniform(low=1/max_dilation, high=max_dilation, size=(m, p))
    shifts = rng.uniform(low=-max_shift, high=max_shift, size=(m, p))
    if message is not None:
        print(message)
    start = time()
    _ = function(S_list_3d, dilations, shifts, max_dilation, max_shift, shift_before_dilation, n_concat)
    return time() - start

In [5]:
# first call to functions
for i, function in enumerate(functions):
    print(run_expe(function, message=messages[i]))

For loops : 
0.8774631023406982
Vectorize : 
0.6061549186706543
Vectorize without np.argmin : 
1.1288089752197266


In [6]:
# second call to functions (faster)
for i, function in enumerate(functions):
    print(run_expe(function, message=messages[i]))

For loops : 
0.2554759979248047
Vectorize : 
0.030426025390625
Vectorize without np.argmin : 
0.03172707557678223


# jit

In [7]:
apply_both_delays_3d_cyclic_jit = jax.jit(apply_both_delays_3d_cyclic, static_argnums=(3, 4, 5, 6))
print(run_expe(apply_both_delays_3d_cyclic_jit, message="Jit..."))

Jit...
19.75782084465027


In [8]:
apply_dilations_shifts_3d_jax_jit = jax.jit(apply_dilations_shifts_3d_jax, static_argnums=(3, 4, 5, 6))
print(run_expe(apply_dilations_shifts_3d_jax_jit, message="Jit..."))

Jit...
0.3099629878997803


In [9]:
apply_dilations_shifts_3d_no_argmin_jit = jax.jit(apply_dilations_shifts_3d_no_argmin, static_argnums=(3, 4, 5, 6))
print(run_expe(apply_dilations_shifts_3d_no_argmin_jit, message="Jit..."))

Jit...
0.2597219944000244


# apply_delay time after jit

In [10]:
functions_jit = [apply_both_delays_3d_cyclic_jit, apply_dilations_shifts_3d_jax_jit, apply_dilations_shifts_3d_no_argmin_jit]
for i, function in enumerate(functions_jit):
    print(run_expe(function, message=messages[i]))

For loops : 
0.007430076599121094
Vectorize : 
0.004225015640258789
Vectorize without np.argmin : 
0.003993034362792969
