In [1]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit.Chem.Draw import MolsToGridImage
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import seaborn as sns
from matplotlib.gridspec import GridSpec
import pandas as pd
from scipy.spatial.distance import pdist, squareform

# Path to your embedding files
MANUAL_EMBEDDINGS_PATH = './embeddings/final_embeddings_molecules_20250310_143038.pkl'
GAN_EMBEDDINGS_PATH = './embeddings/final_embeddings_molecules_20250310_142842.pkl'

def load_embeddings(file_path):
    """Load embeddings and extract SMILES information"""
    print(f"Loading embeddings from: {file_path}")
    
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    
    embeddings = data['embeddings']
    
    # Extract SMILES from various possible locations in the saved data
    smiles_list = []
    
    # Option 1: Direct smiles_list field
    if 'smiles_list' in data:
        smiles_list = data['smiles_list']
    # Option 2: Extract from molecule_data
    elif 'molecule_data' in data:
        for mol_data in data['molecule_data']:
            if isinstance(mol_data, dict) and 'smiles' in mol_data:
                smiles_list.append(mol_data['smiles'])
            else:
                # If no SMILES found, add empty string as placeholder
                smiles_list.append("")
    
    print(f"Loaded {len(embeddings)} embeddings and {len(smiles_list)} SMILES strings")
    
    # Filter out entries with empty SMILES
    valid_indices = [i for i, smiles in enumerate(smiles_list) if smiles]
    if len(valid_indices) < len(smiles_list):
        print(f"Found {len(valid_indices)} valid SMILES strings out of {len(smiles_list)}")
        valid_embeddings = embeddings[valid_indices]
        valid_smiles = [smiles_list[i] for i in valid_indices]
        return valid_embeddings, valid_smiles
    
    return embeddings, smiles_list

def generate_fingerprints(smiles_list, fp_type='ecfp', radius=2, bits=2048):
    """Generate fingerprints for a list of SMILES strings"""
    fps = []
    valid_mols = []
    valid_smiles = []
    valid_indices = []
    
    for i, smiles in enumerate(tqdm(smiles_list, desc=f"Generating {fp_type} fingerprints")):
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue
                
            if fp_type.lower() == 'ecfp':
                # Generate Extended Connectivity Fingerprint (Morgan)
                fp = GetMorganFingerprintAsBitVect(mol, radius, nBits=bits)
            elif fp_type.lower() == 'rdkfp':
                # Generate RDKit fingerprint
                fp = FingerprintMols.FingerprintMol(mol, minPath=1, maxPath=7, fpSize=bits)
            else:
                raise ValueError(f"Unsupported fingerprint type: {fp_type}")
                
            # Convert to numpy array
            arr = np.zeros((bits,))
            DataStructs.ConvertToNumpyArray(fp, arr)
            
            fps.append(arr)
            valid_mols.append(mol)
            valid_smiles.append(smiles)
            valid_indices.append(i)
            
        except Exception as e:
            print(f"Error processing molecule {i} ({smiles}): {e}")
    
    return np.array(fps), valid_mols, valid_smiles, valid_indices

def calculate_fp_similarities(fp_array):
    """Calculate pairwise fingerprint similarities using Tanimoto coefficient"""
    n = len(fp_array)
    similarities = np.zeros((n, n))
    
    for i in tqdm(range(n), desc="Calculating fingerprint similarities"):
        for j in range(i, n):
            # Calculate Tanimoto similarity
            intersection = np.sum(fp_array[i] * fp_array[j])
            union = np.sum(fp_array[i]) + np.sum(fp_array[j]) - intersection
            
            # Avoid division by zero
            if union > 0:
                sim = intersection / union
            else:
                sim = 0.0
                
            similarities[i, j] = sim
            similarities[j, i] = sim  # Symmetric matrix
            
    return similarities

def calculate_embedding_distances(embeddings):
    """Calculate pairwise distances between embeddings using cosine distance"""
    # Normalize embeddings
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    norms[norms == 0] = 1.0  # Avoid division by zero
    embeddings_norm = embeddings / norms
    
    # Calculate cosine similarity
    cos_sim = np.dot(embeddings_norm, embeddings_norm.T)
    
    # Convert to cosine distance (1 - similarity)
    distances = 1 - cos_sim
    
    return distances

