In [2]:
import numpy as np
from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector
from rdkit import Chem
import torch
from rdkit.Chem import AllChem

def mol2graph(mol: Chem.Mol) -> tuple:
    """
    Convert an RDKit molecule object to a graph representation.

    Args:
        mol (Chem.Mol): The RDKit molecule object.

    Returns:
        tuple: A tuple containing:
            - num_nodes (np.int16): The number of nodes in the graph.
            - edges (np.ndarray): The edges of the graph represented as an array of shape (2, num_edges).
            - node_features (np.ndarray): The node features represented as an array of shape (num_nodes, num_node_features).
            - edge_features (np.ndarray): The edge features represented as an array of shape (num_edges, num_edge_features).
    """
    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append(atom_to_feature_vector(atom))
    x = np.array(atom_features_list, dtype=np.int64)

    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype=np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype=np.int64)

    else:  # mol has no bonds
        edge_index = np.empty((2, 0), dtype=np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype=np.int64)

    num_nodes = np.array(len(x), dtype=np.int16)
    edges = edge_index.T.astype(np.int16)
    edge_features = edge_attr.astype(np.int16)
    node_features = x.astype(np.int16)

    return num_nodes, edges, node_features, edge_features

In [3]:
smiles='CC(c1cc(Cc2c3cccc2)c3cc1)C(O)=O'
mol =Chem.MolFromSmiles(smiles)
num_nodes, edges, node_features, edge_features = mol2graph(mol)
node_features.shape

(18, 9)

In [3]:
NODE_FEATURES_OFFSET = 128
EDGE_FEATURES_OFFSET = 8
dtype = np.float32
NUM_CONFS=40

def floyd_warshall(A):
    n = A.shape[0]
    D = np.zeros((n,n), dtype=np.int16)
    
    for i in range(n):
        for j in range(n):
            if i == j:
                pass
            elif A[i,j] == 0:
                D[i,j] = 510
            else:
                D[i,j] = 1
    
    for k in range(n):
        for i in range(n):
            for j in range(n):
                old_dist = D[i,j]
                new_dist = D[i,k] + D[k,j]
                if new_dist < old_dist:
                    D[i,j] = new_dist
    return D

def preprocess_data(num_nodes, edges, node_feats, edge_feats):
    node_feats = node_feats + np.arange(1,node_feats.shape[-1]*NODE_FEATURES_OFFSET+1,
                                            NODE_FEATURES_OFFSET,dtype=np.int16)
    edge_feats = edge_feats + np.arange(1,edge_feats.shape[-1]*EDGE_FEATURES_OFFSET+1,
                                            EDGE_FEATURES_OFFSET,dtype=np.int16)
    
    A = np.zeros((num_nodes,num_nodes),dtype=np.int16)
    E = np.zeros((num_nodes,num_nodes,edge_feats.shape[-1]),dtype=np.int16)
    for k in range(edges.shape[0]):
        i,j = edges[k,0], edges[k,1]
        A[i,j] = 1
        E[i,j] = edge_feats[k]
    
    D = floyd_warshall(A)
    return node_feats, D, E

def coords2dist(mol):
    new_mol = Chem.AddHs(mol)
    res = AllChem.EmbedMultipleConfs(new_mol, numConfs=NUM_CONFS, numThreads=0)
    ### MMFF generates multiple conformations
    # res = AllChem.MMFFOptimizeMoleculeConfs(new_mol, numThreads=0)
    new_mol = Chem.RemoveHs(new_mol)
    index, _ = min(enumerate(res), key=lambda x: x[1])
    conf = new_mol.GetConformer(id=index)
    coords = conf.GetPositions()
    coords = coords[:new_mol.GetNumAtoms()].astype(dtype)
    coords = torch.tensor(coords)
    return torch.norm(coords.unsqueeze(-2) - coords.unsqueeze(-3), dim=-1)

In [14]:
node_features, distance_matrix, edge_features_matrix = preprocess_data(num_nodes, edges, node_features, edge_features)

In [5]:
dist_input = coords2dist(mol)

In [21]:
node_features.shape[-1]
edge_features.shape

(34, 3)

In [22]:
dist_mat

array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 3, 2],
       [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 2, 1],
       [2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2],
       [3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 3],
       [4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 2],
       [5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 3, 2, 1, 2, 3],
       [6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 3, 2, 2, 3, 4],
       [7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 2, 1, 2, 4, 5],
       [8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 2, 3, 5, 6],
       [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 6, 7],
       [8, 7, 6, 5, 4, 3, 3, 2, 2, 1, 0, 1, 2, 5, 6],
       [7, 6, 5, 4, 3, 2, 2, 1, 2, 2, 1, 0, 1, 4, 5],
       [6, 5, 4, 3, 2, 1, 2, 2, 3, 3, 2, 1, 0, 3, 4],
       [3, 2, 3, 2, 1, 2, 3, 4, 5, 6, 5, 4, 3, 0, 1],
       [2, 1, 2, 3, 2, 3, 4, 5, 6, 7, 6, 5, 4, 1, 0]], dtype=int16)

In [27]:
edge_features.shape

(34, 3)