# Using transformers to predict binding affinity

So far, we've tried methods using XGBoost on hand-engineering featurizations. We've then shown how a transformer can be trained on masked SMILES strings to (hopefully) build meaningful representations of them. Now, we want to combine these rich molecule embeddings with embeddings for the protein pockets, for which we will be using Meta's ESM3. The principles behind ESM3's training are very similar to how CuteSmileyBERT was trained, only done at a far larger scale.

Our goal here will be to produce a neural network that takes in the embeddings of both the ligand and the pocket, and output a prediction for pKd. Our baseline will consist of concatenating both input vectors, then using a simple Multi-Layer Perceptron (MLP) architecture. This will be extremely helpful in telling us whether our embedding models encoded any relevant information, and if deep learning is a suitable solution. 

In [1]:
import sys
sys.path.append("..")

import os
import json

import torch
from transformers import PretrainedConfig, PreTrainedModel, AutoTokenizer

from src.transformer_classes import CuteSmileyBERT, CuteSmileyBERTConfig, SMILESTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# Loading the model checkpoint from Hugging Face Hub
REPO = "marcosbolanos/cutesmileybert-4.8m" 

# We're defining the tokenizer locally for now
# Hugging Face needed standardized definitions, no time to implement
VOCAB_PATH = "../data/vocab.json"
with open(VOCAB_PATH, "r") as f:
    vocab = json.load(f)
inv_vocab = {v : k for k, v in vocab.items()}
tokenizer = SMILESTokenizer(vocab, inv_vocab)

# This is the model config, loaded from the Hugging Face Repo
config = CuteSmileyBERTConfig.from_pretrained(REPO)
# And this loads the model's weights
model = CuteSmileyBERT.from_pretrained(REPO, config=config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

encoded = tokenizer("CCCCCC", return_tensors="pt") if callable(tokenizer) else {"input_ids": torch.tensor([tokenizer.encode("CCO")])}
input_ids = encoded["input_ids"].to(device)

with torch.no_grad():
    emb = model(input_ids, return_embeddings=True)  # use your return_embeddings flag
print(emb.shape)

torch.Size([1, 6, 256])
