In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.optim as optim
from sklearn.cluster import Birch
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import pairwise_distances
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import esm
from descriptastorus.descriptors import rdNormalizedDescriptors
import warnings
warnings.filterwarnings('ignore')


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_complete = pd.read_csv("datasets/chembl_uniprot_joined.tsv", sep="\t").dropna(subset=['Smiles', 'Sequence'])
print(f"Number of complete entries: {len(df_complete)}")

Number of complete entries: 2724335


In [3]:
import torch
from esm.models.esmc import ESMC
from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    LogitsConfig,
    LogitsOutput,
    ProteinType,
)
EMBEDDING_CONFIG = LogitsConfig(
    sequence=True, return_embeddings=True, return_hidden_states=False
)
# Cell 3: ESM-C protein embeddings + RDKit normalized descriptors
print("Setting up ESM-C model...")
client = ESMC.from_pretrained("esmc_600m").to("cuda")
def embed_sequence(model: ESM3InferenceClient, sequence: str) -> torch.Tensor:
    protein = ESMProtein(sequence=sequence)
    protein_tensor = model.encode(protein)
    output = model.logits(protein_tensor, EMBEDDING_CONFIG)
    # output.embeddings: shape [1, seq_len, 1152]
    mean_embedding = output.embeddings.mean(dim=1).squeeze(0).detach().cpu().numpy()  # shape [1152]
    return mean_embedding

# Get unique proteins
proteins_unique = df_complete['Sequence'].unique()
print(f"Unique proteins to embed: {len(proteins_unique):,}")

protein_embed_dict = {}

for sequence in tqdm(proteins_unique, total=len(proteins_unique), desc="Embedding proteins"):
    if sequence not in protein_embed_dict:
        try:
            embedding = embed_sequence(client, sequence)
            protein_embed_dict[sequence] = embedding
        except Exception as e:
            print(f"Error embedding {sequence}: {e}")


print("Generating protein embeddings...")



Setting up ESM-C model...


Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 3862.16it/s]


Unique proteins to embed: 6,336


Embedding proteins: 100%|██████████| 6336/6336 [02:01<00:00, 52.10it/s]

Generating protein embeddings...





In [4]:
# Convert embeddings dict to a pickle file
import pickle
with open("datasets/protein_chembl_embeddings.pkl", "wb") as f:
    pickle.dump(protein_embed_dict, f)

In [4]:
import pickle
# convert pickle to python dict
with open("datasets/protein_chembl_embeddings.pkl", "rb") as f:
    protein_embed_dict = pickle.load(f)

len(protein_embed_dict)  # Check the number of embeddings

6336

In [None]:
# Cell 4: RDKit molecular descriptors (Parallelized)
from multiprocessing import Pool, cpu_count
from functools import partial
import pickle

def process_smiles_batch(smiles_batch):
    """Process a batch of SMILES strings and return results"""
    gen = rdNormalizedDescriptors.RDKit2DNormalized()
    batch_results = {}
    
    for smiles in smiles_batch:
        try:
            desc = gen.process(smiles=smiles)
            if desc is not None and len(desc) > 1:
                batch_results[smiles] = np.array(desc[1:], dtype=np.float32)
        except Exception as e:
            print(f"Error processing {smiles}: {e}")
    
    return batch_results

print("Computing molecular descriptors with multiprocessing...")

# Get unique molecules
unique_molecules = df_complete['Smiles'].unique()
print(f"Unique molecules: {len(unique_molecules):,}")

# Use 20 cores as requested
n_cores = 25
print(f"Using {n_cores} cores for parallel processing")

# Split molecules into batches for parallel processing
batch_size = len(unique_molecules) // n_cores + 1
molecule_batches = [unique_molecules[i:i + batch_size] for i in range(0, len(unique_molecules), batch_size)]

print(f"Split into {len(molecule_batches)} batches of ~{batch_size} molecules each")

# Process batches in parallel
molecule_features = {}
with Pool(n_cores) as pool:
    # Use tqdm to show progress
    results = list(tqdm(
        pool.imap(process_smiles_batch, molecule_batches),
        total=len(molecule_batches),
        desc="Processing batches"
    ))
    
    # Combine results from all batches
    for batch_result in results:
        molecule_features.update(batch_result)