def plot_similarity_vs_distance(embedding_distances, fp_similarities, fp_type, title=None, save_path=None):
    """Plot fingerprint similarity vs. embedding distance ranking percentage"""
    n = embedding_distances.shape[0]
    
    # Convert distances to rankings for each molecule
    flat_distances = []
    flat_similarities = []
    
    for i in range(n):
        # Get distances from molecule i to all others
        distances = embedding_distances[i]
        
        # Create ranking (0 to 1)
        sorted_indices = np.argsort(distances)
        rankings = np.linspace(0, 1, n)
        
        # Skip self-comparison
        for j, idx in enumerate(sorted_indices):
            if i != idx:  # Skip self
                percentile = rankings[j]
                flat_distances.append(percentile)
                flat_similarities.append(fp_similarities[i, idx])
    
    # Create bins for distance rankings
    bins = np.linspace(0, 1, 11)  # 10 bins
    bin_centers = (bins[:-1] + bins[1:]) / 2
    
    # Calculate mean and std for each bin
    bin_means = []
    bin_stds = []
    
    for i in range(len(bins) - 1):
        mask = (np.array(flat_distances) >= bins[i]) & (np.array(flat_distances) < bins[i+1])
        if np.sum(mask) > 0:
            bin_means.append(np.mean(np.array(flat_similarities)[mask]))
            bin_stds.append(np.std(np.array(flat_similarities)[mask]))
        else:
            bin_means.append(0)
            bin_stds.append(0)
    
    # Create plot
    plt.figure(figsize=(8, 6))
    plt.errorbar(bin_centers, bin_means, yerr=bin_stds, fmt='o-', capsize=4, label=fp_type)
    
    plt.xlabel('Distance Ranking Percentage', fontsize=14)
    plt.ylabel(f'{fp_type} Similarity', fontsize=14)
    plt.ylim(0, 1)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)
    
    if title:
        plt.title(title, fontsize=16)
    else:
        plt.title(f'{fp_type} Similarity vs. Distance Ranking', fontsize=16)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()
    
    return bin_centers, bin_means, bin_stds

def plot_similarity_distribution(fp_similarities_ecfp, fp_similarities_rdkfp, query_idx, title=None, save_path=None):
    """Plot distribution of ECFP and RDKFP similarities for a query molecule"""
    # Get similarities to query molecule
    ecfp_sims = fp_similarities_ecfp[query_idx]
    rdkfp_sims = fp_similarities_rdkfp[query_idx]
    
    plt.figure(figsize=(8, 6))
    
    plt.hist(ecfp_sims, bins=30, alpha=0.7, label='ECFP', color='red', density=True)
    plt.hist(rdkfp_sims, bins=30, alpha=0.7, label='RDKFP', color='green', density=True)
    
    plt.xlabel('FP Similarity', fontsize=14)
    plt.ylabel('Fraction', fontsize=14)
    plt.legend(fontsize=12)
    
    if title:
        plt.title(title, fontsize=16)
    else:
        plt.title('Distribution of Fingerprint Similarities', fontsize=16)
    
    plt.grid(True, alpha=0.3)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()
    
    return ecfp_sims, rdkfp_sims

def visualize_closest_molecules(embeddings, valid_mols, valid_smiles, 
                              fp_similarities_ecfp, fp_similarities_rdkfp,
                              query_idx, n_closest=9, title=None, save_path=None):
    """Visualize a query molecule and its closest neighbors in embedding space"""
    # Calculate cosine similarities
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    embeddings_norm = embeddings / norms
    
    query_embedding = embeddings_norm[query_idx].reshape(1, -1)
    similarities = np.dot(embeddings_norm, query_embedding.T).flatten()
    
    # Get indices of most similar molecules (excluding the query itself)
    sim_indices = np.argsort(similarities)[::-1]
    closest_indices = [idx for idx in sim_indices if idx != query_idx][:n_closest]
    
    # Get fingerprint similarities
    ecfp_sims = fp_similarities_ecfp[query_idx, closest_indices]
    rdkfp_sims = fp_similarities_rdkfp[query_idx, closest_indices]
    
    # Create molecule objects for visualization
    query_mol = valid_mols[query_idx]
    closest_mols = [valid_mols[idx] for idx in closest_indices]
    
    # Prepare legend texts
    legends = []
    legends.append(f"Query\n{valid_smiles[query_idx]}")
    
    for i in range(n_closest):
        legends.append(f"RDKFP: {rdkfp_sims[i]:.3f}\nECFP: {ecfp_sims[i]:.3f}\n{valid_smiles[closest_indices[i]]}")
    
    # Create grid image
    mols_to_draw = [query_mol] + closest_mols
    img = MolsToGridImage(mols_to_draw, molsPerRow=5, subImgSize=(200, 200),
                        legends=legends)
    
    if title:
        img.save(save_path)
    
    return img

