In [1]:
from rdkit.Chem.rdmolfiles import MolFromPDBFile
from rdkit.Chem.rdchem import Mol
from rdkit.Chem import AllChem

import numpy as np
import rdkit.Chem as Chem
from rdkit.Chem import AddHs, AssignStereochemistry, HybridizationType, ChiralType, BondStereo, MolFromMol2File
from rdkit.Chem.AllChem import ComputeGasteigerCharges
import os
import sys
sys.path.append("../../")
from src.data.pocket_utils import get_atom_coordinates, find_pocket_atoms_RDKit
from src.utils.constants import ES_THRESHOLD, METAL_OX_STATES, POCKET_THRESHOLD
from tqdm import tqdm 
from src.data.utils import pdb_to_rdkit_mol, mol2_to_rdkit_mol, get_vdw_radius, add_charges_to_molecule, get_node_features, get_edge_features, extract_charges_from_mol2, vdw_interactions

def get_edge_features(mol: Mol,
                    pairwise_function: callable,
                      ) -> tuple:
    
    """

    Extracts edge features from a given RDKit molecule object.
    Parameters:
        mol (rdkit.Chem.rdchem.Mol): RDKit molecule object.
        pocket_atom_indices (list): A list of ints containing the indices of the atoms in the pocket.
        pairwise_function (callable): A function that takes two atoms and returns edge features.

    Returns:
        edge_features (np.ndarray): A 2D array of shape (num_bonds, num_edge_features) containing edge features.
        edge_indices (np.ndarray): A 2D array of shape (num_bonds, 2) containing the indices of the atoms connected by each bond.
    
    """

    # Initialize a list to store edge features and indices
    edge_indices, edge_features = [], []

    #for every atom in the pocket, create an edge between atoms if they are within 4 angstroms of each other
    for i, atom1 in tqdm(enumerate(range(mol.GetNumAtoms()))):
    
        atom_i = mol.GetAtomWithIdx(atom1)

        for j, atom2 in tqdm(enumerate(range(mol.GetNumAtoms()))):
    
            atom_j = mol.GetAtomWithIdx(atom2)

            if j > i: #only consider the upper triangle of the matrix
    
                i_j_edge_features = pairwise_function(mol, atom1, atom2)

                if i_j_edge_features:

                    # Append edge indices to list, duplicating to account for both directions
                    edge_indices.append((atom1, atom2))
                    edge_indices.append((atom2, atom1))

                    # Append edge features to list, duplicating to account for both directions
                    edge_features.append(i_j_edge_features)
                    edge_features.append(i_j_edge_features)

    return np.array(edge_features, dtype='float64'), np.array(edge_indices, dtype='int64')

In [2]:
# Define test cases
test_cases = [("CCCCCCCCCC=C", 0, 10), # expected output: None
              ("C=C", 0, 1) # expected output: (0, 0, 0, 0, 0)
              ]

# Initialize RDKit molecules and conformers
molecules = []
for smiles, a1, a2 in test_cases:
    mol = Chem.MolFromSmiles(smiles)
    AllChem.EmbedMolecule(mol)
    molecules.append((mol, a1, a2))


[09:11:22] Molecule does not have explicit Hs. Consider calling AddHs()
[09:11:22] Molecule does not have explicit Hs. Consider calling AddHs()


In [3]:
#call edge features function with vdw_interactions function
for mol, a1, a2 in molecules:
    edge_features, edge_indices = get_edge_features(mol, vdw_interactions)
    print(edge_features, edge_indices)

11it [00:00, 861.80it/s]
11it [00:00, 1110.57it/s]
11it [00:00, 1160.22it/s]
11it [00:00, 1199.43it/s]
11it [00:00, 1638.40it/s]
11it [00:00, 1807.75it/s]
11it [00:00, 2243.49it/s]
11it [00:00, 3882.63it/s]
11it [00:00, 2996.32it/s]
11it [00:00, 4251.90it/s]
11it [00:00, 79137.81it/s]
11it [00:00, 109.06it/s]


[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]] [[ 0  1]
 [ 1  0]
 [ 0  2]
 [ 2  0]
 [ 0  3]
 [ 3  0]
 [ 1  2]
 [ 2  1]
 [ 1  3]
 [ 3  1]
 [ 2  3]
 [ 3  2]
 [ 2  4]
 [ 4  2]
 [ 2  5]
 [ 5  2]
 [ 3  4]
 [ 4  3]
 [ 3  5]
 [ 5  3]
 [ 4  5]
 [ 5  4]
 [ 4  6]


2it [00:00, 978.04it/s]
2it [00:00, 60787.01it/s]
2it [00:00, 406.58it/s]

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]] [[0 1]
 [1 0]]



