In [1]:
import torch
import torch.nn as nn
import numpy as np
import shap
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
import umap
import pickle
import os
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
from tqdm import tqdm

class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

class MolecularInterpreter:
    """Interpreter for molecular embeddings and encoder"""
    
    def __init__(self, encoder_path: str, embedding_path: str):
        """
        Initialize interpreter with saved model and embeddings
        Args:
            encoder_path: Path to saved encoder
            embedding_path: Path to saved embeddings
        """
        # Load saved checkpoint
        print(f"Loading encoder from {encoder_path}")
        checkpoint = torch.load(encoder_path, map_location='cpu')
        
        # Get model configuration
        print("Initializing encoder architecture...")
        try:
            model_info = checkpoint.get('model_info', {})
            node_dim = model_info.get('node_dim')
            edge_dim = model_info.get('edge_dim')
            
            # If dimensions not in model_info, try to infer from saved weights
            if node_dim is None or edge_dim is None:
                # Get first layer's weight shape
                first_layer = next(iter(checkpoint['encoder_state_dict'].items()))
                if 'node_encoder.0.weight' in checkpoint['encoder_state_dict']:
                    node_dim = checkpoint['encoder_state_dict']['node_encoder.0.weight'].shape[1]
                else:
                    raise ValueError("Could not determine node_dim from saved weights")
                
                if 'edge_encoder.0.weight' in checkpoint['encoder_state_dict']:
                    edge_dim = checkpoint['encoder_state_dict']['edge_encoder.0.weight'].shape[1]
                else:
                    raise ValueError("Could not determine edge_dim from saved weights")
                
            print(f"Model dimensions: node_dim={node_dim}, edge_dim={edge_dim}")
            
            # Initialize model
            self.encoder = GraphDiscriminator(
                node_dim=node_dim,
                edge_dim=edge_dim,
                hidden_dim=128,
                output_dim=128
            )
            
            # Load state dict
            self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            self.encoder.eval()
            print("Encoder loaded successfully")
            
        except Exception as e:
            print(f"Error initializing encoder: {str(e)}")
            raise
            
        # Load embeddings
        print(f"\nLoading embeddings from {embedding_path}")
        try:
            with open(embedding_path, 'rb') as f:
                data = pickle.load(f)
                self.embeddings = data['embeddings']
                self.graphs = data['labels']
                
            print(f"Loaded {len(self.embeddings)} embeddings of dimension {self.embeddings.shape[1]}")
            
        except Exception as e:
            print(f"Error loading embeddings: {str(e)}")
            raise
            
        # Create output directory
        os.makedirs('molecular_analysis', exist_ok=True)

    def compute_atom_importance(self, graph_data) -> np.ndarray:
        """
        Compute SHAP values for atoms in a molecule
        Args:
            graph_data: Molecular graph data
        Returns:
            Array of SHAP values per atom
        """
        # Convert input to tensor if needed
        x = torch.tensor(graph_data.x) if not isinstance(graph_data.x, torch.Tensor) else graph_data.x
        
        def model_fn(features):
            with torch.no_grad():
                # Create a new graph data object with the modified features
                new_data = Data(
                    x=features,
                    edge_index=graph_data.edge_index,
                    edge_attr=graph_data.edge_attr
                )
                return self.encoder(new_data)
        
        # Initialize SHAP explainer
        background = torch.zeros_like(x)  # Use zero background
        explainer = shap.GradientExplainer(model_fn, background)
        shap_values = explainer.shap_values(x)
        
        # Aggregate SHAP values across features
        if isinstance(shap_values, list):
            shap_values = np.array(shap_values).mean(axis=0)
        atom_importance = np.abs(shap_values).mean(axis=1)
        
        return atom_importance
        
    def visualize_atom_importance(self, smiles: str, importance_values: np.ndarray,
                                save_path: str) -> None:
        """
        Visualize atom importance on molecular structure
        Args:
            smiles: SMILES string of molecule
            importance_values: SHAP values per atom
            save_path: Path to save visualization
        """
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Failed to parse SMILES: {smiles}")
                return
                
            # Ensure we have the right number of values
            if len(importance_values) != mol.GetNumAtoms():
                print(f"Mismatch in number of atoms: {len(importance_values)} values for {mol.GetNumAtoms()} atoms")
                return
            
            # Normalize importance values to [0,1]
            norm_values = (importance_values - importance_values.min()) / \
                         (importance_values.max() - importance_values.min() + 1e-9)
            
            # Create atom colors (red = important, blue = less important)
            atom_colors = {
                i: (1.0, 1.0 - v, 1.0 - v) 
                for i, v in enumerate(norm_values)
            }
            
            # Draw molecule
            img = Draw.MolToImage(
                mol,
                highlightAtoms=list(range(mol.GetNumAtoms())),
                highlightAtomColors=atom_colors,
                size=(400, 400)
            )
            img.save(save_path)
            
        except Exception as e:
            print(f"Error visualizing molecule: {str(e)}")

    def analyze_embedding_space(self) -> None:
        """Analyze and visualize embedding space"""
        try:
            # Reduce dimensionality
            reducer = umap.UMAP(n_components=2, random_state=42)
            embedding_2d = reducer.fit_transform(self.embeddings)
            
            # Plot embedding space
            plt.figure(figsize=(12, 8))
            scatter = plt.scatter(
                embedding_2d[:, 0],
                embedding_2d[:, 1],
                c=np.arange(len(embedding_2d)),
                cmap='viridis',
                alpha=0.6
            )
            plt.colorbar(scatter, label='Molecule Index')
            plt.title('Molecular Embedding Space (UMAP)')
            plt.savefig('molecular_analysis/embedding_space.png')
            plt.close()
            
            return embedding_2d
            
        except Exception as e:
            print(f"Error analyzing embedding space: {str(e)}")
            return None

    def analyze_feature_patterns(self) -> Dict:
        """Analyze patterns in how features influence embeddings"""
        try:
            # Get SHAP values for a subset of molecules
            n_samples = min(100, len(self.graphs))
            all_shap_values = []
            
            for i in tqdm(range(n_samples), desc="Computing SHAP values"):
                shap_values = self.compute_atom_importance(self.graphs[i])
                all_shap_values.append(shap_values)
                
            all_shap_values = np.array(all_shap_values)
            
            # Analyze feature patterns
            mean_importance = all_shap_values.mean(axis=0)
            std_importance = all_shap_values.std(axis=0)
            
            # Plot feature importance distribution
            plt.figure(figsize=(12, 6))
            sns.boxplot(data=all_shap_values)
            plt.title('Distribution of Atom Importance Across Molecules')
            plt.xlabel('Atom Index')
            plt.ylabel('SHAP Value')
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig('molecular_analysis/feature_importance_distribution.png')
            plt.close()
            
            return {
                'mean_importance': mean_importance,
                'std_importance': std_importance,
                'all_shap_values': all_shap_values
            }
            
        except Exception as e:
            print(f"Error analyzing feature patterns: {str(e)}")
            return None

