In [1]:
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax

In [2]:
@partial(jax.jit, static_argnames=("num_edges",))
def naive(pos, n_node, cutoff: float, num_edges: int, active=None):
    num_nodes = pos.shape[0]
    dist = jnp.linalg.norm(pos[None, :, :] - pos[:, None, :], axis=-1)
    i = jnp.repeat(jnp.arange(len(n_node)), n_node, total_repeat_length=num_nodes)
    adj = (0 < dist) & (dist <= cutoff) & (i[None, :] == i[:, None])
    if active is not None:
        adj = adj & jnp.repeat(active, n_node, total_repeat_length=num_nodes)
    senders, receivers = jnp.nonzero(adj, size=num_edges, fill_value=num_nodes)
    n_edge_per_node = jnp.sum(adj, axis=1)
    n_edge = jnp.zeros(len(n_node), jnp.int32).at[i].add(n_edge_per_node)
    return senders, receivers, n_edge


@partial(jax.jit, static_argnames=("num_edges",))
def better(pos, n_node, cutoff: float, num_edges: int, active=None):
    cum_node = jnp.concatenate([jnp.array([0]), jnp.cumsum(n_node)])

    def body_graph(carry, graph_index):
        def body_node_out(i, carry):
            def body_node_in(j, carry):
                count, senders, receivers = carry
                dist = jnp.linalg.norm(pos[i] - pos[j])

                senders = senders.at[count].set(i)
                receivers = receivers.at[count].set(j)
                count = count + (dist <= cutoff)

                senders = senders.at[count].set(j)
                receivers = receivers.at[count].set(i)
                count = count + (dist <= cutoff)

                return count, senders, receivers

            return lax.fori_loop(i + 1, cum_node[graph_index + 1], body_node_in, carry)

        lower, upper = cum_node[graph_index], cum_node[graph_index + 1]
        if active is not None:
            upper = jnp.where(active[graph_index], upper, lower)

        count, senders, receivers = lax.fori_loop(lower, upper, body_node_out, carry)
        return (count, senders, receivers), count - carry[0]

    num_nodes = pos.shape[0]
    senders = num_nodes * jnp.ones(num_edges, jnp.int32)
    receivers = num_nodes * jnp.ones(num_edges, jnp.int32)
    (count, senders, receivers), n_edge = lax.scan(
        body_graph, (0, senders, receivers), jnp.arange(len(n_node))
    )
    senders = senders.at[count].set(num_nodes)
    receivers = receivers.at[count].set(num_nodes)
    return senders, receivers, n_edge

In [3]:
# check naive and better give the same results

pos = jax.random.normal(jax.random.PRNGKey(0), (14, 3))
n_node = jnp.array([3, 4, 3, 4])
num_edges = 50
active = jnp.array([1, 0, 0, 1])
cutoff = 2.5


senders, receivers, n_edge = better(pos, n_node, cutoff, num_edges, active)
i = jnp.argsort(receivers)
senders, receivers = senders[i], receivers[i]
i = jnp.argsort(senders)
senders, receivers = senders[i], receivers[i]


senders2, receivers2, n_edge2 = naive(pos, n_node, cutoff, num_edges, active)

np.testing.assert_array_equal(n_edge, n_edge2)
np.testing.assert_array_equal(receivers, receivers2)
np.testing.assert_array_equal(senders, senders2)

n_edge

Array([4, 0, 0, 2], dtype=int32, weak_type=True)