In [1]:
import torch
import pickle
from rdkit import Chem
from rdkit.Chem import Draw
import shap
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Optional
from torch_geometric.data import Data, Batch
from rdkit.Chem import RemoveHs
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit import RDLogger
# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
import tensorflow as tensorflow
import traceback
import os
from datetime import datetime
from rdkit.Chem import rdDepictor
import matplotlib.pyplot as plt
from rdkit.Chem import AllChem, Draw, rdDepictor
from matplotlib.colors import LinearSegmentedColormap

class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

# Load Encoder
def load_encoder(model_path, device='cpu'):
    """Load trained encoder"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    encoder.eval()
    return encoder.to(device)

# Load Embeddings
def load_embeddings(filepath):
    """Load embeddings and labels"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['embeddings'], data['labels']

# Paths from your saved model
encoder_path = './checkpoints/encoders/final_encoder_20250216_111050.pt'
# embedding_path = './embeddings/final_embeddings_20250216_111005.pkl'
embedding_path = './embeddings/final_embeddings_fixed.pkl'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load encoder and embeddings
encoder = load_encoder(encoder_path, device)
embeddings, graph_data = load_embeddings(embedding_path)

  checkpoint = torch.load(model_path, map_location=device)


In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
class MolecularLIME:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.eval()
        
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]        

        # Define valence states
        self.valence_states = {
            'N': {
                'neutral': {'max_valence': 3, 'charges': [0]},
                'charged': {'max_valence': 4, 'charges': [1]}
            },
            'O': {
                'neutral': {'max_valence': 2, 'charges': [0]},
                'charged': {'max_valence': 1, 'charges': [-1]}
            },
            'C': {
                'neutral': {'max_valence': 4, 'charges': [0]},
                'charged': {'max_valence': 3, 'charges': [-1, 1]}
            }
        }

        # Feature names aligned with the model's expected input
        self.feature_names = [
            'AtomicNum',
            'ChiralTag',  # Added this as the model expects it
            'FormalCharge',
            'Hybridization',
            'IsAromatic',
            'NumHs',
            'Valence',
            'Degree',
            'InRing'
        ]
        
        # Updated feature ranges
        self.feature_ranges = {
            'AtomicNum': list(range(1, 119)),
            'ChiralTag': list(range(4)),  # Four possible chiral tags
            'FormalCharge': [-2, -1, 0, 1, 2],
            'Hybridization': list(range(1, 7)),
            'IsAromatic': [0, 1],
            'NumHs': list(range(5)),
            'Valence': list(range(1, 7)),
            'Degree': list(range(5)),
            'InRing': [0, 1]
        }

    def generate_perturbations(self, mol: Chem.Mol, atom_idx: int, n_samples: int = 100) -> List[Tuple[np.ndarray, np.ndarray]]:
        """Generate more diverse perturbations to capture all feature importances"""
        try:
            atom = mol.GetAtomWithIdx(atom_idx)
            symbol = atom.GetSymbol()
            total_bonds = int(sum(bond.GetBondTypeAsDouble() for bond in atom.GetBonds()))

            perturbations = []
            base_atomic_feat, base_phys_feat = self.extract_features(atom)
            perturbations.append((base_atomic_feat, base_phys_feat))

            # Define valid changes for each atom type
            atom_changes = {
                'N': {
                    'atomic_nums': [6, 7, 15],  # C, N, P
                    'charges': [0, 1],
                    'hybridizations': [2, 3],  # sp2, sp3
                    'is_aromatic': [True, False]
                },
                'O': {
                    'atomic_nums': [7, 8, 16],  # N, O, S
                    'charges': [-1, 0],
                    'hybridizations': [2, 3],
                    'is_aromatic': [True, False]
                },
                'C': {
                    'atomic_nums': [6, 7],  # C, N
                    'charges': [0],
                    'hybridizations': [2, 3],
                    'is_aromatic': [True, False]
                }
            }

            # Get valid changes for current atom
            changes = atom_changes.get(symbol, {
                'atomic_nums': [atom.GetAtomicNum()],
                'charges': [atom.GetFormalCharge()],
                'hybridizations': [atom.GetHybridization()],
                'is_aromatic': [atom.GetIsAromatic()]
            })

            # Generate diverse perturbations
            for atomic_num in changes['atomic_nums']:
                for charge in changes['charges']:
                    for hybrid in changes['hybridizations']:
                        for is_arom in changes['is_aromatic']:
                            try:
                                mol_copy = Chem.Mol(mol)
                                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)

                                # Calculate max hydrogens based on valence
                                max_valence = {
                                    6: 4,  # C
                                    7: 4 if charge == 1 else 3,  # N
                                    8: 2,  # O
                                    15: 5,  # P
                                    16: 6   # S
                                }.get(atomic_num, 4)

                                available_valence = max_valence - total_bonds
                                h_range = range(max(0, min(3, available_valence)) + 1)

                                for h_count in h_range:
                                    try:
                                        mol_temp = Chem.Mol(mol_copy)
                                        atom_temp = mol_temp.GetAtomWithIdx(atom_idx)

                                        atom_temp.SetAtomicNum(atomic_num)
                                        atom_temp.SetFormalCharge(charge)
                                        atom_temp.SetHybridization(Chem.HybridizationType(hybrid))
                                        atom_temp.SetIsAromatic(is_arom)
                                        atom_temp.SetNumExplicitHs(h_count)

                                        try:
                                            Chem.SanitizeMol(mol_temp)
                                            atomic_feat, phys_feat = self.extract_features(atom_temp)

                                            if not (np.array_equal(atomic_feat, base_atomic_feat) and 
                                                   np.array_equal(phys_feat, base_phys_feat)):
                                                perturbations.append((atomic_feat, phys_feat))

                                        except:
                                            continue

                                    except:
                                        continue

                            except:
                                continue

            # Ensure enough diverse perturbations
            if len(perturbations) < 2:
                return []

            # Limit number of perturbations if too many
            if len(perturbations) > n_samples:
                indices = np.random.choice(len(perturbations), n_samples, replace=False)
                perturbations = [perturbations[i] for i in indices]

            return perturbations

        except Exception as e:
            print(f"Error in generate_perturbations: {str(e)}")
            return []


    def _validate_atom_state(self, atom: Chem.Atom, total_bond_order: float) -> bool:
        """Validate atom state with strict chemical rules"""
        symbol = atom.GetSymbol()
        charge = atom.GetFormalCharge()
        num_h = atom.GetTotalNumHs()
        total_valence = total_bond_order + num_h

        # Valence rules for common atoms
        valence_rules = {
            'C': {
                0: {'max': 4, 'min': 2},
                1: {'max': 3, 'min': 2},
                -1: {'max': 3, 'min': 2}
            },
            'N': {
                0: {'max': 3, 'min': 1},
                1: {'max': 4, 'min': 2},
                -1: {'max': 2, 'min': 1}
            },
            'O': {
                0: {'max': 2, 'min': 1},
                -1: {'max': 1, 'min': 0},
                1: {'max': 3, 'min': 2}
            }
        }

        if symbol in valence_rules:
            rules = valence_rules[symbol].get(charge)
            if rules:
                if total_valence < rules['min'] or total_valence > rules['max']:
                    return False

                # Additional chemical environment checks
                if symbol == 'N':
                    if charge == 1 and total_valence < 3:  # NH3+ minimum
                        return False
                elif symbol == 'O':
                    if charge == -1 and total_valence > 1:  # O- maximum
                        return False

        return True    
    
    def is_valid_valence(self, atom: Chem.Atom) -> bool:
        """Check if atom has valid valence"""
        symbol = atom.GetSymbol()
        charge = atom.GetFormalCharge()
        explicit_valence = atom.GetExplicitValence()

        # Maximum allowed valences for common atoms
        max_valences = {
            'C': {0: 4, 1: 3, -1: 3},
            'N': {0: 3, 1: 4, -1: 2},
            'O': {0: 2, -1: 1, 1: 3},
            'P': {0: 5, -1: 4, 1: 4},
            'S': {0: 6, -1: 5, 1: 5}
        }

        if symbol in max_valences:
            max_val = max_valences[symbol].get(charge, 4)
            return explicit_valence <= max_val

        return True  # Default to True for uncommon atoms    
    
    def is_valid_perturbation(self, features: np.ndarray) -> bool:
        """Check if perturbation is chemically valid"""
        atomic_num = int(features[0])
        chiral_tag = int(features[1])
        formal_charge = int(features[2])
        num_hs = int(features[5])
        valence = int(features[6])
        degree = int(features[7])
        
        # Basic validation rules
        if atomic_num <= 0 or atomic_num > 118:
            return False
            
        if chiral_tag not in range(4):
            return False
            
        if abs(formal_charge) > 2:
            return False
            
        if num_hs + degree > valence:
            return False
            
        # Element-specific rules
        if atomic_num == 6:  # Carbon
            if valence > 4:
                return False
        elif atomic_num == 7:  # Nitrogen
            if valence > 3 and formal_charge != 1:
                return False
        elif atomic_num == 8:  # Oxygen
            if valence > 2 and formal_charge != -1:
                return False
                
        return True        

    def is_valid_state(self, atom: Chem.Atom, charge: int, h_count: int) -> bool:
        """Check if a given state is valid for an atom"""
        symbol = atom.GetSymbol()
        bonded_valence = sum(bond.GetBondTypeAsDouble() for bond in atom.GetBonds())
        total_valence = bonded_valence + h_count
        
        if symbol == 'N':
            if charge == 1:  # NH4+ like
                return total_valence <= 4
            else:  # Neutral N
                return total_valence <= 3
        elif symbol == 'O':
            if charge == -1:  # O-
                return total_valence <= 1
            else:  # Neutral O
                return total_valence <= 2
        elif symbol == 'C':
            return total_valence <= 4
        else:
            return total_valence <= self.get_max_valence(symbol)
            
    def get_max_valence(self, symbol: str) -> int:
        """Get maximum valence for an atom"""
        common_valences = {
            'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1,
            'P': 5, 'S': 6, 'Cl': 1, 'Br': 1, 'I': 1
        }
        return common_valences.get(symbol, 4)  # Default to 4 for unknown atoms

    def get_valid_states(self, atom: Chem.Atom) -> List[Dict[str, int]]:
        """Get valid states for an atom based on its current environment"""
        symbol = atom.GetSymbol()
        current_charge = atom.GetFormalCharge()
        current_h = atom.GetTotalNumHs()
        bonds = list(atom.GetBonds())
        bonded_valence = sum(bond.GetBondTypeAsDouble() for bond in bonds)
        
        # Get available atom states
        atom_states = self.atom_valence_states.get(symbol, {})
        valid_states = []
        
        # Always include current state if it's chemically valid
        current_state = {'charge': current_charge, 'h': current_h}
        if self.is_valid_state(atom, current_state['charge'], current_state['h']):
            valid_states.append(current_state)
        
        if symbol == 'N':
            if current_charge == 1:  # Currently NH+ or NH2+
                # Can remove a proton to go neutral
                if current_h > 0:
                    valid_states.append({'charge': 1, 'h': current_h - 1})
            else:  # Neutral N
                # Can protonate if valence allows
                if bonded_valence + current_h <= 3:
                    valid_states.append({'charge': 0, 'h': current_h})
                    
        elif symbol == 'O':
            if current_charge == -1:  # O-
                # Can be protonated to OH
                valid_states.extend([
                    {'charge': -1, 'h': 0},
                    {'charge': 0, 'h': 1} if bonded_valence < 2 else {'charge': -1, 'h': 0}
                ])
            else:  # Neutral O
                if bonded_valence + current_h <= 2:
                    valid_states.append({'charge': 0, 'h': current_h})
                    
        elif symbol == 'C':
            # Carbon can have max 4 bonds total
            if bonded_valence + current_h <= 4:
                valid_states.append({'charge': 0, 'h': current_h})
                if current_h > 0:  # Can remove H if present
                    valid_states.append({'charge': 0, 'h': current_h - 1})
                    
        # Remove duplicates while preserving order
        seen = set()
        unique_states = []
        for state in valid_states:
            state_tuple = (state['charge'], state['h'])
            if state_tuple not in seen:
                seen.add(state_tuple)
                unique_states.append(state)
                
        print(f"Generated {len(unique_states)} valid states for {symbol}{atom.GetIdx()}: {unique_states}")
        return unique_states
        
    def extract_features(self, atom: Chem.Atom) -> Tuple[np.ndarray, np.ndarray]:
        """Extract atom features with correct array handling"""
        atomic_features = []
        physical_features = []

        try:
            # Atomic features
            atomic_features = [
                atom.GetAtomicNum(),
                int(atom.GetChiralTag()),
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence()
            ]

            # Physical features
            mol = Chem.MolFromSmiles(f'[{atom.GetSymbol()}]')
            mol_weight = Descriptors.ExactMolWt(mol) if mol else 0.0
            log_p = Descriptors.MolLogP(mol) if mol else 0.0

            physical_features = [
                mol_weight,
                log_p,
                atom.GetDegree(),
                int(atom.IsInRing()),
                atom.GetTotalNumHs(),
                atom.GetImplicitValence(),
                atom.GetExplicitValence()
            ]

        except Exception as e:
            print(f"Error extracting features: {str(e)}")
            # Return default values if extraction fails
            atomic_features = [0] * 7
            physical_features = [0.0] * 7

        return np.array(atomic_features, dtype=np.float32), np.array(physical_features, dtype=np.float32)
        
    def get_valid_states(self, atom: Chem.Atom) -> List[Dict[str, int]]:
        """Get valid states for an atom based on current environment"""
        symbol = atom.GetSymbol()
        current_charge = atom.GetFormalCharge()
        current_h = atom.GetTotalNumHs()
        bonded_valence = sum(bond.GetBondTypeAsDouble() for bond in atom.GetBonds())
        
        states = []
        atom_info = self.valence_states.get(symbol, {})
        
        if not atom_info:
            return [{'charge': current_charge, 'h': current_h}]
            
        # Get current valence state (neutral or charged)
        state_type = 'charged' if current_charge != 0 else 'neutral'
        state_info = atom_info[state_type]
        
        # Calculate max hydrogens allowed
        max_h = max(0, state_info['max_valence'] - bonded_valence)
        
        # Add current state first
        current = {'charge': current_charge, 'h': current_h}
        if current not in states:
            states.append(current)
            
        # Add valid alternatives
        if symbol == 'N' and current_charge == 1:  # NH+ or NH2+
            if current_h > 0:
                states.append({'charge': 1, 'h': current_h - 1})
        elif symbol == 'O' and current_charge == -1:  # O-
            states.append({'charge': -1, 'h': 0})
            if bonded_valence < 2:
                states.append({'charge': 0, 'h': 1})
        elif symbol == 'C':
            if current_h > 0:
                states.append({'charge': current_charge, 'h': current_h - 1})
                
        return states    
    
    def get_valid_modifications(self, atom: Chem.Atom) -> Dict:
        """Get valid modifications for an atom based on its type"""
        symbol = atom.GetSymbol()
        if symbol not in self.valid_states:
            return {'charges': [0], 'max_hydrogens': 0}
        return self.valid_states[symbol]

    def extract_atom_features(self, atom: Chem.Atom) -> Tuple[List[int], List[float]]:
        """Extract categorical and physical features from an atom to match model dimensions"""
        # Categorical features (2 dimensions)
        x_cat = [
            int(atom.GetAtomicNum()),
            int(atom.GetChiralTag())
        ]
        
        # Physical features (7 dimensions)
        x_phys = [
            float(atom.GetFormalCharge()),
            float(atom.GetHybridization()),
            float(atom.GetIsAromatic()),
            float(atom.GetTotalNumHs()),
            float(atom.GetTotalValence()),
            float(atom.GetDegree()),
            float(atom.IsInRing())
        ]
        
        return x_cat, x_phys

    def get_connected_atoms(self, atom: Chem.Atom) -> List[Chem.Atom]:
        """Get atoms connected to the target atom"""
        connected = []
        for bond in atom.GetBonds():
            other_atom = bond.GetOtherAtom(atom)
            connected.append((other_atom, bond.GetBondType()))
        return connected

    def generate_neighborhood(self, mol: Chem.Mol, atom_idx: int, n_samples: int = 50) -> List[Chem.Mol]:
        """Generate valid molecule perturbations"""
        neighborhoods = []
        target_atom = mol.GetAtomWithIdx(atom_idx)
        
        print(f"\nGenerating neighborhood for {target_atom.GetSymbol()}{atom_idx}")
        print(f"Current state: charge={target_atom.GetFormalCharge()}, H={target_atom.GetTotalNumHs()}")
        
        # Get valid states
        valid_states = self.get_valid_states(target_atom)
        print(f"Found {len(valid_states)} valid states:", valid_states)
        
        # Always include original
        neighborhoods.append(Chem.Mol(mol))
        
        # Try each valid state
        for state in valid_states:
            if (state['charge'] == target_atom.GetFormalCharge() and 
                state['h'] == target_atom.GetTotalNumHs()):
                continue
                
            try:
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                atom_copy.SetFormalCharge(state['charge'])
                atom_copy.SetNumExplicitHs(state['h'])
                
                try:
                    # Sanitize with careful options
                    Chem.SanitizeMol(mol_copy, 
                                   sanitizeOps=Chem.SANITIZE_ADJUSTHS | 
                                             Chem.SANITIZE_PROPERTIES)
                    neighborhoods.append(mol_copy)
                    print(f"Added valid state: charge={state['charge']}, H={state['h']}")
                except Exception as e:
                    print(f"Failed sanitization: {str(e)}")
                    continue
                    
            except Exception as e:
                print(f"Failed to apply state: {str(e)}")
                continue
                
        return neighborhoods

    def _mol_to_graph_data(self, mol: Chem.Mol) -> Data:
        """Convert molecule to graph data with correct feature dimensions"""
        try:
            # Get original features
            atom_feats = []
            phys_feats = []

            for atom in mol.GetAtoms():
                # Categorical features (2D)
                atom_feat = [
                    self.atom_list.index(atom.GetAtomicNum()),
                    self.chirality_list.index(atom.GetChiralTag())
                ]

                # Physical features (8D)
                try:
                    contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                    contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                except:
                    contrib_mw = 0.0
                    contrib_logp = 0.0

                phys_feat = [
                    contrib_mw,
                    contrib_logp,
                    atom.GetFormalCharge(),
                    int(atom.GetHybridization()),
                    int(atom.GetIsAromatic()),
                    atom.GetTotalNumHs(),
                    atom.GetTotalValence(),
                    atom.GetDegree()
                ]

                atom_feats.append(atom_feat)
                phys_feats.append(phys_feat)

            x_cat = torch.tensor(atom_feats, dtype=torch.long)
            x_phys = torch.tensor(phys_feats, dtype=torch.float)

            # Get bond features (5D)
            edge_index = []
            edge_attr = []

            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()

                # Add edges in both directions
                edge_index += [[start, end], [end, start]]

                # Bond features
                try:
                    bond_type = self.bond_list.index(bond.GetBondType())
                    bond_dir = self.bonddir_list.index(bond.GetBondDir())
                except:
                    bond_type = 0
                    bond_dir = 0

                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(bond.IsInRing()),
                    int(bond.GetIsAromatic())
                ]
                edge_attr += [feat, feat]

            if not edge_index:  # Handle molecules with no bonds
                edge_index = [[0, 0]]
                edge_attr = [[0, 0, 0, 0, 0]]

            edge_index = torch.tensor(edge_index, dtype=torch.long).t()
            edge_attr = torch.tensor(edge_attr, dtype=torch.float)

            return Data(
                x_cat=x_cat,
                x_phys=x_phys,
                edge_index=edge_index,
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error in _mol_to_graph_data: {str(e)}")
            return None

    def _batch_predict(self, graph_data_list: List[Data]) -> np.ndarray:
        """Get model predictions with correct input handling"""
        if not graph_data_list:
            return np.array([])

        try:
            batch = Batch.from_data_list(graph_data_list).to(self.device)

            with torch.no_grad():
                outputs = self.model(batch)
                # Use L2 norm of embeddings as prediction
                predictions = torch.norm(outputs, dim=1).cpu().numpy()

            return predictions

        except Exception as e:
            print(f"Error in batch prediction: {str(e)}")
            traceback.print_exc()  # Add this for more detailed error info
            return np.array([])

    def explain_atom(self, mol: Chem.Mol, atom_idx: int, n_samples: int = 100) -> Dict:
        """Generate LIME explanation with improved perturbation handling"""
        try:
            # Generate perturbations
            perturbations = self.generate_perturbations(mol, atom_idx, n_samples)
            if len(perturbations) < 2:
                print(f"Could not generate explanation for atom {atom_idx}")
                return None

            # Get model predictions
            predictions = []
            valid_perturbations = []

            for atomic_feat, phys_feat in perturbations:
                try:
                    mol_copy = Chem.Mol(mol)
                    atom_copy = mol_copy.GetAtomWithIdx(atom_idx)

                    # Apply features
                    self.apply_features_to_atom(atom_copy, atomic_feat, phys_feat)

                    # Convert to graph data
                    graph_data = self._mol_to_graph_data(mol_copy)
                    if graph_data is not None:
                        batch = Batch.from_data_list([graph_data]).to(self.device)

                        with torch.no_grad():
                            output = self.model(batch)
                            pred = torch.norm(output, dim=1).cpu().numpy()[0]

                        predictions.append(pred)
                        valid_perturbations.append(atomic_feat)

                except Exception as e:
                    continue

            if len(predictions) < 2:
                print(f"Could not get valid predictions for atom {atom_idx}")
                return None

            # Calculate importance scores
            importance_scores = self._calculate_importance(
                valid_perturbations,
                np.array(predictions),
                valid_perturbations[0]
            )

            return {
                'local_importance': importance_scores,
                'local_prediction': float(predictions[0]),
                'feature_names': self.feature_names,
                'perturbed_predictions': predictions
            }

        except Exception as e:
            print(f"Error in explain_atom: {str(e)}")
            return None
    
    def apply_features_to_atom(self, atom: Chem.Atom, atomic_feat: np.ndarray, phys_feat: np.ndarray) -> None:
        """Apply features to atom with proper type conversion"""
        try:
            # Apply atomic features
            atom.SetAtomicNum(int(atomic_feat[0]))
            atom.SetChiralTag(Chem.ChiralType.values[int(atomic_feat[1])])
            atom.SetFormalCharge(int(atomic_feat[2]))
            atom.SetHybridization(Chem.HybridizationType(int(atomic_feat[3])))
            atom.SetIsAromatic(bool(atomic_feat[4]))
            atom.SetNumExplicitHs(int(atomic_feat[5]))

        except Exception as e:
            raise ValueError(f"Error applying features: {str(e)}")

    def _calculate_importance(self, features: List[np.ndarray], 
                            predictions: np.ndarray,
                            base_features: np.ndarray) -> Dict[str, float]:
        """Calculate feature importance with improved diversity handling"""
        if len(predictions) < 2:
            return {i: 0.0 for i in range(len(self.feature_names))}

        try:
            features_array = np.array(features)
            base_pred = predictions[0]
            pred_changes = predictions[1:] - base_pred

            # Initialize importance scores
            importance_scores = {}

            for i, feat_name in enumerate(self.feature_names):
                try:
                    feat_values = features_array[:, i]
                    if len(np.unique(feat_values)) < 2:
                        importance_scores[i] = 0.0
                        continue

                    feat_changes = features_array[1:, i] - features_array[0, i]

                    # Calculate multiple metrics

                    # 1. Correlation importance
                    correlation = np.corrcoef(feat_changes, pred_changes)[0, 1]
                    if np.isnan(correlation):
                        correlation = 0.0

                    # 2. Effect size importance
                    effect_size = np.mean(np.abs(pred_changes[feat_changes != 0])) / (np.std(predictions) + 1e-6)

                    # 3. Feature uniqueness
                    unique_ratio = len(np.unique(feat_values)) / len(feat_values)

                    # 4. Predictive power
                    pred_power = np.abs(np.mean(pred_changes[feat_changes != 0])) / (np.mean(np.abs(pred_changes)) + 1e-6)

                    # Combine metrics
                    importance = (
                        0.3 * correlation +
                        0.3 * np.tanh(effect_size) +
                        0.2 * unique_ratio +
                        0.2 * pred_power
                    )

                    importance_scores[i] = np.clip(importance, -1.0, 1.0)

                except Exception as e:
                    print(f"Error calculating importance for feature {feat_name}: {str(e)}")
                    importance_scores[i] = 0.0

            # Normalize scores
            abs_scores = np.abs(list(importance_scores.values()))
            max_abs = max(abs_scores) if len(abs_scores) > 0 else 1.0

            if max_abs > 0:
                importance_scores = {k: v/max_abs for k, v in importance_scores.items()}

            return importance_scores

        except Exception as e:
            print(f"Error in _calculate_importance: {str(e)}")
            return {i: 0.0 for i in range(len(self.feature_names))}

def visualize_lime_results(symbol:str, atom_idx: int, explanation: Dict, save_path: str = None):
    """Improved visualization of LIME results"""
    if explanation is None:
        print(f"No valid explanation for atom {atom_idx}")
        return
        
    plt.figure(figsize=(12, 6))
    
    # Get importance scores
    importance_scores = []
    feature_names = explanation['feature_names']
    for i in range(len(feature_names)):
        score = explanation['local_importance'].get(i, 0.0)
        importance_scores.append(score)
    
    # Sort by absolute importance
    feat_importance = list(zip(feature_names, importance_scores))
    feat_importance.sort(key=lambda x: abs(x[1]), reverse=True)
    
    features, scores = zip(*feat_importance)
    
    # Create bar plot with improved styling
    y_pos = np.arange(len(features))
    
    # Custom colormap for importance scores
    colors = []
    for score in scores:
        if score > 0:
            colors.append('#2ecc71')  # Green for positive
        elif score < 0:
            colors.append('#e74c3c')  # Red for negative
        else:
            colors.append('#95a5a6')  # Gray for zero
    
    bars = plt.barh(y_pos, scores, align='center', color=colors, alpha=0.8)
    plt.yticks(y_pos, features)
    
    # Add value labels with improved formatting
    for i, bar in enumerate(bars):
        width = bar.get_width()
        if abs(width) >= 0.01:  # Only show non-zero values
            label_pos = width + np.sign(width) * 0.01
            plt.text(label_pos, bar.get_y() + bar.get_height()/2,
                    f'{width:.3f}',
                    va='center',
                    ha='left' if width >= 0 else 'right',
                    fontsize=10)
    
    plt.title(f'LIME Feature Importance for Node {symbol} {atom_idx}',
              fontsize=12, pad=20)
    plt.xlabel('Importance Score', fontsize=11)
    
    # Add grid and style improvements
    plt.grid(True, axis='x', linestyle='--', alpha=0.3)
    plt.axvline(x=0, color='black', linewidth=0.5, alpha=0.5)
    
    # Adjust layout and save
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def run_lime_analysis(dataset, encoder, device):
    """Run LIME analysis on a molecule"""
    save_dir = f'Lime_Explanations\lime_explanations_{timestamp}'
    os.makedirs(save_dir, exist_ok=True)
    
    # Get molecule from dataset
    mol_data = dataset[1]  
    smiles = mol_data.smiles
    mol = Chem.MolFromSmiles(smiles)
    
    lime_explainer = MolecularLIME(encoder, device)
    important_atoms = [2, 9, 13, 15]  # Example atoms to analyze
    results = {}
    
    for atom_idx in important_atoms:
        atom = mol.GetAtomWithIdx(atom_idx)
        print(f"\nAnalyzing atom {atom.GetSymbol()}{atom_idx}...")
        
        symbol = atom.GetSymbol()
        print(f"the symbol is :",symbol)
        
        explanation = lime_explainer.explain_atom(mol, atom_idx, n_samples=50)
        
        if explanation is None:
            print(f"Could not generate explanation for atom {atom_idx}")
            continue
            
        results[atom_idx] = explanation
        
        save_path = os.path.join(save_dir, f'lime_importance_atom_{atom_idx}.png')
        visualize_lime_results(symbol,atom_idx, explanation, save_path)
        
        # Print top 5 important features
        importance_pairs = [(name, explanation['local_importance'].get(i, 0.0))
                          for i, name in enumerate(explanation['feature_names'])]
        sorted_features = sorted(importance_pairs, key=lambda x: abs(x[1]), reverse=True)[:5]
        
        print("\nTop 5 most important features:")
        for feat_name, importance in sorted_features:
            print(f"  {feat_name}: {importance:.4f}")
    
    return {
        'molecule_smiles': smiles,
        'important_atoms': important_atoms,
        'results': results
    }

In [4]:
    if __name__ == "__main__":


        print("Starting data loading...")
        extractor = MolecularFeatureExtractor()
        smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"

        dataset = []
        failed_smiles = []

        # Modified data loading to store SMILES
        with open(smiles_file, 'r') as f:
            for line in f:
                smiles = line.strip()
                data = extractor.process_molecule(smiles)
                if data is not None:
                    # Add SMILES as an attribute to the Data object
                    data.smiles = smiles  # Add this line
                    dataset.append(data)
                else:
                    failed_smiles.append(smiles)

        print(f"1. Loaded dataset with {len(dataset)} graphs.")
        print(f"2. Failed SMILES count: {len(failed_smiles)}")

        if not dataset:
            print("No valid graphs generated.")

        # Make sure to import needed libraries
        from rdkit import Chem
        from rdkit.Chem import AllChem, Draw
        import traceback

        os.makedirs('molecule_explanation', exist_ok=True)
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

        try:
            print("\nStarting molecule analysis...")    

            results = run_lime_analysis(dataset, encoder, device)

            if results is not None:
                print("\nAnalysis Summary:")
                print(f"Analyzed results:", results)
            else:
                print("\nAnalysis failed!")

        except Exception as e:
            print(f"\nError in main execution: {str(e)}")
            import traceback
            traceback.print_exc()    

Starting data loading...
1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0

Starting molecule analysis...

Analyzing atom N2...
the symbol is : N
Error calculating importance for feature Degree: index 7 is out of bounds for axis 1 with size 7
Error calculating importance for feature InRing: index 8 is out of bounds for axis 1 with size 7


[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] Explicit valence for atom # 2 C, 4, is greater than permitted
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] Explicit valence for atom # 2 C, 4, is greater than permitted
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aromatic
[22:05:39] non-ring atom 2 marked aro


