In [1]:
from pathlib import Path
from data import load_molecules_split, pad_molecule_data, augment_hop_edges_folded_type
import os
import sys
from tqdm import tqdm
from data import MoleculeData, BONDS

In [2]:
N_MAX = 29
# E_MAX = 56
E_MAX = N_MAX * (N_MAX - 1)

In [3]:
absolute_path = Path(os.getcwd()).resolve()
load_dir = absolute_path / "preprocessed_data"
train_data = load_molecules_split(load_dir / 'train', num_mols=2048)
padded_train_data = [
    (pad_molecule_data(augment_hop_edges_folded_type(m, num_bond_types=len(BONDS)), N_MAX, E_MAX), smiles_str, molblock)
    for m, smiles_str, molblock  in tqdm(train_data, desc="Padding molecules")
]


Loading split=train: 100%|██████████| 2048/2048 [00:08<00:00, 239.71it/s]
Padding molecules: 100%|██████████| 2048/2048 [00:34<00:00, 58.77it/s] 


In [4]:
from rdkit import Chem
from rdkit.Chem import AllChem, rdDistGeom
from rdkit.Geometry import Point3D
import py3Dmol
import numpy as np
import jax.numpy as jnp

In [5]:
def viz_molblock(molblock: str, width: int = 400, height: int = 400):
    viewer = py3Dmol.view(width=width, height=height)
    viewer.addModel(molblock, "mol")
    viewer.setStyle({"stick": {}})
    viewer.zoomTo()
    viewer.show()

def viz_from_molblock_and_coords(
    molblock: str,
    coords,  # (N, 3) JAX/NumPy array
    width: int = 400,
    height: int = 400,
):

    # Reconstruct the exact molecule template (preserve explicit hydrogens)
    mol = Chem.MolFromMolBlock(molblock, removeHs=False, sanitize=True)
    if mol is None:
        raise ValueError("MolFromMolBlock failed. MolBlock may be invalid or incompatible.")

    coords = np.asarray(coords, dtype=np.float64)
    n = mol.GetNumAtoms()
    if coords.shape != (n, 3):
        raise ValueError(f"coords shape {coords.shape} does not match mol atoms {(n, 3)}.")

    # Overwrite conformers with your coordinates
    conf = Chem.Conformer(n)
    for i in range(n):
        x, y, z = coords[i]
        conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))

    mol.RemoveAllConformers()
    mol.AddConformer(conf, assignId=True)

    # Convert back to MolBlock for py3Dmol rendering
    new_block = Chem.MolToMolBlock(mol, confId=0)
    viewer = py3Dmol.view(width=width, height=height)
    viewer.addModel(new_block, "mol")
    viewer.setStyle({"stick": {}})
    viewer.zoomTo()
    viewer.show()

In [22]:
import random

rand_idx = random.randint(0, len(train_data) - 1)
# rand_idx = 0
sample_mol, smiles, molblock = train_data[rand_idx]
padded_sample_mol, p_smiles, p_molblock = padded_train_data[rand_idx]
assert smiles == p_smiles
assert molblock == p_molblock
print(smiles)

