# Machine Learning Model for Binding Affinity using BindingDB


#### Install dependencies and import modules

In [1]:
# Install dependencies
!pip install -q torch fair-esm transformers aiondata pymilvus &> /dev/null

# Standard library imports
import os
from pathlib import Path
from typing import Iterable

# Third-party imports for numerical operations and machine learning
import joblib
import numpy as np
import torch
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType

# Third-party imports for deep learning and specific models
from transformers import AutoTokenizer, AutoModel

# Imports from specific libraries used in cheminformatics and bioinformatics
from rdkit import RDLogger, Chem
import esm
from aiondata import BindingAffinity

aiondata_path = Path(os.environ.get("AIONDATA_CACHE", "~/.aiondata")).expanduser()
ligand_dim = 768
protein_dim = 480

# Only test on n examples, set to None to test on all examples
test_only_n_examples = 2000

#### Load previously saved model and data

In [2]:
# Load the pre-trained model
try:
    model = joblib.load(aiondata_path / "models" / "binding_affinity.joblib")
except FileNotFoundError:
    print("Model not found. Please train the model first.")

# Load embeddings local file from PyMilvus
client = MilvusClient(str(aiondata_path / "embeddings.db"))

ligand_schema = CollectionSchema(fields=[
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=ligand_dim),
    FieldSchema(name="inchikey", dtype=DataType.VARCHAR, max_length=27),
])
protein_schema = CollectionSchema(fields=[
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=protein_dim),
    FieldSchema(name="sequence", dtype=DataType.VARCHAR, max_length=1028),
])

if not client.has_collection("ligands"):
    client.create_collection("ligands", dimension=ligand_dim, schema=ligand_schema)
if not client.has_collection("proteins"):
    client.create_collection("proteins", dimension=protein_dim, schema=protein_schema)


#### Load BindingAffinity and prepare for machine learning

In [3]:

# Load BindingDB into a Polars DataFrame
df = BindingAffinity().to_df()

# Filter out rows with missing SMILES, Sequence, or Binds values
df = df.drop_nulls(subset=["SMILES", "Sequence", "Binds"])

# For test purposes only use a subset of the data
if test_only_n_examples:
    df = df.sample(n=test_only_n_examples, shuffle=True, seed=18)

# Get the SMILES, Sequence, and Binds columns
ligands = df["SMILES"]
target_sequence = df["Sequence"]
affinity = df["Binds"]

# Suppress RDKit warnings and errors
RDLogger.DisableLog("rdApp.*")  

#### Create Protein Embeddings using ESM

In [4]:
# Load ESM-35m model
esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D()

# Prepare model and move it to evaluation mode
esm_model = esm_model.eval()
if torch.cuda.is_available():
    esm_model = esm_model.cuda()

id_value = 0

def create_protein_embedding(sequence: str):
    """Generate embeddings for a single protein sequence using a cache to store previous computations."""
    global id_value
    
    # Check if the embedding is in the database
    query = client.query(collection_name="proteins", filter=f"sequence == '{sequence}'")
    if query and query[0]["sequence"] == sequence:
        return np.array(query[0]["vector"])

    # Convert sequence to tokens
    tokens = torch.tensor([esm_alphabet.encode(sequence)])
    if torch.cuda.is_available():
        tokens = tokens.cuda()

    with torch.no_grad():
        results = esm_model(tokens, repr_layers=[12])  # Extract embeddings from the last layer

    # Extract embeddings and move to cpu
    embeddings_full = results["representations"][12].squeeze(0).cpu()

    # Reduce the embeddings to 1D by averaging across the sequence length
    embedding = embeddings_full.mean(dim=0).numpy()

    # Store in database
    client.insert(collection_name="proteins", data=[{"id": id_value, "sequence": sequence, "vector": embedding}])
    id_value += 1

    return embedding

def create_embedding_generator(sequences: list[str]):
    """Generate embeddings for a list of protein sequences."""
    for sequence in tqdm(sequences, desc="Generating protein embeddings", unit=" proteins"):
        yield create_protein_embedding(sequence)

# Generate embeddings for all protein sequences
X_proteins = list(create_embedding_generator(target_sequence))


Generating protein embeddings:   0%|          | 0/2000 [00:00<?, ? proteins/s]