Top 5 most important features:
  AtomicNum: 1.0000
  Hybridization: 0.7032
  NumHs: 0.6431
  Valence: 0.6431
  FormalCharge: 0.5240

Analyzing atom N9...
the symbol is : N
Error calculating importance for feature Degree: index 7 is out of bounds for axis 1 with size 7
Error calculating importance for feature InRing: index 8 is out of bounds for axis 1 with size 7


[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] Explicit valence for atom # 9 C, 4, is greater than permitted
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] Explicit valence for atom # 9 C, 4, is greater than permitted
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aromatic
[22:05:39] non-ring atom 9 marked aro


Top 5 most important features:
  AtomicNum: 1.0000
  NumHs: 0.5965
  Valence: 0.5965
  Hybridization: 0.5809
  FormalCharge: 0.4878

Analyzing atom C13...
the symbol is : C
Could not generate explanation for atom 13
Could not generate explanation for atom 13

Analyzing atom O15...
the symbol is : O
Error calculating importance for feature Degree: index 7 is out of bounds for axis 1 with size 7
Error calculating importance for feature InRing: index 8 is out of bounds for axis 1 with size 7


[22:05:40] non-ring atom 13 marked aromatic
[22:05:40] non-ring atom 13 marked aromatic
[22:05:40] non-ring atom 13 marked aromatic
[22:05:40] Explicit valence for atom # 13 N, 4, is greater than permitted
[22:05:40] non-ring atom 13 marked aromatic
[22:05:40] Explicit valence for atom # 13 N, 4, is greater than permitted
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] Explicit valence for atom # 15 N, 3, is greater than permitted
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] Explicit valence for atom # 15 N, 3, is greater than permitted
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic
[22:05:40] non-ring atom 15 marked aromatic



