In [1]:
# pip install tensorflow[and-cuda]
# !pip install numpy==1.24.3 --force-reinstall

In [2]:
import numpy as np
print(np.__version__)

2.0.2


In [3]:
import torch
import pickle
from rdkit import Chem
from rdkit.Chem import Draw
import shap
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import DataLoader
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Optional
from torch_geometric.data import Data
from rdkit.Chem import RemoveHs
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit import RDLogger
# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
import tensorflow as tensorflow
import traceback
import os
from datetime import datetime
from rdkit.Chem import rdDepictor
import matplotlib.pyplot as plt
from rdkit.Chem import AllChem, Draw, rdDepictor
from matplotlib.colors import LinearSegmentedColormap

class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

# Load Encoder
def load_encoder(model_path, device='cpu'):
    """Load trained encoder"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    encoder.eval()
    return encoder.to(device)

# Load Embeddings
def load_embeddings(filepath):
    """Load embeddings and labels"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['embeddings'], data['labels']

# Paths from your saved model
encoder_path = './checkpoints/encoders/final_encoder_20250216_111050.pt'
embedding_path = './embeddings/final_embeddings_20250216_111005.pkl'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load encoder and embeddings
encoder = load_encoder(encoder_path, device)
embeddings, graph_data = load_embeddings(embedding_path)


  checkpoint = torch.load(model_path, map_location=device)


In [4]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [5]:
import shap
import torch
import numpy as np
from torch_geometric.data import Batch, Data
from typing import List, Tuple
from matplotlib.colors import LinearSegmentedColormap

class GraphModelWrapper:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.eval()
        self.original_x_cat = None
        self.original_x_phys = None
        self.batch = None
        
    def __call__(self, features):
        """
        Custom call method to handle graph data
        features: feature matrix (numpy array)
        """
        with torch.no_grad():
            if isinstance(features, np.ndarray):
                features = torch.tensor(features, dtype=torch.float).to(self.device)
            
            # Get original shapes
            num_nodes = self.original_x_cat.size(0)
            num_cat_features = self.original_x_cat.size(1)
            
            # Reshape features to match original dimensions
            x_cat = features.reshape(num_nodes, num_cat_features).to(torch.long)
            
            # Create a new batch with modified features
            new_data = Data(
                x_cat=x_cat,
                x_phys=self.original_x_phys,
                edge_index=self.batch.edge_index,
                edge_attr=self.batch.edge_attr,
                batch=self.batch.batch if hasattr(self.batch, 'batch') else None
            )
            
            # Get model output
            outputs = self.model(new_data)
            return outputs.cpu().numpy()

class ModifiedGraphWrapper:
    def __init__(self, model, device, batch):
        self.model = model
        self.device = device
        self.original_batch = batch
        self.model.eval()
        self.num_nodes = batch.x_cat.shape[0]
        self.num_features = batch.x_cat.shape[1]
        self.num_phys_features = batch.x_phys.shape[1]
        
    def __call__(self, x):
        with torch.no_grad():
            try:
                # Convert input to tensor
                x = torch.tensor(x, dtype=torch.float).to(self.device)
                
                # Reshape x to match the expected input shape
                if len(x.shape) == 1:
                    x = x.reshape(1, self.num_nodes, self.num_features)
                else:
                    x = x.reshape(-1, self.num_nodes, self.num_features)
                
                print(f"Processing batch of size {x.shape[0]}")
                
                all_results = []
                for idx in range(x.shape[0]):
                    # Extract categorical features
                    x_cat = x[idx].to(torch.long)
                    
                    # Ensure proper dimensions for x_cat and x_phys
                    if len(x_cat.shape) == 2:
                        x_cat = x_cat.unsqueeze(0)
                    x_phys = self.original_batch.x_phys.unsqueeze(0)
                    
                    # Create consistent batch dimension
                    batch_idx = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device)
                    
                    # Create data object for this sample
                    new_data = Data(
                        x_cat=x_cat.squeeze(0),
                        x_phys=x_phys.squeeze(0),
                        edge_index=self.original_batch.edge_index,
                        edge_attr=self.original_batch.edge_attr,
                        batch=batch_idx,
                        num_nodes=self.num_nodes
                    ).to(self.device)
                    
                    # Get node features
                    node_features = torch.cat([new_data.x_cat.float(), new_data.x_phys], dim=-1)
                    x_encoded = self.model.node_encoder(node_features)
                    
                    # Get intermediate representations
                    x1 = F.relu(self.model.conv1(x_encoded, new_data.edge_index))
                    x2 = F.relu(self.model.conv2(x1, new_data.edge_index))
                    x3 = self.model.conv3(x2, new_data.edge_index)
                    
                    # Combine representations from different layers
                    combined_features = torch.stack([x1, x2, x3], dim=0)
                    node_embeddings = torch.mean(combined_features, dim=0)
                    
                    # Compute node importance
                    node_importance = torch.norm(node_embeddings, dim=1).cpu().numpy()
                    all_results.append(node_importance)
                
                result = np.array(all_results)
                print(f"Result shape: {result.shape}")
                return result
                
            except Exception as e:
                print(f"Error in model wrapper: {e}")
                print(f"Debug info:")
                print(f"x shape: {x.shape}")
                if 'node_features' in locals():
                    print(f"node_features shape: {node_features.shape}")
                if 'node_embeddings' in locals():
                    print(f"node_embeddings shape: {node_embeddings.shape}")
                raise

