In [30]:
from torch_geometric.utils import to_dense_adj
from torch_geometric.datasets import QM9 
from tqdm import tqdm
import torch
import numpy as np

In [4]:
path = '/mnt/mntsdb/genai/10_623_final_project/dataset/QM9'
dataset = QM9(path)

In [32]:

def one_hot_vector(val, lst):
    """Converts a value to a one-hot vector based on options in lst"""
    if val not in lst:
        val = lst[-1]
    return map(lambda x: x == val, lst)



def get_node_features(
    mol,
) -> torch.Tensor:

    row, col = mol.edge_index
    N = mol.num_nodes
    feats = []
    for i in range(N):
        attrs: list[float] = []
        # 1) atomic number
        attrs += one_hot_vector(
            int(mol.z[i].item()),
            [5, 6, 7, 8, 9, 15, 16, 17, 35, 53, 999]
        )
        # find all neighbors j of i via directed edges (i->j)
        nbrs = col[(row == i)]
        deg = int(nbrs.size(0))
        # 2) degree
        attrs += one_hot_vector(
            deg,
            [0, 1, 2, 3, 4, 5]
        )
        # 3) H-neighbor count (atomic number == 1)
        h_count = int((mol.z[nbrs] == 1).sum().item())
        attrs += one_hot_vector(
            h_count,
            [0, 1, 2, 3, 4]
        )

        feats.append(np.array(attrs, dtype=np.float32))

    # stack into a single (N, D) array
    return np.stack(feats, axis=0)


_atomic_number_to_symbol = {
    1: 'H',
    6: 'C',
    7: 'N',
    8: 'O',
    9: 'F',
}

def get_qm9_features(mol):

    adj = to_dense_adj(mol.edge_index)[0]

    pos  = mol.pos                 
    dist = torch.cdist(pos, pos) 

    # 5) atomic symbols
    z_array = mol.z.tolist()       # list of ints
    symbols = [_atomic_number_to_symbol.get(z, '999') for z in z_array]

    node_features = get_node_features(mol)

    return node_features, adj.cpu().numpy(), dist.cpu().numpy(), pos.cpu().numpy(), symbols

In [11]:
dataset_ss = {mol.smiles : mol for mol in tqdm(dataset)}

  0%|                                                                                        | 0/130831 [00:00<?, ?it/s]

100%|████████████████████████████████████████████████████████████████████████| 130831/130831 [00:02<00:00, 56428.11it/s]


In [33]:
features = {ss : get_qm9_features(mol) for ss, mol in tqdm(dataset_ss.items())}

100%|██████████████████████████████████████████████████████████████████████████| 130729/130729 [02:38<00:00, 825.80it/s]


In [34]:
features['[H]C([H])([H])[H]']

(array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 0., 0.]], dtype=float32),
 array([[0., 1., 1., 1., 1.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]], dtype=float32),
 array([[0.       , 1.0919182, 1.0919425, 1.0918945, 1.0919341],
        [1.0919182, 0.       , 1.7830887, 1.783101 , 1.7831048],
        [1.0919425, 1.7830887, 0.       , 1.7831084, 1.7831008],
        [1.0918945, 1.783101 , 1.7831084, 0.       , 1.7831068],
        [1.0919341, 1

In [39]:
def dict_to_npz(data_dict: dict[str, tuple], npz_path: str):
    smiles_list = list(data_dict.keys())
    values      = list(data_dict.values())
    feats_list, adj_list, dist_list, pos_list, symbols_list = zip(*values)

    def make_array(lst):
        return np.array(lst, dtype=object)

    smiles_arr = np.array(smiles_list, dtype='<U')  # or dtype=object if very long
    feats_arr  = make_array(feats_list)
    adj_arr    = make_array(adj_list)
    dist_arr   = make_array(dist_list)
    pos_arr    = make_array(pos_list)
    syms_arr   = make_array(symbols_list)

    np.savez_compressed(
        npz_path,
        smiles = smiles_arr,
        feats  = feats_arr,
        adj    = adj_arr,
        dist   = dist_arr,
        pos    = pos_arr,
        symbols= syms_arr,
    )


In [41]:
path = "/mnt/mntsdb/genai/10_623_final_project/dataset/quantum_features.npz"
dict_to_npz(features, path)