In [1]:
import sys

import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import jraph
import numpy as np

sys.path.append("..")

import ase.build

from input_pipeline import ase_atoms_to_jraph_graph

In [2]:
# in the following code, graph is a jraph.GraphsTuple
# undirected graph
# nodes = valence electrons
# edges = bonds (0 initially)


def allocate(graph, per_edge):
    """
    For each edge `e`, allocate `per_edge[e]` electrons from each vertex of the edge to the edge itself.
    """
    e = graph.edges
    n = graph.nodes
    s, r = graph.senders, graph.receivers

    # update valence values per edge/node
    e = e + per_edge
    n = n.at[s].add(-per_edge)
    n = n.at[r].add(-per_edge)

    return graph._replace(edges=e, nodes=n)


def allocation_imposed(graph, edge_mask):
    """Computes the number of electrons to distribute per edge on the next allocation."""
    num_nodes = graph.nodes.shape[0]
    e = graph.edges
    n = graph.nodes
    s, r = graph.senders, graph.receivers
    m = edge_mask

    ## Situation 1: Share electrons between nodes that have electrons to share
    # a: how many neighbors have electrons to share
    a = jnp.zeros(num_nodes, dtype=jnp.int32)
    a = a.at[s].add(m & (n[r] > 0))
    a = a.at[r].add(m & (n[s] > 0))

    # a: only one neighbor has electrons to share AND I have electrons to share
    a = (a == 1) & (n > 0)

    # edge_alloc1: when one of the two sides NEED to share
    edge_alloc1 = jnp.where(a[s] | a[r], jnp.minimum(n[s], n[r]), 0)

    ## Situation 2: If edge has no electrons, allocate 1 electron to that edge
    edge_alloc2 = jnp.maximum(0, 1 - e)

    ## Situation 3: If edge_alloc has no electrons being distributed, but some node has electrons to give,
    # choose a conformation in which to distribute electrons.

    return jnp.where(m, jnp.maximum(edge_alloc1, edge_alloc2), 0)


def allocate_while_you_have_to(graph, edge_mask):
    def cond(graph):
        a = allocation_imposed(graph, edge_mask)
        return jnp.any(a > 0)

    def body(graph):
        a = allocation_imposed(graph, edge_mask)
        return allocate(graph, a)

    return jax.lax.while_loop(cond, body, graph)
    # while cond(graph):
    #     graph = body(graph)
    #     print(graph)
    # return graph


@jax.jit
def check_valence(graph, edge_mask=None):
    """
    Checks if the graph satisfies the valence rules.

    Args:
    graph: a jraph.GraphsTuple
        nodes: number of electrons that need to be shared
        senders, receivers: undirected edges representing bonds

    Returns:
    0 = valence OK
    1 = Algorithm not finished
    2 = Impossible to satisfy valence
    """
    if edge_mask is None:
        edge_mask = jnp.ones(len(graph.senders), dtype=jnp.bool_)

    graph = graph._replace(edges=jnp.zeros_like(graph.senders))

    graph = allocate_while_you_have_to(graph, edge_mask)

    edge_mistakes = edge_mask & (graph.edges < 1)
    node_mistakes = graph.nodes < 0
    node_incomplete = graph.nodes > 0

    m = e3nn.scatter_sum(node_mistakes, nel=graph.n_node)
    m += e3nn.scatter_sum(edge_mistakes, nel=graph.n_edge)
    i = e3nn.scatter_sum(node_incomplete, nel=graph.n_node)

    return jnp.where(m == 0, jnp.where(i == 0, 0, 1), 2)

In [3]:
# distance below which two atoms are for sure at least single bonded
bond_dist = {
    (6, 6): 1.50,
    (1, 6): 1.20,
    (8, 1): 1.00,
    # (1, 7): ?,
    # (7, 7): ?,
    # (8, 6): ?,
    # (6, 7): ?,
    # (8, 7): ?,
    # (9, 6): ?,
    # (1, 1): ?,
    # (9, 1): ?,
    # (8, 9): ?,
    # (8, 8): ?,
    # (9, 9): ?,
    # (9, 7): ?,
}

atomic_numbers = jnp.array([1, 6, 7, 8, 9])
valence = jnp.array([1, 4, 3, 2, 1])

bond_dist_matrix = np.zeros((len(atomic_numbers), len(atomic_numbers)))
for (zi, zj), dist in bond_dist.items():
    if np.isin(zi, atomic_numbers) and np.isin(zj, atomic_numbers):
        si = np.searchsorted(atomic_numbers, zi)
        sj = np.searchsorted(atomic_numbers, zj)
        bond_dist_matrix[si, sj] = dist
        bond_dist_matrix[sj, si] = dist

bond_dist_matrix = jnp.asarray(bond_dist_matrix)