# Save the molecular features to a pickle file
with open("datasets/rdkit_chembl_descriptors.pkl", "wb") as f:
    pickle.dump(molecule_features, f)

print(f"Generated features for {len(molecule_features)} molecules")


Computing molecular descriptors...
Unique molecules: 1,017,257


RDKit descriptors:   3%|▎         | 34582/1017257 [07:48<3:30:35, 77.77it/s] function application failed (fr_NH0->COc1ccc(C2C(C(N)=O)=C(C)Nc3nc(-c4cccc(Cl)c4)nn32)cc1OC)
Traceback (most recent call last):
  File "/home/nroethler/miniconda3/envs/gen_ca/lib/python3.10/site-packages/descriptastorus/descriptors/rdDescriptors.py", line 432, in applyFunc
    return functions[name](m)
  File "/home/nroethler/miniconda3/envs/gen_ca/lib/python3.10/site-packages/rdkit/Chem/Fragments.py", line 46, in <lambda>
    fn = lambda mol, countUnique=True, pattern=patt: _CountMatches(
  File "/home/nroethler/miniconda3/envs/gen_ca/lib/python3.10/site-packages/rdkit/Chem/Fragments.py", line 24, in _CountMatches
    return len(mol.GetSubstructMatches(patt, uniquify=unique))
  File "/home/nroethler/miniconda3/envs/gen_ca/lib/python3.10/site-packages/rdkit/Chem/Draw/IPythonConsole.py", line 304, in _GetSubstructMatches
    res = mol.__GetSubstructMatches(query, *args, **kwargs)
KeyboardInterrupt
Could not com

In [5]:
# Alternative: Using joblib for even better performance with RDKit
from joblib import Parallel, delayed

def process_single_smiles(smiles):
    """Process a single SMILES string"""
    gen = rdNormalizedDescriptors.RDKit2DNormalized()
    try:
        desc = gen.process(smiles=smiles)
        if desc is not None and len(desc) > 1:
            return smiles, np.array(desc[1:], dtype=np.float32)
    except Exception as e:
        print(f"Error processing {smiles}: {e}")
    return smiles, None

print("Computing molecular descriptors with joblib...")

# Get unique molecules
unique_molecules = df_complete['Smiles'].unique()
print(f"Unique molecules: {len(unique_molecules):,}")

# Process with joblib (often more efficient for scientific computing)
results = Parallel(n_jobs=25, verbose=1, batch_size=100)(
    delayed(process_single_smiles)(smiles) 
    for smiles in unique_molecules
)

# Convert results to dictionary, filtering out None values
molecule_features_joblib = {
    smiles: features for smiles, features in results if features is not None
}

# Save the molecular features to a pickle file
with open("datasets/rdkit_chembl_descriptors_joblib.pkl", "wb") as f:
    pickle.dump(molecule_features_joblib, f)

print(f"Generated features for {len(molecule_features_joblib)} molecules using joblib")

Computing molecular descriptors with joblib...
Unique molecules: 1,017,257
Unique molecules: 1,017,257


[Parallel(n_jobs=25)]: Using backend LokyBackend with 25 concurrent workers.
[Parallel(n_jobs=25)]: Done  50 tasks      | elapsed:    2.3s
[Parallel(n_jobs=25)]: Done  50 tasks      | elapsed:    2.3s
[Parallel(n_jobs=25)]: Done 15050 tasks      | elapsed:   17.1s
[Parallel(n_jobs=25)]: Done 15050 tasks      | elapsed:   17.1s
[Parallel(n_jobs=25)]: Done 40050 tasks      | elapsed:   39.5s
[Parallel(n_jobs=25)]: Done 40050 tasks      | elapsed:   39.5s
[Parallel(n_jobs=25)]: Done 75050 tasks      | elapsed:  1.2min
[Parallel(n_jobs=25)]: Done 75050 tasks      | elapsed:  1.2min
  return asanyarray(a).trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out)
  Bn = An - res[n] * I
  return asanyarray(a).trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out)
  Bn = An - res[n] * I
