In [3]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs, Draw
from sklearn.metrics.pairwise import cosine_similarity
import pickle
import os
import random
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def compute_ecfp(mol, radius=2, nBits=2048):
    """Compute ECFP (Morgan) fingerprint for a molecule"""
    if mol is None:
        return None
    return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)

def compute_rdkfp(mol, minPath=1, maxPath=7, nBits=2048):
    """Compute RDKit topological fingerprint"""
    if mol is None:
        return None
    return AllChem.RDKFingerprint(mol, minPath=minPath, maxPath=maxPath, fpSize=nBits)

def calculate_fingerprint_similarity(mol1, mol2, fp_type="ecfp"):
    """Calculate fingerprint similarity between two molecules"""
    if mol1 is None or mol2 is None:
        return 0.0
    
    try:
        if fp_type.lower() == "ecfp":
            fp1 = compute_ecfp(mol1)
            fp2 = compute_ecfp(mol2)
        elif fp_type.lower() == "rdkfp":
            fp1 = compute_rdkfp(mol1)
            fp2 = compute_rdkfp(mol2)
        else:
            raise ValueError(f"Unsupported fingerprint type: {fp_type}")
        
        if fp1 is None or fp2 is None:
            return 0.0
            
        return DataStructs.TanimotoSimilarity(fp1, fp2)
    except Exception as e:
        logger.error(f"Error calculating fingerprint similarity: {e}")
        return 0.0

def mol_to_img(mol, molSize=(300, 200)):
    """Convert RDKit molecule to image using a simple approach"""
    if mol is None:
        return None
    
    try:
        # Use the simpler Draw.MolToImage function which is more reliable
        image = Draw.MolToImage(mol, size=molSize)
        return image
    except Exception as e:
        logger.error(f"Error converting molecule to image: {e}")
        return None

def load_embeddings(filepath):
    """Load embeddings from pickle file"""
    try:
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        return data
    except Exception as e:
        logger.error(f"Error loading embeddings from {filepath}: {e}")
        return None

def find_nearest_neighbors(query_idx, embeddings, num_neighbors=9, exclude_self=True):
    """Find nearest neighbors for a query molecule in embedding space"""
    query_embedding = embeddings[query_idx].reshape(1, -1)
    
    # Calculate cosine similarities
    similarities = cosine_similarity(query_embedding, embeddings)[0]
    
    # Get indices of nearest neighbors
    if exclude_self:
        # Exclude the query molecule itself (which would be the closest with similarity=1)
        neighbor_indices = np.argsort(similarities)[::-1][1:num_neighbors+1]
    else:
        neighbor_indices = np.argsort(similarities)[::-1][:num_neighbors]
    
    return neighbor_indices, similarities[neighbor_indices]

