In [1]:
# pip install tensorflow[and-cuda]
# !pip install numpy==1.24.3 --force-reinstall

In [2]:
import numpy as np
print(np.__version__)

2.0.2


In [3]:
import torch
import pickle
from rdkit import Chem
from rdkit.Chem import Draw
import shap
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Optional
from torch_geometric.data import Data
from rdkit.Chem import RemoveHs
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit import RDLogger
# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
import tensorflow as tensorflow
import traceback
import os
from datetime import datetime
from rdkit.Chem import rdDepictor
import matplotlib.pyplot as plt
from rdkit.Chem import AllChem, Draw, rdDepictor
from matplotlib.colors import LinearSegmentedColormap

class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

# Load Encoder
def load_encoder(model_path, device='cpu'):
    """Load trained encoder"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    encoder.eval()
    return encoder.to(device)

# Load Embeddings
def load_embeddings(filepath):
    """Load embeddings and labels"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['embeddings'], data['labels']

# Paths from your saved model
encoder_path = './checkpoints/encoders/final_encoder_20250216_111050.pt'
embedding_path = './embeddings/final_embeddings_20250216_111005.pkl'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load encoder and embeddings
encoder = load_encoder(encoder_path, device)
embeddings, graph_data = load_embeddings(embedding_path)


  checkpoint = torch.load(model_path, map_location=device)


In [4]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [5]:
import shap
import torch
import numpy as np
from torch_geometric.data import Batch, Data
from typing import List, Tuple
from matplotlib.colors import LinearSegmentedColormap

class GraphModelWrapper:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.eval()
        self.original_x_cat = None
        self.original_x_phys = None
        self.batch = None
        
    def __call__(self, features):
        """
        Custom call method to handle graph data
        features: feature matrix (numpy array)
        """
        with torch.no_grad():
            if isinstance(features, np.ndarray):
                features = torch.tensor(features, dtype=torch.float).to(self.device)
            
            # Get original shapes
            num_nodes = self.original_x_cat.size(0)
            num_cat_features = self.original_x_cat.size(1)
            
            # Reshape features to match original dimensions
            x_cat = features.reshape(num_nodes, num_cat_features).to(torch.long)
            
            # Create a new batch with modified features
            new_data = Data(
                x_cat=x_cat,
                x_phys=self.original_x_phys,
                edge_index=self.batch.edge_index,
                edge_attr=self.batch.edge_attr,
                batch=self.batch.batch if hasattr(self.batch, 'batch') else None
            )
            
            # Get model output
            outputs = self.model(new_data)
            return outputs.cpu().numpy()

class ModifiedGraphWrapper:
    def __init__(self, model, device, batch):
        self.model = model
        self.device = device
        self.original_batch = batch
        self.model.eval()
        self.num_nodes = batch.x_cat.shape[0]
        self.num_features = batch.x_cat.shape[1]
        self.num_phys_features = batch.x_phys.shape[1]
        
    def __call__(self, x):
        with torch.no_grad():
            try:
                # Convert input to tensor
                x = torch.tensor(x, dtype=torch.float).to(self.device)
                
                # Reshape x to match the expected input shape
                if len(x.shape) == 1:
                    x = x.reshape(1, self.num_nodes, self.num_features)
                else:
                    x = x.reshape(-1, self.num_nodes, self.num_features)
                
                print(f"Processing batch of size {x.shape[0]}")
                
                all_results = []
                for idx in range(x.shape[0]):
                    # Extract categorical features
                    x_cat = x[idx].to(torch.long)
                    
                    # Ensure proper dimensions for x_cat and x_phys
                    if len(x_cat.shape) == 2:
                        x_cat = x_cat.unsqueeze(0)
                    x_phys = self.original_batch.x_phys.unsqueeze(0)
                    
                    # Create consistent batch dimension
                    batch_idx = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device)
                    
                    # Create data object for this sample
                    new_data = Data(
                        x_cat=x_cat.squeeze(0),
                        x_phys=x_phys.squeeze(0),
                        edge_index=self.original_batch.edge_index,
                        edge_attr=self.original_batch.edge_attr,
                        batch=batch_idx,
                        num_nodes=self.num_nodes
                    ).to(self.device)
                    
                    # Get node features
                    node_features = torch.cat([new_data.x_cat.float(), new_data.x_phys], dim=-1)
                    x_encoded = self.model.node_encoder(node_features)
                    
                    # Get intermediate representations
                    x1 = F.relu(self.model.conv1(x_encoded, new_data.edge_index))
                    x2 = F.relu(self.model.conv2(x1, new_data.edge_index))
                    x3 = self.model.conv3(x2, new_data.edge_index)
                    
                    # Combine representations from different layers
                    combined_features = torch.stack([x1, x2, x3], dim=0)
                    node_embeddings = torch.mean(combined_features, dim=0)
                    
                    # Compute node importance
                    node_importance = torch.norm(node_embeddings, dim=1).cpu().numpy()
                    all_results.append(node_importance)
                
                result = np.array(all_results)
                print(f"Result shape: {result.shape}")
                return result
                
            except Exception as e:
                print(f"Error in model wrapper: {e}")
                print(f"Debug info:")
                print(f"x shape: {x.shape}")
                if 'node_features' in locals():
                    print(f"node_features shape: {node_features.shape}")
                if 'node_embeddings' in locals():
                    print(f"node_embeddings shape: {node_embeddings.shape}")
                raise

