In [1]:
import torch
import pickle
from rdkit import Chem
from rdkit.Chem import Draw
import shap
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Optional
from torch_geometric.data import Data
from rdkit.Chem import RemoveHs
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit import RDLogger
# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
import tensorflow as tensorflow
import traceback
import os
from datetime import datetime
from rdkit.Chem import rdDepictor
import matplotlib.pyplot as plt
from rdkit.Chem import AllChem, Draw, rdDepictor
from matplotlib.colors import LinearSegmentedColormap
from random import Random

import lime
import lime.lime_tabular
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.data import Data
from typing import List, Tuple, Dict


class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

# Load Encoder
def load_encoder(model_path, device='cpu'):
    """Load trained encoder"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    encoder.eval()
    return encoder.to(device)

# Load Embeddings
def load_embeddings(filepath):
    """Load embeddings and labels"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['embeddings'], data['labels']

# Paths from your saved model
encoder_path = './checkpoints/encoders/final_encoder_20250216_111050.pt'
embedding_path = './embeddings/final_embeddings_20250216_111005.pkl'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load encoder and embeddings
encoder = load_encoder(encoder_path, device)
embeddings, graph_data = load_embeddings(embedding_path)


  checkpoint = torch.load(model_path, map_location=device)


In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
import shap
import torch
import numpy as np
from torch_geometric.data import Batch, Data
from typing import List, Tuple
from matplotlib.colors import LinearSegmentedColormap

class GraphModelWrapper:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.eval()
        self.original_x_cat = None
        self.original_x_phys = None
        self.batch = None
        
    def __call__(self, features):
        """
        Custom call method to handle graph data
        features: feature matrix (numpy array)
        """
        with torch.no_grad():
            if isinstance(features, np.ndarray):
                features = torch.tensor(features, dtype=torch.float).to(self.device)
            
            # Get original shapes
            num_nodes = self.original_x_cat.size(0)
            num_cat_features = self.original_x_cat.size(1)
            
            # Reshape features to match original dimensions
            x_cat = features.reshape(num_nodes, num_cat_features).to(torch.long)
            
            # Create a new batch with modified features
            new_data = Data(
                x_cat=x_cat,
                x_phys=self.original_x_phys,
                edge_index=self.batch.edge_index,
                edge_attr=self.batch.edge_attr,
                batch=self.batch.batch if hasattr(self.batch, 'batch') else None
            )
            
            # Get model output
            outputs = self.model(new_data)
            return outputs.cpu().numpy()

class ModifiedGraphWrapper:
    def __init__(self, model, device, batch):
        self.model = model
        self.device = device
        self.original_batch = batch
        self.model.eval()
        self.num_nodes = batch.x_cat.shape[0]
        self.num_features = batch.x_cat.shape[1]
        self.num_phys_features = batch.x_phys.shape[1]
        
    def __call__(self, x):
        with torch.no_grad():
            try:
                # Convert input to tensor
                x = torch.tensor(x, dtype=torch.float).to(self.device)
                
                # Reshape x to match the expected input shape
                if len(x.shape) == 1:
                    x = x.reshape(1, self.num_nodes, self.num_features)
                else:
                    x = x.reshape(-1, self.num_nodes, self.num_features)
                
                print(f"Processing batch of size {x.shape[0]}")
                
                all_results = []
                for idx in range(x.shape[0]):
                    # Extract categorical features
                    x_cat = x[idx].to(torch.long)
                    
                    # Ensure proper dimensions for x_cat and x_phys
                    if len(x_cat.shape) == 2:
                        x_cat = x_cat.unsqueeze(0)
                    x_phys = self.original_batch.x_phys.unsqueeze(0)
                    
                    # Create consistent batch dimension
                    batch_idx = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device)
                    
                    # Create data object for this sample
                    new_data = Data(
                        x_cat=x_cat.squeeze(0),
                        x_phys=x_phys.squeeze(0),
                        edge_index=self.original_batch.edge_index,
                        edge_attr=self.original_batch.edge_attr,
                        batch=batch_idx,
                        num_nodes=self.num_nodes
                    ).to(self.device)
                    
                    # Get node features
                    node_features = torch.cat([new_data.x_cat.float(), new_data.x_phys], dim=-1)
                    x_encoded = self.model.node_encoder(node_features)
                    
                    # Get intermediate representations
                    x1 = F.relu(self.model.conv1(x_encoded, new_data.edge_index))
                    x2 = F.relu(self.model.conv2(x1, new_data.edge_index))
                    x3 = self.model.conv3(x2, new_data.edge_index)
                    
                    # Combine representations from different layers
                    combined_features = torch.stack([x1, x2, x3], dim=0)
                    node_embeddings = torch.mean(combined_features, dim=0)
                    
                    # Compute node importance
                    node_importance = torch.norm(node_embeddings, dim=1).cpu().numpy()
                    all_results.append(node_importance)
                
                result = np.array(all_results)
                print(f"Result shape: {result.shape}")
                return result
                
            except Exception as e:
                print(f"Error in model wrapper: {e}")
                print(f"Debug info:")
                print(f"x shape: {x.shape}")
                if 'node_features' in locals():
                    print(f"node_features shape: {node_features.shape}")
                if 'node_embeddings' in locals():
                    print(f"node_embeddings shape: {node_embeddings.shape}")
                raise

