In [1]:
import torch
import torch.nn as nn
import numpy as np
import shap
import lime
import lime.lime_tabular
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import pickle
import os
from torch_geometric.data import Data
from tqdm import tqdm
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from sklearn.metrics.pairwise import cosine_similarity

class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

class EmbeddingAnalyzer:
    """Analyzer for learned molecular embeddings"""
    
    def __init__(self, embedding_path: str):
        """Initialize analyzer with saved embeddings"""
        print("Loading embeddings...")
        with open(embedding_path, 'rb') as f:
            data = pickle.load(f)
            self.embeddings = data['embeddings']  # Shape: (n_molecules, embedding_dim)
            self.graphs = data['labels']          # Original graph data
            
        print(f"Loaded {len(self.embeddings)} molecules")
        print(f"Embedding dimension: {self.embeddings.shape[1]}")
        
        # Create output directories
        os.makedirs('xai_analysis/shap', exist_ok=True)
        os.makedirs('xai_analysis/lime', exist_ok=True)
        
    def analyze_shap(self, molecule_idx: int, n_background: int = 100) -> Dict:
        """
        Analyze embedding dimensions using SHAP
        Args:
            molecule_idx: Index of molecule to analyze
            n_background: Number of background samples for SHAP
        """
        print(f"\nAnalyzing molecule {molecule_idx} with SHAP...")
        
        # Get target embedding
        target_embedding = self.embeddings[molecule_idx]
        
        # Create background dataset
        background_indices = np.random.choice(
            len(self.embeddings), 
            min(n_background, len(self.embeddings)), 
            replace=False
        )
        background_data = self.embeddings[background_indices]
        
        # Train a simple classifier to predict similarity
        from sklearn.ensemble import RandomForestRegressor
        similarity_model = RandomForestRegressor(n_estimators=100)
        
        # Calculate cosine similarities for training
        from sklearn.metrics.pairwise import cosine_similarity
        similarities = cosine_similarity([target_embedding], background_data)[0]
        
        # Train model to predict similarity from embeddings
        similarity_model.fit(background_data, similarities)
        
        # Initialize SHAP explainer
        explainer = shap.TreeExplainer(similarity_model)
        shap_values = explainer.shap_values(
            target_embedding.reshape(1, -1),
            check_additivity=False
        )
        
        # Create visualizations
        self._plot_shap_summary(
            shap_values, 
            target_embedding,
            f'xai_analysis/shap/molecule_{molecule_idx}'
        )
        
        return {
            'shap_values': shap_values,
            'feature_importance': np.abs(shap_values).mean(axis=0),
            'similarity_scores': similarities
        }
     
    def _initialize_encoder(self, checkpoint: Dict) -> nn.Module:
        """Initialize encoder with saved weights"""
        # Get model info
        model_info = checkpoint.get('model_info', {})
        node_dim = model_info.get('node_dim')
        edge_dim = model_info.get('edge_dim')
        
        # Initialize model (using your GraphDiscriminator class)
        encoder = GraphDiscriminator(
            node_dim=node_dim,
            edge_dim=edge_dim,
            hidden_dim=128,
            output_dim=128
        ).to(self.device)
        
        # Load weights
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        return encoder
    
    def analyze_lime(self, molecule_idx: int, n_samples: int = 1000) -> Dict:
        """
        Analyze embedding using LIME
        Args:
            molecule_idx: Index of molecule to analyze
            n_samples: Number of samples for LIME
        """
        print(f"\nAnalyzing molecule {molecule_idx} with LIME...")
        
        # Get target embedding
        target_embedding = self.embeddings[molecule_idx]
        
        # Initialize LIME explainer
        explainer = lime.lime_tabular.LimeTabularExplainer(
            self.embeddings,
            feature_names=[f'dim_{i}' for i in range(self.embeddings.shape[1])],
            mode='regression'
        )
        
        # Define prediction function for similarity
        def similarity_predictor(x):
            return cosine_similarity(x, target_embedding.reshape(1, -1))
            
        # Get LIME explanation
        explanation = explainer.explain_instance(
            target_embedding,
            similarity_predictor,
            num_features=20,
            num_samples=n_samples
        )
        
        # Create visualization
        self._plot_lime_explanation(
            explanation,
            f'xai_analysis/lime/molecule_{molecule_idx}'
        )
        
        return {
            'explanation': explanation,
            'feature_weights': dict(explanation.as_list())
        }
        
    def _plot_shap_summary(self, shap_values: np.ndarray, 
                          target_embedding: np.ndarray, save_path: str):
        """Create SHAP summary plot"""
        plt.figure(figsize=(12, 8))
        
        # Create summary plot
        shap.summary_plot(
            shap_values,
            target_embedding.reshape(1, -1),
            feature_names=[f'dim_{i}' for i in range(self.embeddings.shape[1])],
            show=False
        )
        
        plt.title('SHAP Values for Embedding Dimensions')
        plt.tight_layout()
        plt.savefig(f'{save_path}_summary.png')
        plt.close()
        
        # Create bar plot of feature importance
        plt.figure(figsize=(12, 6))
        importance = np.abs(shap_values).mean(axis=0)
        
        # Sort by importance
        idx = np.argsort(importance)[-20:]  # Top 20 dimensions
        plt.barh(
            [f'dim_{i}' for i in idx],
            importance[idx]
        )
        plt.title('Top 20 Important Embedding Dimensions')
        plt.xlabel('Mean |SHAP Value|')
        plt.tight_layout()
        plt.savefig(f'{save_path}_importance.png')
        plt.close()
        
    def _plot_lime_explanation(self, explanation, save_path: str):
        """Create LIME explanation plot"""
        plt.figure(figsize=(12, 8))
        explanation.as_pyplot_figure()
        plt.title('LIME Feature Weights')
        plt.tight_layout()
        plt.savefig(f'{save_path}_explanation.png')
        plt.close()
        
    def analyze_molecule(self, molecule_idx: int) -> Dict:
        """Run complete analysis for a molecule"""
        results = {}
        
        # SHAP analysis
        results['shap'] = self.analyze_shap(molecule_idx)
        
        # Print top SHAP features
        importance = results['shap']['feature_importance']
        top_dims = np.argsort(importance)[-5:]
        print("\nTop 5 Important Dimensions (SHAP):")
        for dim in top_dims:
            print(f"Dimension {dim}: {importance[dim]:.4f}")
            
        # LIME analysis
        results['lime'] = self.analyze_lime(molecule_idx)
        
        # Print top LIME features
        weights = results['lime']['feature_weights']
        print("\nTop 5 Important Features (LIME):")
        for feature, weight in sorted(weights.items(), key=lambda x: abs(x[1]))[-5:]:
            print(f"{feature}: {weight:.4f}")
            
        return results