def combined_visualization(manual_embeddings, manual_smiles, gan_embeddings, gan_smiles, 
                         output_dir='visualization_results'):
    """Generate combined visualizations for comparing embedding methods"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Find common SMILES in both embeddings
    common_smiles = list(set(manual_smiles).intersection(set(gan_smiles)))
    print(f"Found {len(common_smiles)} common SMILES strings")
    
    if len(common_smiles) == 0:
        print("No common molecules found between the two embeddings!")
        return None
    
    # Limit to a manageable subset if needed
    if len(common_smiles) > 1000:
        print(f"Using first 1000 common molecules for analysis")
        common_smiles = common_smiles[:1000]
    
    # Generate fingerprints
    print("Generating fingerprints...")
    ecfp_fps, valid_mols, valid_smiles, valid_indices = generate_fingerprints(common_smiles, fp_type='ecfp')
    
    if len(valid_smiles) == 0:
        print("No valid molecules found after ECFP fingerprint generation!")
        return None
        
    rdkfp_fps, _, _, _ = generate_fingerprints(valid_smiles, fp_type='rdkfp')
    
    if len(rdkfp_fps) == 0:
        print("No valid molecules found after RDKFP fingerprint generation!")
        return None
    
    # Calculate fingerprint similarities
    print("Calculating fingerprint similarities...")
    if len(ecfp_fps) > 0:
        ecfp_similarities = calculate_fp_similarities(ecfp_fps)
    else:
        print("No valid ECFP fingerprints generated!")
        return None
        
    if len(rdkfp_fps) > 0:
        rdkfp_similarities = calculate_fp_similarities(rdkfp_fps)
    else:
        print("No valid RDKFP fingerprints generated!")
        return None
    
    # Make sure all valid_smiles exist in both embedding sets
    valid_filtered_smiles = []
    valid_filtered_mols = []
    
    for i, smiles in enumerate(valid_smiles):
        if smiles in manual_smiles and smiles in gan_smiles:
            valid_filtered_smiles.append(smiles)
            valid_filtered_mols.append(valid_mols[i])
    
    if len(valid_filtered_smiles) == 0:
        print("No valid molecules found in both embedding sets!")
        return None
    
    print(f"Using {len(valid_filtered_smiles)} molecules found in both embedding sets")
    
    # Instead of indexing the similarity matrices, recalculate them for the filtered molecules
    ecfp_filtered_fps, _, _, _ = generate_fingerprints(valid_filtered_smiles, fp_type='ecfp')
    rdkfp_filtered_fps, _, _, _ = generate_fingerprints(valid_filtered_smiles, fp_type='rdkfp')
    
    if len(ecfp_filtered_fps) > 0 and len(rdkfp_filtered_fps) > 0:
        ecfp_similarities = calculate_fp_similarities(ecfp_filtered_fps)
        rdkfp_similarities = calculate_fp_similarities(rdkfp_filtered_fps)
    else:
        print("Failed to generate fingerprints for filtered molecules!")
        return None
    
    # Extract manual embeddings for valid filtered smiles
    manual_indices = [manual_smiles.index(smiles) for smiles in valid_filtered_smiles]
    filtered_manual_embeddings = manual_embeddings[manual_indices]
    manual_distances = calculate_embedding_distances(filtered_manual_embeddings)
    
    # Extract GAN embeddings for valid filtered smiles
    gan_indices = [gan_smiles.index(smiles) for smiles in valid_filtered_smiles]
    filtered_gan_embeddings = gan_embeddings[gan_indices]
    gan_distances = calculate_embedding_distances(filtered_gan_embeddings)
    
    # Choose a query molecule (e.g., the first one)
    if len(valid_filtered_smiles) > 0:
        query_idx = 0
        query_smiles = valid_filtered_smiles[query_idx]
        print(f"Using query molecule: {query_smiles}")
    else:
        print("No valid molecules found for visualization!")
        return None
    
    # Plot similarity vs distance for both embeddings - ECFP
    print("Plotting similarity vs distance curves...")
    plot_similarity_vs_distance(
        manual_distances, 
        ecfp_similarities, 
        'ECFP',
        title='Manual CL: ECFP Similarity vs. Distance',
        save_path=f"{output_dir}/manual_ecfp_similarity_vs_distance.png"
    )
    
    plot_similarity_vs_distance(
        gan_distances, 
        ecfp_similarities, 
        'ECFP',
        title='GAN CL: ECFP Similarity vs. Distance',
        save_path=f"{output_dir}/gan_ecfp_similarity_vs_distance.png"
    )
    
    # Plot similarity vs distance for both embeddings - RDKFP
    plot_similarity_vs_distance(
        manual_distances, 
        rdkfp_similarities, 
        'RDKFP',
        title='Manual CL: RDKFP Similarity vs. Distance',
        save_path=f"{output_dir}/manual_rdkfp_similarity_vs_distance.png"
    )
    
    plot_similarity_vs_distance(
        gan_distances, 
        rdkfp_similarities, 
        'RDKFP',
        title='GAN CL: RDKFP Similarity vs. Distance',
        save_path=f"{output_dir}/gan_rdkfp_similarity_vs_distance.png"
    )
    
    # Plot similarity distributions
    plot_similarity_distribution(
        ecfp_similarities,
        rdkfp_similarities,
        query_idx,
        title='Manual CL: FP Similarity Distribution',
        save_path=f"{output_dir}/manual_fp_distribution.png"
    )
    
    plot_similarity_distribution(
        ecfp_similarities,
        rdkfp_similarities,
        query_idx,
        title='GAN CL: FP Similarity Distribution',
        save_path=f"{output_dir}/gan_fp_distribution.png"
    )
    
    # Visualize closest molecules
    print("Creating molecule visualizations...")
    
    # For Manual CL
    manual_mol_img = visualize_closest_molecules(
        filtered_manual_embeddings,
        valid_mols,
        valid_smiles,
        ecfp_similarities,
        rdkfp_similarities,
        query_idx,
        title='Manual CL: Query molecule and closest neighbors',
        save_path=f"{output_dir}/manual_closest_molecules.png"
    )
    
    # For GAN CL
    gan_mol_img = visualize_closest_molecules(
        filtered_gan_embeddings,
        valid_mols,
        valid_smiles,
        ecfp_similarities,
        rdkfp_similarities,
        query_idx,
        title='GAN CL: Query molecule and closest neighbors',
        save_path=f"{output_dir}/gan_closest_molecules.png"
    )
    
    # Create comparison figure
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 12))
    
    # Manual vs GAN for ECFP
    manual_bin_centers, manual_ecfp_means, manual_ecfp_stds = plot_similarity_vs_distance(
        manual_distances, ecfp_similarities, 'ECFP', save_path=None)
    
    gan_bin_centers, gan_ecfp_means, gan_ecfp_stds = plot_similarity_vs_distance(
        gan_distances, ecfp_similarities, 'ECFP', save_path=None)
    
    ax1.errorbar(manual_bin_centers, manual_ecfp_means, yerr=manual_ecfp_stds, 
                 fmt='o-', capsize=4, label='Manual CL', color='blue')
    ax1.errorbar(gan_bin_centers, gan_ecfp_means, yerr=gan_ecfp_stds, 
                 fmt='s-', capsize=4, label='GAN CL', color='red')
    
    ax1.set_xlabel('Distance Ranking Percentage')
    ax1.set_ylabel('ECFP Similarity')
    ax1.set_ylim(0, 1)
    ax1.grid(True, alpha=0.3)
    ax1.set_title('ECFP Similarity vs. Distance Ranking')
    ax1.legend()
    
    # Manual vs GAN for RDKFP
    manual_bin_centers, manual_rdkfp_means, manual_rdkfp_stds = plot_similarity_vs_distance(
        manual_distances, rdkfp_similarities, 'RDKFP', save_path=None)
    
    gan_bin_centers, gan_rdkfp_means, gan_rdkfp_stds = plot_similarity_vs_distance(
        gan_distances, rdkfp_similarities, 'RDKFP', save_path=None)
    
    ax2.errorbar(manual_bin_centers, manual_rdkfp_means, yerr=manual_rdkfp_stds, 
                 fmt='o-', capsize=4, label='Manual CL', color='blue')
    ax2.errorbar(gan_bin_centers, gan_rdkfp_means, yerr=gan_rdkfp_stds, 
                 fmt='s-', capsize=4, label='GAN CL', color='red')
    
    ax2.set_xlabel('Distance Ranking Percentage')
    ax2.set_ylabel('RDKFP Similarity')
    ax2.set_ylim(0, 1)
    ax2.grid(True, alpha=0.3)
    ax2.set_title('RDKFP Similarity vs. Distance Ranking')
    ax2.legend()
    
    # Load molecule images and display
    from PIL import Image
    
    manual_img_path = f"{output_dir}/manual_closest_molecules.png"
    gan_img_path = f"{output_dir}/gan_closest_molecules.png"
    
    if os.path.exists(manual_img_path) and os.path.exists(gan_img_path):
        manual_mol_img = Image.open(manual_img_path)
        gan_mol_img = Image.open(gan_img_path)
        
        ax3.imshow(np.array(manual_mol_img))
        ax3.set_title("Manual CL: Query molecule and closest neighbors")
        ax3.axis('off')
        
        ax4.imshow(np.array(gan_mol_img))
        ax4.set_title("GAN CL: Query molecule and closest neighbors")
        ax4.axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/combined_comparison.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"All visualizations saved to {output_dir}")
    
    return {
        'valid_mols': valid_mols,
        'valid_smiles': valid_smiles,
        'ecfp_similarities': ecfp_similarities,
        'rdkfp_similarities': rdkfp_similarities
    }

def analyze_specific_molecule(smiles, manual_embeddings, manual_smiles, 
                            gan_embeddings, gan_smiles, output_dir='molecule_analysis'):
    """Analyze a specific molecule and its neighborhood in both embedding spaces"""
    os.makedirs(output_dir, exist_ok=True)
    
    # First, verify the molecule is valid
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"Invalid SMILES: {smiles}")
        return None
    
    # Check if molecule exists in both embedding sets
    if smiles not in manual_smiles or smiles not in gan_smiles:
        print(f"Molecule {smiles} not found in both embedding sets")
        return None
    
    try:
        # Find indices
        manual_idx = manual_smiles.index(smiles)
        gan_idx = gan_smiles.index(smiles)
        
        # Find closest neighbors in each embedding space
        # For Manual embeddings
        manual_embedding = manual_embeddings[manual_idx].reshape(1, -1)
        manual_norms = np.linalg.norm(manual_embeddings, axis=1, keepdims=True)
        manual_norms[manual_norms == 0] = 1.0
        manual_embeddings_norm = manual_embeddings / manual_norms
        manual_similarities = np.dot(manual_embeddings_norm, manual_embedding.T).flatten()
        
        # For GAN embeddings
        gan_embedding = gan_embeddings[gan_idx].reshape(1, -1)
        gan_norms = np.linalg.norm(gan_embeddings, axis=1, keepdims=True)
        gan_norms[gan_norms == 0] = 1.0
        gan_embeddings_norm = gan_embeddings / gan_norms
        gan_similarities = np.dot(gan_embeddings_norm, gan_embedding.T).flatten()
        
        # Get top most similar molecules for each embedding (limiting to available molecules)
        n_neighbors = min(10, len(manual_similarities)-1, len(gan_similarities)-1)
        if n_neighbors <= 0:
            print(f"Not enough molecules for neighbor analysis")
            return None
        
        # Get top n_neighbors for manual embeddings
        manual_top_indices = np.argsort(manual_similarities)[::-1][:n_neighbors+1]
        manual_top_indices = [idx for idx in manual_top_indices if idx != manual_idx][:n_neighbors]
        
        # Get top n_neighbors for GAN embeddings
        gan_top_indices = np.argsort(gan_similarities)[::-1][:n_neighbors+1]
        gan_top_indices = [idx for idx in gan_top_indices if idx != gan_idx][:n_neighbors]
        
        # Filter for valid molecules
        manual_valid_indices = []
        manual_mols = []
        manual_neighbor_smiles = []
        
        for idx in manual_top_indices:
            if idx < len(manual_smiles):
                n_smiles = manual_smiles[idx]
                n_mol = Chem.MolFromSmiles(n_smiles)
                if n_mol is not None:
                    manual_valid_indices.append(idx)
                    manual_mols.append(n_mol)
                    manual_neighbor_smiles.append(n_smiles)
        
        gan_valid_indices = []
        gan_mols = []
        gan_neighbor_smiles = []
        
        for idx in gan_top_indices:
            if idx < len(gan_smiles):
                n_smiles = gan_smiles[idx]
                n_mol = Chem.MolFromSmiles(n_smiles)
                if n_mol is not None:
                    gan_valid_indices.append(idx)
                    gan_mols.append(n_mol)
                    gan_neighbor_smiles.append(n_smiles)
        
        # Create legends
        query_legend = f"Query:\n{smiles}"
        
        manual_legends = []
        for i, idx in enumerate(manual_valid_indices):
            if i < len(manual_valid_indices):
                sim = manual_similarities[idx]
                manual_legends.append(f"Sim: {sim:.3f}\n{manual_smiles[idx]}")
        
        gan_legends = []
        for i, idx in enumerate(gan_valid_indices):
            if i < len(gan_valid_indices):
                sim = gan_similarities[idx]
                gan_legends.append(f"Sim: {sim:.3f}\n{gan_smiles[idx]}")
        
        # Draw molecules if we have any valid ones
        if len(manual_mols) > 0 or len(gan_mols) > 0:
            # Query molecule
            query_img = MolsToGridImage([mol], molsPerRow=1, subImgSize=(200, 200),
                                    legends=[query_legend])
            query_img.save(f"{output_dir}/query_molecule.png")
            
            # Manual neighbors
            if len(manual_mols) > 0:
                manual_img = MolsToGridImage(manual_mols, molsPerRow=min(5, len(manual_mols)), 
                                          subImgSize=(200, 200), legends=manual_legends)
                manual_img.save(f"{output_dir}/manual_neighbors.png")
            
            # GAN neighbors
            if len(gan_mols) > 0:
                gan_img = MolsToGridImage(gan_mols, molsPerRow=min(5, len(gan_mols)), 
                                       subImgSize=(200, 200), legends=gan_legends)
                gan_img.save(f"{output_dir}/gan_neighbors.png")
            
            # Create combined visualization if we have the images
            if os.path.exists(f"{output_dir}/query_molecule.png"):
                try:
                    from PIL import Image
                    fig, axes = plt.subplots(1 + min(1, len(manual_mols) > 0) + min(1, len(gan_mols) > 0), 
                                          1, figsize=(10, 8))
                    
                    # Handle case when only one subplot is created
                    if not isinstance(axes, np.ndarray):
                        axes = [axes]
                    
                    # Query molecule
                    query_img_pil = Image.open(f"{output_dir}/query_molecule.png")
                    axes[0].imshow(np.array(query_img_pil))
                    axes[0].set_title(f"Query Molecule: {smiles}")
                    axes[0].axis('off')
                    
                    # Manual neighbors if available
                    subplot_idx = 1
                    if len(manual_mols) > 0 and subplot_idx < len(axes):
                        if os.path.exists(f"{output_dir}/manual_neighbors.png"):
                            manual_img_pil = Image.open(f"{output_dir}/manual_neighbors.png")
                            axes[subplot_idx].imshow(np.array(manual_img_pil))
                            axes[subplot_idx].set_title("Nearest Neighbors in Manual CL Embedding Space")
                            axes[subplot_idx].axis('off')
                            subplot_idx += 1
                    
                    # GAN neighbors if available
                    if len(gan_mols) > 0 and subplot_idx < len(axes):
                        if os.path.exists(f"{output_dir}/gan_neighbors.png"):
                            gan_img_pil = Image.open(f"{output_dir}/gan_neighbors.png")
                            axes[subplot_idx].imshow(np.array(gan_img_pil))
                            axes[subplot_idx].set_title("Nearest Neighbors in GAN CL Embedding Space")
                            axes[subplot_idx].axis('off')
                    
                    plt.tight_layout()
                    plt.savefig(f"{output_dir}/combined_neighbors.png", dpi=300, bbox_inches='tight')
                    plt.close()
                except Exception as e:
                    print(f"Error creating combined visualization: {e}")
            
            print(f"Molecule analysis saved to {output_dir}")
            
            return {
                'query_mol': mol,
                'manual_neighbors': manual_mols,
                'gan_neighbors': gan_mols,
                'manual_similarities': manual_similarities[manual_valid_indices] if len(manual_valid_indices) > 0 else [],
                'gan_similarities': gan_similarities[gan_valid_indices] if len(gan_valid_indices) > 0 else []
            }
        else:
            print("No valid neighbors found for visualization")
            return None
            
    except Exception as e:
        print(f"Error analyzing molecule {smiles}: {e}")
        return None

def find_interesting_molecules(manual_embeddings, manual_smiles, gan_embeddings, gan_smiles, output_dir='interesting_molecules'):
    """Find molecules where manual and GAN embeddings differ significantly in their nearest neighbors"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Find common molecules
    common_smiles = list(set(manual_smiles).intersection(set(gan_smiles)))
    print(f"Found {len(common_smiles)} common molecules")
    
    if len(common_smiles) == 0:
        print("No common molecules found!")
        return []
    
    # Keep only molecules that exist in both embedding sets
    filtered_smiles = []
    for smiles in common_smiles:
        try:
            # Make sure molecule exists in both embedding sets
            manual_idx = manual_smiles.index(smiles)
            gan_idx = gan_smiles.index(smiles)
            # Make sure the molecule is valid 
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                filtered_smiles.append(smiles)
        except (ValueError, IndexError):
            # Skip if not found in either set
            continue
    
    if len(filtered_smiles) == 0:
        print("No valid common molecules found!")
        return []
        
    print(f"Using {len(filtered_smiles)} valid common molecules")
    
    # Sample a subset for analysis if there are too many
    if len(filtered_smiles) > 100:
        sampled_indices = np.random.choice(len(filtered_smiles), 100, replace=False)
        sampled_smiles = [filtered_smiles[i] for i in sampled_indices]
        print(f"Sampling 100 molecules for analysis")
    else:
        sampled_smiles = filtered_smiles
    
    # Extract embeddings for sampled molecules
    manual_indices = []
    gan_indices = []
    final_sampled_smiles = []
    
    for smiles in sampled_smiles:
        try:
            manual_idx = manual_smiles.index(smiles)
            gan_idx = gan_smiles.index(smiles)
            manual_indices.append(manual_idx)
            gan_indices.append(gan_idx)
            final_sampled_smiles.append(smiles)
        except (ValueError, IndexError):
            # Skip if index not found
            continue
    
    if len(final_sampled_smiles) == 0:
        print("No valid molecules after indexing!")
        return []
        
    print(f"Analyzing {len(final_sampled_smiles)} molecules")
    
    sampled_manual_embeddings = manual_embeddings[manual_indices]
    sampled_gan_embeddings = gan_embeddings[gan_indices]
    
    # Calculate distances within each embedding space
    manual_distances = calculate_embedding_distances(sampled_manual_embeddings)
    gan_distances = calculate_embedding_distances(sampled_gan_embeddings)
    
    # For each molecule, find nearest neighbors in both spaces
    neighbor_divergence = []
    
    for i, smiles in enumerate(tqdm(final_sampled_smiles, desc="Analyzing neighbor divergence")):
        try:
            # Get top 10 neighbors in manual space (or fewer if not enough molecules)
            num_neighbors = min(10, len(final_sampled_smiles) - 1)
            if num_neighbors <= 0:
                continue
                
            # Get top neighbors in manual space
            manual_nearest = np.argsort(manual_distances[i])[:num_neighbors+1]  # Include self
            manual_nearest = [idx for idx in manual_nearest if idx != i][:num_neighbors]  # Remove self
            
            # Get top neighbors in gan space
            gan_nearest = np.argsort(gan_distances[i])[:num_neighbors+1]
            gan_nearest = [idx for idx in gan_nearest if idx != i][:num_neighbors]
            
            # Calculate overlap
            overlap = len(set(manual_nearest).intersection(set(gan_nearest)))
            overlap_ratio = overlap / num_neighbors if num_neighbors > 0 else 0
            
            # Measure divergence (1 - overlap ratio)
            divergence = 1 - overlap_ratio
            
            neighbor_divergence.append({
                'index': i,
                'smiles': smiles,
                'divergence': divergence,
                'manual_neighbors': [final_sampled_smiles[idx] for idx in manual_nearest],
                'gan_neighbors': [final_sampled_smiles[idx] for idx in gan_nearest]
            })
        except Exception as e:
            print(f"Error analyzing molecule {i} ({smiles}): {e}")
            continue
    
    if len(neighbor_divergence) == 0:
        print("No valid divergence analysis results!")
        return []
        
    # Sort by divergence
    neighbor_divergence.sort(key=lambda x: x['divergence'], reverse=True)
    
    # Save top most divergent molecules
    with open(f"{output_dir}/most_divergent_molecules.txt", 'w') as f:
        f.write("SMILES,Divergence\n")
        for mol in neighbor_divergence[:min(10, len(neighbor_divergence))]:
            f.write(f"{mol['smiles']},{mol['divergence']:.3f}\n")
    
    # Analyze top most divergent molecules
    max_to_analyze = min(3, len(neighbor_divergence))
    if max_to_analyze > 0:
        for i, mol_data in enumerate(neighbor_divergence[:max_to_analyze]):
            try:
                analyze_specific_molecule(
                    mol_data['smiles'],
                    manual_embeddings,
                    manual_smiles,
                    gan_embeddings,
                    gan_smiles,
                    output_dir=f"{output_dir}/divergent_molecule_{i+1}"
                )
            except Exception as e:
                print(f"Error analyzing divergent molecule {i+1}: {e}")
                continue
    
    print(f"Interesting molecule analysis saved to {output_dir}")
    
    return neighbor_divergence