Top 5 most important features:
  Hybridization: 1.0000
  AtomicNum: 0.8959
  NumHs: 0.7275
  Valence: 0.7275
  FormalCharge: 0.4091

Analysis Summary:
Analyzed results: {'molecule_smiles': 'CC[NH+](CC)C1CCC([NH2+]C2CC2)(C(=O)[O-])C1', 'important_atoms': [2, 9, 13, 15], 'results': {2: {'local_importance': {0: np.float64(1.0), 1: np.float64(0.0), 2: np.float64(0.5239678200699016), 3: np.float64(0.7032025401845149), 4: np.float64(0.0), 5: np.float64(0.6431066937670791), 6: np.float64(0.6431066937670791), 7: np.float64(0.0), 8: np.float64(0.0)}, 'local_prediction': 4.899072647094727, 'feature_names': ['AtomicNum', 'ChiralTag', 'FormalCharge', 'Hybridization', 'IsAromatic', 'NumHs', 'Valence', 'Degree', 'InRing'], 'perturbed_predictions': [np.float32(4.8990726), np.float32(4.8066216), np.float32(4.822563), np.float32(4.8066216), np.float32(4.822563), np.float32(4.830337), np.float32(4.830337), np.float32(4.865579), np.float32(4.865579), np.float32(4.884666), np.float32(4.884666), np.float3