In [1]:
import jraph
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [19]:
# Undirected graph: CH4
ch3 = jraph.GraphsTuple(
    nodes=jnp.array([4, 1, 1, 1, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0]),  # bonds (initially empty)
    globals=None,
    senders=jnp.array([0, 0, 0, 0]),
    receivers=jnp.array([1, 2, 3, 4]),
    n_node=jnp.array([5]),
    n_edge=jnp.array([4]),
)

In [2]:
def allocate(graph, per_edge):
    """
    For each edge `e`, allocate `per_edge[e]` electrons from each vertex of the edge to the edge itself.
    """
    valence_edges = graph.edges
    valence_nodes = graph.nodes
    valence_senders = graph.senders
    valence_receivers = graph.receivers

    # update valence values per edge/node
    valence_edges = valence_edges + per_edge
    valence_nodes = valence_nodes.at[valence_senders].add(-per_edge)
    valence_nodes = valence_nodes.at[valence_receivers].add(-per_edge)

    return graph._replace(edges=valence_edges, nodes=valence_nodes)

In [21]:
def allocation_imposed(graph):
    """Computes the number of electrons to distribute per edge on the next allocation."""
    num_nodes = graph.nodes.shape[0]
    valence_edges = graph.edges
    valence_nodes = graph.nodes
    valence_senders = graph.senders
    valence_receivers = graph.receivers

    # Situation 1: Share electrons between nodes that have electrons to give
    node_alloc = jnp.zeros(num_nodes, dtype=jnp.int32)
    # for each edge, add 1 to each sender if receiver still has electrons to give
    node_alloc = node_alloc.at[valence_senders].add(valence_nodes[valence_receivers] > 0)
    # for each edge, add 1 to each receiver if sender still has electrons to give
    node_alloc = node_alloc.at[valence_receivers].add(valence_nodes[valence_senders] > 0)
    node_alloc = (node_alloc == 1) & (valence_nodes > 0)
    # for each edge, if sender or receiver has at least 1 electron to give,
    # set edge allocation to the smaller of the two valences
    edge_alloc1 = jnp.where(
        node_alloc[valence_senders] | node_alloc[valence_receivers],
        jnp.minimum(valence_nodes[valence_senders], valence_nodes[valence_receivers]),
        0
    )

    # Situation 2: If edge has no electrons, allocate 1 electron
    edge_alloc2 = jnp.maximum(0, 1 - valence_edges)

    return jnp.maximum(edge_alloc1, edge_alloc2)


def allocate_while_you_have_to(graph):
    def cond(graph):
        a = allocation_imposed(graph)
        return jnp.any(a > 0)

    def body(graph):
        a = allocation_imposed(graph)
        return allocate(graph, a)

    return jax.lax.while_loop(cond, body, graph)

In [4]:
def valence(graph):
    """
    0 = valence OK
    1 = Algorithm not finished
    2 = Impossible to satisfy valence
    """
    graph = allocate_while_you_have_to(graph)

    edge_mistakes = graph.edges < 1
    node_mistakes = graph.nodes < 0
    node_incomplete = graph.nodes > 0

    m = e3nn.scatter_sum(node_mistakes, nel=graph.n_node)
    m += e3nn.scatter_sum(edge_mistakes, nel=graph.n_edge)
    i = e3nn.scatter_sum(node_incomplete, nel=graph.n_node)

    return jnp.where(m == 0, jnp.where(i == 0, 0, 1), 2)

In [15]:
# CH4 (edges directed the opposite way)

graph = jraph.GraphsTuple(
    nodes=jnp.array([4, 1, 1, 1, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0]),  # bonds (initially empty)
    globals=None,
    senders=jnp.array([1, 0, 3, 4]),
    receivers=jnp.array([0, 2, 0, 0]),
    n_node=jnp.array([5]),
    n_edge=jnp.array([4]),
)
print(graph.nodes)
graph = allocate_while_you_have_to(graph)
print(graph.nodes)
valence(graph)

[4 1 1 1 1]
[0 0 0 0 0]


Array([0], dtype=int32, weak_type=True)

In [16]:
# Undirected graph: H-C-C-C-C-C-C-H
graph = jraph.GraphsTuple(
    nodes=jnp.array([1, 4, 4, 4, 4, 4, 4, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 5, 6]),
    receivers=jnp.array([1, 2, 3, 4, 5, 6, 7]),
    n_node=jnp.array([9]),
    n_edge=jnp.array([7]),
)
print(graph.nodes)
graph = allocate_while_you_have_to(graph)
print(graph.nodes)

[1 4 4 4 4 4 4 1]
[0 0 0 0 0 0 0 0]


In [17]:
# Undirected graph: H-C-C-C-C-C-C-H
graph = jraph.GraphsTuple(
    nodes=jnp.array([1, 4, 4, 4, 4, 4, 4, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 5, 6]),
    receivers=jnp.array([1, 2, 3, 4, 5, 6, 7]),
    n_node=jnp.array([9]),
    n_edge=jnp.array([7]),
)
valence(graph)

Array([0], dtype=int32, weak_type=True)

In [18]:
valence(ch3)

NameError: name 'ch3' is not defined

In [20]:
# Undirected graph: square
graph = jraph.GraphsTuple(
    nodes=jnp.array([3, 3, 3, 3]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3]),
    receivers=jnp.array([1, 2, 3, 0]),
    n_node=jnp.array([4]),
    n_edge=jnp.array([4]),
)
print(graph.nodes, graph.edges)
graph = allocate_while_you_have_to(graph)
print(graph.nodes, graph.edges)
valence(graph)

[3 3 3 3] [0 0 0 0]
[1 1 1 1] [1 1 1 1]


Array([1], dtype=int32, weak_type=True)