In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs, Draw
from rdkit.Chem.Draw import MolDraw2DCairo
from sklearn.metrics.pairwise import cosine_similarity
import pickle
import os
import random
from tqdm import tqdm
import io
from PIL import Image
import logging
import hashlib

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def compute_ecfp(mol, radius=2, nBits=2048):
    """Compute ECFP (Morgan) fingerprint for a molecule"""
    if mol is None:
        return None
    return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)

def compute_rdkfp(mol, minPath=1, maxPath=7, nBits=2048):
    """Compute RDKit topological fingerprint"""
    if mol is None:
        return None
    return AllChem.RDKFingerprint(mol, minPath=minPath, maxPath=maxPath, fpSize=nBits)

def calculate_fingerprint_similarity(mol1, mol2, fp_type="ecfp"):
    """Calculate fingerprint similarity between two molecules"""
    if mol1 is None or mol2 is None:
        return 0.0
    
    try:
        if fp_type.lower() == "ecfp":
            fp1 = compute_ecfp(mol1)
            fp2 = compute_ecfp(mol2)
        elif fp_type.lower() == "rdkfp":
            fp1 = compute_rdkfp(mol1)
            fp2 = compute_rdkfp(mol2)
        else:
            raise ValueError(f"Unsupported fingerprint type: {fp_type}")
        
        if fp1 is None or fp2 is None:
            return 0.0
            
        return DataStructs.TanimotoSimilarity(fp1, fp2)
    except Exception as e:
        logger.error(f"Error calculating fingerprint similarity: {e}")
        return 0.0

def mol_to_img(mol, molSize=(300, 200), kekulize=True):
    """Convert RDKit molecule to image"""
    if mol is None:
        return None
    
    try:
        # Prepare the molecule for drawing
        mc = Chem.Mol(mol.ToBinary())
        if kekulize:
            try:
                Chem.Kekulize(mc)
            except:
                mc = Chem.Mol(mol.ToBinary())
        
        # Use Draw.MolToImage as a more compatible alternative
        image = Draw.MolToImage(mc, size=molSize)
        return image
    except Exception as e:
        logger.error(f"Error converting molecule to image: {e}")
        try:
            # Fallback to simpler drawing method
            image = Draw.MolToImage(mol, size=molSize)
            return image
        except:
            return None

def load_embeddings(filepath):
    """Load embeddings from pickle file"""
    try:
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        return data
    except Exception as e:
        logger.error(f"Error loading embeddings from {filepath}: {e}")
        return None

def extract_smiles_from_data(data_obj, field_name='molecule_data'):
    """Extract SMILES strings from embedding data object"""
    smiles_list = []
    
    if field_name in data_obj and isinstance(data_obj[field_name], list):
        for mol_data in data_obj[field_name]:
            if isinstance(mol_data, dict) and 'smiles' in mol_data:
                smiles_list.append(mol_data['smiles'])
            elif hasattr(mol_data, 'smiles'):
                smiles_list.append(mol_data.smiles)
    
    return smiles_list

def find_nearest_neighbors(query_idx, embeddings, num_neighbors=6, exclude_self=True):
    """Find nearest neighbors for a query molecule in embedding space"""
    query_embedding = embeddings[query_idx].reshape(1, -1)
    
    # Calculate cosine similarities
    similarities = cosine_similarity(query_embedding, embeddings)[0]
    
    # Get indices of nearest neighbors
    if exclude_self:
        # Exclude the query molecule itself (which would be the closest with similarity=1)
        neighbor_indices = np.argsort(similarities)[::-1][1:num_neighbors+1]
    else:
        neighbor_indices = np.argsort(similarities)[::-1][:num_neighbors]
    
    return neighbor_indices, similarities[neighbor_indices]

def read_smiles_file(file_path):
    """Read SMILES strings from a file"""
    smiles_list = []
    with open(file_path, 'r') as f:
        for line in f:
            smiles = line.strip()
            if smiles:
                smiles_list.append(smiles)
    return smiles_list

