In [1]:
import torch
import pickle
from rdkit import Chem
from rdkit.Chem import Draw
import shap
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Optional
from torch_geometric.data import Data, Batch
from rdkit.Chem import RemoveHs
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit import RDLogger
# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
import tensorflow as tensorflow
import traceback
import os
from datetime import datetime
from rdkit.Chem import rdDepictor
import matplotlib.pyplot as plt
from rdkit.Chem import AllChem, Draw, rdDepictor
from matplotlib.colors import LinearSegmentedColormap


class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

# Load Encoder
def load_encoder(model_path, device='cpu'):
    """Load trained encoder"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    encoder.eval()
    return encoder.to(device)

# Load Embeddings
def load_embeddings(filepath):
    """Load embeddings and labels"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['embeddings'], data['labels']

# Paths from your saved model
encoder_path = './checkpoints/encoders/final_encoder_20250216_111050.pt'
embedding_path = './embeddings/final_embeddings_20250216_111005.pkl'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load encoder and embeddings
encoder = load_encoder(encoder_path, device)
embeddings, graph_data = load_embeddings(embedding_path)

  checkpoint = torch.load(model_path, map_location=device)


In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
import torch
from torch_geometric.data import Data, Batch
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import List, Tuple, Dict
import matplotlib.pyplot as plt
from datetime import datetime
import os

