In [1]:
import jax
import jax.numpy as jnp
import jaxkdtree

In [2]:
# Generating some random coordinates
pos = jax.random.normal(jax.random.PRNGKey(0), (12, 1000, 3))

## Batching

Check that batching works.

In [3]:
# Finding k=8 nearest neighbors, no batch
res_single_0 = jaxkdtree.kNN(pos[0], k=8, max_radius=100.0)
res_single_1 = jaxkdtree.kNN(pos[1], k=8, max_radius=100.0)

In [4]:
# Finding k=8 nearest neighbors, batched
res_batch = jax.vmap(jaxkdtree.kNN, in_axes=(0,None,None))(pos, 8, 100.0)

In [5]:
# Check that batched and unbatched give same result
jnp.allclose(res_batch[0], res_single_0), jnp.allclose(res_batch[1], res_single_1)

(Array(True, dtype=bool), Array(True, dtype=bool))

## JIT

Check that jit works.

In [6]:
# Make sure jit works
jax.jit(jax.vmap(jaxkdtree.kNN, in_axes=(0,None,None)), static_argnums=(1,2,))(pos, 8, 100.0);

## Compare to pairwise-distance calculation

Make sure we get the same answer with pairwise distances.

In [7]:
def pairwise_distances(point_cloud):
    """ Compute pairwise distances between points in a point cloud"""
    dr = point_cloud[:, None, :] - point_cloud[None, :, :]
    return jnp.sum(dr**2, axis=-1)

# Pairwise distances and sorted indices
distance_matrices = jax.vmap(pairwise_distances)(pos)
dist_results_indices = jnp.argsort(distance_matrices, axis=-1)[..., :8]

jnp.allclose(res_batch, dist_results_indices)

Array(True, dtype=bool)

## Different $k$

In [8]:
# Finding k=50 nearest neighbors, batched
res_batch = jax.vmap(jaxkdtree.kNN, in_axes=(0,None,None))(pos, 50, 100.0)

# Using pairwise-distance matrix
dist_results_indices = jnp.argsort(distance_matrices, axis=-1)[..., :50]

jnp.allclose(res_batch, dist_results_indices)

Array(True, dtype=bool)