mol = ase.build.molecule("C6H6")
nn_cutoff = 5.0
graph = ase_atoms_to_jraph_graph(mol, atomic_numbers, nn_cutoff)
d = jnp.linalg.norm(
    graph.nodes.positions[graph.senders] - graph.nodes.positions[graph.receivers],
    axis=-1,
)
si = graph.nodes.species[graph.senders]
sj = graph.nodes.species[graph.receivers]

out = check_valence(
    graph._replace(nodes=valence[graph.nodes.species]),
    (d < bond_dist_matrix[si, sj]) & (graph.senders < graph.receivers),
)

out <= 1

Array([ True], dtype=bool, weak_type=True)

In [65]:
# CH4 (edges directed the opposite way)

graph = jraph.GraphsTuple(
    nodes=jnp.array([4, 1, 1, 1, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0]),  # bonds (initially empty)
    globals=None,
    senders=jnp.array([1, 0, 3, 4]),
    receivers=jnp.array([0, 2, 0, 0]),
    n_node=jnp.array([5]),
    n_edge=jnp.array([4]),
)
print(graph.nodes)
graph = allocate_while_you_have_to(graph)
print(graph.nodes)
valence(graph)

[4 1 1 1 1]
[0 0 0 0 0]


Array([0], dtype=int32, weak_type=True)

In [64]:
# Undirected graph: H-C-C-C-C-C-C-H
graph = jraph.GraphsTuple(
    nodes=jnp.array([1, 4, 4, 4, 4, 4, 4, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 5, 6]),
    receivers=jnp.array([1, 2, 3, 4, 5, 6, 7]),
    n_node=jnp.array([9]),
    n_edge=jnp.array([7]),
)
print(graph.nodes)
graph = allocate_while_you_have_to(graph)
print(graph.nodes)

[1 4 4 4 4 4 4 1]
[0 0 0 0 0 0 0 0]


In [54]:
# Undirected graph: H-C-C-C-C-C-C-H
graph = jraph.GraphsTuple(
    nodes=jnp.array([1, 4, 4, 4, 4, 4, 4, 1]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 5, 6]),
    receivers=jnp.array([1, 2, 3, 4, 5, 6, 7]),
    n_node=jnp.array([9]),
    n_edge=jnp.array([7]),
)
valence(graph)

Array([0], dtype=int32, weak_type=True)

In [88]:
# Undirected graph: square
graph = jraph.GraphsTuple(
    nodes=jnp.array([3, 3, 3, 3]),  # electrons to be shared
    edges=jnp.array([0, 0, 0, 0]),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3]),
    receivers=jnp.array([1, 2, 3, 0]),
    n_node=jnp.array([4]),
    n_edge=jnp.array([4]),
)
print(graph.nodes, graph.edges)
graph = allocate_while_you_have_to(graph)
print(graph.nodes, graph.edges)
valence(graph)

# nodes: [1 1 1 1]
# edges: [1 1 1 1]
# need to break the cycle somehow (i.e. just transfer electrons between (0, 1) and (2, 3), but not (1, 2) and (3, 0))

[3 3 3 3] [0 0 0 0]
[False False False False]
[False False False False]
[False False False False]
[1 1 1 1] [1 1 1 1]
[False False False False]


Array([1], dtype=int32, weak_type=True)

In [95]:
# manually force a change
nodes = jnp.array([0, 0, 1, 1])
edges = jnp.array([2, 1, 1, 1])
graph = graph._replace(nodes=nodes, edges=edges)
graph = allocate_while_you_have_to(graph)
print(graph.nodes, graph.edges)
valence(graph)

[False False  True  True]
[False False  True  True]
[False False False False]
[0 0 0 0] [2 1 2 1]
[False False False False]


Array([0], dtype=int32, weak_type=True)

In [77]:
# benzene
graph = jraph.GraphsTuple(
    nodes=jnp.array([4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1]),  # electrons to be shared
    edges=jnp.array(
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5]),
    receivers=jnp.array([1, 2, 3, 4, 5, 0, 6, 7, 8, 9, 10, 11]),
    n_node=jnp.array([12]),
    n_edge=jnp.array([12]),
)
print(graph.nodes, graph.edges)
graph = allocate_while_you_have_to(graph)
print(graph.nodes, graph.edges)
valence(graph)

graph.nodes

[4 4 4 4 4 4 1 1 1 1 1 1] [0 0 0 0 0 0 0 0 0 0 0 0]
[False False False False False False  True  True  True  True  True  True]
[False False False False False False  True  True  True  True  True  True]
[False False False False False False False False False False False False]
[1 1 1 1 1 1 0 0 0 0 0 0] [1 1 1 1 1 1 1 1 1 1 1 1]
[False False False False False False False False False False False False]


Array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], dtype=int32)

