In [31]:
import jraph
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn

In [32]:
# Undirected graph: CH4
ch3 = jraph.GraphsTuple(
    nodes=jnp.array([4, 1, 1, 1, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0]),  # bonds (initially empty)
    globals=None,
    senders=jnp.array([0, 0, 0, 0]),
    receivers=jnp.array([1, 2, 3, 4]),
    n_node=jnp.array([5]),
    n_edge=jnp.array([4]),
)

In [33]:
def allocate(graph, per_edge):
    e = graph.edges
    n = graph.nodes
    s = graph.senders
    r = graph.receivers
    e = e + per_edge
    n = n.at[s].add(-per_edge)
    n = n.at[r].add(-per_edge)
    return graph._replace(edges=e, nodes=n)

In [34]:
def allocation_imposed(graph):
    num_nodes = graph.nodes.shape[0]
    n = graph.nodes
    e = graph.edges
    s = graph.senders
    r = graph.receivers

    # Situation 1: I can't share to anyone else
    x = jnp.zeros(num_nodes, dtype=jnp.int32)
    x = x.at[s].add(n[r] > 0)
    x = x.at[r].add(n[s] > 0)
    x = (x == 1) & (n > 0)
    a1 = jnp.where(x[s] | x[r], jnp.minimum(n[s], n[r]), 0)

    # Situation 2: We need at least one electron per edge
    a2 = jnp.maximum(0, 1 - e)

    return jnp.maximum(a1, a2)


def allocate_while_you_have_to(graph):
    def cond(graph):
        a = allocation_imposed(graph)
        return jnp.any(a > 0)

    def body(graph):
        a = allocation_imposed(graph)
        return allocate(graph, a)

    return jax.lax.while_loop(cond, body, graph)

In [49]:
def valence(graph):
    """
    0 = valence OK
    1 = Algorithm not finished
    2 = Impossible to satisfy valence
    """
    graph = allocate_while_you_have_to(graph)

    edge_mistakes = graph.edges < 1
    node_mistakse = graph.nodes < 0
    node_incomplete = graph.nodes > 0

    m = e3nn.scatter_sum(node_mistakse, nel=graph.n_node)
    m += e3nn.scatter_sum(edge_mistakes, nel=graph.n_edge)
    i = e3nn.scatter_sum(node_incomplete, nel=graph.n_node)

    return jnp.where(m == 0, jnp.where(i == 0, 0, 1), 2)

In [50]:
# Undirected graph: H-C-C-C-C-C-C-H
graph = jraph.GraphsTuple(
    nodes=jnp.array([1, 4, 4, 4, 4, 4, 4, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 5, 6]),
    receivers=jnp.array([1, 2, 3, 4, 5, 6, 7]),
    n_node=jnp.array([9]),
    n_edge=jnp.array([7]),
)
print(graph.nodes)
graph = allocate_while_you_have_to(graph)
print(graph.nodes)

[1 4 4 4 4 4 4 1]
[0 0 0 0 0 0 0 0]


In [51]:
# Undirected graph: H-C-C-C-C-C-C-H
graph = jraph.GraphsTuple(
    nodes=jnp.array([1, 4, 4, 4, 4, 4, 4, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 5, 6]),
    receivers=jnp.array([1, 2, 3, 4, 5, 6, 7]),
    n_node=jnp.array([9]),
    n_edge=jnp.array([7]),
)
valence(graph)

Array([0], dtype=int32, weak_type=True)

In [52]:
valence(ch3)

Array([0], dtype=int32, weak_type=True)

In [53]:
# Undirected graph: square
graph = jraph.GraphsTuple(
    nodes=jnp.array([3, 3, 3, 3]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3]),
    receivers=jnp.array([1, 2, 3, 0]),
    n_node=jnp.array([4]),
    n_edge=jnp.array([4]),
)
print(graph.nodes, graph.edges)
graph = allocate_while_you_have_to(graph)
print(graph.nodes, graph.edges)
valence(graph)

[3 3 3 3] [0 0 0 0]
[1 1 1 1] [1 1 1 1]


Array([1], dtype=int32, weak_type=True)