In [1]:
# !pip install umap-learn

In [2]:
import torch
import torch.nn as nn
import numpy as np
import shap
import lime
import lime.lime_tabular
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import pickle
import os
from torch_geometric.data import Data, DataLoader
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
from tqdm import tqdm

from typing import Dict
import seaborn as sns


In [3]:
class MolecularInterpreter:
    """Interpreter for molecular embeddings and encoder"""
    
    def __init__(self, encoder_path: str, embedding_path: str):
        """
        Initialize interpreter with saved model and embeddings
        Args:
            encoder_path: Path to saved encoder
            embedding_path: Path to saved embeddings
        """
        # Load encoder
        checkpoint = torch.load(encoder_path)
        
        # Get model configuration from checkpoint
        model_info = checkpoint.get('model_info', {})
        node_dim = model_info.get('node_dim', 128)
        edge_dim = model_info.get('edge_dim', 128)
        hidden_dim = model_info.get('hidden_dim', 128)
        output_dim = model_info.get('output_dim', 128)
        
        # Initialize encoder model
        self.encoder = GraphDiscriminator(
            node_dim=node_dim,
            edge_dim=edge_dim,
            hidden_dim=hidden_dim,
            output_dim=output_dim
        )
        
        # Load state dict
        self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        self.encoder.eval()
        
        # Load embeddings
        print(f"Loading embeddings from {embedding_path}")
        with open(embedding_path, 'rb') as f:
            data = pickle.load(f)
            self.embeddings = data['embeddings']
            self.graphs = data['labels']
            
        print(f"Loaded {len(self.embeddings)} embeddings of dimension {self.embeddings.shape[1]}")
        
        # Create output directory
        os.makedirs('molecular_analysis', exist_ok=True)

    def compute_atom_importance(self, graph_data) -> np.ndarray:
        """
        Compute SHAP values for atoms in a molecule
        Args:
            graph_data: Molecular graph data
        Returns:
            Array of SHAP values per atom
        """
        # Convert input to tensor if needed
        x = torch.tensor(graph_data.x) if not isinstance(graph_data.x, torch.Tensor) else graph_data.x
        
        def model_fn(features):
            with torch.no_grad():
                # Create a new graph data object with the modified features
                new_data = Data(
                    x=features,
                    edge_index=graph_data.edge_index,
                    edge_attr=graph_data.edge_attr
                )
                return self.encoder(new_data)
        
        # Initialize SHAP explainer
        background = torch.zeros_like(x)  # Use zero background
        explainer = shap.GradientExplainer(model_fn, background)
        shap_values = explainer.shap_values(x)
        
        # Aggregate SHAP values across features
        if isinstance(shap_values, list):
            shap_values = np.array(shap_values).mean(axis=0)
        atom_importance = np.abs(shap_values).mean(axis=1)
        
        return atom_importance
        
    def visualize_atom_importance(self, smiles: str, importance_values: np.ndarray,
                                save_path: str) -> None:
        """
        Visualize atom importance on molecular structure
        Args:
            smiles: SMILES string of molecule
            importance_values: SHAP values per atom
            save_path: Path to save visualization
        """
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Failed to parse SMILES: {smiles}")
                return
                
            # Ensure we have the right number of values
            if len(importance_values) != mol.GetNumAtoms():
                print(f"Mismatch in number of atoms: {len(importance_values)} values for {mol.GetNumAtoms()} atoms")
                return
            
            # Normalize importance values to [0,1]
            norm_values = (importance_values - importance_values.min()) / \
                         (importance_values.max() - importance_values.min() + 1e-9)
            
            # Create atom colors (red = important, blue = less important)
            atom_colors = {
                i: (1.0, 1.0 - v, 1.0 - v) 
                for i, v in enumerate(norm_values)
            }
            
            # Draw molecule
            img = Draw.MolToImage(
                mol,
                highlightAtoms=list(range(mol.GetNumAtoms())),
                highlightAtomColors=atom_colors,
                size=(400, 400)
            )
            img.save(save_path)
            
        except Exception as e:
            print(f"Error visualizing molecule: {str(e)}")

    def analyze_embedding_space(self) -> None:
        """Analyze and visualize embedding space"""
        try:
            # Reduce dimensionality
            reducer = umap.UMAP(n_components=2, random_state=42)
            embedding_2d = reducer.fit_transform(self.embeddings)
            
            # Plot embedding space
            plt.figure(figsize=(12, 8))
            scatter = plt.scatter(
                embedding_2d[:, 0],
                embedding_2d[:, 1],
                c=np.arange(len(embedding_2d)),
                cmap='viridis',
                alpha=0.6
            )
            plt.colorbar(scatter, label='Molecule Index')
            plt.title('Molecular Embedding Space (UMAP)')
            plt.savefig('molecular_analysis/embedding_space.png')
            plt.close()
            
            return embedding_2d
            
        except Exception as e:
            print(f"Error analyzing embedding space: {str(e)}")
            return None

    def analyze_feature_patterns(self) -> Dict:
        """Analyze patterns in how features influence embeddings"""
        try:
            # Get SHAP values for a subset of molecules
            n_samples = min(100, len(self.graphs))
            all_shap_values = []
            
            for i in tqdm(range(n_samples), desc="Computing SHAP values"):
                shap_values = self.compute_atom_importance(self.graphs[i])
                all_shap_values.append(shap_values)
                
            all_shap_values = np.array(all_shap_values)
            
            # Analyze feature patterns
            mean_importance = all_shap_values.mean(axis=0)
            std_importance = all_shap_values.std(axis=0)
            
            # Plot feature importance distribution
            plt.figure(figsize=(12, 6))
            sns.boxplot(data=all_shap_values)
            plt.title('Distribution of Atom Importance Across Molecules')
            plt.xlabel('Atom Index')
            plt.ylabel('SHAP Value')
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig('molecular_analysis/feature_importance_distribution.png')
            plt.close()
            
            return {
                'mean_importance': mean_importance,
                'std_importance': std_importance,
                'all_shap_values': all_shap_values
            }
            
        except Exception as e:
            print(f"Error analyzing feature patterns: {str(e)}")
            return None
        
