In [1]:
from rdkit.Chem.rdmolfiles import MolFromPDBFile
from rdkit.Chem.rdchem import Mol
from rdkit.Chem import AllChem

import numpy as np
import rdkit.Chem as Chem
from rdkit.Chem import AddHs, AssignStereochemistry, HybridizationType, ChiralType, BondStereo, MolFromMol2File
from rdkit.Chem.AllChem import ComputeGasteigerCharges
import os
import sys
sys.path.append("../../")
from src.data.pocket_utils import get_atom_coordinates, find_pocket_atoms_RDKit
from src.utils.constants import ES_THRESHOLD, METAL_OX_STATES, POCKET_THRESHOLD
from tqdm import tqdm 
from src.data.utils import pdb_to_rdkit_mol, mol2_to_rdkit_mol, get_vdw_radius, add_charges_to_molecule, get_node_features, get_edge_features, extract_charges_from_mol2, vdw_interactions

In [8]:
# Define test cases
test_cases = [("CCCCCCCCCC=C", 0, 10), # expected output: None
              ("C=C", 0, 1) # expected output: (0, 0, 0, 0, 0)
              ]

# Initialize RDKit molecules and conformers
molecules = []
for smiles, a1, a2 in test_cases:
    mol = Chem.MolFromSmiles(smiles)
    AllChem.EmbedMolecule(mol)
    #add hydrogens
    mol = Chem.AddHs(mol)
    molecules.append((mol, a1, a2))


[09:22:44] Molecule does not have explicit Hs. Consider calling AddHs()
[09:22:44] Molecule does not have explicit Hs. Consider calling AddHs()


In [15]:
def get_node_features(mol: Mol) -> np.ndarray:
    """
    Extracts node features from a given RDKit molecule object.
    Parameters:

        mol (rdkit.Chem.rdchem.Mol): RDKit molecule object. It should have a "protein_or_ligand" property for each atom

    Returns:
        node_features (np.ndarray): A 2D array of shape (num_atoms, num_node_features) containing node features:
            (protein_or_ligand_id, atomic_num, atomic_mass, aromatic_indicator, ring_indicator, hybridization,
        chirality, num_heteroatoms, degree, num_hydrogens, formal_charge, num_radical_electrons)
    """
    
    node_features  = []

    # Iterate over each atom in the molecule and calculate node features
    for atom in mol.GetAtoms():

        protein_or_ligand_id = [-1 if atom.GetProp('protein_or_ligand') == 'protein' else 1][0]
        
        # Calculate node features
        atomic_num = atom.GetAtomicNum()
        atomic_mass = atom.GetMass()
        aromatic_indicator = int(atom.GetIsAromatic())
        ring_indicator = int(atom.IsInRing())
        hybridization_tag = atom.GetHybridization()

        if hybridization_tag == HybridizationType.SP:
            hybridization = 1
        elif hybridization_tag == HybridizationType.SP2:
            hybridization = 2
        elif hybridization_tag == HybridizationType.SP3:
            hybridization = 3
        else:
            hybridization = 0

        chiral_tag = atom.GetChiralTag()

        if chiral_tag == ChiralType.CHI_TETRAHEDRAL_CW:
            chirality = 1

        elif chiral_tag == ChiralType.CHI_TETRAHEDRAL_CCW:
            chirality = -1
        else:
            chirality = 0

        num_heteroatoms = len([bond for bond in atom.GetBonds() if bond.GetOtherAtom(atom).GetAtomicNum() != atom.GetAtomicNum()])
        degree = atom.GetDegree()
        num_hydrogens = len([bond for bond in atom.GetBonds() if bond.GetOtherAtom(atom).GetAtomicNum() == 1])

        formal_charge = atom.GetFormalCharge()
        num_radical_electrons = atom.GetNumRadicalElectrons()

        # Append node features to list
        node_features.append((protein_or_ligand_id, atomic_num, atomic_mass, aromatic_indicator, ring_indicator, hybridization,
        chirality, num_heteroatoms, degree, num_hydrogens, formal_charge, num_radical_electrons))

    return np.array(node_features, dtype='float64')

In [16]:
def tag_atoms(mol, tag):
    for atom in mol.GetAtoms():
        atom.SetProp('protein_or_ligand', tag)

#tag all of the atoms to be ligands in teh test case
for mol, a1, a2 in molecules:
    tag_atoms(mol, 'ligand')

In [17]:
#call edge features function with vdw_interactions function
for mol, a1, a2 in molecules:
    node_features = get_node_features(mol)
    print(node_features)
    print("\n")

"""
Recall: the desired output is:

 (protein_or_ligand_id, atomic_num, atomic_mass, aromatic_indicator, ring_indicator, hybridization,
        chirality, num_heteroatoms, degree, num_hydrogens, formal_charge, num_radical_electrons)
        
"""

atom 0 features: protein_or_ligand_id: 1, atomic_num: 6, atomic_mass: 12.011, aromatic_indicator: 0, ring_indicator: 0, hybridization: 3, chirality: 0, num_heteroatoms: 3, degree: 4, num_hydrogens: 3, formal_charge: 0, num_radical_electrons: 0
atom 1 features: protein_or_ligand_id: 1, atomic_num: 6, atomic_mass: 12.011, aromatic_indicator: 0, ring_indicator: 0, hybridization: 3, chirality: 0, num_heteroatoms: 2, degree: 4, num_hydrogens: 2, formal_charge: 0, num_radical_electrons: 0
atom 2 features: protein_or_ligand_id: 1, atomic_num: 6, atomic_mass: 12.011, aromatic_indicator: 0, ring_indicator: 0, hybridization: 3, chirality: 0, num_heteroatoms: 2, degree: 4, num_hydrogens: 2, formal_charge: 0, num_radical_electrons: 0
atom 3 features: protein_or_ligand_id: 1, atomic_num: 6, atomic_mass: 12.011, aromatic_indicator: 0, ring_indicator: 0, hybridization: 3, chirality: 0, num_heteroatoms: 2, degree: 4, num_hydrogens: 2, formal_charge: 0, num_radical_electrons: 0
atom 4 features: protein

'\nRecall: the desired output is:\n\n (protein_or_ligand_id, atomic_num, atomic_mass, aromatic_indicator, ring_indicator, hybridization,\n        chirality, num_heteroatoms, degree, num_hydrogens, formal_charge, num_radical_electrons)\n\nExpected output:\n[1 6 12 0 0 0 3 0 1 1 3 0 0]\n\n [[ 1.  6. 12.  0.  0.  3.  0.  3.  4.  0.  0.  0]]\n\n'