C[C@@H](C#N)[C@@H](O)[C@H]1CN1


In [23]:
training_n = sample_mol.pos.shape[0]
n_smiles_heavy = Chem.MolFromSmiles(smiles).GetNumAtoms()
n_smiles_withH = Chem.AddHs(Chem.MolFromSmiles(smiles)).GetNumAtoms()

print(training_n, n_smiles_heavy, n_smiles_withH)

19 9 19


In [24]:
viz_molblock(molblock)

In [25]:
true_coords = sample_mol.pos
print(true_coords.shape)

(19, 3)


In [26]:
viz_from_molblock_and_coords(molblock, true_coords)

In [27]:
def trim_conformer(coords_vmax3, node_mask_vmax):
    """
    coords_vmax3: array-like (V_max, 3) (JAX or numpy)
    node_mask_vmax: array-like (V_max,) bool or 0/1
    returns: numpy array (V, 3)
    """
    coords = np.asarray(coords_vmax3, dtype=float)
    mask = np.asarray(node_mask_vmax).astype(bool)
    return coords[mask]

padded_true_sample_coords = trim_conformer(padded_sample_mol.pos, padded_sample_mol.node_mask)

In [28]:
viz_from_molblock_and_coords(p_molblock, padded_true_sample_coords)

In [13]:
from flax import nnx
from model import DistanceScoreModel
import orbax.checkpoint as ocp

checkpoint_dir = absolute_path / 'qm9_diffusion_checkpoint'

ckpt_state = checkpoint_dir / "state"
checkpointer = ocp.StandardCheckpointer()

rngs = nnx.Rngs(0, params=1, sampling=2)
# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.
abstract_model = nnx.eval_shape(lambda: DistanceScoreModel(rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)

state_restored = checkpointer.restore(ckpt_state, abstract_state)
print('NNX State restored')
model = nnx.merge(graphdef, state_restored)



NNX State restored


In [29]:
from langevin_sampling import langevin_sampling, convert_dist_score_to_pos_score

sigmas = jnp.geomspace(0.5, 0.02, num=40)
generated_conformation_padded = langevin_sampling(model, padded_sample_mol, sigmas, rngs, N_MAX)

Sigma levels:   0%|          | 0/40 [00:00<?, ?it/s]

sigma 0: 0.5




N_max: 29
E_max: 812
edge_index shape: (2, 812)
edge_mask shape: (812,)
distances shape: (812,)
dist_scores shape: (812,)
dist_scores stats: min/mean/max = -4.891007900238037 0.15930768847465515 4.91745662689209
t=0 pos_rms=0.445 score_rms=5.133 drift_rms=0.007701 noise_rms=0.05697
t=1 pos_rms=0.4565 score_rms=5.042 drift_rms=0.007564 noise_rms=0.05464
t=2 pos_rms=0.452 score_rms=5.174 drift_rms=0.007762 noise_rms=0.05508


Sigma levels:   2%|▎         | 1/40 [00:07<04:58,  7.64s/it]

sigma 1: 0.4603880047798157


Sigma levels:   5%|▌         | 2/40 [00:15<04:45,  7.51s/it]

sigma 2: 0.42391523718833923


Sigma levels:   8%|▊         | 3/40 [00:22<04:34,  7.43s/it]

sigma 3: 0.3903310298919678


Sigma levels:  10%|█         | 4/40 [00:29<04:28,  7.46s/it]

sigma 4: 0.3594096302986145


Sigma levels:  12%|█▎        | 5/40 [00:37<04:23,  7.53s/it]

sigma 5: 0.3309347331523895


Sigma levels:  15%|█▌        | 6/40 [00:44<04:14,  7.48s/it]

sigma 6: 0.3047189712524414


Sigma levels:  18%|█▊        | 7/40 [00:52<04:08,  7.52s/it]

sigma 7: 0.2805781066417694


Sigma levels:  20%|██        | 8/40 [01:00<04:02,  7.59s/it]

sigma 8: 0.2583498954772949


Sigma levels:  22%|██▎       | 9/40 [01:07<03:56,  7.63s/it]

sigma 9: 0.23788265883922577


Sigma levels:  25%|██▌       | 10/40 [01:15<03:49,  7.66s/it]

sigma 10: 0.2190372198820114


Sigma levels:  28%|██▊       | 11/40 [01:23<03:42,  7.68s/it]

sigma 11: 0.20168429613113403


Sigma levels:  30%|███       | 12/40 [01:31<03:35,  7.69s/it]

sigma 12: 0.18570634722709656


Sigma levels:  32%|███▎      | 13/40 [01:38<03:27,  7.68s/it]

sigma 13: 0.17099426686763763


Sigma levels:  35%|███▌      | 14/40 [01:46<03:19,  7.69s/it]

sigma 14: 0.1574474573135376


Sigma levels:  38%|███▊      | 15/40 [01:54<03:12,  7.70s/it]

sigma 15: 0.14497433602809906


Sigma levels:  40%|████      | 16/40 [02:02<03:06,  7.75s/it]

sigma 16: 0.1334889978170395


Sigma levels:  42%|████▎     | 17/40 [02:09<02:58,  7.75s/it]

sigma 17: 0.12291380763053894


Sigma levels:  45%|████▌     | 18/40 [02:17<02:50,  7.74s/it]

sigma 18: 0.1131763905286789


Sigma levels:  48%|████▊     | 19/40 [02:25<02:42,  7.76s/it]

sigma 19: 0.10421022772789001


Sigma levels:  50%|█████     | 20/40 [02:33<02:35,  7.77s/it]

sigma 20: 0.09595444053411484


Sigma levels:  52%|█████▎    | 21/40 [02:40<02:27,  7.77s/it]

sigma 21: 0.08835271000862122


Sigma levels:  55%|█████▌    | 22/40 [02:48<02:20,  7.82s/it]

sigma 22: 0.08135322481393814


Sigma levels:  57%|█████▊    | 23/40 [02:56<02:12,  7.78s/it]

sigma 23: 0.07490837574005127


Sigma levels:  60%|██████    | 24/40 [03:04<02:03,  7.75s/it]

sigma 24: 0.06897389143705368


Sigma levels:  62%|██████▎   | 25/40 [03:12<01:57,  7.85s/it]

sigma 25: 0.06350962817668915


Sigma levels:  65%|██████▌   | 26/40 [03:19<01:48,  7.77s/it]

sigma 26: 0.058478280901908875


Sigma levels:  68%|██████▊   | 27/40 [03:27<01:40,  7.72s/it]

sigma 27: 0.05384530499577522


Sigma levels:  70%|███████   | 28/40 [03:35<01:32,  7.68s/it]

sigma 28: 0.04957953840494156


Sigma levels:  72%|███████▎  | 29/40 [03:42<01:24,  7.67s/it]

sigma 29: 0.04565192386507988


Sigma levels:  75%|███████▌  | 30/40 [03:50<01:16,  7.68s/it]

sigma 30: 0.042035166174173355


Sigma levels:  78%|███████▊  | 31/40 [03:58<01:09,  7.69s/it]

sigma 31: 0.03870515897870064


Sigma levels:  80%|████████  | 32/40 [04:05<01:01,  7.70s/it]

sigma 32: 0.03563882037997246


Sigma levels:  82%|████████▎ | 33/40 [04:13<00:53,  7.71s/it]

sigma 33: 0.03281533718109131


Sigma levels:  85%|████████▌ | 34/40 [04:21<00:46,  7.72s/it]

sigma 34: 0.030215641483664513


Sigma levels:  88%|████████▊ | 35/40 [04:29<00:38,  7.71s/it]

sigma 35: 0.027821894735097885


Sigma levels:  90%|█████████ | 36/40 [04:36<00:30,  7.72s/it]

sigma 36: 0.025617795065045357


Sigma levels:  92%|█████████▎| 37/40 [04:44<00:23,  7.73s/it]

sigma 37: 0.023588282987475395


Sigma levels:  95%|█████████▌| 38/40 [04:52<00:15,  7.73s/it]

sigma 38: 0.021719571202993393


Sigma levels:  98%|█████████▊| 39/40 [05:00<00:07,  7.74s/it]

sigma 39: 0.01999887451529503


Sigma levels: 100%|██████████| 40/40 [05:07<00:00,  7.69s/it]


In [30]:

viz_from_molblock_and_coords(p_molblock, 
                             trim_conformer(generated_conformation_padded, padded_sample_mol.node_mask))

In [17]:
def mmff_relax_from_coords(
    molblock: str,
    coords,                  # (N, 3) array-like
    max_iters: int = 200,
    nonbonded_thresh: float = 100.0,  # keep default-ish unless you know you need tighter
):
    # Reconstruct molecule template with explicit H preserved
    mol = Chem.MolFromMolBlock(molblock, removeHs=False, sanitize=True)
    if mol is None:
        raise ValueError("MolFromMolBlock failed.")

    coords = np.asarray(coords, dtype=np.float64)
    n = mol.GetNumAtoms()
    if coords.shape != (n, 3):
        raise ValueError(f"coords shape {coords.shape} != {(n, 3)}")

    # Attach coords as conformer 0
    conf = Chem.Conformer(n)
    for i in range(n):
        x, y, z = coords[i]
        conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))

    mol.RemoveAllConformers()
    mol.AddConformer(conf, assignId=True)

    # Build MMFF properties
    props = AllChem.MMFFGetMoleculeProperties(mol, mmffVariant="MMFF94s")
    if props is None:
        # Common for weird/unsupported chemistry; try MMFF94
        props = AllChem.MMFFGetMoleculeProperties(mol, mmffVariant="MMFF94")
    if props is None:
        raise ValueError("MMFF properties could not be constructed for this molecule.")

    # Optional: energy before
    ff = AllChem.MMFFGetMoleculeForceField(mol, props, confId=0, nonBondedThresh=nonbonded_thresh)
    e0 = float(ff.CalcEnergy())

    # Optimize in-place
    status = AllChem.MMFFOptimizeMolecule(
        mol,
        mmffVariant="MMFF94s",
        maxIters=max_iters,
        confId=0,
        nonBondedThresh=nonbonded_thresh,
    )
    # status: 0 converged, 1 not converged, -1 error (varies by RDKit build)
    
    # Energy after
    ff2 = AllChem.MMFFGetMoleculeForceField(mol, props, confId=0, nonBondedThresh=nonbonded_thresh)
    e1 = float(ff2.CalcEnergy())

    # Extract relaxed coords
    conf_opt = mol.GetConformer(0)
    coords_opt = np.array(conf_opt.GetPositions(), dtype=np.float64)

    return {
        "mol": mol,
        "coords_opt": coords_opt,
        "status": int(status),
        "energy_before": e0,
        "energy_after": e1,
    }


