In [8]:
from rdkit import Chem

In [9]:
suppl = Chem.SDMolSupplier("../data/raw/gdb9.sdf")

mol = suppl[0]

print(mol.GetNumAtoms())

for atom in mol.GetAtoms():
    print(atom.GetSymbol(), atom.GetHybridization(), atom.GetDegree())

for bond in mol.GetBonds():
    print(bond.GetBeginAtomIdx(), 
          bond.GetEndAtomIdx(),
          bond.GetBondType())

for i in range(mol.GetNumAtoms()):
    conf = mol.GetConformer()
    print(conf.GetAtomPosition(i))  # 3D coordinates

1
C SP3 0
<rdkit.Geometry.rdGeometry.Point3D object at 0x16830bac0>


In [10]:
conf = mol.GetConformer()
x = conf.GetAtomPosition(0)
x.x, x.y, x.z

(-0.0127, 1.0858, 0.008)

In [11]:
x.DirectionVector(conf.GetAtomPosition(1)).x

[20:48:56] 

****
Range Error
atomId
Violation occurred on line 35 in file /Users/runner/work/rdkit-pypi/rdkit-pypi/build/temp.macosx-11.0-arm64-cpython-312/rdkit/Code/GraphMol/Conformer.cpp
Failed Expression: 1 < 1
****



RuntimeError: Range Error
	atomId
	Violation occurred on line 35 in file Code/GraphMol/Conformer.cpp
	Failed Expression: 1 < 1
	RDKIT: 2025.09.3
	BOOST: 1_85


In [13]:
import numpy as np

data = np.load('../data/processed/qm9_dense.npz')

In [17]:
data["node_mask"].shape

(131970, 29)

In [18]:
from rdkit import Chem
from mgd.dataset.encoding import encode_molecule
mol = Chem.MolFromSmiles("CCO")
encode_molecule(mol, feature_style="flat").keys()

dict_keys(['nodes', 'edges', 'node_mask', 'pair_mask'])

In [56]:
%load_ext autoreload
%autoreload 2
import mgd
import jax, numpy as np
from mgd.dataset.dataloader import GraphBatchLoader as Loader
splits = dict(np.load("../data/processed/qm9_splits.npz"))
data = dict(np.load("../data/processed/qm9_dense.npz"))
loader = Loader(data, indices=splits["train"], batch_size=64, key=jax.random.PRNGKey(0))
batch = next(iter(loader))
# batch["nodes"].shape

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [57]:
batch.hybrid

Array([[3, 2, 2, ..., 0, 0, 0],
       [3, 2, 2, ..., 0, 0, 0],
       [3, 3, 3, ..., 0, 0, 0],
       ...,
       [3, 3, 3, ..., 0, 0, 0],
       [3, 3, 3, ..., 0, 0, 0],
       [3, 3, 2, ..., 0, 0, 0]], dtype=int32)

In [58]:
import jax
import jax.numpy as jnp
from mgd.model.embeddings import NodeEmbedder, EdgeEmbedder, GraphEmbedder
from mgd.dataset.encoding import ATOM_VOCAB_SIZE, HYBRID_VOCAB_SIZE

model = NodeEmbedder(atom_vocab=ATOM_VOCAB_SIZE, hybrid_vocab=HYBRID_VOCAB_SIZE, atom_embed_dim=8, hybrid_embed_dim=4, cont_embed_dim=8, hidden_dim=32)
variables = model.init(
    jax.random.PRNGKey(0),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29, 4)),
)
out = model.apply(
    variables,
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29), dtype=jnp.int32),
    jnp.zeros((2, 29, 4)),
)
out.shape

(2, 29, 32)

In [59]:
from mgd.dataset.encoding import BOND_VOCAB_SIZE

model = EdgeEmbedder(edge_vocab=BOND_VOCAB_SIZE, edge_embed_dim=16, hidden_dim=32)
variables = model.init(
    jax.random.PRNGKey(0),
    jnp.zeros((2, 29, 29), dtype=jnp.int32),
)
out = model.apply(
    variables,
    jnp.zeros((2, 29, 29), dtype=jnp.int32),
)
out.shape

(2, 29, 29, 32)

In [60]:
from mgd.model.embeddings import GraphEmbedder