In [102]:
# furan - this DOES work
graph = jraph.GraphsTuple(
    nodes=jnp.array([4, 4, 4, 4, 2, 1, 1, 1, 1]),  # electrons to be shared
    edges=jnp.array(
        [0, 0, 0, 0, 0, 0, 0, 0, 0]
    ),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 0, 1, 2, 3]),
    receivers=jnp.array([1, 2, 3, 4, 0, 5, 6, 7, 8]),
    n_node=jnp.array([9]),
    n_edge=jnp.array([9]),
)
print(graph.nodes, graph.edges)
graph = allocate_while_you_have_to(graph)
print(graph.nodes, graph.edges)
valence(graph)

graph.nodes

[4 4 4 4 2 1 1 1 1] [0 0 0 0 0 0 0 0 0]
GraphsTuple(nodes=Array([1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=int32), edges=Array([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32), receivers=Array([1, 2, 3, 4, 0, 5, 6, 7, 8], dtype=int32), senders=Array([0, 1, 2, 3, 4, 0, 1, 2, 3], dtype=int32), globals=None, n_node=Array([9], dtype=int32), n_edge=Array([9], dtype=int32))
GraphsTuple(nodes=Array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), edges=Array([2, 1, 2, 1, 1, 1, 1, 1, 1], dtype=int32), receivers=Array([1, 2, 3, 4, 0, 5, 6, 7, 8], dtype=int32), senders=Array([0, 1, 2, 3, 4, 0, 1, 2, 3], dtype=int32), globals=None, n_node=Array([9], dtype=int32), n_edge=Array([9], dtype=int32))
[0 0 0 0 0 0 0 0 0] [2 1 2 1 1 1 1 1 1]


Array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [100]:
# cyclopropane - this also works
graph = jraph.GraphsTuple(
    nodes=jnp.array([4, 4, 4, 1, 1, 1, 1, 1, 1]),  # electrons to be shared
    edges=jnp.array(
        [0, 0, 0, 0, 0, 0, 0, 0, 0]
    ),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 0, 0, 1, 1, 2, 2]),
    receivers=jnp.array(
        [
            1,
            2,
            0,
            3,
            4,
            5,
            6,
            7,
            8,
        ]
    ),
    n_node=jnp.array([9]),
    n_edge=jnp.array([9]),
)
print(graph.nodes, graph.edges)
graph = allocate_while_you_have_to(graph)
print(graph.nodes, graph.edges)
valence(graph)

graph.nodes

[4 4 4 1 1 1 1 1 1] [0 0 0 0 0 0 0 0 0]
[False False False  True  True  True  True  True  True]
[False False False  True  True  True  True  True  True]
[False False False False False False False False False]
[0 0 0 0 0 0 0 0 0] [1 1 1 1 1 1 1 1 1]
[False False False False False False False False False]


Array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [6]:
# napthalene
graph = jraph.GraphsTuple(
    nodes=jnp.array(
        [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1]
    ),  # electrons to be shared
    edges=jnp.array(
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ),  # bonds (no electrons at initial state)
    globals=None,
    senders=jnp.array([0, 1, 2, 3, 4, 5, 0, 6, 7, 8, 9, 1, 2, 3, 4, 6, 7, 8, 9]),
    receivers=jnp.array(
        [1, 2, 3, 4, 5, 0, 6, 7, 8, 9, 5, 10, 11, 12, 13, 14, 15, 16, 17]
    ),
    n_node=jnp.array([19]),
    n_edge=jnp.array([19]),
)
print(graph.nodes, graph.edges)
graph = allocate_while_you_have_to(graph)
print(graph.nodes, graph.edges)
valence(graph)

graph.nodes

[4 4 4 4 4 4 4 4 4 4 1 1 1 1 1 1 1 1] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
GraphsTuple(nodes=Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), edges=Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32), receivers=Array([ 1,  2,  3,  4,  5,  0,  6,  7,  8,  9,  5, 10, 11, 12, 13, 14, 15,
       16, 17], dtype=int32), senders=Array([0, 1, 2, 3, 4, 5, 0, 6, 7, 8, 9, 1, 2, 3, 4, 6, 7, 8, 9], dtype=int32), globals=None, n_node=Array([19], dtype=int32), n_edge=Array([19], dtype=int32))
[1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [3]:
import ase.build

import sys

sys.path.append("..")
from analyses.check_valence import check_valence
import analyses.analysis as analysis

In [11]:
mol = ase.build.molecule("C6H6")
pybel = analysis.construct_pybel_mol(mol)

In [12]:
for atom in pybel.atoms:
    print(atom.OBAtom.GetExplicitDegree(), atom.OBAtom.GetExplicitValence())

3 4
3 4
3 4
3 4
3 4
3 4
1 1
1 1
1 1
1 1
1 1
1 1
