# Using transformers to predict binding affinity

So far, we've tried methods using XGBoost on hand-engineering featurizations. We've then shown how a transformer can be trained on masked SMILES strings to (hopefully) build meaningful representations of them. Now, we want to combine these rich molecule embeddings with embeddings for the protein pockets, for which we will be using Meta's ESM2, which is proven to contain structural information despite being trained on sequences. The principles behind ESM2's training are very similar to how CuteSmileyBERT was trained, only done at a far larger scale.

Our goal here will be to produce a neural network that takes in the embeddings of both the ligand and the pocket, and output a prediction for pKd. Our baseline will consist of concatenating both input vectors, then using a simple Multi-Layer Perceptron (MLP) architecture. This will be extremely helpful in telling us whether our embedding models encoded any relevant information, and if deep learning is a suitable solution. 

In [2]:
import sys
sys.path.append("..")

import os
import json
from pathlib import Path

import torch
import numpy as np
import esm
from rdkit import Chem
from biopandas.pdb import PandasPdb
from transformers import PretrainedConfig, PreTrainedModel, AutoTokenizer

from src.transformer_classes import CuteSmileyBERT, CuteSmileyBERTConfig, SMILESTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Loading the model checkpoint from Hugging Face Hub
REPO = "marcosbolanos/cutesmileybert-4.8m" 

# We're defining the tokenizer locally for now
# Hugging Face needed standardized definitions, no time to implement
VOCAB_PATH = "../data/vocab.json"
with open(VOCAB_PATH, "r") as f:
    vocab = json.load(f)
inv_vocab = {v : k for k, v in vocab.items()}
tokenizer = SMILESTokenizer(vocab, inv_vocab)

# This is the model config, loaded from the Hugging Face Repo
config = CuteSmileyBERTConfig.from_pretrained(REPO)
# And this loads the model's weights
ligand_embedder = CuteSmileyBERT.from_pretrained(REPO, config=config)

# This is going to be our function to embed a single smiles string
def get_smiles_pooled_embeddings(smiles: str): 
    # Tokenize the smiles string
    encoded = tokenizer(smiles, return_tensors="pt")
    input_ids = encoded["input_ids"]
    # Feed the tokens into the embedder, recover the embeddings
    with torch.no_grad():
        emb = ligand_embedder(input_ids, return_embeddings=True)
    
    # Here, the embeddings are tensors containing 256-long column vectors for each token
    # In practice, we'll mean pool the embeddings to get a gloabl representation as a single column
    pooled_emb = emb.mean(dim=1)
    return pooled_emb

In [4]:
# Now we load the test/train datasets, just like last time
DATA_DIR = "../data"
PDBBIND_DIR = Path(DATA_DIR, "v2015")
INTERIM_DIR = Path(DATA_DIR, "interim")
DATASET_PATH = Path(INTERIM_DIR, "reg_preprocessed_1.npz")

data = np.load(DATASET_PATH)
# We're going to do operations based on our list of train and test IDs
# Each individual file will be loaded and embedded
train_ids = data["train_ids"]
test_ids = data["test_ids"]

In [5]:
# Here, we're creating the loop to embed all of our train and test IDs
def get_smiles_embeddings_list(pdb_ids: list[str]): 
    embeddings = []
    for pdb_id in pdb_ids:
        ligand_mol2_path = Path(PDBBIND_DIR, pdb_id, pdb_id + "_ligand.mol2")
        # Make sure the file actually exists, otherwise skip
        if not os.path.exists(ligand_mol2_path):
            print(f'molecule {pdb_id} file not found')
            continue
        mol = Chem.MolFromMol2File(ligand_mol2_path, sanitize=False, removeHs=False)
        # Again, we skip it if the molecule didn't load
        if mol is None:
            print(f'molecule {pdb_id} didnt load successfully')
            continue
        # Removing explicit hydrogens to match the format our model was trained on
        mol = Chem.RemoveHs(mol, updateExplicitCount=True)
        smiles = Chem.MolToSmiles(mol,
                                  canonical=True,
                                  isomericSmiles=False,
                                  kekuleSmiles=False,
                                  allHsExplicit=False)
        smiles_emb = get_smiles_pooled_embeddings(smiles)
        embeddings.append(smiles_emb)
    return embeddings