#### Create Ligand Embeddings using ChemBERTa

In [5]:
# Load ChemBERTa model and tokenizer
chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
chemberta_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
chemberta_model.eval()

id_value = 0

def create_ligand_embedding(smiles: str):
    """Generate embeddings for a single SMILES string using a cache to store previous computations."""
    
    global id_value

    # Check if the embedding is in the database
    inchi_key = Chem.MolToInchiKey(Chem.MolFromSmiles(smiles))
    query = client.query(collection_name="ligands", filter=f"inchikey == '{inchi_key}'")
    if query and query[0]["inchikey"] == inchi_key:
        return np.array(query[0]["vector"])
    
    # Truncate SMILES string to 512 characters, required by ChemBERTa
    if len(smiles) > 512:
        smiles = smiles[:512]
    
    inputs = chemberta_tokenizer(smiles, return_tensors="pt")
    
    with torch.no_grad():
        outputs = chemberta_model(**inputs)
    
    # Take the mean of the last hidden state to get a single vector representation
    embedding = outputs.last_hidden_state.mean(dim=1).squeeze(0).numpy()
    
    # Store in database
    client.insert(collection_name="ligands", data=[{"id": id_value, "inchikey": inchi_key, "vector": embedding}])
    id_value += 1
    
    return embedding

def create_embedding_generator(smiles: list[str]):
    """Generate embeddings for a list of SMILES strings."""
    for smile in tqdm(smiles, desc="Generating ligand embeddings", unit=" ligand"):
        yield create_ligand_embedding(smile)

# Generate embeddings for all ligands
X_ligands = list(create_embedding_generator(ligands))


Generating ligand embeddings:   0%|          | 0/2000 [00:00<?, ? ligand/s]

#### Create the model and predict binding values

In [6]:
X = np.concatenate([X_ligands, X_proteins], axis=1)

# Make affinity into a numpy y array
y = affinity.to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=18)

def train_new_model():
    # Train a new Random Forest model
    model = RandomForestClassifier(n_estimators=100, random_state=18)
    model.fit(X_train, y_train)

if not model:
    print("Model not found. Training a new model.")
    train_new_model()

# Evaluate the model ROC-AUC score
y_pred = model.predict_proba(X_test)[:, 1]
roc_auc = roc_auc_score(y_test, y_pred)

print(f"ROC-AUC score: {roc_auc:.4f}")

ROC-AUC score: 0.8919


#### Rank a list of ligands based on probable binding 

In [12]:
def rank_binding_affinity(list_of_ligands: Iterable[str], target_sequence: str):
    """Rank the binding affinity of a list of ligands to a target protein sequence."""
    X_ligands = np.array(list(create_embedding_generator(list_of_ligands)))
    X_protein = np.array([create_protein_embedding(target_sequence)])
    # Duplicate the protein embedding for each ligand
    X_protein = np.tile(X_protein, (X_ligands.shape[0], 1))
    X = np.concatenate([X_ligands, X_protein], axis=1)
    y_pred = model.predict_proba(X)[:, 1]
    
    # sort by binding confidence
    ranked_indices = np.argsort(y_pred)[::-1]
    ranked_ligands = np.array(list_of_ligands)[ranked_indices]
    ranked_scores = y_pred[ranked_indices]

    return ranked_ligands, ranked_scores

# Test the ranking function
ligands = ["CCO", "CCN", "CCC", "CCCC"]
target_sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKR"

ranked_ligands, ranked_scores = rank_binding_affinity(ligands, target_sequence)

for ligand, score in zip(ranked_ligands, ranked_scores):
    print(f"Ligand: {ligand}, Score: {score:.4f}")

Generating ligand embeddings:   0%|          | 0/4 [00:00<?, ? ligand/s]

Ligand: CCCC, Score: 0.5100
Ligand: CCC, Score: 0.4900
Ligand: CCO, Score: 0.4800
Ligand: CCN, Score: 0.4400


#### Save the model and embeddings


In [13]:
# Save the model
model_save_path = aiondata_path / "models"
model_save_path.mkdir(parents=True, exist_ok=True)
joblib.dump(model, model_save_path / "binding_affinity.joblib")

# Close the embeddings
client.close()