In [1]:
import os
import json
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import Descriptors as rdDescriptors
from rdkit.Chem.Scaffolds import MurckoScaffold
from sklearn.manifold import TSNE
from collections import Counter
from sklearn.metrics.pairwise import euclidean_distances
from PIL import Image
from io import BytesIO
import math
from datetime import datetime

# Function definitions remain the same
def load_latest_embeddings(checkpoint_dir):
    embedding_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('embeddings_') and f.endswith('.pt')]
    if not embedding_files:
        raise FileNotFoundError(f"No embeddings file found in {checkpoint_dir}")
    embedding_files.sort()
    latest_embeddings_file = embedding_files[-1]
    embeddings_path = os.path.join(checkpoint_dir, latest_embeddings_file)
    data = torch.load(embeddings_path, map_location='cpu')
    embeddings = data[0].cpu().numpy()
    return embeddings

def load_smiles(data_path, max_mols=500):
    smiles_list = []
    with open(data_path, 'r') as f:
        for line in f:
            s = line.strip()
            if s:
                smiles_list.append(s)
            if len(smiles_list) >= max_mols:
                break
    mol_info = pd.DataFrame({'smiles': smiles_list})
    return mol_info

def compute_mol_properties(mol):
    logp = rdDescriptors.MolLogP(mol)
    mw = rdDescriptors.ExactMolWt(mol)
    tpsa = rdDescriptors.TPSA(mol)
    ring_count = len(Chem.GetSymmSSSR(mol))
    branch_count = sum(1 for a in mol.GetAtoms() if len(a.GetNeighbors()) > 2)
    complexity = ring_count + branch_count
    return logp, mw, tpsa, complexity

