In [30]:
from torch_geometric.utils import to_dense_adj
from torch_geometric.datasets import QM9 
from tqdm import tqdm
import torch
import numpy as np

In [4]:
path = '/mnt/mntsdb/genai/10_623_final_project/dataset/QM9'
dataset = QM9(path)

In [32]:

def one_hot_vector(val, lst):
    """Converts a value to a one-hot vector based on options in lst"""
    if val not in lst:
        val = lst[-1]
    return map(lambda x: x == val, lst)



def get_node_features(
    mol,
) -> torch.Tensor:

    row, col = mol.edge_index
    N = mol.num_nodes
    feats = []
    for i in range(N):
        attrs: list[float] = []
        # 1) atomic number
        attrs += one_hot_vector(
            int(mol.z[i].item()),
            [5, 6, 7, 8, 9, 15, 16, 17, 35, 53, 999]
        )
        # find all neighbors j of i via directed edges (i->j)
        nbrs = col[(row == i)]
        deg = int(nbrs.size(0))
        # 2) degree
        attrs += one_hot_vector(
            deg,
            [0, 1, 2, 3, 4, 5]
        )
        # 3) H-neighbor count (atomic number == 1)
        h_count = int((mol.z[nbrs] == 1).sum().item())
        attrs += one_hot_vector(
            h_count,
            [0, 1, 2, 3, 4]
        )

        feats.append(np.array(attrs, dtype=np.float32))

    # stack into a single (N, D) array
    return np.stack(feats, axis=0)


_atomic_number_to_symbol = {
    1: 'H',
    6: 'C',
    7: 'N',
    8: 'O',
    9: 'F',
}

def get_qm9_features(mol):

    adj = to_dense_adj(mol.edge_index)[0]

    pos  = mol.pos                 
    dist = torch.cdist(pos, pos) 

    # 5) atomic symbols
    z_array = mol.z.tolist()       # list of ints
    symbols = [_atomic_number_to_symbol.get(z, '999') for z in z_array]

    node_features = get_node_features(mol)

    return node_features, adj.cpu().numpy(), dist.cpu().numpy(), pos.cpu().numpy(), symbols

In [11]:
dataset_ss = {mol.smiles : mol for mol in tqdm(dataset)}

  0%|                                                                                        | 0/130831 [00:00<?, ?it/s]

100%|████████████████████████████████████████████████████████████████████████| 130831/130831 [00:02<00:00, 56428.11it/s]


In [33]:
features = {ss : get_qm9_features(mol) for ss, mol in tqdm(dataset_ss.items())}

100%|██████████████████████████████████████████████████████████████████████████| 130729/130729 [02:38<00:00, 825.80it/s]


In [34]:
features['[H]C([H])([H])[H]']

(array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 0., 0.]], dtype=float32),
 array([[0., 1., 1., 1., 1.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]], dtype=float32),
 array([[0.       , 1.0919182, 1.0919425, 1.0918945, 1.0919341],
        [1.0919182, 0.       , 1.7830887, 1.783101 , 1.7831048],
        [1.0919425, 1.7830887, 0.       , 1.7831084, 1.7831008],
        [1.0918945, 1.783101 , 1.7831084, 0.       , 1.7831068],
        [1.0919341, 1

In [39]:
def dict_to_npz(data_dict: dict[str, tuple], npz_path: str):
    smiles_list = list(data_dict.keys())
    values      = list(data_dict.values())
    feats_list, adj_list, dist_list, pos_list, symbols_list = zip(*values)

    def make_array(lst):
        return np.array(lst, dtype=object)

    smiles_arr = np.array(smiles_list, dtype='<U')  # or dtype=object if very long
    feats_arr  = make_array(feats_list)
    adj_arr    = make_array(adj_list)
    dist_arr   = make_array(dist_list)
    pos_arr    = make_array(pos_list)
    syms_arr   = make_array(symbols_list)

    np.savez_compressed(
        npz_path,
        smiles = smiles_arr,
        feats  = feats_arr,
        adj    = adj_arr,
        dist   = dist_arr,
        pos    = pos_arr,
        symbols= syms_arr,
    )


In [89]:
path = "/mnt/mntsdb/genai/10_623_final_project/dataset/quantum_features.npz"
dict_to_npz(features, path)

### Make Low Quality

In [82]:
from rdkit import Chem
from rdkit.Chem import AllChem
from concurrent.futures import ProcessPoolExecutor, as_completed

In [72]:
_BOND_TYPE_MAP = {
    0: Chem.BondType.SINGLE,
    1: Chem.BondType.DOUBLE,
    2: Chem.BondType.TRIPLE,
    3: Chem.BondType.AROMATIC,
}

def embed_with_rdkit(data, symbols: list[str]) -> torch.Tensor:
    
    em = Chem.RWMol()
    for sym in symbols:
        em.AddAtom(Chem.Atom(sym))


    N = len(symbols)
    row, col = data.edge_index
    edge_attrs = data.edge_attr  # shape [num_edges*2, 4] for directed edges
    for i, j, attr in zip(row.tolist(), col.tolist(), edge_attrs.tolist()):
        # only add each undirected bond once
        if i < j:
            bond_idx = int(np.argmax(attr))     # 0=single,1=double,2=triple,3=aromatic
            em.AddBond(i, j, _BOND_TYPE_MAP[bond_idx])

    mol = em.GetMol()

    Chem.SanitizeMol(mol)

    AllChem.EmbedMolecule(mol, AllChem.ETKDG())
    # AllChem.UFFOptimizeMolecule(mol)

    conf = mol.GetConformer()
    positions = []
    for i in range(N):
        pt = conf.GetAtomPosition(i)
        positions.append((pt.x, pt.y, pt.z))

    positions = torch.tensor(positions, dtype = torch.float32)

    return positions.cpu().numpy(), torch.cdist(positions, positions).cpu().numpy()


In [87]:
low_quality_features = {}
fails = 0

max = 50000
count = 0

for ss in tqdm(features):
    try:
        res = embed_with_rdkit(dataset_ss[ss], features[ss][-1])
        low_quality_features[ss] = res
    except Exception:
        fails += 1
    count += 1

    if count >= max:
        break

print(fails)


  0%|                                                                              | 27/130729 [00:00<08:10, 266.29it/s]

 38%|█████████████████████████████                                               | 49999/130729 [13:20<21:32, 62.46it/s]

5100





In [90]:
pos_path = "/mnt/mntsdb/genai/10_623_final_project/dataset/low_quality_positions_50K.npz"
dist_path = "/mnt/mntsdb/genai/10_623_final_project/dataset/low_quality_dist_matricies_50K.npz"

positions = {ss : low_quality_features[ss][0] for ss in low_quality_features}
dists = {ss : low_quality_features[ss][1] for ss in low_quality_features}
np.savez_compressed(pos_path, **positions)
np.savez_compressed(dist_path, **dists)