# Machine Learning Model for Binding Affinity using BindingDB


#### Install dependencies and import modules

In [1]:
# Install dependencies
!pip install -q torch fair-esm transformers aiondata vectorizedb &> /dev/null

# Standard library imports
import os
from pathlib import Path

# Third-party imports for numerical operations and machine learning
import joblib
import numpy as np
import torch
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from vectorizedb import Database

# Third-party imports for deep learning and specific models
from transformers import AutoTokenizer, AutoModel

# Imports from specific libraries used in cheminformatics and bioinformatics
from rdkit import RDLogger
import esm
from aiondata import BindingAffinity

aiondata_path = Path(os.environ.get("AIONDATA_CACHE", "~/.aiondata")).expanduser()
ligand_dim = 768
protein_dim = 480

# Only test on n examples, set to None to test on all examples
test_only_n_examples = 100

#### Optional: Load previously saved model and data

In [2]:
# Load the pre-trained model
try:
    model = joblib.load(aiondata_path / "models" / "binding_affinity.joblib")
except FileNotFoundError:
    print("Model not found. Please train the model first.")

# Load embeddings
try: 
    X_ligands = Database(aiondata_path / "embeddings" / "ligands", dim=ligand_dim)
    X_proteins = Database(aiondata_path / "embeddings" / "proteins", dim=protein_dim)
except FileNotFoundError:
    print("Embeddings not found. Please generate embeddings first.")



#### Load BindingAffinity and prepare for machine learning

In [3]:

# Load BindingDB into a Polars DataFrame
df = BindingAffinity().to_df()

# Filter out rows with missing SMILES, Sequence, or Binds values
df = df.drop_nulls(subset=["SMILES", "Sequence", "Binds"])

# For test purposes only use a subset of the data
if test_only_n_examples:
    df = df.sample(n=test_only_n_examples, shuffle=True, seed=18)

# Get the SMILES, Sequence, and Binds columns
ligands = df["SMILES"]
target_sequence = df["Sequence"]
affinity = df["Binds"]

# Suppress RDKit warnings and errors
RDLogger.DisableLog("rdApp.*")  

#### Create Protein Embeddings using ESM

In [4]:
protein_embedding_cache = {}

# Load ESM-35m model
model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()

# Prepare model and move it to evaluation mode
model = model.eval()
if torch.cuda.is_available():
    model = model.cuda()

def create_protein_embedding(sequence: str, cache: dict):
    """Generate embeddings for a single protein sequence using a cache to store previous computations."""
    if sequence in cache:
        return cache[sequence]

    # Convert sequence to tokens
    tokens = torch.tensor([alphabet.encode(sequence)])
    if torch.cuda.is_available():
        tokens = tokens.cuda()

    with torch.no_grad():
        results = model(tokens, repr_layers=[12])  # Extract embeddings from the last layer

    # Extract embeddings and move to cpu
    embeddings = results["representations"][12].squeeze(0).cpu()

    # Reduce the embeddings to 1D by averaging across the sequence length
    embeddings_1d = embeddings.mean(dim=0)

    # Store in cache
    cache[sequence] = embeddings_1d
    
    return embeddings_1d

def create_embedding_generator(sequences: list[str]):
    """Generate embeddings for a list of protein sequences."""
    for sequence in tqdm(sequences, desc="Generating protein embeddings", unit=" proteins"):
        yield create_protein_embedding(sequence, protein_embedding_cache).numpy()

# Generate embeddings for all protein sequences
X_proteins = np.array(list(create_embedding_generator(target_sequence)), dtype=np.float32)


Generating protein embeddings:   0%|          | 0/100 [00:00<?, ? proteins/s]

#### Create Ligand Embeddings using ChemBERTa

In [5]:
ligand_embedding_cache = {}

# Load ChemBERTa model and tokenizer
chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
chemberta_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
chemberta_model.eval()

def create_ligand_embedding(smiles: str, cache: dict):
    """Generate embeddings for a single SMILES string using a cache to store previous computations."""
    if smiles in cache:
        return cache[smiles]
    
    # Truncate SMILES string to 512 characters, required by ChemBERTa
    if len(smiles) > 512:
        smiles = smiles[:512]
    
    inputs = chemberta_tokenizer(smiles, return_tensors="pt")
    
    with torch.no_grad():
        outputs = chemberta_model(**inputs)
    
    # Take the mean of the last hidden state to get a single vector representation
    embedding = outputs.last_hidden_state.mean(dim=1).squeeze(0)
    
    # Store in cache
    cache[smiles] = embedding
    
    return embedding

def create_embedding_generator(smiles: list[str]):
    """Generate embeddings for a list of SMILES strings."""
    for smile in tqdm(smiles, desc="Generating ligand embeddings", unit=" ligand"):
        yield create_ligand_embedding(smile, ligand_embedding_cache).numpy()

# Generate embeddings for all ligands
X_ligands = np.array(list(create_embedding_generator(ligands)), dtype=np.float32)


Generating ligand embeddings:   0%|          | 0/100 [00:00<?, ? ligand/s]

#### Create the model and predict binding values

In [6]:
X = np.concatenate([X_ligands, X_proteins], axis=1)
# Make affinity into a numpy y array
y = affinity.to_numpy()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=18)

# Train a new Random Forest model
model = RandomForestClassifier(n_estimators=100, random_state=18)
model.fit(X_train, y_train)

# Evaluate the model ROC-AUC score
y_pred = model.predict_proba(X_test)[:, 1]
roc_auc = roc_auc_score(y_test, y_pred)

print(f"ROC-AUC score: {roc_auc:.4f}")

ROC-AUC score: 0.9121


#### Save the model and embeddings


In [8]:
# Save the model
model_save_path = aiondata_path / "models"
model_save_path.mkdir(parents=True, exist_ok=True)
joblib.dump(model, model_save_path / "binding_affinity.joblib")

# Save the embeddings
embeddings_save_path = aiondata_path / "embeddings"
embeddings_save_path.mkdir(parents=True, exist_ok=True)
ligands_db = Database(embeddings_save_path / "ligands", dim=ligand_dim)
proteins_db = Database(embeddings_save_path / "proteins", dim=protein_dim)
for key, value in ligand_embedding_cache.items():
    ligands_db[key] = value
for key, value in protein_embedding_cache.items():
    proteins_db[key] = value

torch.float32


ValueError: Vector must be of type np.float32