def create_gan_visualization(query_idx, gan_embeddings_path, num_neighbors=9, output_dir='./visualizations'):
    """Create visualization of nearest neighbors in GAN embedding space for a given query index"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Load embeddings
    logger.info(f"Loading GAN embeddings from {gan_embeddings_path}")
    gan_data = load_embeddings(gan_embeddings_path)
    
    if gan_data is None:
        logger.error("Failed to load embeddings. Exiting.")
        return
    
    # Extract embeddings
    gan_embeddings = gan_data['embeddings']
    
    # Get SMILES from molecule_data
    gan_smiles_list = []
    if 'molecule_data' in gan_data:
        for mol_data in gan_data['molecule_data']:
            if isinstance(mol_data, dict) and 'smiles' in mol_data:
                gan_smiles_list.append(mol_data['smiles'])
    
    # Check if SMILES were found
    if not gan_smiles_list:
        # Try to get SMILES from other fields
        if hasattr(gan_data, 'smiles_list'):
            gan_smiles_list = gan_data.smiles_list
        elif 'smiles' in gan_data:
            gan_smiles_list = gan_data['smiles']
    
    # Check if we found any SMILES
    if not gan_smiles_list:
        logger.error("No SMILES found in the embeddings file")
        # Create empty SMILES list with same length as embeddings
        gan_smiles_list = [f"mol_{i}" for i in range(len(gan_embeddings))]
    
    # Validate query index
    if query_idx >= len(gan_embeddings):
        logger.error(f"Query index {query_idx} out of range (max: {len(gan_embeddings)-1})")
        return None
    
    # Print information
    logger.info(f"Embeddings shape: {gan_embeddings.shape}")
    logger.info(f"Number of SMILES: {len(gan_smiles_list)}")
    
    # Get query SMILES
    if query_idx < len(gan_smiles_list):
        query_smiles = gan_smiles_list[query_idx]
        query_mol = Chem.MolFromSmiles(query_smiles) if isinstance(query_smiles, str) else None
    else:
        query_smiles = f"mol_{query_idx}"
        query_mol = None
    
    # Find nearest neighbors
    neighbor_indices, similarities = find_nearest_neighbors(query_idx, gan_embeddings, num_neighbors)
    
    # Convert SMILES to molecules if possible
    mols = []
    for i, smiles in enumerate(gan_smiles_list):
        if i < len(gan_smiles_list) and isinstance(smiles, str):
            mol = Chem.MolFromSmiles(smiles)
            mols.append(mol)
        else:
            mols.append(None)
    
    # Calculate fingerprint similarities
    ecfp_sims = []
    rdkfp_sims = []
    for idx in neighbor_indices:
        if query_mol is not None and idx < len(mols) and mols[idx] is not None:
            ecfp_sims.append(calculate_fingerprint_similarity(query_mol, mols[idx], "ecfp"))
            rdkfp_sims.append(calculate_fingerprint_similarity(query_mol, mols[idx], "rdkfp"))
        else:
            ecfp_sims.append(0.0)
            rdkfp_sims.append(0.0)
    
    # Create visualization
    fig = plt.figure(figsize=(12, 10))
    plt.suptitle(f"GAN Embedding Nearest Neighbors for Query Molecule {query_idx}", fontsize=16)
    
    # Create a 4x3 grid (query + 9 neighbors)
    gs = GridSpec(4, 3, figure=fig)
    
    # Plot query molecule in the center of the first row
    ax_query = fig.add_subplot(gs[0, 1])
    
    if query_mol is not None:
        query_img = mol_to_img(query_mol)
        if query_img:
            ax_query.imshow(query_img)
        ax_query.set_title(f"Query molecule\n{query_smiles[:20]}...", fontsize=12)
    else:
        ax_query.text(0.5, 0.5, f"Query {query_idx}\nNo valid SMILES", ha="center", va="center")
        ax_query.set_title(f"Query index: {query_idx}", fontsize=12)
    ax_query.axis('off')
    
    # Plot neighbors in the remaining grid cells
    for i, idx in enumerate(neighbor_indices):
        row = (i // 3) + 1  # Start from second row
        col = i % 3
        ax = fig.add_subplot(gs[row, col])
        
        if idx < len(mols) and mols[idx] is not None:
            mol_img = mol_to_img(mols[idx])
            if mol_img:
                ax.imshow(mol_img)
            ax.set_title(f"Nbr{i+1} (idx: {idx})\nRDKFP: {rdkfp_sims[i]:.3f}\nECFP: {ecfp_sims[i]:.3f}")
        else:
            if idx < len(gan_smiles_list):
                smiles_text = gan_smiles_list[idx][:20] + "..." if len(gan_smiles_list[idx]) > 20 else gan_smiles_list[idx]
                ax.text(0.5, 0.5, f"Nbr{i+1} (idx: {idx})\n{smiles_text}", ha="center", va="center")
            else:
                ax.text(0.5, 0.5, f"Nbr{i+1} (idx: {idx})\nNo valid SMILES", ha="center", va="center")
        ax.axis('off')
    
    # Adjust layout and save
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    
    # Generate a filename
    filename = f"gan_neighbors_query_{query_idx}.png"
    output_path = os.path.join(output_dir, filename)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info(f"GAN visualization saved to {output_path}")
    return output_path

def main():
    # Define file paths
    gan_embeddings_path = './embeddings/final_embeddings_molecules_20250309_110249.pkl'
    output_dir = './molecule_visualizations'
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load GAN embeddings just to get the dataset size
    gan_data = load_embeddings(gan_embeddings_path)
    if gan_data is None:
        logger.error("Failed to load GAN embeddings. Exiting.")
        return
    
    num_molecules = len(gan_data['embeddings'])
    logger.info(f"Found {num_molecules} molecules in GAN embeddings")
    
    # Use fixed query indices that are definitely in range
    query_indices = [0, 10]  # Use first and 11th molecule
    
    for i, query_idx in enumerate(query_indices):
        logger.info(f"Processing query {i+1}/{len(query_indices)}: index {query_idx}")
        output_path = create_gan_visualization(
            query_idx,
            gan_embeddings_path,
            num_neighbors=9,
            output_dir=output_dir
        )
        if output_path:
            logger.info(f"Visualization saved to {output_path}")
        else:
            logger.error(f"Failed to create visualization for query index {query_idx}")

if __name__ == "__main__":
    main()

2025-03-10 12:26:26,649 - __main__ - INFO - Found 9937 molecules in GAN embeddings
2025-03-10 12:26:26,651 - __main__ - INFO - Processing query 1/2: index 0
2025-03-10 12:26:26,652 - __main__ - INFO - Loading GAN embeddings from ./embeddings/final_embeddings_molecules_20250309_110249.pkl
2025-03-10 12:26:29,267 - __main__ - ERROR - No SMILES found in the embeddings file
2025-03-10 12:26:29,270 - __main__ - INFO - Embeddings shape: (9937, 128)
2025-03-10 12:26:29,270 - __main__ - INFO - Number of SMILES: 9937
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_0
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_0
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_0' for input: 'mol_0'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_0
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_0
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_0' for input: 'mol_0

[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_7130' for input: 'mol_7130'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_7131
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_7131
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_7131' for input: 'mol_7131'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_7132
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_7132
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_7132' for input: 'mol_7132'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_7133
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_7133
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_7133' for input: 'mol_7133'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_7134
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29

[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_8401' for input: 'mol_8401'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_8402
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_8402
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_8402' for input: 'mol_8402'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_8403
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_8403
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_8403' for input: 'mol_8403'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_8404
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_8404
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_8404' for input: 'mol_8404'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_8405
[12:26:29] SMILES Parse Error: check for mistakes around position

[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_9416' for input: 'mol_9416'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_9417
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_9417
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_9417' for input: 'mol_9417'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_9418
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_9418
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_9418' for input: 'mol_9418'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_9419
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29] mol_9419
[12:26:29] ^
[12:26:29] SMILES Parse Error: Failed parsing SMILES 'mol_9419' for input: 'mol_9419'
[12:26:29] SMILES Parse Error: syntax error while parsing: mol_9420
[12:26:29] SMILES Parse Error: check for mistakes around position 1:
[12:26:29

2025-03-10 12:26:30,522 - __main__ - INFO - GAN visualization saved to ./molecule_visualizations\gan_neighbors_query_0.png
2025-03-10 12:26:30,824 - __main__ - INFO - Visualization saved to ./molecule_visualizations\gan_neighbors_query_0.png
2025-03-10 12:26:30,826 - __main__ - INFO - Processing query 2/2: index 10
2025-03-10 12:26:30,827 - __main__ - INFO - Loading GAN embeddings from ./embeddings/final_embeddings_molecules_20250309_110249.pkl
2025-03-10 12:26:33,669 - __main__ - ERROR - No SMILES found in the embeddings file
2025-03-10 12:26:33,673 - __main__ - INFO - Embeddings shape: (9937, 128)
2025-03-10 12:26:33,674 - __main__ - INFO - Number of SMILES: 9937
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_10
[12:26:33] SMILES Parse Error: check for mistakes around position 1:
[12:26:33] mol_10
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_10' for input: 'mol_10'
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_0
[12:26:33] SM

[12:26:33] SMILES Parse Error: syntax error while parsing: mol_6571
[12:26:33] SMILES Parse Error: check for mistakes around position 1:
[12:26:33] mol_6571
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_6571' for input: 'mol_6571'
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_6572
[12:26:33] SMILES Parse Error: check for mistakes around position 1:
[12:26:33] mol_6572
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_6572' for input: 'mol_6572'
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_6573
[12:26:33] SMILES Parse Error: check for mistakes around position 1:
[12:26:33] mol_6573
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_6573' for input: 'mol_6573'
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_6574
[12:26:33] SMILES Parse Error: check for mistakes around position 1:
[12:26:33] mol_6574
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_6574

[12:26:33] mol_7174
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_7174' for input: 'mol_7174'
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_7175
[12:26:33] SMILES Parse Error: check for mistakes around position 1:
[12:26:33] mol_7175
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_7175' for input: 'mol_7175'
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_7176
[12:26:33] SMILES Parse Error: check for mistakes around position 1:
[12:26:33] mol_7176
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_7176' for input: 'mol_7176'
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_7177
[12:26:33] SMILES Parse Error: check for mistakes around position 1:
[12:26:33] mol_7177
[12:26:33] ^
[12:26:33] SMILES Parse Error: Failed parsing SMILES 'mol_7177' for input: 'mol_7177'
[12:26:33] SMILES Parse Error: syntax error while parsing: mol_7178
[12:26:33] SMILES Parse Error: check for mist

[12:26:33] mol_7868
[12:26:34] ^
[12:26:34] SMILES Parse Error: Failed parsing SMILES 'mol_7868' for input: 'mol_7868'
[12:26:34] SMILES Parse Error: syntax error while parsing: mol_7869
[12:26:34] SMILES Parse Error: check for mistakes around position 1:
[12:26:34] mol_7869
[12:26:34] ^
[12:26:34] SMILES Parse Error: Failed parsing SMILES 'mol_7869' for input: 'mol_7869'
[12:26:34] SMILES Parse Error: syntax error while parsing: mol_7870
[12:26:34] SMILES Parse Error: check for mistakes around position 1:
[12:26:34] mol_7870
[12:26:34] ^
[12:26:34] SMILES Parse Error: Failed parsing SMILES 'mol_7870' for input: 'mol_7870'
[12:26:34] SMILES Parse Error: syntax error while parsing: mol_7871
[12:26:34] SMILES Parse Error: check for mistakes around position 1:
[12:26:34] mol_7871
[12:26:34] ^
[12:26:34] SMILES Parse Error: Failed parsing SMILES 'mol_7871' for input: 'mol_7871'
[12:26:34] SMILES Parse Error: syntax error while parsing: mol_7872
[12:26:34] SMILES Parse Error: check for mist

[12:26:34] SMILES Parse Error: Failed parsing SMILES 'mol_8723' for input: 'mol_8723'
[12:26:34] SMILES Parse Error: syntax error while parsing: mol_8724
[12:26:34] SMILES Parse Error: check for mistakes around position 1:
[12:26:34] mol_8724
[12:26:34] ^
[12:26:34] SMILES Parse Error: Failed parsing SMILES 'mol_8724' for input: 'mol_8724'
[12:26:34] SMILES Parse Error: syntax error while parsing: mol_8725
[12:26:34] SMILES Parse Error: check for mistakes around position 1:
[12:26:34] mol_8725
[12:26:34] ^
[12:26:34] SMILES Parse Error: Failed parsing SMILES 'mol_8725' for input: 'mol_8725'
[12:26:34] SMILES Parse Error: syntax error while parsing: mol_8726
[12:26:34] SMILES Parse Error: check for mistakes around position 1:
[12:26:34] mol_8726
[12:26:34] ^
[12:26:34] SMILES Parse Error: Failed parsing SMILES 'mol_8726' for input: 'mol_8726'
[12:26:34] SMILES Parse Error: syntax error while parsing: mol_8727
[12:26:34] SMILES Parse Error: check for mistakes around position 1:
[12:26:34

2025-03-10 12:26:34,896 - __main__ - INFO - GAN visualization saved to ./molecule_visualizations\gan_neighbors_query_10.png
2025-03-10 12:26:35,224 - __main__ - INFO - Visualization saved to ./molecule_visualizations\gan_neighbors_query_10.png