In [31]:
trimmed_generated_coords = trim_conformer(generated_conformation_padded, padded_sample_mol.node_mask)
out = mmff_relax_from_coords(molblock, trimmed_generated_coords, max_iters=200)
print("MMFF status:", out["status"])
print("Energy:", out["energy_before"], "->", out["energy_after"])

MMFF status: 0
Energy: 40.798873774036025 -> 34.43387249514519


In [32]:
viz_from_molblock_and_coords(molblock, out["coords_opt"])

In [None]:
for _ in range(3):
    rand_idx = random.randint(0, len(train_data) - 1)
    sample_mol, smiles, molblock = train_data[rand_idx]
    padded_sample_mol, p_smiles, p_molblock = padded_train_data[rand_idx]
    assert smiles == p_smiles
    assert molblock == p_molblock
    print(smiles)

In [21]:
from IPython.display import display, HTML
import ipywidgets as widgets

def _mol_from_molblock_with_coords(molblock: str, coords_xyz: np.ndarray) -> Chem.Mol:
    """Return an RDKit Mol whose conformer is overwritten with coords_xyz (N,3)."""
    mol = Chem.MolFromMolBlock(molblock, removeHs=False, sanitize=True)
    if mol is None:
        raise ValueError("MolFromMolBlock failed. MolBlock may be invalid.")

    coords_xyz = np.asarray(coords_xyz, dtype=np.float64)
    n = mol.GetNumAtoms()
    if coords_xyz.shape != (n, 3):
        raise ValueError(f"coords shape {coords_xyz.shape} != {(n,3)} for this molblock.")

    conf = Chem.Conformer(n)
    for i in range(n):
        x, y, z = coords_xyz[i]
        conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))

    mol.RemoveAllConformers()
    mol.AddConformer(conf, assignId=True)
    return mol