[Parallel(n_jobs=25)]: Done 120050 tasks      | elapsed:  1.9min
[Parallel(n_jobs=25)]: Done 120050 tasks      | elapsed:  1.9min
[Parallel(n_jobs=25)]: Done 175050 tasks      | elapsed

Generated features for 1017257 molecules using joblib


In [18]:
df_complete['concatenated_features'] = df_complete.apply(
    lambda row: np.concatenate((protein_embed_dict[row['Sequence']], molecule_features_joblib[row['Smiles']])), axis=1
)
import numpy as np

def is_finite_vector(vec):
    return vec is not None and np.all(np.isfinite(vec))

# Filter df_complete to only rows with valid protein and molecule features
valid_mask = df_complete.apply(
    lambda row: (
        row['Sequence'] in protein_embed_dict and
        row['Smiles'] in molecule_features_joblib and
        is_finite_vector(protein_embed_dict[row['Sequence']]) and
        is_finite_vector(molecule_features_joblib[row['Smiles']])
    ),
    axis=1
)

print(f"Keeping {valid_mask.sum()} of {len(df_complete)} rows with valid, finite features.")
features = np.stack(df_complete[valid_mask]['concatenated_features'].values).astype('float32')
features = np.ascontiguousarray(features)
features.tofile("./datasets/chembl_features.npy")

Keeping 2721823 of 2724335 rows with valid, finite features.


In [20]:
features.shape

(2721823, 1352)

In [29]:
features = np.memmap("./datasets/chembl_features.npy", dtype='float32', mode='r', shape=(2721823, 1352))

In [22]:
features

memmap([[ 1.5787469e-03,  2.5417006e-03, -7.9095270e-03, ...,
          4.7035982e-08,  1.6663340e-01,  7.5553104e-02],
        [ 3.0751788e-05, -6.3050087e-03, -1.6719768e-02, ...,
          4.7035982e-08,  1.6663340e-01,  8.3625734e-01],
        [-3.3837259e-03,  2.1120692e-04, -1.0468062e-02, ...,
          4.7035982e-08,  1.6663340e-01,  7.0839033e-02],
        ...,
        [ 7.4451244e-03,  5.9425263e-03, -8.3871325e-03, ...,
          4.7035982e-08,  1.6663340e-01,  3.3313647e-01],
        [-3.6006721e-03, -1.9245783e-02, -1.9882960e-02, ...,
          4.7035982e-08,  1.6663340e-01,  3.3695695e-01],
        [ 2.6436313e-03,  3.0820039e-03, -1.6944237e-02, ...,
          4.7035982e-08,  1.6663340e-01,  4.6162996e-01]], dtype=float32)

In [12]:
import numpy as np

def is_finite_vector(vec):
    return vec is not None and np.all(np.isfinite(vec))

# Filter df_complete to only rows with valid protein and molecule features
valid_mask = df_complete.apply(
    lambda row: (
        row['Sequence'] in protein_embed_dict and
        row['Smiles'] in molecule_features_joblib and
        is_finite_vector(protein_embed_dict[row['Sequence']]) and
        is_finite_vector(molecule_features_joblib[row['Smiles']])
    ),
    axis=1
)

print(f"Keeping {valid_mask.sum()} of {len(df_complete)} rows with valid, finite features.")

Keeping 2721823 of 2724335 rows with valid, finite features.


In [14]:
df_complete[~valid_mask]