class MolecularLIME:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.eval()
        
        # Define valid atom states
        self.atom_valence_states = {
            'N': {  # Nitrogen
                'single_bond': {'max_h': 3, 'charge': 0},  # sp3 N
                'double_bond': {'max_h': 2, 'charge': 0},  # sp2 N
                'triple_bond': {'max_h': 1, 'charge': 0},  # sp N
                'positively_charged': {'max_h': 4, 'charge': 1},  # NH4+
            },
            'O': {  # Oxygen
                'single_bond': {'max_h': 2, 'charge': 0},  # sp3 O
                'double_bond': {'max_h': 0, 'charge': 0},  # sp2 O
                'negatively_charged': {'max_h': 0, 'charge': -1},  # O-
            },
            'C': {  # Carbon
                'single_bond': {'max_h': 4, 'charge': 0},  # sp3 C
                'double_bond': {'max_h': 2, 'charge': 0},  # sp2 C
                'triple_bond': {'max_h': 1, 'charge': 0},  # sp C
            }
        }
        
        # Feature names for output
        self.feature_names = [
            'AtomicNum',
            'FormalCharge',
            'Hybridization',
            'IsAromatic',
            'NumHs',
            'Valence',
            'Degree',
            'InRing',
            'Mass',
            'Extra'
        ]
        
        # Define valid states for different atoms
        self.valid_states = {
            'N': {  # Nitrogen
                'charges': [-1, 0, 1],
                'max_valence': 3,
                'max_hydrogens': 3
            },
            'C': {  # Carbon
                'charges': [-1, 0, 1],
                'max_valence': 4,
                'max_hydrogens': 4
            },
            'O': {  # Oxygen
                'charges': [-1, 0, 1],
                'max_valence': 2,
                'max_hydrogens': 2
            }
        }


    def is_valid_state(self, atom: Chem.Atom, charge: int, h_count: int) -> bool:
        """Check if a given state is valid for an atom"""
        symbol = atom.GetSymbol()
        bonded_valence = sum(bond.GetBondTypeAsDouble() for bond in atom.GetBonds())
        total_valence = bonded_valence + h_count
        
        if symbol == 'N':
            if charge == 1:  # NH4+ like
                return total_valence <= 4
            else:  # Neutral N
                return total_valence <= 3
        elif symbol == 'O':
            if charge == -1:  # O-
                return total_valence <= 1
            else:  # Neutral O
                return total_valence <= 2
        elif symbol == 'C':
            return total_valence <= 4
        else:
            return total_valence <= self.get_max_valence(symbol)
            
    def get_max_valence(self, symbol: str) -> int:
        """Get maximum valence for an atom"""
        common_valences = {
            'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1,
            'P': 5, 'S': 6, 'Cl': 1, 'Br': 1, 'I': 1
        }
        return common_valences.get(symbol, 4)  # Default to 4 for unknown atoms

    def get_valid_states(self, atom: Chem.Atom) -> List[Dict[str, int]]:
        """Get valid states for an atom based on current environment"""
        symbol = atom.GetSymbol()
        current_charge = atom.GetFormalCharge()
        current_h = atom.GetTotalNumHs()
        bonded_valence = sum(bond.GetBondTypeAsDouble() for bond in atom.GetBonds())
        
        states = []
        atom_info = self.valence_states.get(symbol, {})
        
        if not atom_info:
            return [{'charge': current_charge, 'h': current_h}]
            
        # Get current valence state (neutral or charged)
        state_type = 'charged' if current_charge != 0 else 'neutral'
        state_info = atom_info[state_type]
        
        # Calculate max hydrogens allowed
        max_h = max(0, state_info['max_valence'] - bonded_valence)
        
        # Add current state first
        current = {'charge': current_charge, 'h': current_h}
        if current not in states:
            states.append(current)
            
        # Add valid alternatives
        if symbol == 'N' and current_charge == 1:  # NH+ or NH2+
            if current_h > 0:
                states.append({'charge': 1, 'h': current_h - 1})
        elif symbol == 'O' and current_charge == -1:  # O-
            states.append({'charge': -1, 'h': 0})
            if bonded_valence < 2:
                states.append({'charge': 0, 'h': 1})
        elif symbol == 'C':
            if current_h > 0:
                states.append({'charge': current_charge, 'h': current_h - 1})
                
        return states    
        
    def extract_features(self, atom: Chem.Atom) -> np.ndarray:
        """Extract atom features as numpy array"""
        features = np.zeros(10, dtype=np.float32)
        
        # Fill in features
        features[0] = float(atom.GetAtomicNum())
        features[1] = float(atom.GetFormalCharge())
        features[2] = float(atom.GetHybridization())
        features[3] = float(atom.GetIsAromatic())
        features[4] = float(atom.GetTotalNumHs())
        features[5] = float(atom.GetTotalValence())
        features[6] = float(atom.GetDegree())
        features[7] = float(atom.IsInRing())
        features[8] = float(atom.GetMass())
        features[9] = 0.0
        
        return features        
        
    def get_valid_states(self, atom: Chem.Atom) -> List[Dict[str, int]]:
        """Get valid states for an atom based on its current environment"""
        symbol = atom.GetSymbol()
        current_charge = atom.GetFormalCharge()
        current_h = atom.GetTotalNumHs()
        bonds = list(atom.GetBonds())
        bonded_valence = sum(bond.GetBondTypeAsDouble() for bond in bonds)
        
        # Get available atom states
        atom_states = self.atom_valence_states.get(symbol, {})
        valid_states = []
        
        # Always include current state if it's chemically valid
        current_state = {'charge': current_charge, 'h': current_h}
        if self.is_valid_state(atom, current_state['charge'], current_state['h']):
            valid_states.append(current_state)
        
        if symbol == 'N':
            if current_charge == 1:  # Currently NH+ or NH2+
                # Can remove a proton to go neutral
                if current_h > 0:
                    valid_states.append({'charge': 1, 'h': current_h - 1})
            else:  # Neutral N
                # Can protonate if valence allows
                if bonded_valence + current_h <= 3:
                    valid_states.append({'charge': 0, 'h': current_h})
                    
        elif symbol == 'O':
            if current_charge == -1:  # O-
                # Can be protonated to OH
                valid_states.extend([
                    {'charge': -1, 'h': 0},
                    {'charge': 0, 'h': 1} if bonded_valence < 2 else {'charge': -1, 'h': 0}
                ])
            else:  # Neutral O
                if bonded_valence + current_h <= 2:
                    valid_states.append({'charge': 0, 'h': current_h})
                    
        elif symbol == 'C':
            # Carbon can have max 4 bonds total
            if bonded_valence + current_h <= 4:
                valid_states.append({'charge': 0, 'h': current_h})
                if current_h > 0:  # Can remove H if present
                    valid_states.append({'charge': 0, 'h': current_h - 1})
                    
        # Remove duplicates while preserving order
        seen = set()
        unique_states = []
        for state in valid_states:
            state_tuple = (state['charge'], state['h'])
            if state_tuple not in seen:
                seen.add(state_tuple)
                unique_states.append(state)
                
        print(f"Generated {len(unique_states)} valid states for {symbol}{atom.GetIdx()}: {unique_states}")
        return unique_states
    
    
    def get_valid_modifications(self, atom: Chem.Atom) -> Dict:
        """Get valid modifications for an atom based on its type"""
        symbol = atom.GetSymbol()
        if symbol not in self.valid_states:
            return {'charges': [0], 'max_hydrogens': 0}
        return self.valid_states[symbol]

    def extract_atom_features(self, atom: Chem.Atom) -> Tuple[List[int], List[float]]:
        """Extract categorical and physical features from an atom to match model dimensions"""
        # Categorical features (2 dimensions)
        x_cat = [
            int(atom.GetAtomicNum()),
            int(atom.GetChiralTag())
        ]
        
        # Physical features (7 dimensions)
        x_phys = [
            float(atom.GetFormalCharge()),
            float(atom.GetHybridization()),
            float(atom.GetIsAromatic()),
            float(atom.GetTotalNumHs()),
            float(atom.GetTotalValence()),
            float(atom.GetDegree()),
            float(atom.IsInRing())
        ]
        
        return x_cat, x_phys

    def get_connected_atoms(self, atom: Chem.Atom) -> List[Chem.Atom]:
        """Get atoms connected to the target atom"""
        connected = []
        for bond in atom.GetBonds():
            other_atom = bond.GetOtherAtom(atom)
            connected.append((other_atom, bond.GetBondType()))
        return connected

    def generate_neighborhood(self, mol: Chem.Mol, atom_idx: int, n_samples: int = 50) -> List[Chem.Mol]:
        """Generate valid molecule perturbations"""
        neighborhoods = []
        target_atom = mol.GetAtomWithIdx(atom_idx)
        
        print(f"\nGenerating neighborhood for {target_atom.GetSymbol()}{atom_idx}")
        print(f"Current state: charge={target_atom.GetFormalCharge()}, H={target_atom.GetTotalNumHs()}")
        
        # Get valid states
        valid_states = self.get_valid_states(target_atom)
        print(f"Found {len(valid_states)} valid states:", valid_states)
        
        # Always include original
        neighborhoods.append(Chem.Mol(mol))
        
        # Try each valid state
        for state in valid_states:
            if (state['charge'] == target_atom.GetFormalCharge() and 
                state['h'] == target_atom.GetTotalNumHs()):
                continue
                
            try:
                mol_copy = Chem.Mol(mol)
                atom_copy = mol_copy.GetAtomWithIdx(atom_idx)
                
                atom_copy.SetFormalCharge(state['charge'])
                atom_copy.SetNumExplicitHs(state['h'])
                
                try:
                    # Sanitize with careful options
                    Chem.SanitizeMol(mol_copy, 
                                   sanitizeOps=Chem.SANITIZE_ADJUSTHS | 
                                             Chem.SANITIZE_PROPERTIES)
                    neighborhoods.append(mol_copy)
                    print(f"Added valid state: charge={state['charge']}, H={state['h']}")
                except Exception as e:
                    print(f"Failed sanitization: {str(e)}")
                    continue
                    
            except Exception as e:
                print(f"Failed to apply state: {str(e)}")
                continue
                
        return neighborhoods

    def _mol_to_graph_data(self, mol: Chem.Mol) -> Data:
        """Convert molecule to graph data"""
        # Get atom features
        x_cat = []
        x_phys = []
        
        for atom in mol.GetAtoms():
            features = self.extract_features(atom)
            x_cat.append([int(features[0])])  # Atomic number only
            x_phys.append(features[1:])  # Rest of features
            
        x_cat = torch.tensor(x_cat, dtype=torch.long)
        x_phys = torch.tensor(x_phys, dtype=torch.float)
        
        # Get bond features
        edge_index = []
        edge_attr = []
        
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_index += [[i, j], [j, i]]
            
            feat = [
                float(bond.GetBondTypeAsDouble()),
                float(bond.GetIsConjugated()),
                float(bond.IsInRing()),
                float(bond.GetIsAromatic()),
                float(bond.GetBondTypeAsDouble())
            ]
            edge_attr += [feat, feat]
            
        if not edge_index:
            edge_index = [[0, 0]]
            edge_attr = [[0, 0, 0, 0, 0]]
            
        edge_index = torch.tensor(edge_index, dtype=torch.long).t()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
        
        return Data(
            x_cat=x_cat,
            x_phys=x_phys,
            edge_index=edge_index,
            edge_attr=edge_attr,
            num_nodes=x_cat.size(0)
        )

    def _batch_predict(self, graph_data_list: List[Data]) -> np.ndarray:
        """Get model predictions for a batch of molecules"""
        if not graph_data_list:
            return np.array([])
            
        batch = Batch.from_data_list(graph_data_list).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(batch)
            predictions = torch.norm(outputs, dim=1).cpu().numpy()
            
        return predictions

    def explain_atom(self, mol: Chem.Mol, atom_idx: int, n_samples: int = 50) -> Dict:
        """Generate LIME explanation for a specific atom"""
        # Generate neighborhood
        neighborhood = self.generate_neighborhood(mol, atom_idx, n_samples)
        
        print(f"Generated {len(neighborhood)} molecules")
        
        # Extract features and predictions
        feature_list = []
        graph_data_list = []
        
        # First handle original molecule
        target_atom = mol.GetAtomWithIdx(atom_idx)
        base_features = self.extract_features(target_atom)
        
        for idx, m in enumerate(neighborhood):
            try:
                data = self._mol_to_graph_data(m)
                if data is None:
                    print(f"Invalid graph data for molecule {idx}")
                    continue
                    
                atom = m.GetAtomWithIdx(atom_idx)
                features = self.extract_features(atom)
                
                feature_list.append(features)
                graph_data_list.append(data)
                
            except Exception as e:
                print(f"Error processing molecule {idx}: {str(e)}")
                continue
                
        if len(graph_data_list) <= 1:
            print("Not enough valid molecules generated")
            return None
            
        # Get model predictions
        predictions = self._batch_predict(graph_data_list)
        if len(predictions) == 0:
            return None
            
        # Calculate feature importance
        features = np.stack(feature_list)
        importance_scores = {}
        
        for i, name in enumerate(self.feature_names):
            if np.std(features[:, i]) > 0:
                correlation = np.corrcoef(features[:, i], predictions)[0, 1]
                importance_scores[i] = correlation if not np.isnan(correlation) else 0.0
            else:
                importance_scores[i] = 0.0
                
        return {
            'local_importance': importance_scores,
            'local_prediction': float(predictions[0]),
            'feature_names': self.feature_names,
            'perturbed_predictions': predictions.tolist()
        }