In [6]:
# Now we need embeddings for the protein pockets 
# For this, we're going to iterate through our dataset and get the sequences
# We have three-letter AA names, so we'll map them to single letter
aa_map = {
    'ALA':'A', 'ARG':'R', 'ASN':'N', 'ASP':'D',
    'CYS':'C', 'GLU':'E', 'GLN':'Q', 'GLY':'G',
    'HIS':'H', 'ILE':'I', 'LEU':'L', 'LYS':'K',
    'MET':'M', 'PHE':'F', 'PRO':'P', 'SER':'S',
    'THR':'T', 'TRP':'W', 'TYR':'Y', 'VAL':'V',
    'SEC':'U', 'PYL':'O'
}

# This function gives us the sequence string for a given PDB ID
def get_pocket_sequence(pdb_id: str) -> str: 
    pocket_pdb_path = Path(PDBBIND_DIR, pdb_id, pdb_id + "_pocket.pdb")
    if not Path.exists(pocket_pdb_path):
        print(f'Warning: couldnt find pocket for complex {pdb_id}')
        return None
    ppdb = PandasPdb().read_pdb(pocket_pdb_path)
    df = ppdb.df['ATOM']
    df = df.drop_duplicates(subset='residue_number', keep='first')
    df = df.sort_values(by=['chain_id', 'residue_number'])
    seq = ''.join(aa_map.get(res, 'X') for res in df['residue_name'])
    return seq

sequence = get_pocket_sequence(test_ids[0])
print(sequence)

PYIELKLAGRWPVKVFIHNHKRYSAGERIVDIIATD


In [None]:
# Now we load ESM2 thanks to the python package
model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval() 

ESM2(
  (embed_tokens): Embedding(33, 640, padding_idx=1)
  (layers): ModuleList(
    (0-29): 30 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=640, out_features=640, bias=True)
        (v_proj): Linear(in_features=640, out_features=640, bias=True)
        (q_proj): Linear(in_features=640, out_features=640, bias=True)
        (out_proj): Linear(in_features=640, out_features=640, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=640, out_features=2560, bias=True)
      (fc2): Linear(in_features=2560, out_features=640, bias=True)
      (final_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=600, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((640,), eps=1e-05, elementw

: 

In [None]:
# This helper will get pocket sequences and put them in the required format
# ESM2 takes inputs as a list of ('name', SEQUENCE) tuples
def get_esm_inputs_for_pdb_ids(pdb_ids:list[str]) -> tuple[list[tuple[str, str]], list[str]]: 
    missing_ids = []
    esm_inputs = []
    for pdb_id in pdb_ids:
        seq = get_pocket_sequence(pdb_id)
        if seq == None:
            missing_ids.append(pdb_id)
            continue
        esm_inputs.append((str(pdb_id), seq))
    return esm_inputs, missing_ids

# Generate embeddings for a given list of PDB IDs
def get_esm_embedding_list(pdb_ids:list[str]):
    esm_inputs, missing_ids = get_esm_inputs_for_pdb_ids(pdb_ids)
    # This is the tokenizer used by ESM2
    batch_labels, batch_strs, batch_tokens = batch_converter(esm_inputs)

    # This gives us the last-layer token embeddings of each of our sequences
    # return_contacts = False means we won't get the attention patterns
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[30], return_contacts=False)
    token_embeddings = results["representations"][30]

    # We then mean pool the token embeddings for every sequence
    # The result is a list of per-sequence embeddings, which we will fetch
    sequence_embeddings = []
    for i, (_, seq) in enumerate(esm_inputs):
        sequence_embeddings.append(token_embeddings[i, 1 : len(seq) + 1].mean(0))
    return sequence_embeddings, missing_ids

train_esm_embeddings, missing_train_ids = get_esm_embedding_list(train_ids)
test_esm_embeddings, missing_test_ids = get_esm_embedding_list(test_ids)

print(len(train_esm_embeddings))
print(len(test_esm_embeddings))
print(train_esm_embeddings[0].shape)
print(f'Missing IDs: train: {len(missing_train_ids)}, test: {len(missing_test_ids)}')

In [None]:
# Now we know which IDs actually have protein pocket data
# We can will remove those from our dataset and get SMILES embeddings for the remaining ones
# None will be missing after this step
retained_train_ids = train_ids - missing_train_ids
retained_test_ids = test_ids - missing_test_ids

train_smiles_embeddings = get_smiles_embeddings_list(retained_train_ids)
test_smiles_embeddings = get_smiles_embeddings_list(retained_test_ids)

print(len(train_smiles_embeddings))
print(len(test_smiles_embeddings))
print(train_smiles_embeddings[0].shape)

In [None]:
def get_concatenated_embeddings(smiles_embeddings, esm_embeddings):
    