In [1]:
import jax
import jax.numpy as jnp
import jaxkdtree

In [2]:
# Generating some random coordinates
pos = jax.random.normal(jax.random.PRNGKey(0), (32, 1000, 3))

In [3]:
# Finding k=8 nearest neighbors, no batch
res_single_0 = jaxkdtree.kNN(pos[0], k=8, max_radius=1.0)
res_single_1 = jaxkdtree.kNN(pos[1], k=8, max_radius=1.0)

In [4]:
# Finding k=8 nearest neighbors, batched
res_batch = jax.vmap(jaxkdtree.kNN, in_axes=(0,None,None))(pos, 8, 1.0)

In [5]:
# Check that batched and unbatched give same result
jnp.allclose(res_batch[0], res_single_0), jnp.allclose(res_batch[1], res_single_1)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [6]:
# Make sure jit works
jax.jit(jax.vmap(jaxkdtree.kNN, in_axes=(0,None,None)), static_argnums=(1,2,))(pos, 8, 1.0)

Array([[[  0, 396,  34, ..., 791, 142, 571],
        [  1, 680, 569, ..., 284, 134, 325],
        [  2, 899, 789, ..., 927, 394, 393],
        ...,
        [997, 478,   6, ..., 506, 498, 998],
        [998, 498, 123, ..., 505,  61, 501],
        [999, 894, 222, ..., 444,  61, 248]],

       [[  0, 850, 313, ..., 211, 847, 416],
        [  1, 160, 562, ..., 674, 644, 321],
        [  2, 798, 453, ..., 794, 198, 826],
        ...,
        [997, 498, 942, ..., 248, 125, 503],
        [998, 498, 505, ..., 996, 123, 254],
        [999,  61, 499, ..., 889, 124,  54]],

       [[  0, 793, 447, ..., 341,  77, 574],
        [  1, 739,  84, ..., 904, 451,  77],
        [  2, 829, 817, ..., 115, 102, 825],
        ...,
        [997, 248, 498, ..., 893, 505,  61],
        [998,  30, 509, ..., 498, 503, 253],
        [999, 994, 890, ..., 992, 124, 973]],

       ...,

       [[  0, 961, 480, ..., 479, 370, 377],
        [  1, 270, 542, ..., 679, 569, 680],
        [  2, 818, 408, ..., 815, 932, 928