def main():
    """Run embedding analysis"""
    try:
        analyzer = EmbeddingAnalyzer(
            embedding_path='./embeddings/final_embeddings_20250216_111005.pkl'
        )
        
        # Analyze first few molecules
        results = {}
        for idx in range(3):
            print(f"\n{'='*20} Analyzing Molecule {idx} {'='*20}")
            results[idx] = analyzer.analyze_molecule(idx)
            
        print("\nAnalysis complete. Results saved in xai_analysis/")
        return results
        
    except Exception as e:
        print(f"Error during analysis: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

if __name__ == "__main__":
    results = main()

Loading embeddings...
Loaded 41 molecules
Embedding dimension: 128


Analyzing molecule 0 with SHAP...

Top 5 Important Dimensions (SHAP):
Dimension 118: 0.0356
Dimension 65: 0.0462
Dimension 60: 0.0543
Dimension 31: 0.0684
Dimension 43: 0.0702

Analyzing molecule 0 with LIME...

Top 5 Important Features (LIME):
dim_115 <= -0.02: 0.0286
dim_91 <= -0.10: 0.0311
dim_2 <= 0.01: 0.0353
dim_114 > 0.05: 0.0411
dim_100 <= -0.08: 0.0695


Analyzing molecule 1 with SHAP...

Top 5 Important Dimensions (SHAP):
Dimension 18: 0.0324
Dimension 92: 0.0351
Dimension 86: 0.0356
Dimension 111: 0.0374
Dimension 74: 0.0419

Analyzing molecule 1 with LIME...

Top 5 Important Features (LIME):
dim_41 > 0.01: 0.0228
dim_35 <= -0.03: 0.0257
dim_74 > 0.03: 0.0274
dim_92 <= -0.05: 0.0284
dim_115 <= -0.02: 0.0416


Analyzing molecule 2 with SHAP...

Top 5 Important Dimensions (SHAP):
Dimension 66: 0.0202
Dimension 24: 0.0219
Dimension 126: 0.0222
Dimension 54: 0.0726
Dimension 16: 0.0902

Analyzing molecule 2 wit

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>