In [1]:
import torch
import torch.nn as nn
import numpy as np
import shap
import lime
import lime.lime_tabular
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import pickle
import os
from torch_geometric.data import Data
from tqdm import tqdm
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing

class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

class MolecularXAIAnalyzer:
    """Analyzer for molecular embeddings using SHAP and LIME"""
    
    def __init__(self, encoder_path: str, embedding_path: str):
        """Initialize analyzer with saved model and embeddings"""
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load encoder and embeddings
        self._load_model_and_data(encoder_path, embedding_path)
        
        # Create output directories
        os.makedirs('xai_analysis/shap', exist_ok=True)
        os.makedirs('xai_analysis/lime', exist_ok=True)
        
    def _load_model_and_data(self, encoder_path: str, embedding_path: str):
        """Load saved model and data"""
        print("Loading encoder and embeddings...")
        
        # Load encoder
        checkpoint = torch.load(encoder_path, map_location=self.device)
        self.encoder = self._initialize_encoder(checkpoint)
        self.encoder.eval()
        
        # Load embeddings and graphs
        with open(embedding_path, 'rb') as f:
            data = pickle.load(f)
            self.embeddings = data['embeddings']  # Shape: (N, embedding_dim)
            self.graph_data = data['labels']      # List of tuples with graph data
            
        print(f"Loaded {len(self.embeddings)} molecules")
        print(f"Embedding dimension: {self.embeddings.shape[1]}")
        
        # Parse graph structure
        self._parse_graph_structure()
        
    def _parse_graph_structure(self):
        """Parse graph structure from saved data"""
        print("\nParsing graph structure...")
        
        # Each graph is a tuple of (key, value) pairs
        self.graph_elements = {}
        for item in self.graph_data[0]:  # Look at first graph
            if isinstance(item, tuple):
                key, value = item
                self.graph_elements[key] = value
                print(f"Found element: {key}, type: {type(value)}")
                
        self.feature_names = self._get_feature_names()
        
    def _get_feature_names(self) -> List[str]:
        """Get feature names from graph structure"""
        feature_names = []
        
        # Add names for each feature based on graph structure
        for key in self.graph_elements.keys():
            if key in ['x_cat', 'x_phys']:
                value = self.graph_elements[key]
                if isinstance(value, torch.Tensor):
                    n_features = value.shape[1]
                    feature_names.extend([f"{key}_{i}" for i in range(n_features)])
                    
        return feature_names
     
    def _initialize_encoder(self, checkpoint: Dict) -> nn.Module:
        """Initialize encoder with saved weights"""
        # Get model info
        model_info = checkpoint.get('model_info', {})
        node_dim = model_info.get('node_dim')
        edge_dim = model_info.get('edge_dim')
        
        # Initialize model (using your GraphDiscriminator class)
        encoder = GraphDiscriminator(
            node_dim=node_dim,
            edge_dim=edge_dim,
            hidden_dim=128,
            output_dim=128
        ).to(self.device)
        
        # Load weights
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        return encoder
    
    def analyze_molecular_features_shap(self, molecule_idx: int) -> Dict:
        """Analyze molecular features using SHAP"""
        # Get molecule data
        molecule_data = self.graph_data[molecule_idx]
        molecule_embedding = self.embeddings[molecule_idx]
        
        # Create feature matrix
        features = self._create_feature_matrix(molecule_data)
        
        # Initialize SHAP explainer
        def model_fn(x):
            return self.encoder(self._features_to_data(x))
            
        explainer = shap.KernelExplainer(model_fn, features)
        shap_values = explainer.shap_values(features)
        
        # Create visualizations
        self._visualize_shap_results(shap_values, molecule_idx)
        
        return {
            'shap_values': shap_values,
            'feature_importance': np.abs(shap_values).mean(axis=0),
            'feature_names': self.feature_names
        }
        
    def analyze_embedding_lime(self, molecule_idx: int) -> Dict:
        """Analyze embeddings using LIME"""
        # Get molecule data
        molecule_data = self.graph_data[molecule_idx]
        molecule_embedding = self.embeddings[molecule_idx]
        
        # Create feature matrix
        features = self._create_feature_matrix(molecule_data)
        
        # Initialize LIME explainer
        explainer = lime.lime_tabular.LimeTabularExplainer(
            features,
            feature_names=self.feature_names,
            mode='regression'
        )
        
        def predict_fn(x):
            return self.encoder(self._features_to_data(x)).detach().numpy()
            
        # Get explanation
        exp = explainer.explain_instance(
            features[0],
            predict_fn,
            num_features=len(self.feature_names)
        )
        
        # Visualize results
        self._visualize_lime_results(exp, molecule_idx)
        
        return {
            'lime_explanation': exp,
            'feature_weights': dict(exp.as_list())
        }
        
    def _create_feature_matrix(self, molecule_data) -> np.ndarray:
        """Create feature matrix from molecule data"""
        features = []
        
        # Combine features from different elements
        for key in ['x_cat', 'x_phys']:
            if key in self.graph_elements:
                value = self.graph_elements[key]
                if isinstance(value, torch.Tensor):
                    features.append(value.numpy())
                    
        return np.concatenate(features, axis=1)
        
    def _features_to_data(self, features: np.ndarray) -> Data:
        """Convert feature matrix back to graph data"""
        # Split features back into original components
        start = 0
        data_dict = {}
        
        for key in ['x_cat', 'x_phys']:
            if key in self.graph_elements:
                value = self.graph_elements[key]
                if isinstance(value, torch.Tensor):
                    n_features = value.shape[1]
                    data_dict[key] = torch.tensor(
                        features[:, start:start+n_features],
                        dtype=value.dtype
                    )
                    start += n_features
                    
        # Add edge information
        data_dict['edge_index'] = self.graph_elements.get('edge_index')
        data_dict['edge_attr'] = self.graph_elements.get('edge_attr')
        
        return Data(**data_dict)
        
    def _visualize_shap_results(self, shap_values: np.ndarray, molecule_idx: int):
        """Create visualizations for SHAP analysis"""
        # Feature importance plot
        plt.figure(figsize=(12, 6))
        shap.summary_plot(
            shap_values,
            features=self._create_feature_matrix(self.graph_data[molecule_idx]),
            feature_names=self.feature_names,
            show=False
        )
        plt.tight_layout()
        plt.savefig(f'xai_analysis/shap/molecule_{molecule_idx}_importance.png')
        plt.close()
        
    def _visualize_lime_results(self, explanation, molecule_idx: int):
        """Create visualizations for LIME results"""
        plt.figure(figsize=(12, 6))
        explanation.as_pyplot_figure()
        plt.tight_layout()
        plt.savefig(f'xai_analysis/lime/molecule_{molecule_idx}_explanation.png')
        plt.close()