Unnamed: 0,Molecule ChEMBL ID,Molecule Name,Molecule Max Phase,Molecular Weight,#RO5 Violations,AlogP,Compound Key,Smiles,Standard Type,Standard Relation,...,Document Journal,Document Year,Cell ChEMBL ID,Properties,Action Type,Standard Text Value,Value,Entry,Sequence,concatenated_features
1616,CHEMBL51085,EBSELEN,3.0,274.18,,,SID856002,O=c1c2ccccc2[se]n1-c1ccccc1,Potency,'=',...,,,,,,,12.5893,P19838,MAEDDPYLGRPEQMFHLDPSLTHTIFNPEVFQPQMALPTDGPYLQI...,"[0.0023406143, -0.0050573843, -0.011684245, -0..."
2185,CHEMBL1592721,,,428.09,,,SID29215799,O=[N+]([O-])c1ccc(-c2[se]c3ccccc3c2I)cc1,Potency,'=',...,,,,,,,8.9125,P10636,MAEPRQEFEVMEDHAGTYGLGDRKDQGGYTMHQDQEGDTDAGLKES...,"[0.0061808834, 0.006355596, -0.009214631, -0.0..."
7534,CHEMBL1091971,,,416.09,,,SID50108395,O=[As](O)(O)c1ccc(Cc2ccc([As](=O)(O)O)cc2)cc1,Potency,'=',...,,,,,,,2.2387,P00352,MSSSGTPDLPVLLTDLKIQYTKIFINNEWHDSVSGKKFPVFNPATE...,"[-0.0033837259, 0.00021120692, -0.010468062, -..."
9883,CHEMBL4868581,,,404.38,,,5c,NS(=O)(=O)c1ccc([Se]CC(O)CSc2ncccn2)cc1,Ki,'=',...,Eur J Med Chem,2021.0,,TIME = 0.25 hr,INHIBITOR,,71.1000,P22748,MRMLLALLALSAARPSASAESHWCYEVQAESSNYPCLVPVKWGGNC...,"[-0.0009853028, 0.022307659, -0.0116948085, -0..."
9905,CHEMBL4868581,,,404.38,,,5c,NS(=O)(=O)c1ccc([Se]CC(O)CSc2ncccn2)cc1,Ki,'=',...,Eur J Med Chem,2021.0,,TIME = 0.25 hr,INHIBITOR,,31.0000,Q16790,MAPLCPSPWLPLLIPAPAPGLTVQLLLSLLLLVPVHPQRLPRMQED...,"[-0.006559426, 0.011227086, -0.021629887, -0.0..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2724563,CHEMBL4482933,,,424.98,,,5d,NS(=O)(=O)c1ccc(C[Te]c2ccc3ccccc3c2)cc1,Ki,'=',...,J Med Chem,2020.0,,,,,182.3000,P00918,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,"[-0.01850301, 0.016053043, -0.0134952795, 0.00..."
2725590,CHEMBL1321154,ROXARSONE,,263.04,,,Roxarsone,O=[N+]([O-])c1cc([As](=O)(O)O)ccc1O,IC50,'=',...,ACS Med Chem Lett,2022.0,,TIME = 0.5 hr,INHIBITOR,,80.0000,Q9NRX4,MAVADLALIPDVDIDSDGVFKYVLIRVHSAPRSGAPAAESKEIVRG...,"[-0.0047203316, 0.0013114266, -0.00097053277, ..."
2726774,CHEMBL51085,EBSELEN,3.0,274.18,,,Ebselen,O=c1c2ccccc2[se]n1-c1ccccc1,IC50,'=',...,J Med Chem,2023.0,,TIME = 0.4167 hr,INHIBITOR,,0.3000,P04839,MGNWAVNEGLSIFVILVWLGLNVFLFVWYYRVYDIPPKFFYTRKLL...,"[-0.0012299576, 0.008150385, -0.007632038, -0...."
2727081,CHEMBL5193976,,,326.62,,,51,O=c1c2ccccc2[se]n1-c1cc(Cl)ccc1F,IC50,'=',...,J Med Chem,2022.0,,,INHIBITOR,,15.2400,P0DTD1,MESLVPGFNEKTHVQLSLPVLQVRDVLVRGFGDSVEEVLSEARQHL...,"[-0.002904211, 0.020237833, 0.00031897248, 0.0..."


In [26]:
df_valid = df_complete[valid_mask].copy()

In [30]:
# Cell: Clustering with faiss-gpu
import faiss
import numpy as np
import gc

from sympy import true
batch_size= 544_867
n_points, d = features.shape
print(f"Number of points: {n_points}, Dimension: {d}")
# n_points = 12201748
print("Clustering with faiss-gpu...")
# Set number of clusters (e.g., same logic as before)
n_clusters = 100  # Between 10-300 clusters
print(f"Using {n_clusters} clusters")

# Initialize faiss KMeans (GPU)
ngpu= faiss.get_num_gpus()          
res = faiss.StandardGpuResources() 
kmeans = faiss.Kmeans(d=features.shape[1], k=n_clusters, gpu=True, niter=255, verbose=True, seed=5, max_points_per_centroid=300000, spherical=True)

# Train KMeans
kmeans.train(features)

cpu_centroids = faiss.IndexFlatL2(d)
cpu_centroids.add(kmeans.centroids)

gpu_centroids = faiss.index_cpu_to_all_gpus(cpu_centroids)  # two-device copy


# Assign clusters
labels = np.empty(len(features), dtype=np.int32)
for start in range(0, len(features), batch_size):
    end = min(start + batch_size, len(features))
    batch = np.ascontiguousarray(features[start:end])
    _, I = gpu_centroids.search(batch, 1)
    labels[start:end] = I.ravel()

df_valid['faiss_cluster'] = labels
print(f"Cluster assignment complete. Cluster counts:")
unique, counts = np.unique(labels, return_counts=True)
for u, c in zip(unique, counts):
    print(f"  Cluster {u}: {c} samples")

# Clean up GPU memory
del features, kmeans, res
gc.collect()

Number of points: 2721823, Dimension: 1352
Clustering with faiss-gpu...
Using 100 clusters
Clustering 2721823 points in 1352D to 100 clusters, redo 1 times, 255 iterations
  Preprocessing in 1.35 s
Clustering 2721823 points in 1352D to 100 clusters, redo 1 times, 255 iterations
  Preprocessing in 1.35 s
  Iteration 254 (260.49 s, search 141.10 s): objective=0 imbalance=100.000 nsplit=99       
  Iteration 254 (260.49 s, search 141.10 s): objective=0 imbalance=100.000 nsplit=99       
Cluster assignment complete. Cluster counts:
  Cluster 0: 570 samples
  Cluster 1: 344 samples
  Cluster 2: 116 samples
  Cluster 3: 571 samples
  Cluster 4: 48 samples
  Cluster 5: 24 samples
  Cluster 6: 250 samples
  Cluster 7: 368 samples
  Cluster 8: 319 samples
  Cluster 9: 46 samples
  Cluster 10: 24 samples
  Cluster 11: 321 samples
  Cluster 12: 273 samples
  Cluster 13: 250 samples
  Cluster 14: 250 samples
  Cluster 15: 321 samples
  Cluster 16: 343 samples
  Cluster 17: 23 samples
  Cluster 18:

8

In [27]:
df_valid['faiss_cluster'] = labels
print(f"Cluster assignment complete. Cluster counts:")
unique, counts = np.unique(labels, return_counts=True)
for u, c in zip(unique, counts):
    print(f"  Cluster {u}: {c} samples")

# Clean up GPU memory
del features, kmeans, res
gc.collect()

Cluster assignment complete. Cluster counts:
  Cluster 5: 24 samples
  Cluster 13: 2713655 samples
  Cluster 31: 23 samples
  Cluster 35: 23 samples
  Cluster 38: 23 samples
  Cluster 44: 23 samples
  Cluster 49: 24 samples
  Cluster 55: 23 samples
  Cluster 70: 23 samples
  Cluster 75: 24 samples
  Cluster 78: 24 samples
  Cluster 83: 23 samples
  Cluster 85: 23 samples
  Cluster 88: 23 samples
  Cluster 101: 250 samples
  Cluster 102: 23 samples
  Cluster 104: 48 samples
  Cluster 107: 23 samples
  Cluster 112: 23 samples
  Cluster 113: 546 samples
  Cluster 116: 23 samples
  Cluster 118: 250 samples
  Cluster 121: 23 samples
  Cluster 125: 23 samples
  Cluster 128: 23 samples
  Cluster 129: 23 samples
  Cluster 137: 250 samples
  Cluster 138: 23 samples
  Cluster 139: 23 samples
  Cluster 143: 47 samples
  Cluster 144: 250 samples
  Cluster 145: 273 samples
  Cluster 154: 250 samples
  Cluster 155: 273 samples
  Cluster 158: 24 samples
  Cluster 159: 24 samples
  Cluster 160: 46 sam

1664

In [33]:
from cuml.manifold.umap import UMAP


ImportError: /home/nroethler/miniconda3/envs/gen_ca/lib/python3.10/site-packages/cuml/cluster/../../../../libcusolver.so.11: undefined symbol: cublasSetEnvironmentMode, version libcublas.so.12