In [1]:
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, Descriptors
import numpy as np
import shap
import lime
import torch
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
from rdkit.Chem.Draw import IPythonConsole

import torch
import torch.nn as nn
import numpy as np
import shap
import lime
import lime.lime_tabular
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import pickle
import os
from torch_geometric.data import Data
from tqdm import tqdm
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing

class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

def debug_data(embedding_path: str):
    """Debug the saved data structure"""
    with open(embedding_path, 'rb') as f:
        data = pickle.load(f)
        print("\nData structure:")
        print("Keys:", data.keys())
        print("\nEmbeddings shape:", data['embeddings'].shape)
        print("\nFirst graph data:")
        print(data['labels'][0])
        return data

class ChemicalXAIAnalyzer:
    """Chemical interpretability analyzer for molecular embeddings"""
    
    def __init__(self, embedding_path: str):
        """Initialize analyzer with saved embeddings and molecular data"""
        print("Loading data...")
        
        # Load and debug data
        self.data = debug_data(embedding_path)
        self.embeddings = self.data['embeddings']
        self.graphs = self.data['labels']
        
        # Create output directories
        os.makedirs('chemical_analysis', exist_ok=True)
        
        # Initialize chemical feature extractors
        self.initialize_feature_extractors()
        
    def initialize_feature_extractors(self):
        """Initialize chemical feature extraction"""
        # Define important chemical features
        self.atom_features = [
            'Atomic_num',
            'Formal_charge',
            'Hybridization',
            'Aromatic',
            'Num_Hs',
            'Valence'
        ]
        
        self.bond_features = [
            'Bond_type',
            'Conjugated',
            'In_ring'
        ]
        
        # Define functional groups to track
        self.functional_groups = {
            'Alcohol': '[OH]',
            'Amine': '[NH2]',
            'Carboxyl': '[COOH]',
            'Carbonyl': '[C=O]',
            'Aromatic': 'a',
            'Halogen': '[F,Cl,Br,I]'
        }
        
    def _extract_graph_features(self, graph_data) -> Dict:
        """Extract features from graph data"""
        features = {}
        
        # Process the graph data which is a tuple of component tuples
        for item in graph_data:
            if isinstance(item, tuple):
                key, value = item
                features[key] = value
                
        return features
        
    def analyze_molecular_features(self, molecule_idx: int) -> Dict:
        """Analyze chemical features of a molecule"""
        # Get molecular data
        graph_data = self.graphs[molecule_idx]
        features = self._extract_graph_features(graph_data)
        
        # Get embedding
        embedding = self.embeddings[molecule_idx]
        
        # Extract atom features
        atom_features = []
        if 'x_cat' in features and 'x_phys' in features:
            x_cat = features['x_cat']
            x_phys = features['x_phys']
            
            for i in range(x_cat.shape[0]):
                atom_feat = {
                    'Categorical': x_cat[i].numpy(),
                    'Physical': x_phys[i].numpy()
                }
                atom_features.append(atom_feat)
                
        # Extract bond features
        bond_features = []
        if 'edge_index' in features and 'edge_attr' in features:
            edge_index = features['edge_index']
            edge_attr = features['edge_attr']
            
            for i in range(edge_attr.shape[0]):
                bond_feat = {
                    'atoms': (edge_index[0, i].item(), edge_index[1, i].item()),
                    'attributes': edge_attr[i].numpy()
                }
                bond_features.append(bond_feat)
                
        # Analyze feature importance
        importance = self._analyze_importance(
            atom_features,
            bond_features,
            embedding
        )
        
        return {
            'atom_features': atom_features,
            'bond_features': bond_features,
            'importance': importance
        }
        
    def _analyze_importance(self, atom_features: List, 
                          bond_features: List,
                          embedding: np.ndarray) -> Dict:
        """Analyze feature importance using SHAP"""
        # Convert features to matrix form
        atom_matrix = np.array([
            np.concatenate([f['Categorical'], f['Physical']])
            for f in atom_features
        ])
        
        # Get background data
        n_background = min(100, len(self.embeddings))
        background_indices = np.random.choice(
            len(self.embeddings), 
            n_background, 
            replace=False
        )
        background = self.embeddings[background_indices]
        
        # Create feature explanation
        importance = {
            'atoms': {},
            'bonds': {}
        }
        
        # Analyze atom importance
        for i, atom in enumerate(atom_features):
            feat_vec = np.concatenate([atom['Categorical'], atom['Physical']])
            importance['atoms'][i] = np.abs(feat_vec).mean()
            
        # Analyze bond importance
        for i, bond in enumerate(bond_features):
            importance['bonds'][i] = np.abs(bond['attributes']).mean()
            
        return importance
        
    def visualize_importance(self, results: Dict, molecule_idx: int):
        """Create visualizations of feature importance"""
        # 1. Atom importance plot
        plt.figure(figsize=(10, 6))
        atoms = list(results['importance']['atoms'].keys())
        atom_imp = list(results['importance']['atoms'].values())
        
        plt.bar(atoms, atom_imp)
        plt.title('Atom Importance')
        plt.xlabel('Atom Index')
        plt.ylabel('Importance Score')
        plt.tight_layout()
        plt.savefig(f'chemical_analysis/molecule_{molecule_idx}_atom_importance.png')
        plt.close()
        
        # 2. Bond importance plot
        plt.figure(figsize=(10, 6))
        bonds = list(results['importance']['bonds'].keys())
        bond_imp = list(results['importance']['bonds'].values())
        
        plt.bar(bonds, bond_imp)
        plt.title('Bond Importance')
        plt.xlabel('Bond Index')
        plt.ylabel('Importance Score')
        plt.tight_layout()
        plt.savefig(f'chemical_analysis/molecule_{molecule_idx}_bond_importance.png')
        plt.close()