def compute_scaffold(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
    return Chem.MolToSmiles(scaffold) if scaffold else None

def prepare_data(manual_embeddings, gan_embeddings, mol_info):
    # Ensure consistent length
    n = min(len(mol_info), manual_embeddings.shape[0], gan_embeddings.shape[0])
    mol_info = mol_info.iloc[:n].reset_index(drop=True)
    manual_embeddings = manual_embeddings[:n]
    gan_embeddings = gan_embeddings[:n]

    # Compute properties
    properties = {'LogP': [], 'MolWeight': [], 'TPSA': [], 'Complexity': []}
    scaffolds = []
    for i, row in mol_info.iterrows():
        mol = Chem.MolFromSmiles(row['smiles'])
        if mol:
            logp, mw, tpsa, complexity = compute_mol_properties(mol)
            properties['LogP'].append(logp)
            properties['MolWeight'].append(mw)
            properties['TPSA'].append(tpsa)
            properties['Complexity'].append(complexity)
            scaf = compute_scaffold(row['smiles'])
            scaffolds.append(scaf if scaf else 'None')
        else:
            properties['LogP'].append(np.nan)
            properties['MolWeight'].append(np.nan)
            properties['TPSA'].append(np.nan)
            properties['Complexity'].append(np.nan)
            scaffolds.append('None')

    mol_info['LogP'] = properties['LogP']
    mol_info['MolWeight'] = properties['MolWeight']
    mol_info['TPSA'] = properties['TPSA']
    mol_info['Complexity'] = properties['Complexity']
    mol_info['Scaffold'] = scaffolds

    return mol_info, manual_embeddings, gan_embeddings

def mol_to_img(mol):
    """Convert RDKit mol to PIL Image."""
    from rdkit.Chem import Draw
    if mol is None:
        return None
    d = Draw.MolDraw2DCairo(200, 200)
    d.DrawMolecule(mol)
    d.FinishDrawing()
    img_data = d.GetDrawingText()
    return Image.open(BytesIO(img_data))

def improved_plot_molecular_neighbors(mol_info, manual_embeddings, gan_embeddings, save_dir, example_indices=[0,10,20], k=3):
    """
    Improved function to visualize differences between manual and GAN embeddings.
    - Shows query molecule in the center with manual and GAN neighbors on each side
    - No molecule numbers in the labels
    """
    # Compute distances
    dist_manual = euclidean_distances(manual_embeddings, manual_embeddings)
    dist_gan = euclidean_distances(gan_embeddings, gan_embeddings)

    # Generate a timestamp for the filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # SIMPLER APPROACH: Using regular subplots instead of GridSpec
    num_rows = len(example_indices)
    cols_per_row = 2*k + 1  # Manual neighbors + Center + GAN neighbors
    
    fig, all_axes = plt.subplots(num_rows, cols_per_row, 
                               figsize=(16, 3.5*num_rows),  # Reduced height to decrease row spacing
                               gridspec_kw={'width_ratios': [1]*k + [1.5] + [1]*k})
    
    # Make sure all_axes is 2D even if there's only one row
    if num_rows == 1:
        all_axes = [all_axes]
    
    for row_idx, mol_idx in enumerate(example_indices):
        axes = all_axes[row_idx]
        
        chosen_smi = mol_info['smiles'].iloc[mol_idx]
        chosen_mol = Chem.MolFromSmiles(chosen_smi)
        chosen_label = chosen_smi  # Use full SMILES string without molecule number
        
        # Get neighbor indices
        neighbors_manual_idx = np.argsort(dist_manual[mol_idx])[1:k+1]  # Skip the first one (self)
        neighbors_gan_idx = np.argsort(dist_gan[mol_idx])[1:k+1]
        
        # Display manual neighbors (left side)
        for ni, n_idx in enumerate(neighbors_manual_idx):
            ax = axes[ni]
            n_smi = mol_info['smiles'].iloc[n_idx]
            n_mol = Chem.MolFromSmiles(n_smi)
            n_img = mol_to_img(n_mol)
            ax.imshow(n_img)
            # No title for neighbors
            ax.axis('off')
        
        # Display query molecule (center)
        center_idx = k
        ax_center = axes[center_idx]
        chosen_img = mol_to_img(chosen_mol)
        ax_center.imshow(chosen_img)
        ax_center.set_title(f"{chosen_label}", fontsize=10, fontweight='bold', wrap=True)
        ax_center.axis('off')
        
        # Add vertical lines to separate sections
        ax_center.axvline(x=-10, color='black', linestyle='-', alpha=0.5)
        ax_center.axvline(x=210, color='black', linestyle='-', alpha=0.5)
        
        # Display GAN neighbors (right side)
        for ni, n_idx in enumerate(neighbors_gan_idx):
            ax = axes[center_idx + 1 + ni]
            n_smi = mol_info['smiles'].iloc[n_idx]
            n_mol = Chem.MolFromSmiles(n_smi)
            n_img = mol_to_img(n_mol)
            ax.imshow(n_img)
            # No title for neighbors
            ax.axis('off')
    
    # Add section labels
    fig.text(0.25, 0.98, "MANUAL AUGMENTATION NEIGHBORS", 
            ha="center", fontsize=14, fontweight='bold')
    fig.text(0.75, 0.98, "GAN AUGMENTATION NEIGHBORS", 
            ha="center", fontsize=14, fontweight='bold')
    
#     plt.suptitle("Comparison of Nearest Molecular Neighbors in Embedding Space", 
#                 fontsize=16, y=1.01)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.subplots_adjust(hspace=0.3)  # Reduce vertical space between rows
    
    filename = f"improved_molecular_neighbors_{timestamp}.png"
    plt.savefig(os.path.join(save_dir, filename), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Saved improved neighbor comparison as {filename}")
    return filename

def main():
    # Set these paths according to your project structure
    manual_checkpoint_dir = "D:/PhD/Chapter3.1_GAN_Explainbility/GAN_CL_XAI/Manual/checkpoints"
    gan_checkpoint_dir = "D:/PhD/Chapter3.1_GAN_Explainbility/GAN_CL_XAI/GAN/checkpoints"
    data_path = "D:/PhD/Chapter3/Unsupervised_GAN_Code/pubchem-10m-clean_test.txt"
    save_dir = "D:/PhD/Chapter3.1_GAN_Explainbility/GAN_CL_XAI/analysis_results"
    
    os.makedirs(save_dir, exist_ok=True)
    
    # Load embeddings
    manual_embeddings = load_latest_embeddings(manual_checkpoint_dir)
    gan_embeddings = load_latest_embeddings(gan_checkpoint_dir)
    
    # Load molecular data from file
    mol_info = load_smiles(data_path, max_mols=300)
    
    # Prepare data
    mol_info, manual_embeddings, gan_embeddings = prepare_data(manual_embeddings, gan_embeddings, mol_info)
    
    # Visualize molecular neighbors with improved formatting
    example_indices = [0, 10, 20] if len(mol_info) > 20 else [0]
    improved_plot_molecular_neighbors(mol_info, manual_embeddings, gan_embeddings, save_dir, 
                                     example_indices=example_indices, k=3)
    
    print("Improved molecular visualization generated successfully.")

if __name__ == "__main__":
    main()

  data = torch.load(embeddings_path, map_location='cpu')


Saved improved neighbor comparison as improved_molecular_neighbors_20250310_211932.png
Improved molecular visualization generated successfully.