In [4]:
class MolecularLIME:
    """LIME explainer for molecular graphs with improved type handling"""
    
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.feature_names = [
            'AtomicNum', 'Valence', 'Degree', 
            'IsAromatic', 'Charge', 'NumHs',
            'InRing', 'IsHydrogen'
        ]
        # Define feature ranges for each feature type
        self.feature_ranges = {
            'AtomicNum': list(range(1, 119)),
            'Valence': list(range(7)),
            'Degree': list(range(7)),
            'IsAromatic': [0, 1],
            'Charge': [-1, 0, 1],
            'NumHs': list(range(5)),
            'InRing': [0, 1],
            'IsHydrogen': [0, 1]
        }
        
    def get_valid_states(self, atom: Chem.Atom, mol: Chem.Mol) -> List[Tuple[int, int, int]]:
        """Calculate valid states for an atom considering its environment
        
        Returns:
            List of tuples (charge, bond_count, hydrogen_count)
        """
        atomic_num = atom.GetAtomicNum()
        atom_idx = atom.GetIdx()
        
        # Calculate existing bonds and their orders
        bond_sum = 0
        bond_count = 0
        for bond in atom.GetBonds():
            if bond.GetBondType() == Chem.rdchem.BondType.SINGLE:
                bond_sum += 1
            elif bond.GetBondType() == Chem.rdchem.BondType.DOUBLE:
                bond_sum += 2
            elif bond.GetBondType() == Chem.rdchem.BondType.TRIPLE:
                bond_sum += 3
            elif bond.GetBondType() == Chem.rdchem.BondType.AROMATIC:
                bond_sum += 1.5
            bond_count += 1
        
        # Round bond sum to nearest integer
        bond_sum = round(bond_sum)
        print(f"\nAtom {atom.GetSymbol()}{atom_idx} analysis:")
        print(f"Current bond sum: {bond_sum}")
        print(f"Number of bonds: {bond_count}")
        
        valid_states = []
        
        if atomic_num == 7:  # Nitrogen
            max_bonds = 4  # Maximum including charge
            current_h = atom.GetTotalNumHs()
            
            # Common nitrogen states
            possible_states = [
                # (charge, bonds, hydrogens)
                (0, bond_count, 3 - bond_count),  # Neutral sp3
                (1, bond_count, 2 - bond_count),  # Protonated
                (-1, bond_count, 2 - bond_count), # Deprotonated
            ]
            
            # Filter valid states
            for charge, bonds, h_count in possible_states:
                total_electrons = 5 + (-charge) + (2 * h_count) + bond_sum
                if 0 <= h_count <= 3 and total_electrons <= 8:
                    valid_states.append((charge, bonds, h_count))
                    
        elif atomic_num == 8:  # Oxygen
            max_bonds = 2
            current_h = atom.GetTotalNumHs()
            
            # Common oxygen states
            possible_states = [
                (0, bond_count, 2 - bond_count),  # Neutral
                (-1, bond_count, 1 - bond_count), # Negative
                (1, bond_count, 1 - bond_count),  # Positive (rare)
            ]
            
            # Filter valid states
            for charge, bonds, h_count in possible_states:
                total_electrons = 6 + (-charge) + (2 * h_count) + bond_sum
                if 0 <= h_count <= 2 and total_electrons <= 8:
                    valid_states.append((charge, bonds, h_count))
                    
        elif atomic_num == 6:  # Carbon
            max_bonds = 4
            current_h = atom.GetTotalNumHs()
            
            # Common carbon states
            possible_states = [
                (0, bond_count, 4 - bond_count),  # sp3
                (1, bond_count, 3 - bond_count),  # Carbocation
                (-1, bond_count, 3 - bond_count), # Carbanion
            ]
            
            # Filter valid states
            for charge, bonds, h_count in possible_states:
                total_electrons = 4 + (-charge) + (2 * h_count) + bond_sum
                if 0 <= h_count <= 4 and total_electrons <= 8:
                    valid_states.append((charge, bonds, h_count))
        
        # Print valid states for debugging
        print(f"\nValid states for {atom.GetSymbol()}{atom_idx}:")
        for charge, bonds, h_count in valid_states:
            print(f"  Charge={charge}, Bonds={bonds}, H={h_count}")
        
        return valid_states

    def generate_neighborhood(self, mol: Chem.Mol, atom_idx: int, n_samples: int = 50) -> Tuple[List[Chem.Mol], np.ndarray]:
        """Generate chemically valid perturbations with proper state handling"""
        neighborhoods = []
        feature_vectors = []
        
        # Get the target atom
        target_atom = mol.GetAtomWithIdx(atom_idx)
        atomic_num = int(target_atom.GetAtomicNum())
        
        print(f"\nGenerating neighborhood for atom {target_atom.GetSymbol()}{atom_idx}")
        print(f"Initial state: Charge={target_atom.GetFormalCharge()}, "
              f"Valence={target_atom.GetTotalValence()}, "
              f"Explicit Hs={target_atom.GetNumExplicitHs()}, "
              f"Implicit Hs={target_atom.GetNumImplicitHs()}, "
              f"Aromatic={target_atom.GetIsAromatic()}")
        
        # Get valid states as tuples
        valid_states = self.get_valid_states(target_atom, mol)
        
        if not valid_states:
            print(f"No valid states found for {target_atom.GetSymbol()}{atom_idx}")
            return [], np.array([])
            
        for i in range(n_samples):
            try:
                # Create a copy of the molecule
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                # Select a random valid state
                state_idx = i % len(valid_states)
                charge, h_count, bonds, aromatic, valence, degree = valid_states[state_idx]
                
                # Apply the state changes
                atom_copy.SetFormalCharge(int(charge))
                atom_copy.SetNumExplicitHs(int(h_count))
                atom_copy.SetIsAromatic(aromatic)
                
                # Create feature vector
                features = [
                    int(atomic_num),
                    int(valence),
                    int(degree),
                    int(aromatic),
                    int(charge),
                    int(h_count),
                    int(target_atom.IsInRing()),
                    int(atomic_num == 1)
                ]
                
                try:
                    Chem.SanitizeMol(mol_copy)
                    neighborhoods.append(mol_copy)
                    feature_vectors.append(features)
                    print(f"Sample {i}: Created state - Charge={charge}, H={h_count}, "
                          f"Bonds={bonds}, Aromatic={aromatic}, Valence={valence}, "
                          f"Degree={degree}")
                except Exception as e:
                    print(f"Sample {i}: Failed sanitization: {str(e)}")
                    continue
                    
            except Exception as e:
                print(f"Sample {i}: Failed with error: {str(e)}")
                continue
        
        print(f"\nGenerated {len(neighborhoods)} valid neighbors")
        return neighborhoods, np.array(feature_vectors)
        neighborhoods = []
        feature_vectors = []
        
        # Get the target atom
        target_atom = mol.GetAtomWithIdx(atom_idx)
        atomic_num = int(target_atom.GetAtomicNum())
        
        print(f"\nGenerating neighborhood for atom {target_atom.GetSymbol()}{atom_idx}")
        print(f"Initial state: Charge={target_atom.GetFormalCharge()}, "
              f"Valence={target_atom.GetTotalValence()}, "
              f"Explicit Hs={target_atom.GetNumExplicitHs()}, "
              f"Implicit Hs={target_atom.GetNumImplicitHs()}, "
              f"Aromatic={target_atom.GetIsAromatic()}")
        
        # Get valid states with all feature variations
        valid_states = self.get_valid_states(target_atom, mol)
        
        if not valid_states:
            print(f"No valid states found for {target_atom.GetSymbol()}{atom_idx}")
            return [], np.array([])
            
        for i in range(n_samples):
            try:
                # Create a copy of the molecule
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                # Select a random valid state
                state_idx = i % len(valid_states)
                state = valid_states[state_idx]
                
                # Apply the state changes
                atom_copy.SetFormalCharge(int(state['charge']))
                atom_copy.SetNumExplicitHs(int(state['h_count']))
                atom_copy.SetIsAromatic(state['aromatic'])
                
                # Create feature vector with all variations
                features = [
                    int(atomic_num),
                    int(state['valence']),
                    int(state['degree']),
                    int(state['aromatic']),
                    int(state['charge']),
                    int(state['h_count']),
                    int(target_atom.IsInRing()),
                    int(atomic_num == 1)
                ]
                
                try:
                    Chem.SanitizeMol(mol_copy)
                    neighborhoods.append(mol_copy)
                    feature_vectors.append(features)
                    print(f"Sample {i}: Created state - Charge: {state['charge']}, "
                          f"H: {state['h_count']}, Aromatic: {state['aromatic']}, "
                          f"Valence: {state['valence']}, Degree: {state['degree']}")
                except Exception as e:
                    print(f"Sample {i}: Failed sanitization: {str(e)}")
                    continue
                    
            except Exception as e:
                print(f"Sample {i}: Failed with error: {str(e)}")
                continue
        
        print(f"\nGenerated {len(neighborhoods)} valid neighbors")
        return neighborhoods, np.array(feature_vectors)
        neighborhoods = []
        feature_vectors = []
        
        # Get the target atom
        target_atom = mol.GetAtomWithIdx(atom_idx)
        atomic_num = int(target_atom.GetAtomicNum())
        
        print(f"\nGenerating neighborhood for atom {target_atom.GetSymbol()}{atom_idx}")
        print(f"Initial state: Charge={target_atom.GetFormalCharge()}, "
              f"Valence={target_atom.GetTotalValence()}, "
              f"Explicit Hs={target_atom.GetNumExplicitHs()}, "
              f"Implicit Hs={target_atom.GetNumImplicitHs()}, "
              f"Aromatic={target_atom.GetIsAromatic()}")
        
        # Get valid states considering current bonding
        valid_states = self.get_valid_states(target_atom, mol)
        
        if not valid_states:
            print(f"No valid states found for {target_atom.GetSymbol()}{atom_idx}")
            return [], np.array([])
            
        for i in range(n_samples):
            try:
                # Create a copy of the molecule
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                # Select a random valid state
                state_idx = i % len(valid_states)
                new_charge, new_bonds, new_h = valid_states[state_idx]
                
                # Modify atom properties
                atom_copy.SetFormalCharge(int(new_charge))
                atom_copy.SetNumExplicitHs(int(new_h))
                atom_copy.SetIsAromatic(bool(np.random.random() < 0.2 and atomic_num in [6, 7]))
                
                # Create feature vector
                features = [
                    int(atomic_num),
                    int(new_bonds + new_h),  # Total valence
                    int(atom_copy.GetDegree()),
                    int(atom_copy.GetIsAromatic()),
                    int(new_charge),
                    int(new_h),
                    int(atom_copy.IsInRing()),
                    int(atomic_num == 1)
                ]
                
                try:
                    Chem.SanitizeMol(mol_copy)
                    neighborhoods.append(mol_copy)
                    feature_vectors.append(features)
                    print(f"Sample {i}: Successfully created state (charge={new_charge}, bonds={new_bonds}, H={new_h}, "
                          f"aromatic={atom_copy.GetIsAromatic()})")
                except Exception as e:
                    print(f"Sample {i}: Failed sanitization: {str(e)}")
                    continue
                    
            except Exception as e:
                print(f"Sample {i}: Failed with error: {str(e)}")
                continue
        
        print(f"\nGenerated {len(neighborhoods)} valid neighbors")
        return neighborhoods, np.array(feature_vectors)
        neighborhoods = []
        feature_vectors = []
        
        # Get the target atom
        target_atom = mol.GetAtomWithIdx(atom_idx)
        atomic_num = int(target_atom.GetAtomicNum())
        
        print(f"\nGenerating neighborhood for atom {target_atom.GetSymbol()}{atom_idx}")
        print(f"Initial state: Charge={target_atom.GetFormalCharge()}, "
              f"Valence={target_atom.GetTotalValence()}, "
              f"Explicit Hs={target_atom.GetNumExplicitHs()}, "
              f"Implicit Hs={target_atom.GetNumImplicitHs()}")
        
        # Get valid states considering current bonding
        valid_states = self.get_valid_states(target_atom, mol)
        
        if not valid_states:
            print(f"No valid states found for {target_atom.GetSymbol()}{atom_idx}")
            return [], np.array([])
            
        for i in range(n_samples):
            try:
                # Create a copy of the molecule
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                # Select a random valid state
                state_idx = i % len(valid_states)
                new_charge, new_bonds, new_h = valid_states[state_idx]
                
                # Apply the new state
                atom_copy.SetFormalCharge(int(new_charge))
                atom_copy.SetNumExplicitHs(int(new_h))
                
                # Create feature vector
                features = [
                    int(atomic_num),
                    int(new_bonds + new_h),  # Total valence
                    int(atom_copy.GetDegree()),
                    int(atom_copy.GetIsAromatic()),
                    int(new_charge),
                    int(new_h),
                    int(atom_copy.IsInRing()),
                    int(atomic_num == 1)
                ]
                
                try:
                    Chem.SanitizeMol(mol_copy)
                    neighborhoods.append(mol_copy)
                    feature_vectors.append(features)
                    print(f"Sample {i}: Successfully created state (charge={new_charge}, H={new_h})")
                except Exception as e:
                    print(f"Sample {i}: Failed sanitization: {str(e)}")
                    continue
                    
            except Exception as e:
                print(f"Sample {i}: Failed with error: {str(e)}")
                continue
        
        print(f"\nGenerated {len(neighborhoods)} valid neighbors")
        return neighborhoods, np.array(feature_vectors)
        neighborhoods = []
        feature_vectors = []
        
        # Get the target atom
        target_atom = mol.GetAtomWithIdx(atom_idx)
        atomic_num = int(target_atom.GetAtomicNum())
        
        print(f"\nGenerating neighborhood for atom {target_atom.GetSymbol()}{atom_idx}")
        print(f"Initial state: Charge={target_atom.GetFormalCharge()}, "
              f"Valence={target_atom.GetTotalValence()}, "
              f"Explicit Hs={target_atom.GetNumExplicitHs()}, "
              f"Implicit Hs={target_atom.GetNumImplicitHs()}")
        
        # Get valid states considering current bonding
        valid_states = self.get_valid_states(target_atom, mol)
        
        if not valid_states:
            print(f"No valid states found for {target_atom.GetSymbol()}{atom_idx}")
            return [], np.array([])
            
        for i in range(n_samples):
            try:
                # Create a copy of the molecule
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                # Select a random valid state
                new_charge, new_bonds, new_h = valid_states[i % len(valid_states)]
                
                # Apply the new state
                atom_copy.SetFormalCharge(int(new_charge))
                atom_copy.SetNumExplicitHs(int(new_h))
                
                # Create feature vector
                features = [
                    int(atomic_num),
                    int(new_bonds + new_h),  # Total valence
                    int(atom_copy.GetDegree()),
                    int(atom_copy.GetIsAromatic()),
                    int(new_charge),
                    int(new_h),
                    int(atom_copy.IsInRing()),
                    int(atomic_num == 1)
                ]
                
                try:
                    Chem.SanitizeMol(mol_copy)
                    neighborhoods.append(mol_copy)
                    feature_vectors.append(features)
                    print(f"Sample {i}: Successfully created state (charge={new_charge}, bonds={new_bonds}, H={new_h})")
                except Exception as e:
                    print(f"Sample {i}: Failed sanitization: {str(e)}")
                    continue
                    
            except Exception as e:
                print(f"Sample {i}: Failed with error: {str(e)}")
                continue
        
        print(f"\nGenerated {len(neighborhoods)} valid neighbors")
        return neighborhoods, np.array(feature_vectors)
        neighborhoods = []
        feature_vectors = []
        
        # Get the target atom
        target_atom = mol.GetAtomWithIdx(atom_idx)
        atomic_num = int(target_atom.GetAtomicNum())
        
        print(f"\nGenerating neighborhood for atom {target_atom.GetSymbol()}{atom_idx}")
        print(f"Initial state: Charge={target_atom.GetFormalCharge()}, "
              f"Valence={target_atom.GetTotalValence()}, "
              f"Explicit Hs={target_atom.GetNumExplicitHs()}, "
              f"Implicit Hs={target_atom.GetNumImplicitHs()}")
        
        # Calculate current electron count
        current_electrons = (atomic_num +  # Number of protons
                           target_atom.GetFormalCharge() +  # Formal charge
                           target_atom.GetNumRadicalElectrons())  # Radical electrons
        
        # Define valid states based on atom type
        valid_states = []
        if atomic_num == 7:  # Nitrogen
            # States: (charge, num_bonds, num_hydrogens)
            valid_states = [
                (1, 4, 0),   # NH4+
                (0, 3, 0),   # NR3
                (-1, 2, 0),  # NR2-
                (1, 3, 1),   # NHR3+
                (0, 2, 1),   # NHR2
                (1, 2, 2),   # NH2R2+
                (0, 1, 2),   # NH2R
            ]
        elif atomic_num == 8:  # Oxygen
            valid_states = [
                (0, 2, 0),   # OR2
                (-1, 1, 0),  # OR-
                (0, 1, 1),   # OHR
                (-1, 0, 1),  # OH-
            ]
        elif atomic_num == 6:  # Carbon
            valid_states = [
                (0, 4, 0),   # CR4
                (1, 4, 0),   # CR4+
                (-1, 3, 0),  # CR3-
                (0, 3, 1),   # CHR3
                (0, 2, 2),   # CH2R2
            ]
            
        print(f"Valid states for {target_atom.GetSymbol()}: {valid_states}")
        
        for i in range(n_samples):
            try:
                # Create a copy of the molecule
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                if valid_states:
                    # Choose a random valid state using numpy instead of random
                    state_idx = np.random.randint(len(valid_states))
                    new_charge, new_bonds, new_h = valid_states[state_idx]
                    
                    # Apply the new state
                    atom_copy.SetFormalCharge(int(new_charge))
                    atom_copy.SetNumExplicitHs(int(new_h))
                    
                    # Create feature vector
                    features = [
                        int(atomic_num),
                        int(new_bonds + new_h),  # Total valence
                        int(atom_copy.GetDegree()),
                        int(atom_copy.GetIsAromatic()),
                        int(new_charge),
                        int(new_h),
                        int(atom_copy.IsInRing()),
                        int(atomic_num == 1)
                    ]
                    
                    try:
                        Chem.SanitizeMol(mol_copy)
                        neighborhoods.append(mol_copy)
                        feature_vectors.append(features)
                        print(f"Sample {i}: Successfully created state (charge={new_charge}, H={new_h})")
                    except Exception as e:
                        print(f"Sample {i}: Failed sanitization: {str(e)}")
                        continue
                        
            except Exception as e:
                print(f"Sample {i}: Failed with error: {str(e)}")
                continue
        
        # Add variations of the original state if needed
        if len(neighborhoods) < 5:
            print("Adding variations of original state")
            original_state = (
                int(target_atom.GetFormalCharge()),
                int(target_atom.GetTotalValence() - target_atom.GetTotalNumHs()),
                int(target_atom.GetTotalNumHs())
            )
            
            # Find nearby valid states
            nearby_states = [s for s in valid_states if s != original_state][:5]
            
            for state in nearby_states:
                new_charge, new_bonds, new_h = state
                features = [
                    int(atomic_num),
                    int(new_bonds + new_h),
                    int(target_atom.GetDegree()),
                    int(target_atom.GetIsAromatic()),
                    int(new_charge),
                    int(new_h),
                    int(target_atom.IsInRing()),
                    int(atomic_num == 1)
                ]
                neighborhoods.append(mol)
                feature_vectors.append(features)
        
        print(f"\nGenerated {len(neighborhoods)} valid neighbors")
        return neighborhoods, np.array(feature_vectors)
        neighborhoods = []
        feature_vectors = []
        
        # Get the target atom
        target_atom = mol.GetAtomWithIdx(atom_idx)
        atomic_num = int(target_atom.GetAtomicNum())  # Convert to Python int
        
        print(f"\nGenerating neighborhood for atom {target_atom.GetSymbol()}{atom_idx}")
        print(f"Initial state: Charge={target_atom.GetFormalCharge()}, "
              f"Aromatic={target_atom.GetIsAromatic()}, "
              f"Valence={target_atom.GetTotalValence()}")
        
        # Define allowed charge ranges based on atom type
        charge_ranges = {
            6: [-1, 0, 1],     # Carbon
            7: [-1, 0, 1],     # Nitrogen
            8: [-1, 0, 1],     # Oxygen
            9: [-1, 0],        # Fluorine
            15: [-1, 0, 1],    # Phosphorus
            16: [-1, 0, 1],    # Sulfur
        }
        allowed_charges = charge_ranges.get(atomic_num, [0])
        
        for i in range(n_samples):
            try:
                # Create a copy of the molecule
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                # Initialize feature vector with current state
                features = [
                    int(atomic_num),
                    int(atom_copy.GetTotalValence()),
                    int(atom_copy.GetDegree()),
                    int(atom_copy.GetIsAromatic()),
                    int(atom_copy.GetFormalCharge()),
                    int(atom_copy.GetTotalNumHs()),
                    int(atom_copy.IsInRing()),
                    int(atomic_num == 1)
                ]
                
                modifications_made = False
                
                # 1. Modify charge
                if np.random.random() < 0.4:
                    current_charge = int(atom_copy.GetFormalCharge())
                    possible_charges = [c for c in allowed_charges if c != current_charge]
                    if possible_charges:
                        new_charge = int(np.random.choice(possible_charges))  # Convert to Python int
                        atom_copy.SetFormalCharge(new_charge)
                        features[4] = new_charge
                        modifications_made = True
                        print(f"Sample {i}: Changed charge from {current_charge} to {new_charge}")
                
                # 2. Modify hydrogens if applicable
                if np.random.random() < 0.3 and atomic_num not in [9, 17, 35, 53]:  # Skip halogens
                    current_h = int(atom_copy.GetTotalNumHs())
                    max_h = 3 if atomic_num == 6 else 2  # Reduced max H to prevent valence issues
                    possible_h = list(range(max_h + 1))
                    if current_h in possible_h:
                        possible_h.remove(current_h)
                    if possible_h:
                        new_h = int(np.random.choice(possible_h))  # Convert to Python int
                        atom_copy.SetNumExplicitHs(new_h)
                        features[5] = new_h
                        modifications_made = True
                        print(f"Sample {i}: Changed hydrogens from {current_h} to {new_h}")
                
                # Only add to neighborhood if modifications were made and molecule is valid
                if modifications_made:
                    try:
                        Chem.SanitizeMol(mol_copy)
                        # Additional valence check
                        if all(a.GetTotalValence() <= 4 for a in mol_copy.GetAtoms() if a.GetAtomicNum() == 6) and \
                           all(a.GetTotalValence() <= 3 for a in mol_copy.GetAtoms() if a.GetAtomicNum() == 7) and \
                           all(a.GetTotalValence() <= 2 for a in mol_copy.GetAtoms() if a.GetAtomicNum() == 8):
                            neighborhoods.append(mol_copy)
                            feature_vectors.append(features)
                            print(f"Sample {i}: Successfully added to neighborhood")
                    except Exception as e:
                        print(f"Sample {i}: Failed sanitization: {str(e)}")
                        continue
                
            except Exception as e:
                print(f"Sample {i}: Failed with error: {str(e)}")
                continue
        
        # Add variations of the original molecule if neighborhood is too small
        if len(neighborhoods) < 10:
            print("Generating additional variations of original molecule")
            orig_features = [
                int(atomic_num),
                int(target_atom.GetTotalValence()),
                int(target_atom.GetDegree()),
                int(target_atom.GetIsAromatic()),
                int(target_atom.GetFormalCharge()),
                int(target_atom.GetTotalNumHs()),
                int(target_atom.IsInRing()),
                int(atomic_num == 1)
            ]
            
            # Add original molecule with small perturbations to features
            for _ in range(10 - len(neighborhoods)):
                perturbed_features = orig_features.copy()
                perturbed_features[4] += int(np.random.choice([-1, 1]))  # Perturb charge
                neighborhoods.append(mol)
                feature_vectors.append(perturbed_features)
        
        print(f"\nGenerated {len(neighborhoods)} valid neighbors")
        return neighborhoods, np.array(feature_vectors)
        neighborhoods = []
        feature_vectors = []
        
        # Get the target atom
        target_atom = mol.GetAtomWithIdx(atom_idx)
        atomic_num = target_atom.GetAtomicNum()
        
        print(f"\nGenerating neighborhood for atom {target_atom.GetSymbol()}{atom_idx}")
        print(f"Initial state: Charge={target_atom.GetFormalCharge()}, "
              f"Aromatic={target_atom.GetIsAromatic()}, "
              f"Valence={target_atom.GetTotalValence()}")
        
        for i in range(n_samples):
            try:
                # Create a copy of the molecule
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                # Initialize feature vector with current state
                features = [
                    atomic_num,
                    atom_copy.GetTotalValence(),
                    atom_copy.GetDegree(),
                    int(atom_copy.GetIsAromatic()),
                    atom_copy.GetFormalCharge(),
                    atom_copy.GetTotalNumHs(),
                    int(atom_copy.IsInRing()),
                    int(atomic_num == 1)
                ]
                
                # Modify properties with tracking
                modifications_made = False
                
                # 1. Modify charge
                if np.random.random() < 0.4:
                    current_charge = atom_copy.GetFormalCharge()
                    possible_charges = [-1, 0, 1]
                    possible_charges.remove(current_charge)
                    new_charge = np.random.choice(possible_charges)
                    atom_copy.SetFormalCharge(new_charge)
                    features[4] = new_charge
                    modifications_made = True
                    print(f"Sample {i}: Changed charge from {current_charge} to {new_charge}")
                
                # 2. Modify hydrogens if allowed
                if np.random.random() < 0.3 and atomic_num not in [9, 17, 35, 53]:  # Skip halogens
                    current_h = atom_copy.GetTotalNumHs()
                    max_h = 4 if atomic_num == 6 else 3  # Max H for C is 4, others 3
                    possible_h = list(range(max_h + 1))
                    if current_h in possible_h:
                        possible_h.remove(current_h)
                    if possible_h:
                        new_h = np.random.choice(possible_h)
                        atom_copy.SetNumExplicitHs(new_h)
                        features[5] = new_h
                        modifications_made = True
                        print(f"Sample {i}: Changed hydrogens from {current_h} to {new_h}")
                
                # Only add to neighborhood if modifications were made
                if modifications_made:
                    try:
                        Chem.SanitizeMol(mol_copy)
                        neighborhoods.append(mol_copy)
                        feature_vectors.append(features)
                        print(f"Sample {i}: Successfully added to neighborhood")
                    except Exception as e:
                        print(f"Sample {i}: Failed sanitization: {str(e)}")
                        continue
                
            except Exception as e:
                print(f"Sample {i}: Failed with error: {str(e)}")
                continue
        
        # Ensure we have at least one sample by including original
        if not neighborhoods:
            print("No valid perturbations generated, using original molecule")
            neighborhoods.append(mol)
            feature_vectors.append([
                atomic_num,
                target_atom.GetTotalValence(),
                target_atom.GetDegree(),
                int(target_atom.GetIsAromatic()),
                target_atom.GetFormalCharge(),
                target_atom.GetTotalNumHs(),
                int(target_atom.IsInRing()),
                int(atomic_num == 1)
            ])
        
        print(f"\nGenerated {len(neighborhoods)} valid neighbors")
        return neighborhoods, np.array(feature_vectors)
        
    def extract_features(self, mol: Chem.Mol, target_atom_idx: int) -> np.ndarray:
        """Extract interpretable features for a specific atom"""
        atom = mol.GetAtomWithIdx(target_atom_idx)
        features = [
            int(atom.GetAtomicNum()),  # Convert to Python int
            int(atom.GetTotalValence()),
            int(atom.GetDegree()),
            int(atom.GetIsAromatic()),
            int(atom.GetFormalCharge()),
            int(atom.GetNumExplicitHs()),
            int(atom.IsInRing()),
            int(atom.GetAtomicNum() == 1)  # IsHydrogen
        ]
        return np.array(features, dtype=np.int32)  # Specify dtype
    
    def explain_atom(self, 
                    mol: Chem.Mol, 
                    atom_idx: int, 
                    n_samples: int = 50,
                    kernel_width: float = 0.25) -> Dict:
        """Generate LIME explanation for a specific atom with improved importance calculation"""
        try:
            # Generate neighborhood focusing on target atom
            neighborhood, feature_vectors = self.generate_neighborhood(mol, atom_idx, n_samples)
            
            if len(neighborhood) < 3:
                print("Insufficient neighborhood generated")
                return None
            
            print(f"\nProcessing neighborhood of size {len(neighborhood)}")
            
            # Get predictions for neighborhood
            neighborhood_data = [self._mol_to_graph_data(m) for m in neighborhood]
            predictions = self._batch_predict([d for d in neighborhood_data if d is not None])
            
            if len(predictions) < 3:
                print("Insufficient valid predictions")
                return None
            
            print(f"Generated {len(predictions)} predictions")
            print(f"Prediction range: [{np.min(predictions):.4f}, {np.max(predictions):.4f}]")
            
            # Normalize predictions if there's sufficient variation
            pred_range = np.max(predictions) - np.min(predictions)
            if pred_range < 1e-6:
                print("Warning: Very small prediction range")
                normalized_preds = (predictions - np.mean(predictions)) / (np.std(predictions) + 1e-10)
            else:
                normalized_preds = (predictions - np.min(predictions)) / pred_range
            
            # Calculate feature importance with improved stability
            importance_scores = {}
            feature_variations = {}
            
            print("\nFeature variations and importance:")
            for i, feat_name in enumerate(self.feature_names):
                feature_values = feature_vectors[:, i]
                unique_values = np.unique(feature_values)
                value_range = np.max(feature_values) - np.min(feature_values)
                
                print(f"\n{feat_name}:")
                print(f"  Unique values: {unique_values}")
                print(f"  Value range: {value_range:.4f}")
                
                if len(unique_values) > 1:
                    # Normalize feature values
                    norm_features = (feature_values - np.min(feature_values)) / (value_range + 1e-10)
                    
                    # Calculate correlation with prediction considering feature variation
                    if np.std(norm_features) > 1e-6:
                        correlation = np.corrcoef(norm_features, normalized_preds)[0, 1]
                        if not np.isnan(correlation):
                            importance_scores[i] = abs(correlation)
                            print(f"  Correlation: {correlation:.4f}")
                        else:
                            importance_scores[i] = 0.0
                            print("  No correlation (nan)")
                    else:
                        importance_scores[i] = 0.0
                        print("  Insufficient variation")
                else:
                    importance_scores[i] = 0.0
                    print("  No variation")
                
                feature_variations[feat_name] = unique_values.tolist()
            
            # Normalize importance scores
            max_importance = max(importance_scores.values()) if importance_scores else 1.0
            importance_scores = {k: v/max_importance for k, v in importance_scores.items()}
            
            return {
                'local_importance': importance_scores,
                'local_prediction': float(predictions[0]),
                'feature_names': self.feature_names,
                'perturbed_predictions': predictions.tolist(),
                'feature_variations': feature_variations
            }
            
        except Exception as e:
            print(f"Error in explain_atom: {str(e)}")
            traceback.print_exc()
            return None
        """Generate LIME explanation with robust importance calculation"""
        try:
            # Generate neighborhood focusing on target atom
            neighborhood, feature_vectors = self.generate_neighborhood(mol, atom_idx, n_samples)
            
            if len(neighborhood) < 3:
                print("Insufficient neighborhood generated")
                return None
            
            print(f"\nProcessing neighborhood of size {len(neighborhood)}")
            
            # Get predictions for neighborhood
            neighborhood_data = [self._mol_to_graph_data(m) for m in neighborhood]
            predictions = self._batch_predict([d for d in neighborhood_data if d is not None])
            
            if len(predictions) < 3:
                print("Insufficient valid predictions")
                return None
            
            print(f"Generated {len(predictions)} predictions")
            print(f"Prediction range: [{np.min(predictions):.4f}, {np.max(predictions):.4f}]")
            
            # Normalize predictions
            pred_range = np.max(predictions) - np.min(predictions)
            if pred_range < 1e-6:
                print("Warning: Very small prediction range")
                normalized_preds = np.zeros_like(predictions)
            else:
                normalized_preds = (predictions - np.min(predictions)) / pred_range
            
            # Calculate feature importance
            importance_scores = {}
            
            print("\nFeature variations and importance:")
            for i, feat_name in enumerate(self.feature_names):
                feature_values = feature_vectors[:, i]
                unique_values = np.unique(feature_values)
                value_range = np.max(feature_values) - np.min(feature_values)
                
                print(f"\n{feat_name}:")
                print(f"  Unique values: {unique_values}")
                print(f"  Value range: {value_range:.4f}")
                
                if len(unique_values) > 1 and value_range > 1e-6:
                    # Normalize feature values
                    norm_features = (feature_values - np.min(feature_values)) / value_range
                    
                    # Calculate weighted correlation
                    weights = np.exp(-(norm_features**2) / (2 * kernel_width**2))
                    weighted_corr = np.corrcoef(
                        norm_features * weights,
                        normalized_preds * weights
                    )[0, 1]
                    
                    if not np.isnan(weighted_corr):
                        importance_scores[i] = abs(weighted_corr)
                        print(f"  Correlation: {weighted_corr:.4f}")
                    else:
                        importance_scores[i] = 0.0
                        print("  No correlation (nan)")
                else:
                    importance_scores[i] = 0.0
                    print("  No variation")
            
            return {
                'local_importance': importance_scores,
                'local_prediction': float(predictions[0]),
                'feature_names': self.feature_names,
                'perturbed_predictions': predictions.tolist(),
                'feature_variations': {
                    name: np.unique(feature_vectors[:, i]).tolist()
                    for i, name in enumerate(self.feature_names)
                }
            }
            
        except Exception as e:
            print(f"Error in explain_atom: {str(e)}")
            traceback.print_exc()
            return None
        """Generate LIME explanation for a specific atom with detailed logging"""
        try:
            # Generate neighborhood focusing on target atom
            neighborhood, feature_vectors = self.generate_neighborhood(mol, atom_idx, n_samples)
            
            if len(neighborhood) <= 1:
                print("Insufficient neighborhood generated")
                return None
            
            print(f"\nProcessing neighborhood of size {len(neighborhood)}")
            
            # Get predictions for neighborhood
            neighborhood_data = [self._mol_to_graph_data(m) for m in neighborhood]
            predictions = self._batch_predict([d for d in neighborhood_data if d is not None])
            
            if len(predictions) == 0:
                print("No valid predictions generated")
                return None
            
            print(f"Generated {len(predictions)} predictions")
            print(f"Prediction range: [{np.min(predictions):.4f}, {np.max(predictions):.4f}]")
            
            # Calculate feature importance using correlation analysis
            importance_scores = {}
            
            # Get target atom features from original molecule
            target_atom = mol.GetAtomWithIdx(atom_idx)
            original_features = np.array([
                target_atom.GetAtomicNum(),
                target_atom.GetTotalValence(),
                target_atom.GetDegree(),
                int(target_atom.GetIsAromatic()),
                target_atom.GetFormalCharge(),
                target_atom.GetTotalNumHs(),
                int(target_atom.IsInRing()),
                int(target_atom.GetAtomicNum() == 1)
            ])
            
            print("\nFeature variations:")
            for i, feat_name in enumerate(self.feature_names):
                feature_values = feature_vectors[:, i]
                unique_values = np.unique(feature_values)
                
                print(f"{feat_name}: unique values = {unique_values}")
                
                if len(unique_values) > 1:
                    # Calculate correlation if feature varies
                    correlation = np.corrcoef(feature_values, predictions)[0, 1]
                    importance_scores[i] = abs(correlation)
                    print(f"{feat_name}: correlation = {correlation:.4f}")
                else:
                    importance_scores[i] = 0.0
                    print(f"{feat_name}: no variation")
            
            # Get prediction for original molecule
            original_data = self._mol_to_graph_data(mol)
            original_pred = self._batch_predict([original_data])[0] if original_data is not None else 0.0
            
            return {
                'local_importance': importance_scores,
                'local_prediction': float(original_pred),
                'feature_names': self.feature_names,
                'perturbed_predictions': predictions.tolist(),
                'feature_variations': {
                    name: np.unique(feature_vectors[:, i]).tolist()
                    for i, name in enumerate(self.feature_names)
                }
            }
            
        except Exception as e:
            print(f"Error in explain_atom: {str(e)}")
            traceback.print_exc()
            return None
        try:
            # Generate neighborhood with feature changes
            neighborhood, feature_changes = self.generate_neighborhood(mol, n_samples)
            
            if not neighborhood:
                print("No valid neighborhood generated")
                return None
            
            # Get predictions for neighborhood
            neighborhood_data = [self._mol_to_graph_data(m) for m in neighborhood]
            predictions = self._batch_predict([d for d in neighborhood_data if d is not None])
            
            if len(predictions) == 0:
                print("No valid predictions generated")
                return None
            
            # Normalize predictions to [0,1] range
            normalized_preds = (predictions - np.min(predictions)) / (np.max(predictions) - np.min(predictions) + 1e-10)
            
            # Calculate feature importance using weighted correlation
            importance_scores = {}
            feature_matrix = np.array(feature_changes)
            
            for i, feat_name in enumerate(self.feature_names):
                if np.any(feature_matrix[:, i] != 0):  # Only calculate if feature changed
                    # Calculate correlation with kernel weighting
                    weights = np.exp(-(feature_matrix[:, i]**2) / (2 * kernel_width**2))
                    weighted_corr = np.corrcoef(
                        feature_matrix[:, i] * weights,
                        normalized_preds * weights
                    )[0, 1]
                    importance_scores[i] = np.abs(weighted_corr)
                else:
                    importance_scores[i] = 0.0
            
            # Get original prediction
            original_data = self._mol_to_graph_data(mol)
            original_pred = self._batch_predict([original_data])[0] if original_data is not None else 0.0
            
            return {
                'local_importance': importance_scores,
                'local_prediction': float(original_pred),
                'feature_names': self.feature_names,
                'perturbed_predictions': predictions.tolist()
            }
            
        except Exception as e:
            print(f"Error in explain_atom: {str(e)}")
            traceback.print_exc()
            return None
        """Generate LIME explanation for a specific atom"""
        try:
            # Generate neighborhood
            neighborhood = self.generate_neighborhood(mol, n_samples)
            
            if not neighborhood:
                print("Warning: No valid molecules in neighborhood")
                return None
                
            # Extract features for original and perturbed molecules
            original_features = self.extract_features(mol, atom_idx)
            perturbed_features = np.array([
                self.extract_features(m, atom_idx) for m in neighborhood
            ])
            
            # Get predictions for neighborhood
            neighborhood_data = [self._mol_to_graph_data(m) for m in neighborhood]
            perturbed_preds = self._batch_predict(neighborhood_data)
            
            if len(perturbed_preds) == 0:
                print("Warning: No valid predictions generated")
                return None
            
            # Create LIME explainer
            explainer = lime.lime_tabular.LimeTabularExplainer(
                perturbed_features,
                feature_names=self.feature_names,
                kernel_width=kernel_width,
                discretize_continuous=False
            )
            
            # Create custom explanation without using LIME's explain_instance
            perturbed_preds = self._batch_predict(neighborhood_data)
            if len(perturbed_preds) == 0:
                return None
                
            # Normalize predictions to [0,1] range
            normalized_preds = (perturbed_preds - np.min(perturbed_preds)) / (np.max(perturbed_preds) - np.min(perturbed_preds) + 1e-10)
            
            # Calculate feature importance using correlation
            importance_scores = {}
            for i, feat_name in enumerate(self.feature_names):
                feat_values = perturbed_features[:, i]
                if np.std(feat_values) > 0:
                    correlation = np.corrcoef(feat_values, normalized_preds)[0, 1]
                    importance_scores[i] = abs(correlation)
                else:
                    importance_scores[i] = 0.0
                    
            # Calculate local prediction
            local_pred = self._batch_predict([self._mol_to_graph_data(mol)])[0]
            
            return {
                'local_importance': importance_scores,
                'local_prediction': float(local_pred),
                'feature_names': self.feature_names,
                'perturbed_predictions': perturbed_preds.tolist()
            }
            
        except Exception as e:
            print(f"Error in explain_atom: {str(e)}")
            traceback.print_exc()
            return None

    def _batch_predict(self, graph_data_list: List[Data]) -> np.ndarray:
        """Get model predictions for a batch of molecules with added noise for variation"""
        try:
            # Filter out None values
            graph_data_list = [data for data in graph_data_list if data is not None]
            
            if not graph_data_list:
                return np.array([])
            
            # Create batch
            batch = Batch.from_data_list(graph_data_list).to(self.device)
            
            # Get predictions
            with torch.no_grad():
                self.model.eval()
                outputs = self.model(batch)
                
                # Convert to numpy and take norm as importance score
                predictions = torch.norm(outputs, dim=1).cpu().numpy()
                
                # Add small random variations to break ties (1% noise)
                noise = np.random.normal(0, 0.01 * np.std(predictions) + 1e-10, predictions.shape)
                predictions = predictions + noise
                
            return predictions
            
        except Exception as e:
            print(f"Error in _batch_predict: {str(e)}")
            return np.array([])

    def analyze_local_behavior(self, mol: Chem.Mol, atom_idx: int) -> Dict:
        """Analyze local behavior around a specific atom"""
        try:
            # Get original features
            original_features = self.extract_features(mol, atom_idx)
            
            # Generate neighborhood with reduced size for better stability
            num_samples = 50
            neighborhood = self.generate_neighborhood(mol, num_samples)
            
            if not neighborhood:
                print("No valid neighborhood generated")
                return None
            
            # Get predictions for neighborhood
            neighborhood_data = [self._mol_to_graph_data(m) for m in neighborhood]
            predictions = self._batch_predict([d for d in neighborhood_data if d is not None])
            
            if len(predictions) == 0:
                print("No valid predictions generated")
                return None
            
            # Compute statistics
            pred_stats = {
                'mean': float(np.mean(predictions)),
                'std': float(np.std(predictions)),
                'min': float(np.min(predictions)),
                'max': float(np.max(predictions))
            }
            
            # Get importance scores for features
            importance_scores = {}
            atom = mol.GetAtomWithIdx(atom_idx)
            for feat_name, feat_val in zip(self.feature_names, original_features):
                # Calculate correlation between feature changes and predictions
                feat_changes = np.array([
                    self.extract_features(m, atom_idx)[self.feature_names.index(feat_name)] - feat_val
                    for m in neighborhood
                ])
                if np.any(feat_changes != 0):
                    correlation = np.corrcoef(feat_changes, predictions)[0, 1]
                    importance_scores[feat_name] = abs(correlation)
                else:
                    importance_scores[feat_name] = 0.0
            
            return {
                'importance_scores': importance_scores,
                'prediction_stats': pred_stats,
                'perturbed_predictions': predictions.tolist()
            }
            
        except Exception as e:
            print(f"Error in analyze_local_behavior: {str(e)}")
            return None

    def _mol_to_graph_data(self, mol: Chem.Mol) -> Data:
        """Convert molecule to PyTorch Geometric data object with error handling"""
        try:
            # Use the existing feature extractor
            feature_extractor = MolecularFeatureExtractor()
            
            # Extract features
            x_cat, x_phys = feature_extractor.get_atom_features(mol)
            edge_index, edge_attr = feature_extractor.get_bond_features(mol)
            
            # Create graph data object
            data = Data(
                x_cat=x_cat,
                x_phys=x_phys,
                edge_index=edge_index,
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )
            
            return data
            
        except Exception as e:
            print(f"Error in _mol_to_graph_data: {str(e)}")
            return None

    def visualize_explanation(self, 
                            mol: Chem.Mol,
                            atom_idx: int,
                            explanation: Dict,
                            save_path: str = None) -> None:
        """Visualize LIME explanation for an atom"""
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
        
        # 1. Feature Importance Plot
        importances = explanation['local_importance']
        features = explanation['feature_names']
        
        # Sort by absolute importance
        sorted_idx = np.argsort(np.abs([importances[i] for i in range(len(features))]))
        pos = np.arange(len(features))
        
        ax1.barh(pos, [importances[i] for i in sorted_idx])
        ax1.set_yticks(pos)
        ax1.set_yticklabels([features[i] for i in sorted_idx])
        ax1.set_title(f'Local Feature Importance for Atom {mol.GetAtomWithIdx(atom_idx).GetSymbol()}{atom_idx}')
        
        # 2. Neighborhood Plot
        perturbed_preds = explanation['perturbed_predictions']
        ax2.hist(perturbed_preds, bins=30, alpha=0.8)
        ax2.axvline(explanation['local_prediction'], color='r', linestyle='--',
                   label=f'Original Prediction: {explanation["local_prediction"]:.3f}')
        ax2.set_title('Distribution of Predictions in Local Neighborhood')
        ax2.set_xlabel('Model Prediction')
        ax2.set_ylabel('Count')
        ax2.legend()
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.close()    
            
    def _predict_proba(self, features: np.ndarray) -> np.ndarray:
        """Prediction function for LIME explainer"""
        try:
            # Convert features back to molecular representation
            mol = Chem.Mol()  # Create empty molecule

            # Add atoms with features
            feature_size = len(self.feature_names)
            n_atoms = len(features) // feature_size
            features = features.reshape(n_atoms, feature_size)

            for atom_features in features:
                atom = mol.AddAtom(Chem.Atom(int(atom_features[0])))  # AtomicNum
                mol.GetAtomWithIdx(atom).SetFormalCharge(int(atom_features[4]))  # Charge
                mol.GetAtomWithIdx(atom).SetNoImplicit(True)
                mol.GetAtomWithIdx(atom).SetNumExplicitHs(int(atom_features[5]))  # NumHs

            # Convert to graph data
            graph_data = self._mol_to_graph_data(mol)

            if graph_data is None:
                return np.zeros(2)  # Return zero probabilities if conversion fails

            # Get prediction
            prediction = self._batch_predict([graph_data])

            if len(prediction) == 0:
                return np.zeros(2)

            # Convert to probabilities (normalize to [0,1])
            prob = (prediction - prediction.min()) / (prediction.max() - prediction.min() + 1e-10)

            # Return both class probabilities (for binary classification)
            return np.array([1 - prob[0], prob[0]])

        except Exception as e:
            print(f"Error in _predict_proba: {str(e)}")
            return np.zeros(2)

    def _create_perturbed_features(self, original_features: np.ndarray, n_samples: int = 100) -> np.ndarray:
        """Create perturbed versions of the original features"""
        perturbed_features = []
        n_features = len(original_features)

        for _ in range(n_samples):
            # Copy original features
            features = original_features.copy()

            # Randomly modify some features
            n_modifications = np.random.randint(1, max(2, n_features // 3))
            indices = np.random.choice(n_features, n_modifications, replace=False)

            for idx in indices:
                if idx == 0:  # AtomicNum
                    features[idx] = np.random.choice(self.feature_ranges['AtomicNum'])
                elif idx == 4:  # Charge
                    features[idx] += np.random.choice([-1, 1])
                elif idx == 5:  # NumHs
                    features[idx] = max(0, features[idx] + np.random.choice([-1, 1]))
                else:
                    features[idx] = float(not bool(features[idx]))  # Toggle boolean features

            perturbed_features.append(features)

        return np.array(perturbed_features)

    def _compute_feature_importance(self, 
                                  original_features: np.ndarray,
                                  perturbed_features: np.ndarray,
                                  predictions: np.ndarray) -> Dict[str, float]:
        """Compute feature importance based on perturbed predictions"""
        importance_scores = {}

        # Get original prediction
        original_pred = self._predict_proba(original_features)[1]

        # Compute importance for each feature
        for i, feature_name in enumerate(self.feature_names):
            # Get predictions where this feature was modified
            modified_mask = perturbed_features[:, i] != original_features[i]
            if not np.any(modified_mask):
                importance_scores[feature_name] = 0.0
                continue

            modified_preds = predictions[modified_mask]

            # Compute importance as mean absolute difference in predictions
            importance = np.mean(np.abs(modified_preds - original_pred))
            importance_scores[feature_name] = importance

        return importance_scores

In [5]:
def analyze_with_lime(model, dataset, device, idx=1):
    """
    Analyze molecule using LIME for local interpretability
    """
    try:
        # Get molecule data
        mol_data = dataset[idx]
        smiles = mol_data.smiles
        mol = Chem.MolFromSmiles(smiles)
        
        # Create LIME explainer
        lime_explainer = MolecularLIME(model, device)
        
        # Get atoms to explain (based on SHAP importance)
        important_atoms = [2, 9, 13, 15]  # N2, N9, C13, O15 from SHAP analysis
        
        # Analyze each important atom
        explanations = {}
        for atom_idx in important_atoms:
            print(f"\nAnalyzing atom {mol.GetAtomWithIdx(atom_idx).GetSymbol()}{atom_idx}...")
            
            # Get LIME explanation
            explanation = lime_explainer.explain_atom(
                mol,
                atom_idx,
                n_samples=50,
                kernel_width=0.25
            )
            
            # Save explanation
            explanations[atom_idx] = explanation
            
            # Visualize explanation
            lime_explainer.visualize_explanation(
                mol,
                atom_idx,
                explanation,
                save_path=f'lime_explanation_atom_{atom_idx}.png'
            )
            
            # Print local feature importance
            print("\nLocal feature importance:")
            for feat_idx, importance in explanation['local_importance'].items():
                print(f"{explanation['feature_names'][feat_idx]}: {importance:.4f}")
        
        return explanations
        
    except Exception as e:
        print(f"Error in LIME analysis: {str(e)}")
        traceback.print_exc()
        return None

    
def visualize_lime_results(atom_idx: int, explanation: Dict, save_path: str = None):
    """Create improved visualization of LIME feature importance"""
    if explanation is None:
        print(f"No valid explanation for atom {atom_idx}")
        return
        
    # Get importance scores and ensure all features have values
    importance_scores = []
    for i in range(len(explanation['feature_names'])):
        score = explanation['local_importance'].get(i, 0.0)
        importance_scores.append(score)
    
    # Create figure with better dimensions
    plt.figure(figsize=(12, 6))
    
    # Sort features by absolute importance
    feat_importance = list(zip(explanation['feature_names'], importance_scores))
    feat_importance.sort(key=lambda x: abs(x[1]))
    
    # Separate features and scores
    features = [f for f, _ in feat_importance]
    scores = [s for _, s in feat_importance]
    
    # Create horizontal bar plot with improved styling
    colors = ['#1f77b4' if s > 0 else '#d62728' for s in scores]
    y_pos = np.arange(len(features))
    
    plt.barh(y_pos, scores, align='center', color=colors, alpha=0.8)
    plt.yticks(y_pos, features)
    
    # Add value labels on bars
    for i, v in enumerate(scores):
        if abs(v) > 0.01:  # Only show non-zero values
            plt.text(v + np.sign(v)*0.01, i, f'{v:.3f}', 
                    va='center', fontsize=10)
    
    plt.title(f'LIME Feature Importance for Atom {atom_idx}')
    plt.xlabel('Importance Score')
    
    # Add grid for better readability
    plt.grid(True, axis='x', linestyle='--', alpha=0.3)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save plot if path provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved plot to {save_path}")
    else:
        plt.show()
    
    plt.close()

def run_lime_analysis(dataset):
    try:
        print("\nStarting LIME Analysis...")
        
        # Create save directory for results
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        save_dir = f'lime_explanations_{timestamp}'
        os.makedirs(save_dir, exist_ok=True)
        
        # Get molecule from dataset
        mol_idx = 1  # Index of molecule to analyze
        mol_data = dataset[mol_idx]
        smiles = mol_data.smiles
        mol = Chem.MolFromSmiles(smiles)
        
        # Initialize LIME explainer
        lime_explainer = MolecularLIME(
            model=encoder,
            device=device
        )
        
        # Define important atoms to analyze
        important_atoms = [2, 9, 13, 15]  # Your target atoms
        results = {}
        
        for atom_idx in important_atoms:
            print(f"\nAnalyzing atom {mol.GetAtomWithIdx(atom_idx).GetSymbol()}{atom_idx}...")
            
            # Get LIME explanation
            explanation = lime_explainer.explain_atom(
                mol=mol,
                atom_idx=atom_idx,
                n_samples=50
            )
            
            if explanation is None:
                print(f"Failed to generate explanation for atom {atom_idx}")
                continue
                
            # Store results
            results[atom_idx] = explanation
            
            # Create visualization
            save_path = os.path.join(save_dir, f'lime_importance_atom_{atom_idx}.png')
            visualize_lime_results(atom_idx, explanation, save_path)
            
            # Print results
            print(f"\nResults for atom {mol.GetAtomWithIdx(atom_idx).GetSymbol()}{atom_idx}:")
            print("\nTop 5 most important features:")
            importance_pairs = [(name, explanation['local_importance'].get(i, 0.0))
                              for i, name in enumerate(explanation['feature_names'])]
            sorted_features = sorted(importance_pairs, key=lambda x: abs(x[1]), reverse=True)[:5]
            for feat_name, importance in sorted_features:
                print(f"  {feat_name}: {importance:.4f}")
                
        return {
            'molecule_smiles': smiles,
            'analysis_timestamp': timestamp,
            'important_atoms': important_atoms,
            'results': results
        }
        
    except Exception as e:
        print(f"\nError in LIME analysis: {str(e)}")
        traceback.print_exc()
        return None

In [6]:
# First, let's modify your data loading code to store SMILES with each graph
print("Starting data loading...")
extractor = MolecularFeatureExtractor()
smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"

dataset = []
failed_smiles = []

# Modified data loading to store SMILES
with open(smiles_file, 'r') as f:
    for line in f:
        smiles = line.strip()
        data = extractor.process_molecule(smiles)
        if data is not None:
            # Add SMILES as an attribute to the Data object
            data.smiles = smiles  # Add this line
            dataset.append(data)
        else:
            failed_smiles.append(smiles)

print(f"1. Loaded dataset with {len(dataset)} graphs.")
print(f"2. Failed SMILES count: {len(failed_smiles)}")

if not dataset:
    print("No valid graphs generated.")

# Make sure to import needed libraries
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
import traceback

os.makedirs('molecule_explanation', exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

try:
    print("\nStarting molecule analysis...")

    # Run LIME analysis
    lime_results = run_lime_analysis(dataset)

    if lime_results is not None:
        print("\nAnalysis Summary:")
        print(f"Analyzed molecule: {lime_results['molecule_smiles']}")
        print(f"Number of atoms analyzed: {len(lime_results['important_atoms'])}")
        print("\nResults saved in directory with timestamp:", 
              lime_results['analysis_timestamp'])

    else:
        print("\nAnalysis failed!")

except Exception as e:
    print(f"\nError in main execution: {str(e)}")
    import traceback
    traceback.print_exc()

Starting data loading...
1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0

Starting molecule analysis...

Starting LIME Analysis...

Analyzing atom N2...

Generating neighborhood for atom N2
Initial state: Charge=1, Valence=4, Explicit Hs=1, Implicit Hs=0, Aromatic=False

Atom N2 analysis:
Current bond sum: 3
Number of bonds: 3

Valid states for N2:
  Charge=0, Bonds=3, H=0
Sample 0: Failed with error: not enough values to unpack (expected 6, got 3)
Sample 1: Failed with error: not enough values to unpack (expected 6, got 3)
Sample 2: Failed with error: not enough values to unpack (expected 6, got 3)
Sample 3: Failed with error: not enough values to unpack (expected 6, got 3)
Sample 4: Failed with error: not enough values to unpack (expected 6, got 3)
Sample 5: Failed with error: not enough values to unpack (expected 6, got 3)
Sample 6: Failed with error: not enough values to unpack (expected 6, got 3)
Sample 7: Failed with error: not enough values to unpack (expected 6, got