def main():
    """Run chemical interpretability analysis"""
    try:
        analyzer = ChemicalXAIAnalyzer(
            embedding_path='./embeddings/final_embeddings_20250216_111005.pkl'
        )
        
        # Analyze first few molecules
        results = {}
        for idx in range(3):
            print(f"\nAnalyzing molecule {idx}...")
            results[idx] = analyzer.analyze_molecular_features(idx)
            
            # Create visualizations
            analyzer.visualize_importance(results[idx], idx)
            
            # Print summary
            print("\nFeature Importance Summary:")
            print("\nTop 5 Important Atoms:")
            atom_imp = results[idx]['importance']['atoms']
            for atom_idx, imp in sorted(
                atom_imp.items(), 
                key=lambda x: x[1], 
                reverse=True
            )[:5]:
                atom_feat = results[idx]['atom_features'][atom_idx]
                print(f"Atom {atom_idx}: {imp:.4f}")
                print(f"  Categorical features: {atom_feat['Categorical']}")
                print(f"  Physical features: {atom_feat['Physical']}")
                
            print("\nTop 5 Important Bonds:")
            bond_imp = results[idx]['importance']['bonds']
            for bond_idx, imp in sorted(
                bond_imp.items(), 
                key=lambda x: x[1], 
                reverse=True
            )[:5]:
                bond_feat = results[idx]['bond_features'][bond_idx]
                print(f"Bond {bond_idx}: {imp:.4f}")
                print(f"  Between atoms: {bond_feat['atoms']}")
                print(f"  Attributes: {bond_feat['attributes']}")
                
        print("\nAnalysis complete. Results saved in chemical_analysis/")
        return results
        
    except Exception as e:
        print(f"Error during analysis: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

if __name__ == "__main__":
    results = main()



Loading data...

Data structure:
Keys: dict_keys(['embeddings', 'labels'])

Embeddings shape: (41, 128)

First graph data:
('edge_index', tensor([[   0,    1,    1,  ..., 1418, 1394, 1419],
        [   1,    0,    2,  ..., 1394, 1419, 1394]]))

Analyzing molecule 0...

Feature Importance Summary:

Top 5 Important Atoms:

Top 5 Important Bonds:

Analyzing molecule 1...

Feature Importance Summary:

Top 5 Important Atoms:

Top 5 Important Bonds:

Analyzing molecule 2...

Feature Importance Summary:

Top 5 Important Atoms:

Top 5 Important Bonds:

Analysis complete. Results saved in chemical_analysis/