def debug_data_structure(embedding_path):
    """Debug the structure of saved embeddings"""
    with open(embedding_path, 'rb') as f:
        data = pickle.load(f)
        print("\nData keys:", data.keys())
        print("\nType of embeddings:", type(data['embeddings']))
        print("Shape of embeddings:", data['embeddings'].shape)
        print("\nType of labels:", type(data['labels']))
        print("First label:", data['labels'][0])
        print("Type of first label:", type(data['labels'][0]))
        if isinstance(data['labels'][0], tuple):
            print("\nTuple contents:")
            for i, item in enumerate(data['labels'][0]):
                print(f"Item {i}:", type(item))
                print(f"Value: {item}")
                
def main():
    """Run XAI analysis"""
    try:
        # Enable debugging output
        debug_data_structure('./embeddings/final_embeddings_20250216_111005.pkl')
        
        # Initialize analyzer
        analyzer = MolecularXAIAnalyzer(
            encoder_path='./checkpoints/encoders/final_encoder_20250216_111050.pt',
            embedding_path='./embeddings/final_embeddings_20250216_111005.pkl'
        )
        
        # Analyze first few molecules
        for idx in range(3):
            print(f"\nAnalyzing molecule {idx}...")
            
            # SHAP analysis
            print("Running SHAP analysis...")
            shap_results = analyzer.analyze_molecular_features_shap(idx)
            
            # Print top features
            print("\nTop 5 important features (SHAP):")
            importance = shap_results['feature_importance']
            feature_names = shap_results['feature_names']
            top_indices = np.argsort(importance)[-5:]
            for i in top_indices:
                print(f"{feature_names[i]}: {importance[i]:.4f}")
            
            # LIME analysis
            print("\nRunning LIME analysis...")
            lime_results = analyzer.analyze_embedding_lime(idx)
            
            print("\nTop 5 important features (LIME):")
            weights = lime_results['feature_weights']
            sorted_weights = sorted(weights.items(), key=lambda x: abs(x[1]))[-5:]
            for feature, weight in sorted_weights:
                print(f"{feature}: {weight:.4f}")
                
    except Exception as e:
        print(f"Error during analysis: {str(e)}")
        import traceback
        traceback.print_exc()
        
if __name__ == "__main__":
    main()


Data keys: dict_keys(['embeddings', 'labels'])

Type of embeddings: <class 'numpy.ndarray'>
Shape of embeddings: (41, 128)

Type of labels: <class 'list'>
First label: ('edge_index', tensor([[   0,    1,    1,  ..., 1418, 1394, 1419],
        [   1,    0,    2,  ..., 1394, 1419, 1394]]))
Type of first label: <class 'tuple'>

Tuple contents:
Item 0: <class 'str'>
Value: edge_index
Item 1: <class 'torch.Tensor'>
Value: tensor([[   0,    1,    1,  ..., 1418, 1394, 1419],
        [   1,    0,    2,  ..., 1394, 1419, 1394]])
Loading encoder and embeddings...
Loaded 41 molecules
Embedding dimension: 128

Parsing graph structure...

Analyzing molecule 0...
Running SHAP analysis...
Error during analysis: need at least one array to concatenate


Traceback (most recent call last):
  File "C:\Users\Malli\AppData\Local\Temp\ipykernel_83792\2094488411.py", line 304, in main
    shap_results = analyzer.analyze_molecular_features_shap(idx)
  File "C:\Users\Malli\AppData\Local\Temp\ipykernel_83792\2094488411.py", line 160, in analyze_molecular_features_shap
    features = self._create_feature_matrix(molecule_data)
  File "C:\Users\Malli\AppData\Local\Temp\ipykernel_83792\2094488411.py", line 223, in _create_feature_matrix
    return np.concatenate(features, axis=1)
ValueError: need at least one array to concatenate