def main():
    # Create output directory
    os.makedirs('visualization_results', exist_ok=True)
    
    # Load embeddings with SMILES
    print("Loading embeddings...")
    manual_embeddings, manual_smiles = load_embeddings(MANUAL_EMBEDDINGS_PATH)
    if len(manual_smiles) == 0:
        print("Error: No valid SMILES found in manual embeddings!")
        return
    
    gan_embeddings, gan_smiles = load_embeddings(GAN_EMBEDDINGS_PATH)
    if len(gan_smiles) == 0:
        print("Error: No valid SMILES found in GAN embeddings!")
        return
    
    print(f"Loaded {len(manual_smiles)} manual embeddings and {len(gan_smiles)} GAN embeddings")
    
    # Generate combined visualizations
    print("Generating combined visualizations...")
    result_data = combined_visualization(manual_embeddings, manual_smiles, gan_embeddings, gan_smiles)
    
    # Find interesting molecules (where manual and GAN embeddings differ significantly)
    print("Finding molecules with interesting differences between embeddings...")
    divergent_molecules = find_interesting_molecules(manual_embeddings, manual_smiles, gan_embeddings, gan_smiles)
    
    # Print summary
    print("\nAnalysis complete!")
    print(f"Visualizations saved to 'visualization_results/'")
    print(f"Interesting molecules saved to 'interesting_molecules/'")
    
    # Suggest next steps
    print("\nSuggested next steps:")
    print("1. Review the similarity vs. distance plots to compare how well each embedding preserves chemical similarity")
    print("2. Examine the nearest neighbors of molecules to evaluate embedding quality")
    print("3. Investigate the most divergent molecules to understand where the embedding methods differ")

if __name__ == "__main__":
    # When running in a Jupyter notebook, sys.argv might contain additional arguments
    # that would cause the argparse to fail, so we'll check if we're in IPython
    try:
        import IPython
        is_notebook = True
    except ImportError:
        is_notebook = False
    
    if is_notebook:
        # If we're in a notebook, just run the main function without parsing arguments
        main()
    else:
        # Otherwise, parse arguments normally
        import argparse
        
        parser = argparse.ArgumentParser(description="Visualize and compare molecular embeddings.")
        parser.add_argument("--manual", type=str, default=MANUAL_EMBEDDINGS_PATH, 
                          help="Path to manual augmentation embeddings file")
        parser.add_argument("--gan", type=str, default=GAN_EMBEDDINGS_PATH, 
                          help="Path to GAN augmentation embeddings file")
        parser.add_argument("--analyze", type=str, 
                          help="SMILES string of specific molecule to analyze")
        parser.add_argument("--output", type=str, default="visualization_results", 
                          help="Output directory for visualizations")
        
        args = parser.parse_args()
        
        # Update paths if provided
        MANUAL_EMBEDDINGS_PATH = args.manual
        GAN_EMBEDDINGS_PATH = args.gan
        
        if args.analyze:
            # Load embeddings
            manual_embeddings, manual_smiles = load_embeddings(MANUAL_EMBEDDINGS_PATH)
            gan_embeddings, gan_smiles = load_embeddings(GAN_EMBEDDINGS_PATH)
            
            # Analyze specific molecule
            print(f"Analyzing specific molecule: {args.analyze}")
            analyze_specific_molecule(args.analyze, manual_embeddings, manual_smiles, 
                                    gan_embeddings, gan_smiles, 
                                    output_dir=f"{args.output}/specific_molecule")
        else:
            # Run full analysis
            main()

Loading embeddings...
Loading embeddings from: ./embeddings/final_embeddings_molecules_20250310_143038.pkl
Loaded 41 embeddings and 41 SMILES strings
Loading embeddings from: ./embeddings/final_embeddings_molecules_20250310_142842.pkl
Loaded 9 embeddings and 41 SMILES strings
Loaded 41 manual embeddings and 41 GAN embeddings
Generating combined visualizations...
Found 41 common SMILES strings
Generating fingerprints...


Generating ecfp fingerprints: 100%|██████████████████████████████████████████████████| 41/41 [00:00<00:00, 2001.61it/s]
Generating rdkfp fingerprints: 100%|█████████████████████████████████████████████████| 41/41 [00:00<00:00, 6832.20it/s]

Error processing molecule 0 (CC1CCC(O)(CCc2ccsc2)C1): 'bitsPerHash'
Error processing molecule 1 (N#Cc1ccccc1NC(=O)CCSc1nnc(C[NH+]2CCCC2)n1Cc1ccccc1): 'bitsPerHash'
Error processing molecule 2 (C=C(C)CN(CC)c1ccc(C(C)[NH3+])cc1Cl): 'bitsPerHash'
Error processing molecule 3 (COc1ccc2nc(NC(=O)Nc3cnccn3)sc2c1): 'bitsPerHash'
Error processing molecule 4 (COCC(C)C[NH2+]Cc1cccs1): 'bitsPerHash'
Error processing molecule 5 (Cn1ncc(c2ccc(CN3CC[NH2+]CC3)cc2)c1N): 'bitsPerHash'
Error processing molecule 6 (COC(=O)c1ccc(Oc2cnn(c3cc(Cl)cc(Cl)c3)c2c2ccc3cc(OC)ccc3c2)cc1): 'bitsPerHash'
Error processing molecule 7 (COc1ccc(C)cc1Nc1ncnc(N2CCN(c3ccccn3)CC2)c1[N+](=O)[O-]): 'bitsPerHash'
Error processing molecule 8 (O=C(NCc1ccc(F)cc1)N1CC=C(c2c[nH]c3ccccc23)CC1): 'bitsPerHash'
Error processing molecule 9 (CC1(CNC(=O)N2CCC(C)(C)S(=O)(=O)CC2)CCCc2ccccc21): 'bitsPerHash'
Error processing molecule 10 (CCCCCc1ccc(C(=O)Nc2ccccc2N)cc1): 'bitsPerHash'
Error processing molecule 11 (CC(C)OCCCNC(=O)C(NC(=O)c1c(F)cc




IndexError: index 13 is out of bounds for dimension 0 with size 9