def explain_graph_model(model, graph_data, device, num_samples=100):
    """Generate SHAP explanations for graph neural network"""
    batch = Batch.from_data_list([graph_data]).to(device)
    
    # Initialize wrapper
    model_wrapper = ModifiedGraphWrapper(model, device, batch)
    
    # Create background data
    background = batch.x_cat.cpu().numpy().astype(float)
    background_flat = background.reshape(-1)
    
    # Generate background samples
    n_background = 50
    background_samples = []
    for _ in range(n_background):
        perturbed = background_flat.copy()
        noise = np.random.normal(0, 0.3, perturbed.shape)
        perturbed = np.clip(perturbed + noise, 0, None)
        background_samples.append(perturbed)
    
    background_matrix = np.stack(background_samples)
    print(f"Background matrix shape: {background_matrix.shape}")
    
    try:
        # Test run
        test_output = model_wrapper(background_matrix[0:1])
        print(f"Test output shape: {test_output.shape}")
        
        # Initialize explainer
        explainer = shap.KernelExplainer(
            model_wrapper,
            background_matrix,
            link="identity",
            feature_perturbation="interventional"
        )
        
        # Calculate SHAP values
        shap_values = explainer.shap_values(
            background_flat,
            nsamples=200,
            l1_reg="num_features(10)",
            silent=True
        )
        
        if isinstance(shap_values, list):
            shap_values = np.array(shap_values)
        
        print(f"Raw SHAP values shape: {shap_values.shape}")
        
        # Process SHAP values to get node importance
        # Sum absolute SHAP values across all features for each node
        num_nodes = batch.x_cat.shape[0]
        node_importance = np.zeros(num_nodes)
        
        # Aggregate SHAP values per node
        features_per_node = batch.x_cat.shape[1]  # number of features per node
        for i in range(num_nodes):
            start_idx = i * features_per_node
            end_idx = (i + 1) * features_per_node
            node_importance[i] = np.abs(shap_values[start_idx:end_idx]).sum()
        
        # Normalize importance scores
        node_importance = (node_importance - node_importance.min()) / \
                         (node_importance.max() - node_importance.min() + 1e-10)
        
        print("\nNode importance statistics:")
        print(f"Shape: {node_importance.shape}")
        print(f"Range: {node_importance.min():.6f} to {node_importance.max():.6f}")
        print(f"Mean: {node_importance.mean():.6f}")
        print(f"Std: {node_importance.std():.6f}")
        
        # Print top 5 most important nodes
        sorted_indices = np.argsort(node_importance)[::-1]
        print("\nTop 5 most important nodes:")
        for i in range(5):
            idx = sorted_indices[i]
            print(f"Node {idx}: {node_importance[idx]:.6f}")
        
        return node_importance, shap_values
        
    except Exception as e:
        print(f"Error in SHAP calculation: {e}")
        print(f"Debug info:")
        print(f"- Background matrix shape: {background_matrix.shape}")
        print(f"- Number of nodes: {batch.x_cat.shape[0]}")
        print(f"- Feature dimension: {batch.x_cat.shape[1]}")
        if 'shap_values' in locals():
            print(f"- SHAP values shape: {shap_values.shape}")
        raise

        
