In [1]:
import jax
import jax.numpy as jnp
from jax import jit
from jax import random

In [2]:
def f(x):
    return x[x<0]

In [11]:
key = random.PRNGKey(0)
jit_f = jit(f)
jit_f(random.normal(key, shape=(1000,), dtype=jnp.float32))

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[1000])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

In [14]:
from functools import partial
import jax.numpy as jnp
from jax import random, jit

@partial(jit, static_argnames=('num_objects'))
def multi_object_model(T: int, num_objects: int, num_side_bins: int):
  num_dims = 2

  k = random.PRNGKey(0)
  k, sk = random.split(k)
  x = random.uniform(sk, shape=(num_dims, num_objects))

  return x

In [29]:
%%time
multi_object_model(3,3,2)

CPU times: user 816 µs, sys: 733 µs, total: 1.55 ms
Wall time: 1.12 ms


DeviceArray([[0.07239354, 0.02032685, 0.07718182],
             [0.87867916, 0.16457272, 0.1017133 ]], dtype=float32)