In [1]:
import jax
import jax.numpy as jnp
from flax import struct

In [2]:
@struct.dataclass
class ActivationState:
    values: jnp.ndarray
    toggled: jnp.ndarray
    activation_counts: jnp.ndarray

@struct.dataclass
class Network:
    node_indices: jnp.ndarray
    node_types: jnp.ndarray
    edges: jnp.ndarray
    senders: jnp.ndarray
    receivers: jnp.ndarray

    @property
    def n_nodes(self) -> int:
        return len(self.nodes.index)

    @property
    def input_nodes(self) -> int:
        return jnp.where(self.nodes.type == 1)[0]

    @property
    def n_inputs(self) -> int:
        return len(self.input_nodes)

    def get_required_activations(self, max_nodes: int) -> jnp.ndarray:
        """
        Returns an array of size max_nodes where the ``n-th`` element is the number of 
        required activations for the node ``n`` to fire. 
        """
        return jnp.bincount(self.receivers[self.receivers >= 0], minlength=max_nodes)


def init(
    senders: jnp.ndarray,
    receivers: jnp.ndarray,
    inputs: jnp.ndarray,
    max_nodes: int = 100,
) -> tuple[ActivationState, Network]:

    senders = (
        (jnp.ones(max_nodes, dtype=jnp.int32) * -1).at[: len(senders)].set(senders)
    )
    receivers = (
        (jnp.ones(max_nodes, dtype=jnp.int32) * -1).at[: len(receivers)].set(receivers)
    )

    activations = jnp.zeros(max_nodes).at[: len(inputs)].set(inputs)
    activated_nodes = jnp.int32(activations > 0)
    activation_counts = jnp.zeros(max_nodes, dtype=jnp.int32)

    return (
        ActivationState(
            values=activations,
            toggled=activated_nodes,
            activation_counts=activation_counts,
        ),
        Network(
            node_indices=jnp.arange(max_nodes, dtype=jnp.int32),
            node_types=jnp.zeros(max_nodes, dtype=jnp.int32),
            edges=jnp.ones(max_nodes),
            senders=senders,
            receivers=receivers,
        ),
    )


def pad(
    array: jnp.ndarray, dtype=jnp.float32, max_nodes: int = 100, fill_value: int = 0
) -> jnp.ndarray:
    return (jnp.ones(max_nodes, dtype=dtype) * fill_value).at[: len(array)].set(array)

In [3]:
max_nodes = 100
initial_senders = jnp.array([0, 1, 2, 4])
initial_receivers = jnp.array([4, 4, 3, 3])
inputs = jnp.array([0.5, 0.8, 0.2])

activation_state, net = init(initial_senders, initial_receivers, inputs, max_nodes)
print(jax.tree_map(lambda x: x.shape, activation_state))
print(jax.tree_map(lambda x: x.shape, net))

ActivationState(values=(100,), toggled=(100,), activation_counts=(100,))
Network(node_indices=(100,), node_types=(100,), edges=(100,), senders=(100,), receivers=(100,))


In [4]:
def get_active_connections(
    activation_state: ActivationState, network: Network
) -> tuple[jnp.ndarray, jnp.ndarray]:
    active_senders_indices = jnp.where(activation_state.toggled[network.senders] > 0)[0]

    active_senders = jnp.take(network.senders, active_senders_indices, axis=0)
    active_receivers = jnp.take(network.receivers, active_senders_indices, axis=0)

    return pad(active_senders,dtype=jnp.int32, fill_value=-1), pad(active_receivers,dtype=jnp.int32, fill_value=-1)


senders, receivers = get_active_connections(activation_state, net)
senders, receivers

(Array([ 0,  1,  2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],      dtype=int32),
 Array([ 4,  4,  3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],      dtype=int32))

In [5]:
def add_activations(
    senders: jnp.ndarray,
    receivers: jnp.ndarray,
    activation_state: ActivationState,
) -> ActivationState:
    """
    For given sender nodes, iteratively computes the activation
    of receiver nodes while carrying the global activation state.
    """

    def add_single_activation(activation_state: jnp.ndarray, x: tuple) -> jnp.ndarray:
        def _update_activation_state(val: tuple):
            """
            Adds the activation of a sender to a receiver's value and 
            increments its activation count, then deactivates the sender node.

            Note: the deactivation of the sender nodes will only be effective at the 
            end of the iteration (at the next step when computing which nodes should fire).
            """
            activation_state, sender, receiver = val
            values = activation_state.values
            activation_counts = activation_state.activation_counts

            values = values.at[receiver].add(values[sender])
            activation_counts = activation_counts.at[receiver].add(1)
            toggled = activation_state.toggled.at[sender].set(0)
            return (
                activation_state.replace(
                    values=values,
                    activation_counts=activation_counts,
                    toggled=toggled,
                ),
                None,
            )

        def _bypass(val: tuple):
            """ Bypasses the update for a given node. """
            activation_state, _, _ = val
            return (activation_state, None)

        sender, receiver = x

        # nodes with activation -1 are not enabled and should not fire
        return jax.lax.cond(
            sender == -1,
            _bypass,
            _update_activation_state,
            operand=(activation_state, sender, receiver),
        )

    activation_state, _ = jax.lax.scan(
        add_single_activation,
        activation_state,
        jnp.stack((senders, receivers), axis=1),
    )
    return activation_state

activation_state = add_activations(senders, receivers, activation_state)
activation_state

ActivationState(values=Array([0.5, 0.8, 0.2, 0.2, 1.3, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32), toggled=Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), activation_counts=Array([0, 0, 0, 1, 2, 0, 0, 

In [6]:
def toggle_receivers(activation_state: ActivationState, net: Network) -> ActivationState:
    """
    Returns an array of size ``max_neurons`` indicating which nodes have received
    all necessary activations and should fire at the next step.
    """
    activated_nodes = jnp.int32(
        (activation_state.activation_counts > 0)
        & (
            activation_state.activation_counts
            == net.get_required_activations(max_nodes)
        )
    )
    return activation_state.replace(toggled=activated_nodes)

activation_state = toggle_receivers(activation_state, net)
activation_state.toggled

Array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [7]:
senders = jnp.arange(max_nodes) * activation_state.toggled
receivers

Array([ 4,  4,  3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],      dtype=int32)

In [8]:
senders, receivers = get_active_connections(activation_state, net)
add_activations(senders, receivers, activation_state)

ActivationState(values=Array([0.5, 0.8, 0.2, 1.5, 1.3, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32), toggled=Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), activation_counts=Array([0, 0, 0, 2, 2, 0, 0, 

In [9]:
def termination_fn(x: jnp.ndarray) -> bool: 
    return jnp.all(x == 0)

def body_fn(val: tuple):
    activation_state, net = val
    senders, receivers = get_active_connections(activation_state, net)
    activation_state = add_activations(senders, receivers, activation_state)
    activation_state = toggle_receivers(activation_state, net)

    return activation_state, net

jax.lax.while_loop(termination_fn, body_fn, (activation_state, net))

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function body_fn at C:\Users\ryanp\AppData\Local\Temp\ipykernel_35224\3719556624.py:4 for while_loop. This concrete value was not available in Python because it depends on the values of the arguments val[0].toggled and val[1].senders.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError