In [6]:
inp_csv = '5HT1A-predict_intra.csv'
name_in_sdf = 'hit_id'
threshold = 0.5
sdf_name = '5HT1A_filter.sdf'
sdf_out_name = '5HT1A_filter_final.sdf'
figure_out_name = '5HT1A_filter_final.png'

import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
import seaborn as sns
import os

# Read input CSV file
def process_molecules(inp_csv, sdf_name, name_in_sdf, threshold, sdf_out_name, figure_out_name):
    # Read the CSV file with predictions
    df = pd.read_csv(inp_csv)
    
    # Display basic info about input data
    print(f"Total molecules in CSV: {len(df)}")
    print(f"Threshold for filtering: {threshold}")
    
    # Filter molecules based on threshold
    filtered_df = df[df['Score'] > threshold]
    print(f"Molecules with Score > {threshold}: {len(filtered_df)}")
    
    # Get the list of molecule IDs to extract
    ligand_ids = filtered_df['Ligand ID'].tolist()
    print(f"Number of unique ligand IDs to extract: {len(set(ligand_ids))}")
    
    # Read the SDF file
    suppl = Chem.SDMolSupplier(sdf_name)
    
    # Filter molecules from SDF file
    filtered_mols = []
    for mol in suppl:
        if mol is not None:
            mol_id = mol.GetProp(name_in_sdf) if mol.HasProp(name_in_sdf) else None
            if mol_id in ligand_ids:
                filtered_mols.append(mol)
    
    print(f"Extracted {len(filtered_mols)} molecules from SDF file")
    
    # Write filtered molecules to a new SDF file
    output_sdf = sdf_out_name
    writer = Chem.SDWriter(output_sdf)
    for mol in filtered_mols:
        writer.write(mol)
    writer.close()
    print(f"Filtered molecules written to {output_sdf}")
    
    # Create distribution plot of the Score column
    plt.figure(figsize=(10, 6))
    
    # Plot the full distribution
    sns.histplot(df['Score'], kde=True, color='blue', alpha=0.5, label='All molecules')
    
    # Highlight the filtered portion
    sns.histplot(filtered_df['Score'], kde=True, color='red', alpha=0.5, label='Filtered molecules')
    
    # Add a vertical line at the threshold
    plt.axvline(x=threshold, color='black', linestyle='--', label=f'Threshold = {threshold}')
    
    plt.title('Distribution of Score Values')
    plt.xlabel('Score')
    plt.ylabel('Frequency')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Save the plot
    plt.savefig(figure_out_name, dpi=300)
    plt.close()
    print(f"Score distribution plot saved as '{figure_out_name}'")
    
    return filtered_df, filtered_mols

# Run the processing function with the provided parameters
if __name__ == "__main__":
    filtered_df, filtered_mols = process_molecules(
        inp_csv=inp_csv,
        sdf_name=sdf_name,
        name_in_sdf=name_in_sdf,
        threshold=threshold,
        sdf_out_name=sdf_out_name,
        figure_out_name = figure_out_name
    )
    
    # Display summary statistics of the scores
    print("\nSummary statistics of the Score column:")
    print(filtered_df['Score'].describe())

Total molecules in CSV: 8705
Threshold for filtering: 0.5
Molecules with Score > 0.5: 566
Number of unique ligand IDs to extract: 534
Extracted 484 molecules from SDF file
Filtered molecules written to 5HT1A_filter_final.sdf
Score distribution plot saved as '5HT1A_filter_final.png'

Summary statistics of the Score column:
count    566.000000
mean       0.841664
std        0.151785
min        0.502410
25%        0.726271
50%        0.891912
75%        0.981411
max        0.999977
Name: Score, dtype: float64