def create_comparison_visualization(query_smiles, manual_embeddings_path, gan_embeddings_path, 
                                   manual_smiles_list=None, gan_smiles_list=None,
                                   manual_embeddings=None, gan_embeddings=None,
                                   num_neighbors=6, output_dir='./comparisons'):
    """Create visualization comparing nearest neighbors in manual and GAN embedding spaces"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Load embeddings if not provided
    if manual_embeddings is None or gan_embeddings is None or manual_smiles_list is None or gan_smiles_list is None:
        logger.info(f"Loading manual embeddings from {manual_embeddings_path}")
        manual_data = load_embeddings(manual_embeddings_path)
        
        logger.info(f"Loading GAN embeddings from {gan_embeddings_path}")
        gan_data = load_embeddings(gan_embeddings_path)
        
        if manual_data is None or gan_data is None:
            logger.error("Failed to load embeddings. Exiting.")
            return
        
        # Extract embeddings
        manual_embeddings = manual_data['embeddings']
        gan_embeddings = gan_data['embeddings']
        
        # Extract SMILES
        manual_smiles_list = extract_smiles_from_data(manual_data)
        gan_smiles_list = extract_smiles_from_data(gan_data)
    
    # Get indices of query molecule in both datasets
    if query_smiles not in manual_smiles_list:
        logger.error(f"Query SMILES '{query_smiles}' not found in manual embeddings.")
        return None
    
    if query_smiles not in gan_smiles_list:
        logger.error(f"Query SMILES '{query_smiles}' not found in GAN embeddings.")
        return None
        
    manual_query_idx = manual_smiles_list.index(query_smiles)
    gan_query_idx = gan_smiles_list.index(query_smiles)
    
    # Convert query SMILES to RDKit molecule
    query_mol = Chem.MolFromSmiles(query_smiles)
    if query_mol is None:
        logger.error(f"Failed to convert query SMILES {query_smiles} to molecule.")
        return None
    
    # Find nearest neighbors in both embedding spaces
    manual_nn_indices, manual_similarities = find_nearest_neighbors(
        manual_query_idx, manual_embeddings, num_neighbors
    )
    
    gan_nn_indices, gan_similarities = find_nearest_neighbors(
        gan_query_idx, gan_embeddings, num_neighbors
    )
    
    # Convert all SMILES to molecules
    manual_mols = [Chem.MolFromSmiles(smiles) for smiles in manual_smiles_list]
    gan_mols = [Chem.MolFromSmiles(smiles) for smiles in gan_smiles_list]
    
    # Calculate fingerprint similarities
    manual_ecfp_sims = [calculate_fingerprint_similarity(query_mol, manual_mols[idx], "ecfp") 
                      for idx in manual_nn_indices]
    manual_rdkfp_sims = [calculate_fingerprint_similarity(query_mol, manual_mols[idx], "rdkfp") 
                       for idx in manual_nn_indices]
    
    gan_ecfp_sims = [calculate_fingerprint_similarity(query_mol, gan_mols[idx], "ecfp") 
                   for idx in gan_nn_indices]
    gan_rdkfp_sims = [calculate_fingerprint_similarity(query_mol, gan_mols[idx], "rdkfp") 
                    for idx in gan_nn_indices]
    
    # Create visualization
    fig = plt.figure(figsize=(15, 10))
    plt.suptitle(f"Comparison of Nearest Neighbors (Manual vs. GAN Embeddings)", fontsize=16)
    
    # Set up grid layout with 3 rows (1 for query, 2 for neighbors)
    rows = 3
    cols = 4
    gs = GridSpec(rows, cols, figure=fig)
    
    # Plot query molecule at the top center
    ax_query = fig.add_subplot(gs[0, 1:3])
    query_img = mol_to_img(query_mol)
    if query_img:
        ax_query.imshow(query_img)
    ax_query.set_title(f"Query molecule\n{query_smiles[:20]}...", fontsize=12)
    ax_query.axis('off')
    
    # Add "Manual side" and "GAN side" labels
    ax_manual = fig.add_subplot(gs[1, 0])
    ax_manual.text(0.5, 0.5, "Manual\nAugmentation", 
                 ha='center', va='center', fontsize=12, fontweight='bold')
    ax_manual.axis('off')
    
    ax_gan = fig.add_subplot(gs[1, 3])
    ax_gan.text(0.5, 0.5, "GAN\nAugmentation", 
              ha='center', va='center', fontsize=12, fontweight='bold')
    ax_gan.axis('off')
    
    # Calculate grid positions for neighbors (3 rows, 3 columns for each side)
    num_per_row = 3
    
    # Plot manual nearest neighbors
    for i, idx in enumerate(manual_nn_indices):
        row = (i // num_per_row) + 1
        col = (i % num_per_row)
        ax = fig.add_subplot(gs[row, col])
        mol = manual_mols[idx]
        if mol is not None:
            mol_img = mol_to_img(mol)
            if mol_img:
                ax.imshow(mol_img)
            ax.set_title(f"Nbr{i+1}: {idx}\nRDKFP: {manual_rdkfp_sims[i]:.3f}\nECFP: {manual_ecfp_sims[i]:.3f}")
        else:
            ax.text(0.5, 0.5, "Invalid Molecule", ha="center", va="center")
        ax.axis('off')
    
    # Plot GAN nearest neighbors
    for i, idx in enumerate(gan_nn_indices):
        row = (i // num_per_row) + 1
        col = (i % num_per_row) + 1  # Offset by 1 to put in right half
        ax = fig.add_subplot(gs[row, col])
        mol = gan_mols[idx]
        if mol is not None:
            mol_img = mol_to_img(mol)
            if mol_img:
                ax.imshow(mol_img)
            ax.set_title(f"Nbr{i+1}: {idx}\nRDKFP: {gan_rdkfp_sims[i]:.3f}\nECFP: {gan_ecfp_sims[i]:.3f}")
        else:
            ax.text(0.5, 0.5, "Invalid Molecule", ha="center", va="center")
        ax.axis('off')
    
    # Adjust layout and save
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    
    # Generate a unique filename using hash
    filename = f"comparison_{hashlib.md5(query_smiles.encode()).hexdigest()[:10]}.png"
    output_path = os.path.join(output_dir, filename)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info(f"Comparison visualization saved to {output_path}")
    return output_path

def main():
    # Define file paths
    manual_embeddings_path = './embeddings/manual_embeddings_molecules_20250310_110848.pkl'
    gan_embeddings_path = './embeddings/final_embeddings_molecules_20250309_110249.pkl'
    output_dir = './molecule_comparisons'
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load embeddings
    logger.info("Loading manual embeddings...")
    manual_data = load_embeddings(manual_embeddings_path)
    
    logger.info("Loading GAN embeddings...")
    gan_data = load_embeddings(gan_embeddings_path)
    
    # Extract SMILES
    manual_smiles_list = extract_smiles_from_data(manual_data)
    gan_smiles_list = extract_smiles_from_data(gan_data)
    
    # Find common SMILES between the two datasets
    common_smiles = list(set(manual_smiles_list) & set(gan_smiles_list))
    logger.info(f"Found {len(common_smiles)} molecules in common between manual and GAN embeddings")
    
    if not common_smiles:
        logger.error("No common molecules found. Cannot proceed.")
        return
    
    # Select 2 random molecules from common SMILES
    num_molecules = min(2, len(common_smiles))
    query_smiles_list = random.sample(common_smiles, num_molecules)
    
    # Extract embeddings
    manual_embeddings = manual_data['embeddings']
    gan_embeddings = gan_data['embeddings']
    
    for i, query_smiles in enumerate(query_smiles_list):
        logger.info(f"Processing query {i+1}/{num_molecules}: {query_smiles}")
        output_path = create_comparison_visualization(
            query_smiles, 
            manual_embeddings_path, 
            gan_embeddings_path,
            manual_smiles_list=manual_smiles_list,
            gan_smiles_list=gan_smiles_list,
            manual_embeddings=manual_embeddings,
            gan_embeddings=gan_embeddings,
            num_neighbors=6, 
            output_dir=output_dir
        )
        if output_path:
            logger.info(f"Visualization saved to {output_path}")
        else:
            logger.error(f"Failed to create visualization for {query_smiles}")

if __name__ == "__main__":
    main()

2025-03-10 11:56:36,035 - __main__ - INFO - Loading manual embeddings...
2025-03-10 11:56:36,044 - __main__ - INFO - Loading GAN embeddings...
2025-03-10 11:56:39,903 - __main__ - INFO - Found 0 molecules in common between manual and GAN embeddings
2025-03-10 11:56:39,906 - __main__ - ERROR - No common molecules found. Cannot proceed.