def main():
    """Run molecular interpretation analysis"""
    try:
        print("Starting molecular interpretation analysis...")
        
        # Check if files exist
        encoder_path = './checkpoints/encoders/best_encoder.pt'
        embedding_path = './embeddings/final_embeddings.pkl'
        
        if not os.path.exists(encoder_path):
            raise FileNotFoundError(f"Encoder file not found: {encoder_path}")
        if not os.path.exists(embedding_path):
            raise FileNotFoundError(f"Embedding file not found: {embedding_path}")
        
        # Initialize interpreter
        interpreter = MolecularInterpreter(encoder_path, embedding_path)
        
        # Run analysis steps
        print("\n1. Analyzing embedding space...")
        embedding_2d = interpreter.analyze_embedding_space()
        if embedding_2d is not None:
            print("   Embedding space visualization saved")
        
        print("\n2. Analyzing feature patterns...")
        feature_patterns = interpreter.analyze_feature_patterns()
        if feature_patterns is not None:
            print("   Feature patterns analysis completed")
        
        print("\nAnalysis complete! Results saved in 'molecular_analysis' directory")
        
    except Exception as e:
        print(f"Error during analysis: {str(e)}")
        raise

if __name__ == "__main__":
    main()

Starting molecular interpretation analysis...
Loading encoder from ./checkpoints/encoders/best_encoder.pt
Initializing encoder architecture...
Model dimensions: node_dim=10, edge_dim=5
Encoder loaded successfully

Loading embeddings from ./embeddings/final_embeddings.pkl
Loaded 49693 embeddings of dimension 128

1. Analyzing embedding space...


  warn(


   Embedding space visualization saved

2. Analyzing feature patterns...


Computing SHAP values:   0%|                                                                   | 0/100 [00:00<?, ?it/s]

Error analyzing feature patterns: 'tuple' object has no attribute 'x'

Analysis complete! Results saved in 'molecular_analysis' directory