model = GraphEmbedder(atom_embed_dim=8, hybrid_embed_dim=4, cont_embed_dim=8, node_hidden_dim=32, edge_embed_dim=16, edge_hidden_dim=16,
    atom_vocab_dim=ATOM_VOCAB_SIZE, hybrid_vocab_dim=HYBRID_VOCAB_SIZE, edge_vocab_dim=BOND_VOCAB_SIZE)
variables = model.init(
    jax.random.PRNGKey(0),
    batch,
)
out = model.apply(
    variables,
    batch,
)
out.node.shape, out.edge.shape

((64, 29, 32), (64, 29, 29, 16))

In [61]:
from mgd.model.utils import MLP, aggregate_node_edge

node_mask = jnp.ones((2, 29))
node_mask = node_mask.at[:, 5:].set(0)
bond_mask = jnp.ones((2, 29, 29))
bond_mask = (jnp.arange(29) <= 6).astype("float32")[:, None] * (jnp.arange(29) <= 6).astype("float32")[None, :]
bond_mask = jnp.repeat(bond_mask[None, ...], 2, 0)

print(node_mask.shape, bond_mask.shape)

ni = jax.random.normal(jax.random.PRNGKey(0), (2, 29, 16)) * node_mask[..., None]
nj = jax.random.normal(jax.random.PRNGKey(0), (2, 29, 16)) * node_mask[..., None]
eij = jax.random.normal(jax.random.PRNGKey(0), (2, 29, 29, 8)) * bond_mask[..., None]

(2, 29) (2, 29, 29)


In [62]:
from mgd.model.gnn_layers import MessagePassingLayer

model = MessagePassingLayer(node_dim=ni.shape[-1] , edge_dim=eij.shape[-1], mess_dim=8)
params = model.init(jax.random.PRNGKey(0), ni, eij, node_mask=node_mask, pair_mask=bond_mask)
nodes, edges = model.apply(params, ni, eij, node_mask=node_mask, pair_mask=bond_mask)

nodes[0, :, 0]

Array([ 1.0010208 ,  0.977972  , -0.05762177,  0.84743077,  0.4376264 ,
        0.        , -0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ], dtype=float32)

In [63]:
import jax, numpy as np
from mgd.dataset.dataloader import GraphBatchLoader
splits = dict(np.load("../data/processed/qm9_splits.npz"))
data = dict(np.load("../data/processed/qm9_dense.npz"))
loader = GraphBatchLoader(data, indices=splits["train"], batch_size=64, key=jax.random.PRNGKey(0))
batch = next(iter(loader))
None

In [67]:
from mgd.model.backbone import MPNNBackbone
import jax.numpy as jnp

node_dim = 32
edge_dim = 16
mess_dim = 32
time_dim = node_dim

batch_size = batch.atom_type.shape[0]
times = jnp.ones((batch_size,))

batch_size = batch.node_mask.shape[0]

nodes = jax.random.normal(jax.random.PRNGKey(0), (batch_size, 29, node_dim))
edges = jax.random.normal(jax.random.PRNGKey(0), (batch_size, 29, 29, edge_dim))

model = MPNNBackbone(node_dim, edge_dim, mess_dim, time_dim)
params = model.init(jax.random.PRNGKey(0), nodes, edges, times, node_mask=batch.node_mask, pair_mask=batch.pair_mask)
nodes, edges = model.apply(params, nodes, edges, times, node_mask=batch.node_mask, pair_mask=batch.pair_mask)
nodes.shape, edges.shape

((64, 29, 32), (64, 29, 29, 16))

In [74]:
atom_dim = 16
hybrid_dim = 8
cont_dim = 16

model = GraphEmbedder(atom_dim, hybrid_dim, cont_dim, node_dim, edge_dim, mess_dim, time_dim)
params = model.init(jax.random.PRNGKey(0), batch)
x = model.apply(params, batch)
x.node.shape, x.edge.shape

((64, 29, 32), (64, 29, 29, 32))

In [9]:
from mgd.model.denoiser import MPNNDenoiser
model = MPNNDenoiser(atom_dim, hybrid_dim, cont_dim, node_dim, edge_dim, mess_dim, time_dim)
params = model.init(jax.random.PRNGKey(0), batch, times)
nodes, edges = model.apply(params, batch, times)
nodes.shape, edges.shape

((64, 29, 32), (64, 29, 29, 16))

In [75]:
from mgd.model.utils import GraphLatent

