In [1]:
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs
from datetime import datetime

import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
from tqdm import tqdm
import pickle
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')



In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            # Create data object with SMILES
            data = Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

            # Store the SMILES as an attribute
            data.smiles = smiles

            return data

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None

In [4]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
import numpy as np

class MolecularFeatureExtractor:
    """Robust feature extractor for molecule processing"""
    def __init__(self):
        pass

    def process_molecule(self, smiles: str):
        """
        Convert SMILES to graph data with robust feature extraction
        
        Args:
        - smiles (str): SMILES representation of the molecule
        
        Returns:
        - Processed molecule graph data
        """
        try:
            # Convert SMILES to RDKit molecule
            mol = Chem.MolFromSmiles(smiles)
            
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None
            
            # Add hydrogens
            mol = Chem.AddHs(mol)
            
            # Generate 3D coordinates (optional but helpful)
            try:
                AllChem.EmbedMolecule(mol, randomSeed=42)
            except:
                pass
            
            # Extract features
            x_cat = self._get_atom_categorical_features(mol)
            x_phys = self._get_atom_physical_features(mol)
            edge_index = self._get_edge_index(mol)
            edge_attr = self._get_edge_attributes(mol)
            
            # Create a simple data object mimicking PyG Data
            class SimpleData:
                def __init__(self, x_cat, x_phys, edge_index, edge_attr):
                    self.x_cat = torch.tensor(x_cat, dtype=torch.long)
                    self.x_phys = torch.tensor(x_phys, dtype=torch.float)
                    self.edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                    self.edge_attr = torch.tensor(edge_attr, dtype=torch.float)
                    self.num_nodes = len(x_cat)
            
            return SimpleData(
                x_cat=x_cat,
                x_phys=x_phys,
                edge_index=edge_index,
                edge_attr=edge_attr
            )
        except Exception as e:
            print(f"Error processing molecule: {e}")
            import traceback
            traceback.print_exc()
            return None

    def _get_atom_categorical_features(self, mol):
        """
        Extract categorical features for atoms
        
        Features:
        - Atomic number
        - Chirality (simplified)
        """
        chirality_mapping = {
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED: 0,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW: 1,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW: 2,
            Chem.rdchem.ChiralType.CHI_OTHER: 3
        }
        
        return [
            [
                atom.GetAtomicNum(),
                chirality_mapping.get(atom.GetChiralTag(), 0)
            ] 
            for atom in mol.GetAtoms()
        ]

    def _get_atom_physical_features(self, mol):
        """
        Extract physical features for atoms
        
        Features:
        - Formal charge
        - Hybridization
        - Aromaticity
        - Total number of hydrogens
        - Valence
        - Degree
        - Mass
        """
        return [
            [
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree(),
                atom.GetMass()
            ] 
            for atom in mol.GetAtoms()
        ]

    def _get_edge_index(self, mol):
        """
        Get edge indices for the molecular graph
        
        Returns bidirectional edges to represent an undirected graph
        """
        edges = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edges.append([i, j])
            edges.append([j, i])  # Add reverse edge for undirected graph
        return edges

    def _get_edge_attributes(self, mol):
        """
        Extract edge attributes for bonds
        
        Features:
        - Bond type
        - Is conjugated
        - Bond stereo
        """
        # Bond type mapping
        bond_type_mapping = {
            Chem.rdchem.BondType.SINGLE: 0,
            Chem.rdchem.BondType.DOUBLE: 1,
            Chem.rdchem.BondType.TRIPLE: 2,
            Chem.rdchem.BondType.AROMATIC: 3,
            Chem.rdchem.BondType.UNSPECIFIED: 4
        }
        
        # Bond stereo mapping
        bond_stereo_mapping = {
            Chem.rdchem.BondStereo.STEREONONE: 0,
            Chem.rdchem.BondStereo.STEREOANY: 1,
            Chem.rdchem.BondStereo.STEREOZ: 2,
            Chem.rdchem.BondStereo.STEREOE: 3,
            Chem.rdchem.BondStereo.STEREOCIS: 4,
            Chem.rdchem.BondStereo.STEREOTRANS: 5
        }
        
        edge_attrs = []
        for bond in mol.GetBonds():
            # Get bond type
            bond_type = bond_type_mapping.get(bond.GetBondType(), 4)
            
            # Check if conjugated
            is_conjugated = int(bond.GetIsConjugated())
            
            # Get bond stereo
            bond_stereo = bond_stereo_mapping.get(bond.GetStereo(), 0)
            
            # Create feature vector for each edge (both directions)
            feat = [
                bond_type,
                is_conjugated,
                bond_stereo
            ]
            edge_attrs.append(feat)
            edge_attrs.append(feat)  # Add for reverse edge
        
        return edge_attrs

def main():
    # Test the feature extractor
    extractor = MolecularFeatureExtractor()
    
    # Test molecules
    test_smiles = [
        'CC(=O)Nc1ccc(N)cc1',  # Acetaminophen
        'c1ccccc1',             # Benzene
        'CC(=O)O'               # Acetic acid
    ]
    
    for smiles in test_smiles:
        print(f"\nProcessing: {smiles}")
        data = extractor.process_molecule(smiles)
        
        if data:
            print("Categorical Features (x_cat):")
            print(data.x_cat)
            print("\nPhysical Features (x_phys):")
            print(data.x_phys)
            print("\nEdge Index:")
            print(data.edge_index)
            print("\nEdge Attributes:")
            print(data.edge_attr)
            print(f"Number of Nodes: {data.num_nodes}")
        else:
            print("Failed to process molecule")

if __name__ == '__main__':
    main()


Processing: CC(=O)Nc1ccc(N)cc1
Categorical Features (x_cat):
tensor([[6, 0],
        [6, 0],
        [8, 0],
        [7, 0],
        [6, 0],
        [6, 0],
        [6, 0],
        [6, 0],
        [7, 0],
        [6, 0],
        [6, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0]])

Physical Features (x_phys):
tensor([[ 0.0000,  4.0000,  0.0000,  0.0000,  4.0000,  4.0000, 12.0110],
        [ 0.0000,  3.0000,  0.0000,  0.0000,  4.0000,  3.0000, 12.0110],
        [ 0.0000,  3.0000,  0.0000,  0.0000,  2.0000,  1.0000, 15.9990],
        [ 0.0000,  3.0000,  0.0000,  0.0000,  3.0000,  3.0000, 14.0070],
        [ 0.0000,  3.0000,  1.0000,  0.0000,  4.0000,  3.0000, 12.0110],
        [ 0.0000,  3.0000,  1.0000,  0.0000,  4.0000,  3.0000, 12.0110],
        [ 0.0000,  3.0000,  1.0000,  0.0000,  4.0000,  3.0000, 12.0110],
        [ 0.0000,  3.0000,  1.0000,  0.0000,  4.0000,  3.0000