In [None]:
import os
import natsort
import pandas as pd
import numpy as np
import mols2grid
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import PandasTools
from rdkit.Chem import SDMolSupplier
from rdkit.Chem import AllChem
from rdkit.DataStructs import TanimotoSimilarity
from matplotlib import pyplot as plt

In [None]:
# Function for storing info on molecules in each cluster
# Stores info on molecules that are in each cluster and
# which molecule is closest to the cluster center

def get_cluster_dicts(df):
    
    cluster_dict = {}
    cluster_centers = {}
    for idx,row in df.iterrows():
        cluster = row['Cluster']
        name = idx
        if cluster not in cluster_dict.keys():
            cluster_dict[cluster] = []
        if name not in cluster_dict[cluster]:
            cluster_dict[cluster].append(name)

        if row['Center'] == 'Yes': cluster_centers[row['Cluster']] = idx

    return cluster_dict, cluster_centers

#Function for getting the top-scoring hit in each cluster

def get_clusters_top_hits(cluster_dict):
    
    clusters_tophits = []
    for cluster in cluster_dict.keys():
        mol = cluster_dict[cluster][0]
        if mol not in clusters_tophits: clusters_tophits.append(mol)

    return clusters_tophits

# Function for getting up to the top 5-scoring hits in each cluster

def get_clusters_top5_hits(cluster_dict):
    
    clusters_top5hits = []
    for cluster in cluster_dict.keys():
        len_cluster = len(cluster_dict[cluster])
        if len_cluster <= 4:
            for i in range(len_cluster):
                mol = cluster_dict[cluster][i]
                if mol not in clusters_top5hist: clusters_top5hits.append(cluster_dict[cluster][i])
        else:
            for i in range(5):
                mol = cluster_dict[cluster][i]
                if mol not in clusters_top5hits: clusters_top5hits.append(cluster_dict[cluster][i])
    return clusters_top5hits

# Get average intra-cluster similarity
# Calculates one average for each clustering level using all Tanimoto comparisons

def get_avg_similarity(df):

    similarities = []
    df = df[~df.index.duplicated(keep='first')]
    fpgen = AllChem.GetMorganGenerator(radius=2)
    cluster_dict, cluster_centers = get_cluster_dicts(df)
    for cluster in cluster_centers.keys():
        ref_fp = fpgen.GetFingerprint(MolFromSmiles(df['SMILES'][cluster_centers[cluster]]))
        for zinc in cluster_dict[cluster]:
            fp = fpgen.GetFingerprint(MolFromSmiles(df['SMILES'][zinc]))
            tan_sim = TanimotoSimilarity(ref_fp, fp)
            distance = 1 - tan_sim
            similarities.append(1 / (1 + distance))
    
    avg_similarity = np.mean(similarities)

    print(f"Average Similarity to Cluster Center = {avg_similarity}")
    print()

    return avg_similarity

In [None]:
# Initialize storage of cluster counts and average cluster similarities

numbers_of_clusters = []
avg_similarities = []

print("Beginning Calculation of Cluster Similarities over Numbers of Clusters...")
print("----------------------------------------------------------------------")
print()
for fn in natsort.natsorted(os.listdir()):
    
    ## Get info on number of clusters for each .csv
    if fn.endswith(".csv"):
        fn_info = fn.split("_")
        clustered = fn_info[1]
        num = clustered.rstrip("clustered")
        numbers_of_clusters.append(num)

        print("-----------------------------")
        print(f"# of Clusters: {num}")
        print("-----------------------------")

        df = pd.read_csv(fn, index_col="Name")
        
        ### Get average Tanimoto similarity for the clustering threshold
        ### Append it to list of averages - used for plotting later
        avg_similarity = get_avg_similarity(df)
        avg_similarities.append(avg_similarity)

print("Completion of Average Similarity Calculations!!!")

In [None]:
#Creates plot of average intra-cluster similarity over cluster count
plt.plot(numbers_of_clusters, avg_similarities, linewidth=3)
plt.title("Intra-Cluster Similarity with Cluster Count")
plt.xlabel("Cluster Count")
plt.ylabel("Average Similarity to Cluster Center")

In [None]:
# Simple printout of the average cluster similarity with cluster count
counts = pd.DataFrame({"Cluster Count": numbers_of_clusters, "Avg. Similarity": avg_similarities})
counts = counts.set_index("Cluster Count")
with pd.option_context('display.max_rows', None):
    print(counts)

In [None]:
# Read in CSV with optimal number of clusters and get the top hit in each cluster
opt_df = pd.read_csv("{CSV_with_chosen_optimal_clusters}.csv", index_col="Name")
cluster_dict, cluster_centers = get_cluster_dicts(opt_df)
clusters_tophits = get_clusters_top_hits(cluster_dict)
len(clusters_tophits)

In [None]:
# Read in SDF used in clustering, select top hits, and output those to a new SDF
in_sdf = PandasTools.LoadSDF("{original_SDF_used_to_perform_clustering}.sdf", removeHs=False)
in_sdf = in_sdf.set_index('ID')
in_sdf = in_sdf[~df.index.duplicated(keep='first')]
slice_df = in_sdf.loc[in_sdf.index.isin(clusters_tophits)]
PandasTools.WriteSDF(slice_df, "tophits.sdf", properties=slice_df.columns)
print("Best of Clustered Molecules written to SDF!!!")

In [None]:
# Display 2D structures of top hits
mols2grid.display("tophits.sdf")

In [None]:
# This cell and beyond are used for extracting more molecules from the most interesting-looking clusters
# Happens AFTER visual inspection of tophits.sdf

# Input list of SDF molecule numbers using same SDF used for clustering
# Must run cell #6 to run get_cluster_dicts and cell #7 to load "in_sdf"
# Obtains the clusters containing the compounds with these IDs
idxs = [21,28,31,49,117,194]
clusters = []
slice_df = in_sdf.iloc[idxs]
ids = slice_df.ID.to_list()
for ID in ids:
    for cluster in cluster_dict.keys():
        if ID in cluster_dict[cluster]:
            clusters.append(cluster)

In [None]:
# Get top n hits from each of the chosen clusters
# Adjust max number of hits returned at "if count == n: break"; Default is 15
IDs = []
for cluster in clusters:
    count = 0
    for idx,row in opt_df.iterrows():
        if row['Cluster'] == cluster:
            if idx not in IDs:
                IDs.append(idx)
            count += 1
        if count == 15: break

# Read in SDF used in clustering, select top hits, and output those to a new SDF
in_sdf = PandasTools.LoadSDF("{original_SDF_used_to_perform_clustering}.sdf", removeHs=False)
in_sdf = in_sdf.drop_duplicates(subset='ID')
output_sdf = in_sdf.loc[in_sdf['ID'].isin(IDs)]
PandasTools.WriteSDF(output_sdf, "top15_hits.sdf", properties=output_sdf.columns)

In [None]:
# Display 2D structures of top n hits
mols2grid.display("top15_hits.sdf")