In [1]:
from collections import namedtuple
from io import StringIO
from typing import Iterator, Tuple

import ase
import ase.io
from ase.atoms import Atoms
import jax
import jax.numpy as jnp
import jraph
import matplotlib.pyplot as plt
import matscipy.neighbours
import numpy as np


import logging
import os
import zipfile
from functools import cache
from typing import List
from urllib.request import urlopen

import e3nn_jax as e3nn

In [2]:
def download_url(url: str, root: str) -> str:
    """Download if file does not exist in root already. Returns path to file."""
    filename = url.rpartition("/")[2]
    file_path = os.path.join(root, filename)

    try:
        from tqdm import tqdm

        progress = True
    except ImportError:
        progress = False

    data = urlopen(url)
    chunk_size = 1024
    total_size = int(data.info()["Content-Length"].strip())

    if os.path.exists(file_path):
        if os.path.getsize(file_path) == total_size:
            logging.info(f"Using downloaded and verified file: {file_path}")
            return file_path

    logging.info(f"Downloading {url} to {file_path}")

    with open(file_path, "wb") as f:
        if progress:
            with tqdm(total=total_size) as pbar:
                while True:
                    chunk = data.read(chunk_size)
                    if not chunk:
                        break
                    f.write(chunk)
                    pbar.update(chunk_size)
        else:
            while True:
                chunk = data.read(chunk_size)
                if not chunk:
                    break
                f.write(chunk)

    return file_path


def extract_zip(path: str, root: str):
    """Extract zip if content does not exist in root already."""
    logging.info(f"Extracting {path} to {root}...")
    with zipfile.ZipFile(path, "r") as f:
        for name in f.namelist():
            if name.endswith("/"):
                logging.info(f"Skip directory {name}")
                continue
            out_path = os.path.join(root, name)
            file_size = f.getinfo(name).file_size
            if os.path.exists(out_path) and os.path.getsize(out_path) == file_size:
                logging.info(f"Skip existing file {name}")
                continue
            logging.info(f"Extracting {name} to {root}...")
            f.extract(name, root)


def read_sdf(f):
    while True:
        name = f.readline()
        if not name:
            break

        f.readline()
        f.readline()

        L1 = f.readline().split()
        try:
            natoms = int(L1[0])
        except IndexError:
            print(L1)
            break

        positions = []
        symbols = []
        for _ in range(natoms):
            line = f.readline()
            x, y, z, symbol = line.split()[:4]
            symbols.append(symbol)
            positions.append([float(x), float(y), float(z)])

        yield Atoms(symbols=symbols, positions=positions)

        while True:
            line = f.readline()
            if line.startswith("$$$$"):
                break


@cache
def load_qm9(root: str) -> List[Atoms]:
    raw_url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip"

    if not os.path.exists(root):
        os.makedirs(root)

    path = download_url(raw_url, root)
    extract_zip(path, root)

    with open(os.path.join(root, "gdb9.sdf")) as f:
        return list(read_sdf(f))

In [28]:
NodesInfo = namedtuple("NodesInfo", ["positions", "atomic_numbers"])


def ase_atoms_to_jraph_graph(atoms: ase.Atoms, cutoff: float) -> jraph.GraphsTuple:
    receivers, senders = matscipy.neighbours.neighbour_list(
        quantities="ij",
        positions=atoms.positions,
        cutoff=cutoff,
        cell=np.eye(3)
    )

    return jraph.GraphsTuple(
        nodes=NodesInfo(atoms.positions, atoms.numbers),
        edges=None,
        globals=None,
        senders=senders,
        receivers=receivers,
        n_node=np.array([len(atoms)]),
        n_edge=np.array([len(senders)]),
    )


In [29]:
molecules = load_qm9("qm9")[:4]

In [5]:
molecules

[Atoms(symbols='CH4', pbc=False),
 Atoms(symbols='NH3', pbc=False),
 Atoms(symbols='OH2', pbc=False),
 Atoms(symbols='C2H2', pbc=False)]

In [30]:
molecules[3].positions

array([[ 0.5995,  0.    ,  1.    ],
       [-0.5995,  0.    ,  1.    ],
       [-1.6616,  0.    ,  1.    ],
       [ 1.6616,  0.    ,  1.    ]])

In [30]:
molecule_graphs = list(map(lambda m: ase_atoms_to_jraph_graph(m, 5.), molecules))

In [34]:
molecule_graphs