import torch
import torch.nn as nn
import numpy as np
import shap
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
import umap
import pickle
import os
from torch_geometric.data import Data
from tqdm import tqdm

# Import your GraphDiscriminator class
from models import GraphDiscriminator  # Make sure this import matches your model file

def load_data(embedding_path: str) -> Tuple[np.ndarray, List]:
    """Load saved embeddings and graphs"""
    with open(embedding_path, 'rb') as f:
        data = pickle.load(f)
        return data['embeddings'], data['labels']

def main():
    """Run molecular interpretation analysis"""
    try:
        # Initialize interpreter
        interpreter = MolecularInterpreter(
            encoder_path='./checkpoints/encoders/best_encoder.pt',
            embedding_path='./embeddings/final_embeddings.pkl'
        )
        
        print("1. Analyzing embedding space...")
        embedding_2d = interpreter.analyze_embedding_space()
        if embedding_2d is not None:
            print("   Embedding space visualization saved.")
        
        print("\n2. Analyzing feature patterns...")
        feature_patterns = interpreter.analyze_feature_patterns()
        if feature_patterns is not None:
            print("   Feature patterns analysis completed.")
            print(f"   Analyzed {len(feature_patterns['mean_importance'])} features")
        
        print("\n3. Comparing example molecule pairs...")
        # Compare a few molecule pairs
        for i in range(min(3, len(interpreter.graphs))):
            idx1 = i
            idx2 = min(i + 10, len(interpreter.graphs) - 1)  # Compare with molecule 10 steps away
            try:
                comparison = interpreter.compare_molecules(idx1, idx2)
                print(f"\nMolecule Pair {i+1}:")
                print(f"- Embedding Distance: {comparison['distance']:.4f}")
                print(f"- SHAP Correlation: {comparison['shap_correlation']:.4f}")
            except Exception as e:
                print(f"Error comparing molecules {idx1} and {idx2}: {str(e)}")
        
        print("\nAnalysis complete! Results saved in 'molecular_analysis' directory")
        
    except Exception as e:
        print(f"Error during analysis: {str(e)}")
        raise

if __name__ == "__main__":
    main()        

ModuleNotFoundError: No module named 'models'