In [2]:
from rdkit import Chem

In [3]:
suppl = Chem.SDMolSupplier("../data/raw/gdb9.sdf")

mol = suppl[0]

print(mol.GetNumAtoms())

for atom in mol.GetAtoms():
    print(atom.GetSymbol(), atom.GetHybridization(), atom.GetDegree())

for bond in mol.GetBonds():
    print(bond.GetBeginAtomIdx(), 
          bond.GetEndAtomIdx(),
          bond.GetBondType())

for i in range(mol.GetNumAtoms()):
    conf = mol.GetConformer()
    print(conf.GetAtomPosition(i))  # 3D coordinates

1
C SP3 0
<rdkit.Geometry.rdGeometry.Point3D object at 0x11dada2c0>


In [38]:
conf = mol.GetConformer()
x = conf.GetAtomPosition(0)
x.x, x.y, x.z

(0.5995, 0.0, 1.0)

In [48]:
x.DirectionVector(conf.GetAtomPosition(1)).x

-1.0

In [59]:
import numpy as np

data = np.load('../data/processed/qm9_dense.npz')

In [54]:
data["nodes"].shape, data["edges"].shape

((131970, 29, 14), (131970, 29, 29, 5))

In [57]:
data['nodes'][0]

array([[0.  , 1.  , 0.  , 0.  , 0.  , 2.55, 0.  , 0.  , 1.  , 0.  , 0.  ,
        1.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 2.2 , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.25, 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 2.2 , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.25, 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 2.2 , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.25, 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 2.2 , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.25, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0. 

In [21]:
from rdkit import Chem
from mgd.dataset.encoding import encode_molecule
mol = Chem.MolFromSmiles("CCO")
encode_molecule(mol, feature_style="flat").keys()

dict_keys(['nodes', 'edges', 'node_mask', 'pair_mask', 'bond_mask'])

In [4]:
%load_ext autoreload
%autoreload 2
import mgd
import jax, numpy as np
from mgd.dataset.dataloader import GraphBatchLoader as Loader
splits = dict(np.load("../data/processed/qm9_splits.npz"))
data = dict(np.load("../data/processed/qm9_dense.npz"))
loader = Loader(data, indices=splits["train"], batch_size=64, key=jax.random.PRNGKey(0))
batch = next(iter(loader))
# batch["nodes"].shape

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
batch.hybrid

Array([[3, 2, 2, ..., 0, 0, 0],
       [3, 2, 2, ..., 0, 0, 0],
       [3, 3, 3, ..., 0, 0, 0],
       ...,
       [3, 3, 3, ..., 0, 0, 0],
       [3, 3, 3, ..., 0, 0, 0],
       [3, 3, 2, ..., 0, 0, 0]], dtype=int32)

In [10]:
import jax
import jax.numpy as jnp
from mgd.model.embeddings import NodeEmbedder, EdgeEmbedder
from mgd.dataset.encoding import ATOM_VOCAB_SIZE, HYBRID_VOCAB_SIZE

model = NodeEmbedder(atom_vocab=ATOM_VOCAB_SIZE, hybrid_vocab=HYBRID_VOCAB_SIZE, atom_dim=8, hybrid_dim=4, cont_dim=8, hidden_dim=32)
variables = model.init(
    jax.random.PRNGKey(0),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29, 4)),
)
out = model.apply(
    variables,
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29, 4)),
)
out.shape

(2, 29, 32)

In [11]:
from mgd.dataset.encoding import BOND_VOCAB_SIZE

model = EdgeEmbedder(edge_vocab=BOND_VOCAB_SIZE, edge_dim=16, hidden_dim=32)
variables = model.init(
    jax.random.PRNGKey(0),
    jnp.zeros((2, 29, 29), dtype=jnp.int32),
)
out = model.apply(
    variables,
    jnp.zeros((2, 29, 29), dtype=jnp.int32),
)
out.shape

(2, 29, 29, 32)

In [18]:
from mgd.model.utils import MLP, aggregate_node_edge

node_mask = jnp.ones((2, 29))
node_mask = node_mask.at[:, 5:].set(0)
bond_mask = jnp.ones((2, 29, 29))
bond_mask = (jnp.arange(29) <= 6).astype("float32")[:, None] * (jnp.arange(29) <= 6).astype("float32")[None, :]
bond_mask = jnp.repeat(bond_mask[None, ...], 2, 0)

print(node_mask.shape, bond_mask.shape)

ni = jax.random.normal(jax.random.PRNGKey(0), (2, 29, 16)) * node_mask[..., None]
nj = jax.random.normal(jax.random.PRNGKey(0), (2, 29, 16)) * node_mask[..., None]
eij = jax.random.normal(jax.random.PRNGKey(0), (2, 29, 29, 8)) * bond_mask[..., None]

(2, 29) (2, 29, 29)


In [20]:
from mgd.model.gnn_layers import MessagePassingLayer

model = MessagePassingLayer(node_dim=ni.shape[-1] , edge_dim=eij.shape[-1], mess_dim=8)
params = model.init(jax.random.PRNGKey(0), ni, eij, node_mask=node_mask, pair_mask=bond_mask)
nodes, edges = model.apply(params, ni, eij, node_mask=node_mask, pair_mask=bond_mask)

nodes[0, :, 0]

Array([ 1.0010208 ,  0.977972  , -0.05762177,  0.84743077,  0.4376264 ,
        0.        , -0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ], dtype=float32)

In [21]:
import jax, numpy as np
from mgd.dataset.dataloader import GraphBatchLoader
splits = dict(np.load("../data/processed/qm9_splits.npz"))
data = dict(np.load("../data/processed/qm9_dense.npz"))
loader = GraphBatchLoader(data, indices=splits["train"], batch_size=64, key=jax.random.PRNGKey(0))
batch = next(iter(loader))
None

In [22]:
from mgd.model.backbone import MPNNBackbone
import jax.numpy as jnp

atom_dim = 8
hybrid_dim = 8
cont_dim = 8

node_dim = 32
edge_dim = 16
mess_dim = 32
time_dim = node_dim

batch_size = batch.atom_type.shape[0]
times = jnp.ones((batch_size,))

model = MPNNBackbone(atom_dim, hybrid_dim, cont_dim, node_dim, edge_dim, mess_dim, time_dim)
params = model.init(jax.random.PRNGKey(0), batch, times)
nodes, edges = model.apply(params, batch, times)
nodes.shape, edges.shape

((64, 29, 32), (64, 29, 29, 16))

In [24]:
from mgd.model.denoiser import MPNNDenoiser
model = MPNNDenoiser(atom_dim, hybrid_dim, cont_dim, node_dim, edge_dim, mess_dim, time_dim)
params = model.init(jax.random.PRNGKey(0), batch, times)
nodes, edges = model.apply(params, batch, times)
nodes.shape, edges.shape

((64, 29, 32), (64, 29, 29, 16))