[GraphsTuple(nodes=NodesInfo(positions=array([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
        [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
        [-5.4080e-01,  1.4475e+00, -8.7660e-01],
        [-5.2380e-01,  1.4379e+00,  9.0640e-01]]), atomic_numbers=array([6, 1, 1, 1, 1])), edges=None, receivers=array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
       dtype=int32), senders=array([1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3],
       dtype=int32), globals=None, n_node=array([5]), n_edge=array([20])),
 GraphsTuple(nodes=NodesInfo(positions=array([[-0.0404,  1.0241,  0.0626],
        [ 0.0173,  0.0125, -0.0274],
        [ 0.9158,  1.3587, -0.0288],
        [-0.5203,  1.3435, -0.7755]]), atomic_numbers=array([7, 1, 1, 1])), edges=None, receivers=array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=int32), senders=array([1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2], dtype=int32), globals=None, n_node=array([4]), n_edge=ar

In [31]:
batched = jraph.batch(molecule_graphs)

In [36]:
batched

GraphsTuple(nodes=NodesInfo(positions=DeviceArray([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
             [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
             [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
             [-5.4080e-01,  1.4475e+00, -8.7660e-01],
             [-5.2380e-01,  1.4379e+00,  9.0640e-01],
             [-4.0400e-02,  1.0241e+00,  6.2600e-02],
             [ 1.7300e-02,  1.2500e-02, -2.7400e-02],
             [ 9.1580e-01,  1.3587e+00, -2.8800e-02],
             [-5.2030e-01,  1.3435e+00, -7.7550e-01],
             [-3.4400e-02,  9.7750e-01,  7.6000e-03],
             [ 6.4800e-02,  2.0600e-02,  1.5000e-03],
             [ 8.7180e-01,  1.3008e+00,  7.0000e-04],
             [ 5.9950e-01,  0.0000e+00,  1.0000e+00],
             [-5.9950e-01,  0.0000e+00,  1.0000e+00],
             [-1.6616e+00,  0.0000e+00,  1.0000e+00],
             [ 1.6616e+00,  0.0000e+00,  1.0000e+00]], dtype=float32), atomic_numbers=DeviceArray([6, 1, 1, 1, 1, 7, 1, 1, 1, 8, 1, 1, 6, 6, 1, 1], dtyp

In [8]:
centers = e3nn.scatter_sum(batched.nodes.positions, nel=batched.n_node) / batched.n_node[:, jnp.newaxis]

In [9]:
batched.nodes.positions.shape

(16, 3)

In [11]:
jnp.repeat(centers, batched.n_node, axis=0)

DeviceArray([[-0.01267998,  1.0857999 ,  0.00802   ],
             [-0.01267998,  1.0857999 ,  0.00802   ],
             [-0.01267998,  1.0857999 ,  0.00802   ],
             [-0.01267998,  1.0857999 ,  0.00802   ],
             [-0.01267998,  1.0857999 ,  0.00802   ],
             [ 0.0931    ,  0.9347    , -0.192275  ],
             [ 0.0931    ,  0.9347    , -0.192275  ],
             [ 0.0931    ,  0.9347    , -0.192275  ],
             [ 0.0931    ,  0.9347    , -0.192275  ],
             [ 0.30073333,  0.7663    ,  0.00326667],
             [ 0.30073333,  0.7663    ,  0.00326667],
             [ 0.30073333,  0.7663    ,  0.00326667],
             [ 0.        ,  0.        ,  1.        ],
             [ 0.        ,  0.        ,  1.        ],
             [ 0.        ,  0.        ,  1.        ],
             [ 0.        ,  0.        ,  1.        ]], dtype=float32)

In [49]:
new_positions = batched.nodes.positions - jnp.repeat(centers, batched.n_node, axis=0)

In [50]:
e3nn.scatter_sum(new_positions, nel=batched.n_node) / batched.n_node[:, jnp.newaxis]

DeviceArray([[-2.3841858e-08,  4.7683717e-08,  1.1920929e-08],
             [ 0.0000000e+00, -1.4901161e-08,  0.0000000e+00],
             [ 1.9868216e-08, -1.9868216e-08, -1.5522043e-10],
             [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00]],            dtype=float32)

In [52]:
jnp.concatenate([jnp.array([0]), jnp.cumsum(batched.n_node[:-1])])

DeviceArray([ 0,  5,  9, 12], dtype=int32)

In [26]:
GlobalsInfo = namedtuple("GlobalsInfo", ["stop", "target_position", "target_atomic_number"])
mg = list(map(lambda t: t[0]._replace(globals=GlobalsInfo(False, t[1], t[1]*2)), zip(molecule_graphs, range(1, 1+len(molecule_graphs)))))

In [33]:
key = jax.random.PRNGKey(0)
graph_1 = jraph.GraphsTuple(nodes=jax.random.normal(key, (3, 64)),
                  edges=jax.random.normal(key, (5, 64)),
                  senders=jnp.array([0,0,1,1,2]),
                  receivers=jnp.array([1,2,0,2,1]),
                  n_node=jnp.array([3]),
                  n_edge=jnp.array([5]),
                  globals=jax.random.normal(key, (1, 64)))
graph_2 = jraph.GraphsTuple(nodes=jax.random.normal(key, (5, 64)),
                  edges=jax.random.normal(key, (10, 64)),
                  senders=jnp.array([0,0,1,1,2,2,3,3,4,4]),
                  receivers=jnp.array([1,2,0,2,1,0,2,1,3,2]),
                  n_node=jnp.array([5]),
                  n_edge=jnp.array([10]),
                  globals=jax.random.normal(key, (1, 64)))
batch_toy = jraph.batch([graph_1, graph_2])

In [35]:
batch_toy.globals[0]

DeviceArray([ 1.2799144 , -0.39865986, -0.5993886 , -0.7637496 ,
             -0.8983587 , -0.23934689, -1.1704856 , -1.5712125 ,
             -0.15792206,  1.8573927 , -0.04482319, -0.3219256 ,
              0.8962585 , -0.15877262, -0.43754193, -1.5164709 ,
             -1.3206071 ,  0.67375886, -0.12202369, -0.11324466,
              0.24248654,  1.120143  , -0.85038954, -1.4255388 ,
              0.44291094, -0.81818557, -1.6559185 ,  0.72194505,
              0.5261831 ,  0.6698593 ,  0.7937533 ,  0.1336232 ,
              0.27458718,  1.0185868 ,  0.48991668,  1.6080183 ,
             -0.6469497 ,  0.5501491 ,  0.10167892,  0.2767068 ,
             -1.3188282 , -1.5068529 ,  0.8500454 ,  0.5480819 ,
              0.6765775 , -0.16574337, -0.6841306 ,  0.26070496,
              1.038879  , -0.41015548, -1.2598792 ,  0.24818595,
              1.1780113 ,  1.2147156 ,  0.11252337,  0.93331707,
              0.658087  , -0.06883331,  0.3875602 , -0.22836444,
              0.06383018,