def explain_graph_model(model, graph_data, device, num_samples=100):
    """Generate SHAP explanations for graph neural network"""
    batch = Batch.from_data_list([graph_data]).to(device)
    
    # Initialize wrapper
    model_wrapper = ModifiedGraphWrapper(model, device, batch)
    
    # Create background data
    background = batch.x_cat.cpu().numpy().astype(float)
    background_flat = background.reshape(-1)
    
    # Generate background samples
    n_background = 50
    background_samples = []
    for _ in range(n_background):
        perturbed = background_flat.copy()
        noise = np.random.normal(0, 0.3, perturbed.shape)
        perturbed = np.clip(perturbed + noise, 0, None)
        background_samples.append(perturbed)
    
    background_matrix = np.stack(background_samples)
    print(f"Background matrix shape: {background_matrix.shape}")
    
    try:
        # Test run
        test_output = model_wrapper(background_matrix[0:1])
        print(f"Test output shape: {test_output.shape}")
        
        # Initialize explainer
        explainer = shap.KernelExplainer(
            model_wrapper,
            background_matrix,
            link="identity",
            feature_perturbation="interventional"
        )
        
        # Calculate SHAP values
        shap_values = explainer.shap_values(
            background_flat,
            nsamples=200,
            l1_reg="num_features(10)",
            silent=True
        )
        
        if isinstance(shap_values, list):
            shap_values = np.array(shap_values)
        
        print(f"Raw SHAP values shape: {shap_values.shape}")
        
        # Process SHAP values to get node importance
        # Sum absolute SHAP values across all features for each node
        num_nodes = batch.x_cat.shape[0]
        node_importance = np.zeros(num_nodes)
        
        # Aggregate SHAP values per node
        features_per_node = batch.x_cat.shape[1]  # number of features per node
        for i in range(num_nodes):
            start_idx = i * features_per_node
            end_idx = (i + 1) * features_per_node
            node_importance[i] = np.abs(shap_values[start_idx:end_idx]).sum()
        
        # Normalize importance scores
        node_importance = (node_importance - node_importance.min()) / \
                         (node_importance.max() - node_importance.min() + 1e-10)
        
        print("\nNode importance statistics:")
        print(f"Shape: {node_importance.shape}")
        print(f"Range: {node_importance.min():.6f} to {node_importance.max():.6f}")
        print(f"Mean: {node_importance.mean():.6f}")
        print(f"Std: {node_importance.std():.6f}")
        
        # Print top 5 most important nodes
        sorted_indices = np.argsort(node_importance)[::-1]
        print("\nTop 5 most important nodes:")
        for i in range(5):
            idx = sorted_indices[i]
            print(f"Node {idx}: {node_importance[idx]:.6f}")
        
        return node_importance, shap_values
        
    except Exception as e:
        print(f"Error in SHAP calculation: {e}")
        print(f"Debug info:")
        print(f"- Background matrix shape: {background_matrix.shape}")
        print(f"- Number of nodes: {batch.x_cat.shape[0]}")
        print(f"- Feature dimension: {batch.x_cat.shape[1]}")
        if 'shap_values' in locals():
            print(f"- SHAP values shape: {shap_values.shape}")
        raise

        