def _viewer_from_mol(mol: Chem.Mol, width: int, height: int, style=None) -> py3Dmol.view:
    """Create a py3Dmol viewer from an RDKit Mol (uses MolBlock)."""
    if style is None:
        style = {"stick": {}}

    block = Chem.MolToMolBlock(mol, confId=0)
    v = py3Dmol.view(width=width, height=height)
    v.addModel(block, "mol")
    v.setStyle(style)
    v.zoomTo()
    return v

def _unpad_coords(coords_padded: np.ndarray, atom_mask: np.ndarray | None, num_atoms: int | None):
    """
    coords_padded: (N_MAX, 3)
    atom_mask: (N_MAX,) bool/int or None
    num_atoms: int or None
    Returns coords: (N,3)
    """
    coords_padded = np.asarray(coords_padded)

    if num_atoms is None:
        if atom_mask is None:
            raise ValueError("Need either atom_mask or num_atoms to unpad coords.")
        atom_mask = np.asarray(atom_mask).astype(bool)
        idx = np.nonzero(atom_mask)[0]
        if idx.size == 0:
            raise ValueError("atom_mask selects 0 atoms.")
        num_atoms = int(idx.size)

    return coords_padded[:num_atoms, :]

def viz_batch_gt_vs_sampled(
    molblocks: list[str],
    gt_coords_padded,          # (B, N_MAX, 3) JAX/NumPy
    sampled_coords_padded,     # (B, N_MAX, 3) JAX/NumPy
    atom_mask=None,            # (B, N_MAX) optional; 1/True for real atoms
    num_atoms=None,            # (B,) optional; integer per molecule
    max_items: int | None = 16,
    width: int = 350,
    height: int = 300,
    style=None,
    labels=("Reference", "Sampled"),
):
    """
    Renders a batch as a vertical list of rows, each row contains:
      [Reference viewer] [Sampled viewer]

    Provide either:
      - atom_mask (B, N_MAX), OR
      - num_atoms (B,)
    """
    B = len(molblocks)
    gt_coords_padded = np.asarray(gt_coords_padded)
    sampled_coords_padded = np.asarray(sampled_coords_padded)

    if gt_coords_padded.shape[0] != B or sampled_coords_padded.shape[0] != B:
        raise ValueError("Batch dimension mismatch between molblocks and coords arrays.")

    if max_items is None:
        max_items = B
    show_B = min(B, int(max_items))

    rows = []
    header = widgets.HBox(
        [
            widgets.HTML(f"<b>{labels[0]}</b>"),
            widgets.HTML(f"<b>{labels[1]}</b>"),
        ],
        layout=widgets.Layout(justify_content="space-around", width="100%")
    )
    rows.append(header)

    for i in range(show_B):
        # Determine N
        n_i = int(num_atoms[i]) if num_atoms is not None else None
        mask_i = atom_mask[i] if atom_mask is not None else None

        gt_xyz = _unpad_coords(gt_coords_padded[i], mask_i, n_i)
        smp_xyz = _unpad_coords(sampled_coords_padded[i], mask_i, n_i)

        # Build two mols + viewers
        mol_gt = _mol_from_molblock_with_coords(molblocks[i], gt_xyz)
        mol_smp = _mol_from_molblock_with_coords(molblocks[i], smp_xyz)

        v_gt = _viewer_from_mol(mol_gt, width=width, height=height, style=style)
        v_smp = _viewer_from_mol(mol_smp, width=width, height=height, style=style)

        # Embed viewers as HTML so we can place them in widgets
        html_gt = v_gt._make_html()
        html_smp = v_smp._make_html()

        # Optional small index caption
        caption = widgets.HTML(f"<div style='font-size:12px; opacity:0.8;'>#{i}</div>")

        row = widgets.HBox(
            [
                widgets.VBox([caption, widgets.HTML(html_gt)]),
                widgets.VBox([widgets.HTML("&nbsp;"), widgets.HTML(html_smp)]),
            ],
            layout=widgets.Layout(justify_content="flex-start", align_items="flex-start", gap="12px")
        )
        rows.append(row)

    display(widgets.VBox(rows, layout=widgets.Layout(gap="10px")))
