In [1]:
from jax import vmap, jit
import pickle
import pandas as pd
import io
import yaml
from functools import partial
import jax.numpy as jnp

In [2]:
def squared_num(a):
    return a**2


def square_and_select(ind, array):
    return array[ind] ** 2


def ex_ante_select(indices, array_to_select):
    selected_array = jnp.take(array_to_select, indices)
    return vmap(squared_num, in_axes=(0))(selected_array)


def select_on_the_go(indices, array_to_select):
    return vmap(square_and_select, in_axes=(0, None))(indices, array_to_select)


def select_dict_with_loop(input_array):
    out = {}
    for i in range(input_array.shape[0]):
        out[i] = jnp.sum(input_array[i, :])
    return out


def append_array(input_array):
    out_array = jnp.array([])
    for i in range(input_array.shape[0]):
        out_array = jnp.append(out_array, jnp.sum(input_array[i, :]))
    return out_array

In [3]:
size_ind = 100_000_000
to_select_test = jnp.array([2, 4])
indices_test = jnp.append(jnp.zeros(size_ind), jnp.ones(size_ind)).astype(int)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
jit_exante = jit(ex_ante_select)
jit_expost = jit(select_on_the_go)

In [5]:
jit_exante(indices_test, to_select_test)
%timeit jit_exante(indices_test, to_select_test)

224 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
jit_expost(indices_test, to_select_test)
%timeit jit_expost(indices_test, to_select_test)

221 ms ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
size_loop = 50
size_sum = 10_000
multi_dim = jnp.empty((size_loop, size_sum))

jit_dict_select = jit(select_dict_with_loop)
jit_append = jit(append_array)

In [8]:
jit_dict_select(multi_dim)
%timeit jit_dict_select(multi_dim)

275 µs ± 4.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
jit_append(multi_dim)
%timeit jit_append(multi_dim)

215 µs ± 7.21 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
jnp.take(multi_dim, jnp.array([10, 20, -99]), axis=0)

Array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)