gl = GraphLatent(nodes, edges)
gl

GraphLatent(node=Array([[[ 1.5143207 ,  1.6690284 , -0.44365567, ...,  0.15939057,
         -1.0356276 , -0.5286144 ],
        [-0.38800678, -0.29652295, -0.35470158, ...,  0.42621624,
         -1.4029492 , -0.06212837],
        [ 0.23104167, -1.3397987 , -0.611107  , ...,  0.44700438,
          1.1223726 , -1.5242276 ],
        ...,
        [ 1.413042  , -0.2801561 ,  2.5108225 , ..., -0.0928999 ,
          1.9829464 , -0.23259743],
        [ 0.89922374,  1.153322  , -0.23241627, ..., -2.1279385 ,
          1.1735545 , -0.06079815],
        [-2.008366  ,  1.1100187 , -1.4253883 , ..., -0.58610904,
         -0.6570578 , -0.6969856 ]],

       [[-0.08230473, -0.35031515,  0.97767913, ...,  1.3962959 ,
         -1.3213358 , -0.5057471 ],
        [-0.6929476 , -1.7806402 ,  1.8180368 , ...,  1.9349356 ,
          0.13841248, -1.3011549 ],
        [-0.37808174, -1.9823798 , -0.2231345 , ...,  2.5127335 ,
          1.5951207 , -2.6224718 ],
        ...,
        [ 0.6647554 ,  0.66204375,  0

In [80]:
from mgd.model.diffusion_model import GraphDiffusionModel
from mgd.model.denoiser import MPNNDenoiser
from mgd.diffusion.schedules import cosine_beta_schedule
from mgd.dataset.dataloader import GraphBatchLoader
import jax, numpy as np

atom_dim = 32
hybrid_dim = 16
cont_dim = 16

node_dim = 32
edge_dim = 16
mess_dim = 32
time_dim = 32

batch_size = 64

n_timesteps = 1000

key = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(key, 3)

splits = dict(np.load("../data/processed/qm9_splits.npz"))
data = dict(np.load("../data/processed/qm9_dense.npz"))
loader = GraphBatchLoader(data, indices=splits["train"], batch_size=batch_size, key=k1)
batch = next(iter(loader))

times = jax.random.randint(k2, (batch_size,), 0, n_timesteps+1)

embedder = GraphEmbedder(atom_dim, hybrid_dim, cont_dim, node_dim, edge_dim, mess_dim, time_dim)
denoiser = MPNNDenoiser(atom_dim, hybrid_dim, cont_dim, node_dim, edge_dim, mess_dim, time_dim)
schedule = cosine_beta_schedule(n_timesteps)
model = GraphDiffusionModel(embedder, denoiser, schedule)

params = model.init(k3, batch, times)

NameError: name 'latent_from_scalar' is not defined

In [76]:
from mgd.diffusion.schedules import cosine_beta_schedule

cosine_beta_schedule(1000)

DiffusionSchedule(betas=Array([4.12464142e-05, 4.61339951e-05, 5.10215759e-05, 5.57899475e-05,
       6.07967377e-05, 6.55651093e-05, 7.03334808e-05, 7.52806664e-05,
       8.01682472e-05, 8.50558281e-05, 8.98241997e-05, 9.47713852e-05,
       9.95397568e-05, 1.04546547e-04, 1.09255314e-04, 1.14142895e-04,
       1.19030476e-04, 1.23977661e-04, 1.28686428e-04, 1.33693218e-04,
       1.38521194e-04, 1.43289566e-04, 1.48355961e-04, 1.53005123e-04,
       1.58011913e-04, 1.62839890e-04, 1.67667866e-04, 1.72615051e-04,
       1.77443027e-04, 1.82330608e-04, 1.87158585e-04, 1.92165375e-04,
       1.96874142e-04, 2.01761723e-04, 2.06768513e-04, 2.11417675e-04,
       2.16484070e-04, 2.21312046e-04, 2.26259232e-04, 2.31146812e-04,
       2.36034393e-04, 2.40743160e-04, 2.45809555e-04, 2.50637531e-04,
       2.55405903e-04, 2.60412693e-04, 2.65359879e-04, 2.70128250e-04,
       2.75075436e-04, 2.80022621e-04, 2.84910202e-04, 2.89738178e-04,
       2.94625759e-04, 2.99632549e-04, 3.04460526e-04