def visualize_molecule_with_importance(smiles, node_importance, save_path=None):
    """
    Create a figure with both molecular visualization and bar plot of importance scores
    """
    import matplotlib.pyplot as plt
    from rdkit import Chem
    from rdkit.Chem import AllChem, Draw
    from matplotlib.colors import LinearSegmentedColormap
    import numpy as np
    
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"Could not parse SMILES: {smiles}")
        return None
        
    num_atoms = mol.GetNumAtoms()
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 14), 
                                  gridspec_kw={'height_ratios': [1.5, 1]})
    
    # 1. Molecular Visualization
    atom_colors = {}
    atom_info = []
    
    # Ensure importance scores are normalized to [0,1]
    normalized_scores = (node_importance - node_importance.min()) / \
                       (node_importance.max() - node_importance.min())
    
    # Custom color map for better visibility
    def get_color(score):
        """Generate color based on importance score with more distinct gradients"""
        if score > 0.8:
            return (0.8, 0.0, 0.0)  # Deep red
        elif score > 0.6:
            return (1.0, 0.2, 0.2)  # Medium-deep red
        elif score > 0.4:
            return (1.0, 0.5, 0.5)  # Medium red
        elif score > 0.3:
            return (1.0, 0.7, 0.7)  # Light-medium red
        elif score > 0.2:
            return (1.0, 0.85, 0.85)  # Light red
        else:
            return (1.0, 0.95, 0.95)  # Very light red/almost white
    
    # Generate 2D coordinates if they don't exist
    if mol.GetNumConformers() == 0:
        AllChem.Compute2DCoords(mol)
    
    # Create highlighting details
    highlight_atoms = []
    highlight_colors = {}
    atom_labels = {}
    
    for i in range(num_atoms):
        atom = mol.GetAtomWithIdx(i)
        importance = normalized_scores[i]
        atom_info.append((i, atom.GetSymbol(), importance))
        
        # Add to highlighting
        highlight_atoms.append(i)
        highlight_colors[i] = get_color(importance)
        
        # Create detailed atom label with symbol and index
        atom_labels[i] = f"{atom.GetSymbol()}{i}"
    
    # Draw the molecule with enhanced settings
    d2d = Draw.MolDraw2DCairo(800, 800)
    d2d.drawOptions().addAtomIndices = True
    d2d.drawOptions().additionalAtomLabelPadding = 0.25
    d2d.drawOptions().bondLineWidth = 2.0
    d2d.drawOptions().padding = 0.2
    
    # Convert highlight colors to the format RDKit expects
    highlight_colors_map = {}
    for idx, color in highlight_colors.items():
        highlight_colors_map[idx] = color
    
    # Draw molecule
    d2d.DrawMolecule(
        mol,
        highlightAtoms=highlight_atoms,
        highlightAtomColors=highlight_colors_map
    )
    d2d.FinishDrawing()
    
    # Convert to PIL Image
    import io
    from PIL import Image
    png = d2d.GetDrawingText()
    img = Image.open(io.BytesIO(png))
    
    # Show molecule in first subplot
    ax1.imshow(img)
    ax1.axis('off')
    ax1.set_title('Molecular Structure with SHAP Importance\nRed intensity indicates importance', 
                  pad=20, fontsize=14)
    
    # 2. Bar Plot
    # Sort atoms by importance
    atom_info.sort(key=lambda x: x[2], reverse=True)
    
    # Create bar plot
    indices, symbols, scores = zip(*atom_info)
    atom_labels = [f"{symbols[i]}{indices[i]}" for i in range(len(indices))]
    
    bars = ax2.bar(range(len(scores)), scores)
    
    # Color bars using same gradient
    for idx, bar in enumerate(bars):
        bar.set_facecolor(get_color(scores[idx]))
    
    # Customize bar plot
    ax2.set_xticks(range(len(scores)))
    ax2.set_xticklabels(atom_labels, rotation=45, ha='right', fontsize=10)
    ax2.set_ylabel('SHAP Importance Score', fontsize=12)
    ax2.set_title('Atom-wise SHAP Importance Scores', pad=20, fontsize=14)
    ax2.set_ylim(0, 1.05)
    
    # Add grid for easier reading
    ax2.grid(True, axis='y', linestyle='--', alpha=0.7)
    
    # Add legend for color interpretation
    legend_text = (
        'Color Legend:\n'
        '  Deep Red: Very high importance (>0.8)\n'
        '  Medium-deep Red: High importance (0.6-0.8)\n'
        '  Medium Red: Medium importance (0.4-0.6)\n'
        '  Light-medium Red: Low-medium importance (0.3-0.4)\n'
        '  Light Red: Low importance (0.2-0.3)\n'
        '  Very Light Red: Very low importance (<0.2)'
    )
    ax2.text(0.02, 0.98, legend_text, 
             transform=ax2.transAxes, 
             fontsize=8, 
             verticalalignment='top',
             bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    
    return fig

def print_atom_mapping(smiles, node_importance):
    """Print mapping between atom indices and types with their importance scores"""
    from rdkit import Chem
    from rdkit.Chem import AllChem
    import numpy as np
    
    if node_importance is None:
        print("Error: No importance scores available")
        return
        
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            print(f"Error: Could not parse SMILES: {smiles}")
            return
            
        print("\nAtom Mapping:")
        print("Index | Atom | SHAP Score | Normalized Score")
        print("-" * 50)
        
        # Convert to numpy array if needed
        if not isinstance(node_importance, np.ndarray):
            node_importance = np.array(node_importance)
        
        # Normalize importance scores
        normalized_scores = (node_importance - node_importance.min()) / \
                          (node_importance.max() - node_importance.min())
        
        # Create mapping with both raw and normalized scores
        atom_info = []
        for i in range(mol.GetNumAtoms()):
            atom = mol.GetAtomWithIdx(i)
            raw_score = node_importance[i]
            norm_score = normalized_scores[i]
            atom_info.append((i, atom.GetSymbol(), raw_score, norm_score))
        
        # Sort by normalized importance
        atom_info.sort(key=lambda x: x[3], reverse=True)
        
        # Print mapping
        for idx, symbol, raw_score, norm_score in atom_info:
            print(f"{idx:5d} | {symbol:4s} | {raw_score:10.4f} | {norm_score:.4f}")
            
        # Print additional information
        print("\nImportance Score Statistics:")
        print(f"Minimum: {node_importance.min():.4f}")
        print(f"Maximum: {node_importance.max():.4f}")
        print(f"Mean: {node_importance.mean():.4f}")
        print(f"Std Dev: {node_importance.std():.4f}")
        
        # Print top N most important atoms
        N = 5
        print(f"\nTop {N} Most Important Atoms:")
        for idx, symbol, raw_score, norm_score in atom_info[:N]:
            print(f"Atom {symbol}{idx}: {raw_score:.4f} (normalized: {norm_score:.4f})")
            
    except Exception as e:
        print(f"Error in atom mapping: {str(e)}")
        import traceback
        traceback.print_exc()
        

def visualize_molecule_importance(smiles, node_importance, save_path=None):
    mol = Chem.MolFromSmiles(smiles)
    num_atoms = mol.GetNumAtoms()
    
    # Ensure we have the right number of importance scores
    if len(node_importance) >= num_atoms:
        # Get importance scores for atoms
        imp_scores = node_importance[:num_atoms]
        
        # Use percentile-based normalization
        min_val = np.percentile(imp_scores, 10)
        max_val = np.percentile(imp_scores, 90)
        
        if max_val > min_val:
            normalized_scores = (imp_scores - min_val) / (max_val - min_val)
            normalized_scores = np.clip(normalized_scores, 0, 1)
        else:
            normalized_scores = np.zeros_like(imp_scores)
        
        # Apply non-linear scaling to enhance differences
#         normalized_scores = np.power(normalized_scores, 0.5)
        normalized_scores = np.exp(4 * normalized_scores) / np.exp(4)
        
        # Create atom colors
        atom_colors = {}
        importance_info = []
        
        for i in range(num_atoms):
            atom = mol.GetAtomWithIdx(i)
            score = normalized_scores[i]
            raw_score = node_importance[i]
            
            if score < 0.3:
                color = (1.0, 1.0, 1.0)  # White for low importance
            elif score < 0.6:
                color = (1.0, 0.7, 0.7)  # Light red for medium importance
            else:
                color = (1.0, 0.0, 0.0)  # Deep red for high importance
            
            atom_colors[i] = color
            
#             # Enhanced color gradient
#             red = min(1.0, score * 1.2)
#             white = max(0.0, 1.0 - score * 1.5)
#             atom_colors[i] = (1.0, white, white)          
            
            importance_info.append((i, atom.GetSymbol(), raw_score, score))
        
        # Sort and print atom importance
        importance_info.sort(key=lambda x: x[2], reverse=True)
        print("\nAtom Importance Ranking (Top 5):")
        for idx, symbol, raw_score, norm_score in importance_info[:5]:
            print(f"Atom {idx} ({symbol}): raw importance = {raw_score:.6f}, "
                  f"normalized = {norm_score:.4f}")
        
        # Generate 2D coordinates
        AllChem.Compute2DCoords(mol)
        
        # Create list of bonds to highlight
        bonds = []
        for bond in mol.GetBonds():
            begin_idx = bond.GetBeginAtomIdx()
            end_idx = bond.GetEndAtomIdx()
            # Use average of connected atoms' importance
            bond_importance = (normalized_scores[begin_idx] + normalized_scores[end_idx]) / 2
            if bond_importance > 0.5:  # Only highlight significant bonds
                bonds.append(bond.GetIdx())
        
        # Draw molecule with enhanced visualization
        img = Draw.MolToImage(
            mol,
            highlightAtoms=list(range(num_atoms)),
            highlightColor=None,
            highlightAtomColors=atom_colors,
            highlightBonds=bonds,  # List of bond indices
            size=(800, 800),
            highlightRadius=0.5
        )
        
        if save_path:
            img.save(save_path)
            
            # Create legend with enhanced gradient
            fig, ax = plt.subplots(figsize=(8, 1))
            gradient = np.linspace(0, 1, 256).reshape(1, -1)
            
            # Enhanced colormap
            colors = [(1,1,1), (1,0.5,0.5), (1,0,0)]
            cmap = LinearSegmentedColormap.from_list("custom_enhanced_red", colors, N=256)
            
            ax.imshow(gradient, aspect='auto', cmap=cmap)
            ax.set_xticks([0, 128, 255])
            ax.set_xticklabels(['Low', 'Medium', 'High'])
            ax.set_yticks([])
            plt.title('Atom Importance Scale')
            plt.savefig(save_path.replace('.png', '_legend.png'), 
                       bbox_inches='tight', dpi=300)
            plt.close()
        
        return img

    
# Also modify the generate_explanations function to print more information:
def generate_explanations(model, dataset, device, idx=1):
    """
    Generate and visualize explanations for a specific molecule
    """
    # Get specific molecule
    graph_data = dataset[idx]
    
    print(f"Processing molecule {idx}")
    print(f"Graph data shape - x_cat: {graph_data.x_cat.shape}, x_phys: {graph_data.x_phys.shape}")
    
    try:
        # Generate SHAP explanations
        node_importance, shap_values = explain_graph_model(
            model,
            graph_data,
            device,
            num_samples=100
        )
        
        print(f"Generated node importance shape: {node_importance.shape}")

        # Get SMILES string
        if hasattr(graph_data, 'smiles'):
            smiles = graph_data.smiles
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            save_path = f"molecule_explanation/molecule_explanation_{timestamp}.png"
            
            # Create visualization
            fig = visualize_molecule_with_importance(
                smiles,
                node_importance,
                save_path=save_path
            )
            
            if fig is None:
                print(f"Warning: Could not generate visualization for SMILES: {smiles}")
                return node_importance, shap_values, None
            
            # Save the figure directly here
            fig.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory
                
            return node_importance, shap_values, save_path
        else:
            print("Warning: No SMILES data found for visualization")
            return node_importance, shap_values, None
            
    except Exception as e:
        print(f"Error in generate_explanations: {str(e)}")
        traceback.print_exc()
        return None, None, None
    

def create_shap_style_plots(atom_info, node_importance, save_path=None):
    """
    Create SHAP-style visualizations for molecular importance scores
    
    Parameters:
    atom_info: List of tuples (idx, symbol, raw_score, norm_score)
    node_importance: Array of importance scores
    save_path: Path to save the visualization
    """
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.colors import LinearSegmentedColormap
    
    try:
        # Create a figure with multiple subplots
        fig = plt.figure(figsize=(15, 10))
        
        # Only use the first n values from node_importance where n is the number of atoms
        n_atoms = len(atom_info)
        node_importance_truncated = node_importance[:n_atoms]
        
        print(f"Using first {n_atoms} values from node_importance")
        
        # 1. Summary Plot (Top)
        ax1 = plt.subplot2grid((2, 1), (0, 0), rowspan=1)
        
        # Create data mapping
        data = list(zip(range(n_atoms), node_importance_truncated))
        sorted_data = sorted(data, key=lambda x: abs(x[1]), reverse=True)
        sorted_indices, sorted_values = zip(*sorted_data)
        
        # Create custom colormap (red for positive values)
        colors = [(1.0, 0.9, 0.9), (0.8, 0, 0)]
        custom_cmap = LinearSegmentedColormap.from_list("custom_red", colors)
        
        # Plot SHAP values
        y_pos = np.arange(len(sorted_values))
        max_abs_val = max(abs(min(sorted_values)), abs(max(sorted_values)))
        colors = [custom_cmap(abs(val)/max_abs_val) for val in sorted_values]
        
        bars = ax1.barh(y_pos, sorted_values, color=colors)
        
        # Add atom labels
        atom_labels = []
        for idx in sorted_indices:
            atom = atom_info[idx]  # Get atom info for this index
            atom_labels.append(f"{atom[1]}{atom[0]}")  # symbol + index
            
        ax1.set_yticks(y_pos)
        ax1.set_yticklabels(atom_labels)
        
        ax1.set_xlabel('SHAP value (impact on model output)')
        ax1.set_title('Impact of Each Atom on Model Prediction')
        
        # Add grid for readability
        ax1.grid(True, axis='x', linestyle='--', alpha=0.3)
        
        # 2. Force Plot (Bottom)
        ax2 = plt.subplot2grid((2, 1), (1, 0), rowspan=1)
        
        # Calculate base value (mean prediction)
        base_value = np.mean(node_importance_truncated)
        
        # Sort by absolute contribution
        contributions = node_importance_truncated - base_value
        sorted_by_contrib = sorted(zip(atom_info, contributions), 
                                 key=lambda x: abs(x[1]), reverse=True)
        
        # Create force plot
        cumsum = np.cumsum([0] + [x[1] for x in sorted_by_contrib])
        
        # Plot connecting lines
        for i in range(len(cumsum)-1):
            color = 'red' if sorted_by_contrib[i][1] >= 0 else 'blue'
            ax2.plot([cumsum[i], cumsum[i+1]], [1, 1], color=color, linewidth=2)
            
            # Add atom labels for significant contributions
            if abs(sorted_by_contrib[i][1]) > 0.1:  # Threshold for showing labels
                info, contrib = sorted_by_contrib[i]
                label = f"{info[1]}{info[0]}"  # symbol + index
                y_offset = 1.1 if i % 2 == 0 else 0.9
                ax2.annotate(f"{label}\n{contrib:.3f}", 
                            xy=(cumsum[i], 1),
                            xytext=(cumsum[i], y_offset),
                            ha='center', va='center',
                            arrowprops=dict(arrowstyle='->', color='gray'))
        
        # Add base value and final prediction
        ax2.text(base_value, 0.7, f'Base value\n{base_value:.3f}', 
                 ha='center', va='center')
        ax2.text(cumsum[-1], 0.7, f'Final prediction\n{cumsum[-1]:.3f}', 
                 ha='center', va='center')
        
        # Customize force plot
        ax2.set_ylim(0.5, 1.5)
        ax2.set_xlim(min(cumsum)-0.5, max(cumsum)+0.5)
        ax2.set_title('SHAP Force Plot: How Each Atom Contributes to the Prediction')
        ax2.set_xlabel('Model Prediction')
        ax2.axes.get_yaxis().set_visible(False)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path.replace('.png', '_shap.png'), 
                        dpi=300, bbox_inches='tight')
            plt.close()
        
        return fig
        
    except Exception as e:
        print(f"Error in create_shap_style_plots: {str(e)}")
        print("Debug info:")
        print(f"Length of atom_info: {len(atom_info)}")
        print(f"Length of node_importance: {len(node_importance)}")
        if len(atom_info) > 0:
            print(f"Sample atom_info entry: {atom_info[0]}")
        raise
        