def visualize_molecule_with_importance(smiles, node_importance, save_path=None):
    """
    Create a figure with both molecular visualization and bar plot of importance scores
    """
    import matplotlib.pyplot as plt
    from rdkit import Chem
    from rdkit.Chem import AllChem, Draw
    from matplotlib.colors import LinearSegmentedColormap
    
    mol = Chem.MolFromSmiles(smiles)
    num_atoms = mol.GetNumAtoms()
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12), 
                                  gridspec_kw={'height_ratios': [1.5, 1]})
    
    # 1. Molecular Visualization
    atom_colors = {}
    atom_info = []
    
    for i in range(num_atoms):
        atom = mol.GetAtomWithIdx(i)
        importance = node_importance[i]
        atom_info.append((i, atom.GetSymbol(), importance))
        
        # Create color gradient from white (0) to red (1)
        intensity = importance
        atom_colors[i] = (1.0, 1.0 - intensity, 1.0 - intensity)
    
    # Generate 2D coordinates
    AllChem.Compute2DCoords(mol)
    
    # Draw molecule
    img = Draw.MolToImage(
        mol,
        highlightAtoms=list(range(num_atoms)),
        highlightColor=None,
        highlightAtomColors=atom_colors,
        size=(400, 400)
    )
    
    # Show molecule in first subplot
    ax1.imshow(img)
    ax1.axis('off')
    ax1.set_title('Molecular Structure with SHAP Importance')
    
    # 2. Bar Plot
    # Sort atoms by importance
    atom_info.sort(key=lambda x: x[2], reverse=True)
    
    # Create bar plot
    indices, symbols, scores = zip(*atom_info)
    atom_labels = [f"{symbols[i]}{indices[i]}" for i in range(len(indices))]
    
    bars = ax2.bar(range(len(scores)), scores)
    
    # Color bars using same gradient as molecule
    for idx, bar in enumerate(bars):
        bar.set_facecolor((1.0, 1.0 - scores[idx], 1.0 - scores[idx]))
    
    # Customize bar plot
    ax2.set_xticks(range(len(scores)))
    ax2.set_xticklabels(atom_labels, rotation=45, ha='right')
    ax2.set_ylabel('SHAP Importance Score')
    ax2.set_title('Atom-wise SHAP Importance Scores')
    
    # Add grid for easier reading
    ax2.grid(True, axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    
    return fig
        
def visualize_molecule_importance(smiles, node_importance, save_path=None):
    mol = Chem.MolFromSmiles(smiles)
    num_atoms = mol.GetNumAtoms()
    
    # Ensure we have the right number of importance scores
    if len(node_importance) >= num_atoms:
        # Get importance scores for atoms
        imp_scores = node_importance[:num_atoms]
        
        # Use percentile-based normalization
        min_val = np.percentile(imp_scores, 10)
        max_val = np.percentile(imp_scores, 90)
        
        if max_val > min_val:
            normalized_scores = (imp_scores - min_val) / (max_val - min_val)
            normalized_scores = np.clip(normalized_scores, 0, 1)
        else:
            normalized_scores = np.zeros_like(imp_scores)
        
        # Apply non-linear scaling to enhance differences
#         normalized_scores = np.power(normalized_scores, 0.5)
        normalized_scores = np.exp(4 * normalized_scores) / np.exp(4)
        
        # Create atom colors
        atom_colors = {}
        importance_info = []
        
        for i in range(num_atoms):
            atom = mol.GetAtomWithIdx(i)
            score = normalized_scores[i]
            raw_score = node_importance[i]
            
            if score < 0.3:
                color = (1.0, 1.0, 1.0)  # White for low importance
            elif score < 0.6:
                color = (1.0, 0.7, 0.7)  # Light red for medium importance
            else:
                color = (1.0, 0.0, 0.0)  # Deep red for high importance
            
            atom_colors[i] = color
            
#             # Enhanced color gradient
#             red = min(1.0, score * 1.2)
#             white = max(0.0, 1.0 - score * 1.5)
#             atom_colors[i] = (1.0, white, white)          
            
            importance_info.append((i, atom.GetSymbol(), raw_score, score))
        
        # Sort and print atom importance
        importance_info.sort(key=lambda x: x[2], reverse=True)
        print("\nAtom Importance Ranking (Top 5):")
        for idx, symbol, raw_score, norm_score in importance_info[:5]:
            print(f"Atom {idx} ({symbol}): raw importance = {raw_score:.6f}, "
                  f"normalized = {norm_score:.4f}")
        
        # Generate 2D coordinates
        AllChem.Compute2DCoords(mol)
        
        # Create list of bonds to highlight
        bonds = []
        for bond in mol.GetBonds():
            begin_idx = bond.GetBeginAtomIdx()
            end_idx = bond.GetEndAtomIdx()
            # Use average of connected atoms' importance
            bond_importance = (normalized_scores[begin_idx] + normalized_scores[end_idx]) / 2
            if bond_importance > 0.5:  # Only highlight significant bonds
                bonds.append(bond.GetIdx())
        
        # Draw molecule with enhanced visualization
        img = Draw.MolToImage(
            mol,
            highlightAtoms=list(range(num_atoms)),
            highlightColor=None,
            highlightAtomColors=atom_colors,
            highlightBonds=bonds,  # List of bond indices
            size=(800, 800),
            highlightRadius=0.5
        )
        
        if save_path:
            img.save(save_path)
            
            # Create legend with enhanced gradient
            fig, ax = plt.subplots(figsize=(8, 1))
            gradient = np.linspace(0, 1, 256).reshape(1, -1)
            
            # Enhanced colormap
            colors = [(1,1,1), (1,0.5,0.5), (1,0,0)]
            cmap = LinearSegmentedColormap.from_list("custom_enhanced_red", colors, N=256)
            
            ax.imshow(gradient, aspect='auto', cmap=cmap)
            ax.set_xticks([0, 128, 255])
            ax.set_xticklabels(['Low', 'Medium', 'High'])
            ax.set_yticks([])
            plt.title('Atom Importance Scale')
            plt.savefig(save_path.replace('.png', '_legend.png'), 
                       bbox_inches='tight', dpi=300)
            plt.close()
        
        return img

    
# Also modify the generate_explanations function to print more information:
def generate_explanations(model, dataset, device, idx=1):
    """
    Generate and visualize explanations for a specific molecule
    """
    # Get specific molecule
    graph_data = dataset[idx]
    
    print(f"Processing molecule {idx}")
    print(f"Graph data shape - x_cat: {graph_data.x_cat.shape}, x_phys: {graph_data.x_phys.shape}")
    
    # Generate SHAP explanations
    node_importance, shap_values = explain_graph_model(
        model,
        graph_data,
        device,
        num_samples=100
    )
    
    print(f"Generated node importance shape: {node_importance.shape}")

    # Create visualization with both molecule and bar plot
    if hasattr(graph_data, 'smiles'):
        print(f"Visualizing molecule with SMILES: {graph_data.smiles}")
#         img = visualize_molecule_with_importance(
        img = visualize_molecule_importance(
            graph_data.smiles,
            node_importance,
            save_path=f"molecule_explanation/molecule_explanation_{timestamp}.png"
        )
        return node_importance, shap_values, img
    else:
        print("Warning: No SMILES data found for visualization")
        return node_importance, shap_values, None    
    

In [6]:
# First, let's modify your data loading code to store SMILES with each graph
print("Starting data loading...")
extractor = MolecularFeatureExtractor()
smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"

dataset = []
failed_smiles = []

# Modified data loading to store SMILES
with open(smiles_file, 'r') as f:
    for line in f:
        smiles = line.strip()
        data = extractor.process_molecule(smiles)
        if data is not None:
            # Add SMILES as an attribute to the Data object
            data.smiles = smiles  # Add this line
            dataset.append(data)
        else:
            failed_smiles.append(smiles)

print(f"1. Loaded dataset with {len(dataset)} graphs.")
print(f"2. Failed SMILES count: {len(failed_smiles)}")

if not dataset:
    print("No valid graphs generated.")
    
# Make sure to import needed libraries
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
import traceback

os.makedirs('molecule_explanation', exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

try:
    print("\nSelected molecule index: 1")
    node_importance, shap_values, img = generate_explanations(
        encoder,
        dataset,
        device,
        idx=1
    )
    
    if img is not None:        
        img.save(f'molecule_explanation/molecule_explanation_{timestamp}.png')
        print(f"Visualization saved as 'molecule_explanation/molecule_explanation_{timestamp}.png'")
        print(f"Legend saved as 'molecule_explanation_legend_{timestamp}.png'")        
    
except Exception as e:
    print(f"\nError generating explanations: {e}")
    print("\nModel architecture:")
    print(encoder)
    traceback.print_exc()

Starting data loading...
1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0

Selected molecule index: 1
Processing molecule 1
Graph data shape - x_cat: torch.Size([42, 2]), x_phys: torch.Size([42, 8])
Background matrix shape: (50, 84)
Processing batch of size 1
Result shape: (1, 42)
Test output shape: (1, 42)
Processing batch of size 50
Result shape: (50, 42)
Processing batch of size 1
Result shape: (1, 42)
Processing batch of size 10000
Result shape: (10000, 42)
Raw SHAP values shape: (84, 42)

Node importance statistics:
Shape: (42,)
Range: 0.000000 to 1.000000
Mean: 0.210749
Std: 0.300511

Top 5 most important nodes:
Node 2: 1.000000
Node 14: 0.964057
Node 13: 0.951046
Node 9: 0.820947
Node 15: 0.710237
Generated node importance shape: (42,)
Visualizing molecule with SMILES: CC[NH+](CC)C1CCC([NH2+]C2CC2)(C(=O)[O-])C1

Atom Importance Ranking (Top 5):
Atom 2 (N): raw importance = 1.000000, normalized = 1.0000
Atom 14 (O): raw importance = 0.964057, normalized = 1.0000
Atom 1