In [2]:
from rdkit import Chem

In [3]:
suppl = Chem.SDMolSupplier("../data/raw/gdb9.sdf")

mol = suppl[0]

print(mol.GetNumAtoms())

for atom in mol.GetAtoms():
    print(atom.GetSymbol(), atom.GetHybridization(), atom.GetDegree())

for bond in mol.GetBonds():
    print(bond.GetBeginAtomIdx(), 
          bond.GetEndAtomIdx(),
          bond.GetBondType())

for i in range(mol.GetNumAtoms()):
    conf = mol.GetConformer()
    print(conf.GetAtomPosition(i))  # 3D coordinates

1
C SP3 0
<rdkit.Geometry.rdGeometry.Point3D object at 0x11dada2c0>


In [38]:
conf = mol.GetConformer()
x = conf.GetAtomPosition(0)
x.x, x.y, x.z

(0.5995, 0.0, 1.0)

In [48]:
x.DirectionVector(conf.GetAtomPosition(1)).x

-1.0

In [59]:
import numpy as np

data = np.load('../data/processed/qm9_dense.npz')

In [54]:
data["nodes"].shape, data["edges"].shape

((131970, 29, 14), (131970, 29, 29, 5))

In [57]:
data['nodes'][0]

array([[0.  , 1.  , 0.  , 0.  , 0.  , 2.55, 0.  , 0.  , 1.  , 0.  , 0.  ,
        1.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 2.2 , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.25, 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 2.2 , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.25, 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 2.2 , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.25, 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 2.2 , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.25, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0. 

In [21]:
from rdkit import Chem
from mgd.dataset.encoding import encode_molecule
mol = Chem.MolFromSmiles("CCO")
encode_molecule(mol, feature_style="flat").keys()

dict_keys(['nodes', 'edges', 'node_mask', 'pair_mask', 'bond_mask'])

In [12]:
%load_ext autoreload
%autoreload 2
import mgd
import jax, numpy as np
from mgd.dataset.dataloader import QM9Loader
splits = dict(np.load("../data/processed/qm9_splits.npz"))
data = dict(np.load("../data/processed/qm9_dense.npz"))
loader = QM9Loader(data, indices=splits["train"], batch_size=64, key=jax.random.PRNGKey(0))
batch = next(iter(loader))
# batch["nodes"].shape

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
batch['hybrid_ids'].shape

(64, 29)

In [11]:
batch.keys()

dict_keys(['node_mask', 'pair_mask', 'bond_mask', 'atom_one_hot', 'hybrid_one_hot', 'node_continuous', 'edge_one_hot'])

In [None]:
import jax
import jax.numpy as jnp
from mgd.model.embeddings import NodeEmbedder, EdgeEmbedder
from mgd.dataset.encoding import ATOM_VOCAB_SIZE, HYBRID_VOCAB_SIZE

model = NodeEmbedder(atom_vocab=ATOM_VOCAB_SIZE, hybrid_vocab=HYBRID_VOCAB_SIZE, atom_dim=8, hybrid_dim=4, cont_dim=8, hidden_dim=32)
variables = model.init(
    jax.random.PRNGKey(0),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29, 4)),
)
out = model.apply(
    variables,
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29, 4)),
)
out.shape

(2, 29, 32)

In [24]:
from mgd.dataset.encoding import BOND_VOCAB_SIZE

model = EdgeEmbedder(edge_vocab=BOND_VOCAB_SIZE, edge_dim=16, hidden_dim=32)
variables = model.init(
    jax.random.PRNGKey(0),
    jnp.zeros((2, 29, 29), dtype=jnp.int32),
)
out = model.apply(
    variables,
    jnp.zeros((2, 29, 29), dtype=jnp.int32),
)
out.shape

(2, 29, 29, 32)