def create_shap_waterfall(atom_info, node_importance, save_path=None):
    """
    Create a SHAP waterfall plot showing cumulative impact of atoms
    
    Parameters:
    atom_info: List of tuples (idx, symbol, raw_score, norm_score)
    node_importance: Array of importance scores
    save_path: Path to save the visualization
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    try:
        # Only use the first n values from node_importance where n is the number of atoms
        n_atoms = len(atom_info)
        node_importance_truncated = node_importance[:n_atoms]
        
        # Sort by absolute importance
        sorted_atoms = sorted(zip(atom_info, node_importance_truncated), 
                            key=lambda x: abs(x[1]), reverse=True)
        
        # Get top 10 contributors for clearer visualization
        top_n = min(10, len(sorted_atoms))
        sorted_atoms = sorted_atoms[:top_n]
        
        # Calculate cumulative sums
        values = [x[1] for x in sorted_atoms]
        cumsum = np.cumsum([0] + values)
        
        # Create figure
        fig, ax = plt.subplots(figsize=(12, 8))
        
        # Plot waterfall
        for i in range(len(values)):
            # Plot vertical line
            ax.plot([i, i], [cumsum[i], cumsum[i+1]], 
                    color='red' if values[i] >= 0 else 'blue',
                    linewidth=2)
            
            # Plot horizontal line
            if i < len(values)-1:
                ax.plot([i, i+1], [cumsum[i+1], cumsum[i+1]], 
                       color='gray', linestyle='--', alpha=0.5)
            
            # Add atom label
            info = sorted_atoms[i][0]
            label = f"{info[1]}{info[0]}"
            ax.annotate(f"{label}\n{values[i]:.3f}",
                       xy=(i, cumsum[i+1]),
                       xytext=(i, cumsum[i+1] + (0.05 if i % 2 == 0 else -0.05)),
                       ha='center', va='center')
        
        ax.set_title('SHAP Waterfall Plot: Cumulative Impact of Top Atoms')
        ax.set_xlabel('Atoms (ordered by importance)')
        ax.set_ylabel('Cumulative SHAP Value')
        
        # Add grid
        ax.grid(True, axis='y', linestyle='--', alpha=0.3)
        
        # Add explanation
        plt.figtext(0.02, 0.02,
                    'Red lines show positive contributions\n'
                    'Blue lines show negative contributions',
                    fontsize=8)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path.replace('.png', '_waterfall.png'),
                        dpi=300, bbox_inches='tight')
            plt.close()
        
        return fig
        
    except Exception as e:
        print(f"Error in create_shap_waterfall: {str(e)}")
        print("Debug info:")
        print(f"Length of atom_info: {len(atom_info)}")
        print(f"Length of node_importance: {len(node_importance)}")
        if len(atom_info) > 0:
            print(f"Sample atom_info entry: {atom_info[0]}")
        raise        

def create_shap_plots(atom_info, node_importance, save_path=None):
    """
    Create SHAP-style visualizations with beeswarm and force plots
    """
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.colors import LinearSegmentedColormap
    
    try:
        # Create figure with subplots
        fig = plt.figure(figsize=(12, 8))
        
        # Only use the first n values from node_importance where n is the number of atoms
        n_atoms = len(atom_info)
        node_importance_truncated = node_importance[:n_atoms]
        
        # 1. Beeswarm-style Plot (Top)
        ax1 = plt.subplot2grid((2, 1), (0, 0), rowspan=1)
        
        # Sort features by absolute mean SHAP value
        feature_order = np.argsort(np.abs(node_importance_truncated))
        
        # Create color map similar to SHAP
        colors = ['#ff0051', '#fa5e4f', '#f57c47', '#f1933f', '#ec9637', '#e7992f',
                 '#e29b27', '#dcac20', '#d7be18', '#d2d011', '#c7d11b', '#bcd225',
                 '#b2d42f', '#a7d539', '#9dd644', '#92d74e', '#87d858', '#7dd963',
                 '#72da6d', '#67db77', '#5ddc82', '#52dd8c', '#47de96', '#3ddfb0',
                 '#32e0ba', '#27e1c5', '#1de2cf', '#12e3d9', '#07e4e4']
        cmap = LinearSegmentedColormap.from_list("shap", colors)
        
        # Normalize values for coloring
        max_val = np.max(np.abs(node_importance_truncated))
        norm_values = node_importance_truncated / max_val
        
        # Plot dots
        for i, idx in enumerate(feature_order):
            value = node_importance_truncated[idx]
            color = cmap(abs(value) / max_val)
            # Add slight random jitter for beeswarm effect
            jitter = np.random.normal(0, 0.02, 1)[0]
            ax1.scatter(value, i + jitter, c=[color], alpha=0.7)
            
        # Add feature labels
        feature_labels = [f"{atom_info[i][1]}{atom_info[i][0]}" for i in feature_order]
        ax1.set_yticks(range(len(feature_labels)))
        ax1.set_yticklabels(feature_labels)
        
        # Customize plot
        ax1.axvline(x=0, color='gray', linestyle='-', alpha=0.3)
        ax1.set_xlabel('SHAP value (impact on model output)')
        
        # Add grid
        ax1.grid(True, axis='x', linestyle='--', alpha=0.2)
        
        # 2. Force Plot (Bottom)
        ax2 = plt.subplot2grid((2, 1), (1, 0), rowspan=1)
        
        # Calculate base value and contributions
        base_value = np.mean(node_importance_truncated)
        contributions = node_importance_truncated - base_value
        sorted_contributions = sorted(zip(atom_info, contributions), 
                                   key=lambda x: abs(x[1]), reverse=True)
        
        # Create force plot
        cumsum = np.cumsum([0] + [x[1] for x in sorted_contributions])
        
        # Plot segments
        for i in range(len(cumsum)-1):
            contrib = sorted_contributions[i][1]
            color = '#ff0051' if contrib >= 0 else '#008bfb'  # SHAP's red/blue colors
            
            # Draw contribution segment
            ax2.plot([cumsum[i], cumsum[i+1]], [1, 1], 
                    color=color, linewidth=3, solid_capstyle='butt')
            
            # Add feature labels for significant contributions
            if abs(contrib) > 0.1:
                info = sorted_contributions[i][0]
                label = f"{info[1]}{info[0]}"
                y_offset = 1.1 if i % 2 == 0 else 0.9
                ax2.annotate(f"{label}\n{contrib:.3f}", 
                            xy=(cumsum[i], 1),
                            xytext=(cumsum[i], y_offset),
                            ha='center', va='center',
                            arrowprops=dict(arrowstyle='->', color='gray', alpha=0.5))
        
        # Add base and final values
        ax2.text(base_value, 0.7, f'E[f(x)]\n{base_value:.3f}', 
                ha='center', va='center', fontsize=10,
                bbox=dict(facecolor='white', edgecolor='gray', alpha=0.8))
        ax2.text(cumsum[-1], 0.7, f'f(x)\n{cumsum[-1]:.3f}', 
                ha='center', va='center', fontsize=10,
                bbox=dict(facecolor='white', edgecolor='gray', alpha=0.8))
        
        # Customize force plot
        ax2.set_ylim(0.5, 1.5)
        ax2.set_xlim(min(cumsum)-0.5, max(cumsum)+0.5)
        ax2.set_xlabel('Model output')
        ax2.axes.get_yaxis().set_visible(False)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path.replace('.png', '_shap.png'), 
                       dpi=300, bbox_inches='tight')
            plt.close()
        
        return fig
        
    except Exception as e:
        print(f"Error in create_shap_style_plots: {str(e)}")
        print("Debug info:")
        print(f"Length of atom_info: {len(atom_info)}")
        print(f"Length of node_importance: {len(node_importance)}")
        if len(atom_info) > 0:
            print(f"Sample atom_info entry: {atom_info[0]}")
        raise        
        
        
def create_shap_visualizations(atom_info, node_importance, save_path=None):
    """
    Create visualizations using SHAP's native plotting functions
    """
    import shap
    import numpy as np
    import pandas as pd
    
    try:
        # Get number of atoms
        n_atoms = len(atom_info)
        node_importance_truncated = node_importance[:n_atoms]
        
        # Create feature matrix X
        # Convert atom info to a pandas DataFrame
        feature_names = [f"{info[1]}{info[0]}" for info in atom_info]  # Atom symbol + index
        X = pd.DataFrame([info[2] for info in atom_info], columns=['value'])
        X.index = feature_names
        
        # Create SHAP values array
        shap_values = node_importance_truncated
        
        # Create visualizations
        
        # 1. Summary Plot
        plt.figure(figsize=(10, 6))
        shap.summary_plot(
            shap_values,
            X,
            plot_type="bar",
            show=False
        )
        if save_path:
            plt.savefig(save_path.replace('.png', '_summary.png'), 
                       dpi=300, bbox_inches='tight')
            plt.close()
        
        # 2. Waterfall Plot (for the most important atom)
        plt.figure(figsize=(10, 6))
        shap.plots.waterfall(
            shap.Explanation(
                values=shap_values,
                base_values=np.mean(shap_values),
                data=X['value'].values,
                feature_names=feature_names
            ),
            show=False
        )
        if save_path:
            plt.savefig(save_path.replace('.png', '_waterfall.png'),
                       dpi=300, bbox_inches='tight')
            plt.close()
        
        # 3. Force Plot
        force_plot = shap.force_plot(
            base_value=np.mean(shap_values),
            shap_values=shap_values,
            features=X['value'].values,
            feature_names=feature_names,
            matplotlib=True,
            show=False
        )
        if save_path:
            plt.savefig(save_path.replace('.png', '_force.png'),
                       dpi=300, bbox_inches='tight')
            plt.close()
        
        # 4. Beeswarm Plot
        plt.figure(figsize=(10, 6))
        shap.summary_plot(
            shap_values,
            X,
            plot_type="dot",
            show=False
        )
        if save_path:
            plt.savefig(save_path.replace('.png', '_beeswarm.png'),
                       dpi=300, bbox_inches='tight')
            plt.close()
        
        print(f"SHAP visualizations saved with prefix: {save_path}")
        
    except Exception as e:
        print(f"Error in create_shap_visualizations: {str(e)}")
        print("Debug info:")
        print(f"Length of atom_info: {len(atom_info)}")
        print(f"Length of node_importance: {len(node_importance)}")
        print(f"Feature names: {feature_names}")
        raise

def explain_molecule_with_shap(smiles, node_importance, shap_values, save_dir='shap_explanations'):
    """
    Create comprehensive SHAP visualizations for molecular analysis
    """
    import os
    import shap
    import numpy as np
    import pandas as pd
    from rdkit import Chem
    import matplotlib.pyplot as plt
    
    try:
        # Create save directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        
        # Create molecule object
        mol = Chem.MolFromSmiles(smiles)
        n_atoms = mol.GetNumAtoms()
        
        # Truncate node_importance to match number of atoms
        node_importance = node_importance[:n_atoms]
        
        # Create feature matrix with multiple atomic properties
        features = []
        feature_names = []
        for i in range(n_atoms):
            atom = mol.GetAtomWithIdx(i)
            atom_features = [
                atom.GetAtomicNum(),        # Atomic number
                atom.GetTotalValence(),     # Valence
                atom.GetDegree(),           # Degree
                int(atom.GetIsAromatic()),  # Aromaticity
                atom.GetFormalCharge()      # Formal charge
            ]
            features.append(atom_features)
            symbol = atom.GetSymbol()
            feature_names.extend([
                f"{symbol}{i}_AtomicNum",
                f"{symbol}{i}_Valence",
                f"{symbol}{i}_Degree",
                f"{symbol}{i}_Aromatic",
                f"{symbol}{i}_Charge"
            ])
            
        # Convert to numpy arrays
        X = np.array(features)
        
        # Reshape SHAP values to match feature dimensions
        # Repeat the importance scores for each feature
        shap_values_matrix = np.tile(node_importance[:, np.newaxis], (1, 5)).reshape(-1)
        
        # Create a matrix of SHAP values
        shap_matrix = np.zeros((1, len(shap_values_matrix)))
        shap_matrix[0, :] = shap_values_matrix
        
        print(f"\nDebug - Shapes before visualization:")
        print(f"X shape: {X.shape}")
        print(f"Features per atom: {X.shape[1]}")
        print(f"SHAP matrix shape: {shap_matrix.shape}")
        print(f"Feature names length: {len(feature_names)}")
        
        # Create SHAP explanation object
        X_reshaped = X.reshape(1, -1)  # Reshape to 2D
        explanation = shap.Explanation(
            values=shap_matrix,
            base_values=np.array([np.mean(node_importance)]),
            data=X_reshaped,
            feature_names=feature_names
        )
        
        # Create visualizations
        
        # 1. Summary plot
        plt.figure(figsize=(12, 8))
        shap.summary_plot(
            shap_matrix,
            X_reshaped,
            feature_names=feature_names,
            show=False,
            plot_type="bar"
        )
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'summary_plot.png'), 
                    dpi=300, bbox_inches='tight')
        plt.close()
        
        # 2. Bar plot
        plt.figure(figsize=(12, 8))
        shap.plots.bar(explanation, show=False)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'bar_plot.png'), 
                    dpi=300, bbox_inches='tight')
        plt.close()
        
        # 3. Waterfall plot
        plt.figure(figsize=(12, 8))
        shap.plots.waterfall(explanation[0], show=False)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'waterfall_plot.png'), 
                    dpi=300, bbox_inches='tight')
        plt.close()
        
        # 4. Force plot
        force_plot = shap.force_plot(
            explanation.base_values[0],
            explanation.values[0],
            X_reshaped[0],
            feature_names=feature_names,
            matplotlib=True,
            show=False
        )
        plt.savefig(os.path.join(save_dir, 'force_plot.png'), 
                    dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"\nSHAP visualizations saved in directory: {save_dir}")
        
        return explanation
        
    except Exception as e:
        print(f"Error in explain_molecule_with_shap: {str(e)}")
        print("\nDebug information:")
        if 'X' in locals():
            print(f"X shape: {X.shape}")
        if 'shap_matrix' in locals():
            print(f"SHAP matrix shape: {shap_matrix.shape}")
        if 'feature_names' in locals():
            print(f"Number of features: {len(feature_names)}")
        raise

In [6]:
# First, let's modify your data loading code to store SMILES with each graph
print("Starting data loading...")
extractor = MolecularFeatureExtractor()
smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"

dataset = []
failed_smiles = []

# Modified data loading to store SMILES
with open(smiles_file, 'r') as f:
    for line in f:
        smiles = line.strip()
        data = extractor.process_molecule(smiles)
        if data is not None:
            # Add SMILES as an attribute to the Data object
            data.smiles = smiles  # Add this line
            dataset.append(data)
        else:
            failed_smiles.append(smiles)

print(f"1. Loaded dataset with {len(dataset)} graphs.")
print(f"2. Failed SMILES count: {len(failed_smiles)}")

if not dataset:
    print("No valid graphs generated.")
    
# Make sure to import needed libraries
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
import traceback

os.makedirs('molecule_explanation', exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

try:
    print("\nSelected molecule index: 1")
    node_importance, shap_values, save_path = generate_explanations(
        encoder,
        dataset,
        device,
        idx=1
    )
    
    if node_importance is not None and save_path is not None:
        # Get the SMILES for the selected molecule
        smiles = dataset[1].smiles
        
        # Print atom mapping
        print_atom_mapping(smiles, node_importance)
        
        print(f"Visualization saved as '{save_path}'")
        
        # Create molecule object from SMILES
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            # Prepare normalized scores first
            normalized_scores = ((node_importance - node_importance.min()) / 
                              (node_importance.max() - node_importance.min()))
            
            # Create atom_info with correct format
            atom_info = []
            for i in range(mol.GetNumAtoms()):
                atom = mol.GetAtomWithIdx(i)
                atom_info.append({
                    'index': i,
                    'symbol': atom.GetSymbol(),
                    'importance': node_importance[i],
                    'normalized': normalized_scores[i]
                })
            
            # Convert atom_info to the format expected by visualization functions
            viz_atom_info = [(info['index'], info['symbol'], info['importance'], info['normalized']) 
                            for info in atom_info]
            
            print("\nPreparing SHAP visualizations...")
            # Create SHAP-style visualizations
            shap_style_fig = create_shap_style_plots(viz_atom_info, node_importance, save_path)
            waterfall_fig = create_shap_waterfall(viz_atom_info, node_importance, save_path)
            shap_fig = create_shap_plots(viz_atom_info, node_importance, save_path)
            
            print(f"SHAP visualization saved as '{save_path.replace('.png', '_shap_style.png')}'")
            print(f"Waterfall plot saved as '{save_path.replace('.png', '_waterfall.png')}'")
            print(f"SHAP Style visualization saved as '{save_path.replace('.png', '_shap.png')}'")
            
            # Print some debug information
            print("\nVisualization Debug Info:")
            print(f"Number of atoms in molecule: {mol.GetNumAtoms()}")
            print(f"Length of node_importance: {len(node_importance)}")
            print(f"Length of atom_info: {len(atom_info)}")
            
        else:
            print(f"Warning: Could not parse SMILES: {smiles}")
    else:
        print("Warning: No importance scores or visualization path generated")
    
    
        
    if node_importance is not None:
        # Get the SMILES for the selected molecule
        smiles = dataset[1].smiles
        
        # Create save directory
        save_dir = 'molecule_explanation'
        os.makedirs(save_dir, exist_ok=True)
        
        # Generate SHAP visualizations
        X, shapley_values = explain_molecule_with_shap(
            smiles,
            node_importance,
            shap_values,
            save_dir=save_dir
        )
        
        print("\nVisualization Results:")
        print(f"- Number of atoms in molecule: {len(X)}")
        print(f"- Number of features per atom: {X.shape[1]}")
        print(f"- Shape of SHAP values: {shapley_values.shape}")
        print(f"- Visualizations saved in: {save_dir}")
        
    else:
        print("Warning: No importance scores generated")
    
except Exception as e:
    print(f"\nError generating explanations: {e}")
    print("\nModel architecture:")
    print(encoder)
    traceback.print_exc()

Starting data loading...
1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0

Selected molecule index: 1
Processing molecule 1
Graph data shape - x_cat: torch.Size([42, 2]), x_phys: torch.Size([42, 8])
Background matrix shape: (50, 84)
Processing batch of size 1
Result shape: (1, 42)
Test output shape: (1, 42)
Processing batch of size 50
Result shape: (50, 42)
Processing batch of size 1
Result shape: (1, 42)
Processing batch of size 10000
Result shape: (10000, 42)
Raw SHAP values shape: (84, 42)

Node importance statistics:
Shape: (42,)
Range: 0.000000 to 1.000000
Mean: 0.204713
Std: 0.303392

Top 5 most important nodes:
Node 13: 1.000000
Node 9: 0.973454
Node 2: 0.939662
Node 14: 0.934257
Node 15: 0.694645
Generated node importance shape: (42,)

Atom Mapping:
Index | Atom | SHAP Score | Normalized Score
--------------------------------------------------
   13 | C    |     1.0000 | 1.0000
    9 | N    |     0.9735 | 0.9735
    2 | N    |     0.9397 | 0.9397
   14 | O    |     0

Traceback (most recent call last):
  File "C:\Users\Malli\AppData\Local\Temp\ipykernel_58780\2853609985.py", line 107, in <module>
    X, shapley_values = explain_molecule_with_shap(
ValueError: not enough values to unpack (expected 2, got 1)