def visualize_lime_results(atom_idx: int, explanation: Dict, save_path: str = None):
    """Visualize LIME feature importance"""
    if explanation is None:
        print(f"No valid explanation for atom {atom_idx}")
        return
        
    plt.figure(figsize=(10, 6))
    
    # Get importance scores
    importance_scores = []
    for i in range(len(explanation['feature_names'])):
        score = explanation['local_importance'].get(i, 0.0)
        importance_scores.append(score)
    
    # Sort by absolute importance
    feat_importance = list(zip(explanation['feature_names'], importance_scores))
    feat_importance.sort(key=lambda x: abs(x[1]), reverse=True)
    
    features, scores = zip(*feat_importance)
    
    # Create bar plot
    y_pos = np.arange(len(features))
    colors = ['#2ecc71' if s > 0 else '#e74c3c' for s in scores]
    
    plt.barh(y_pos, scores, align='center', color=colors, alpha=0.8)
    plt.yticks(y_pos, features)
    
    # Add value labels
    for i, v in enumerate(scores):
        if abs(v) >= 0.01:  # Only show non-zero values
            plt.text(v + np.sign(v)*0.01, i, f'{v:.3f}', 
                    va='center', ha='left' if v >= 0 else 'right')
    
    plt.title(f'LIME Feature Importance for Atom {atom_idx}')
    plt.xlabel('Importance Score')
    plt.grid(True, axis='x', linestyle='--', alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def run_lime_analysis(dataset, encoder, device):
    """Run LIME analysis on a molecule"""
    save_dir = f'lime_explanations_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
    os.makedirs(save_dir, exist_ok=True)
    
    # Get molecule from dataset
    mol_data = dataset[1]  # Using index 1 as example
    smiles = mol_data.smiles
    mol = Chem.MolFromSmiles(smiles)
    
    lime_explainer = MolecularLIME(encoder, device)
    important_atoms = [2, 9, 13, 15]  # Example atoms to analyze
    results = {}
    
    for atom_idx in important_atoms:
        atom = mol.GetAtomWithIdx(atom_idx)
        print(f"\nAnalyzing atom {atom.GetSymbol()}{atom_idx}...")
        
        explanation = lime_explainer.explain_atom(mol, atom_idx, n_samples=50)
        
        if explanation is None:
            print(f"Could not generate explanation for atom {atom_idx}")
            continue
            
        results[atom_idx] = explanation
        
        save_path = os.path.join(save_dir, f'lime_importance_atom_{atom_idx}.png')
        visualize_lime_results(atom_idx, explanation, save_path)
        
        # Print top 5 important features
        importance_pairs = [(name, explanation['local_importance'].get(i, 0.0))
                          for i, name in enumerate(explanation['feature_names'])]
        sorted_features = sorted(importance_pairs, key=lambda x: abs(x[1]), reverse=True)[:5]
        
        print("\nTop 5 most important features:")
        for feat_name, importance in sorted_features:
            print(f"  {feat_name}: {importance:.4f}")
    
    return {
        'molecule_smiles': smiles,
        'important_atoms': important_atoms,
        'results': results
    }

In [4]:
    if __name__ == "__main__":


        print("Starting data loading...")
        extractor = MolecularFeatureExtractor()
        smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"

        dataset = []
        failed_smiles = []

        # Modified data loading to store SMILES
        with open(smiles_file, 'r') as f:
            for line in f:
                smiles = line.strip()
                data = extractor.process_molecule(smiles)
                if data is not None:
                    # Add SMILES as an attribute to the Data object
                    data.smiles = smiles  # Add this line
                    dataset.append(data)
                else:
                    failed_smiles.append(smiles)

        print(f"1. Loaded dataset with {len(dataset)} graphs.")
        print(f"2. Failed SMILES count: {len(failed_smiles)}")

        if not dataset:
            print("No valid graphs generated.")

        # Make sure to import needed libraries
        from rdkit import Chem
        from rdkit.Chem import AllChem, Draw
        import traceback

        os.makedirs('molecule_explanation', exist_ok=True)
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

        try:
            print("\nStarting molecule analysis...")    

            results = run_lime_analysis(dataset, encoder, device)

            if results is not None:
                print("\nAnalysis Summary:")
                print(f"Analyzed results:", results)
            else:
                print("\nAnalysis failed!")

        except Exception as e:
            print(f"\nError in main execution: {str(e)}")
            import traceback
            traceback.print_exc()    

Starting data loading...
1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0

Starting molecule analysis...

Analyzing atom N2...

Generating neighborhood for N2
Current state: charge=1, H=1
Generated 2 valid states for N2: [{'charge': 1, 'h': 1}, {'charge': 1, 'h': 0}]
Found 2 valid states: [{'charge': 1, 'h': 1}, {'charge': 1, 'h': 0}]
Added valid state: charge=1, H=0
Generated 2 molecules


  x_phys = torch.tensor(x_phys, dtype=torch.float)



Top 5 most important features:
  NumHs: -1.0000
  Valence: -1.0000
  AtomicNum: 0.0000
  FormalCharge: 0.0000
  Hybridization: 0.0000

Analyzing atom N9...

Generating neighborhood for N9
Current state: charge=1, H=2
Generated 2 valid states for N9: [{'charge': 1, 'h': 2}, {'charge': 1, 'h': 1}]
Found 2 valid states: [{'charge': 1, 'h': 2}, {'charge': 1, 'h': 1}]
Added valid state: charge=1, H=1
Generated 2 molecules

Top 5 most important features:
  NumHs: -1.0000
  Valence: -1.0000
  AtomicNum: 0.0000
  FormalCharge: 0.0000
  Hybridization: 0.0000

Analyzing atom C13...

Generating neighborhood for C13
Current state: charge=0, H=0
Generated 1 valid states for C13: [{'charge': 0, 'h': 0}]
Found 1 valid states: [{'charge': 0, 'h': 0}]
Generated 1 molecules
Not enough valid molecules generated
Could not generate explanation for atom 13

Analyzing atom O15...

Generating neighborhood for O15
Current state: charge=-1, H=0
Generated 2 valid states for O15: [{'charge': -1, 'h': 0}, {'charg