# Using transformer embeddings as a data representation

Now that we've generated our embeddings for SMILES strings, we want to combine them with embeddings for protein structues. For that, I decided to go with ESM2, since it works directly with protein structures and is therefore quite lightweight, yet has still been proven to contain structural information of proteins.

This notebook doesn't contain many interesting results, as it's just me preprocessing the data which we will later feed into a neural network.

NOTE: This notebook was also run through Colab on an A100. While embedding our sequences with ESM's smallest model is doable on a CPU in one hour, the A100 gets the entire job done in under a second.

In [1]:
!git clone https://github.com/Orbliss/Cheminformatics_molecule_property_project
!pip install fair-esm rdkit biopandas

%cd /content/Cheminformatics_molecule_property_project

Cloning into 'Cheminformatics_molecule_property_project'...
remote: Enumerating objects: 184, done.[K
remote: Counting objects: 100% (1/1), done.[K
remote: Total 184 (delta 0), reused 0 (delta 0), pack-reused 183 (from 1)[K
Receiving objects: 100% (184/184), 1.14 MiB | 21.97 MiB/s, done.
Resolving deltas: 100% (96/96), done.
Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Collecting rdkit
  Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Collecting biopandas
  Downloading biopandas-0.5.1-py3-none-any.whl.metadata (1.6 kB)
Collecting mmtf-python==1.1.3 (from biopandas)
  Downloading mmtf_python-1.1.3-py2.py3-none-any.whl.metadata (1.2 kB)
Collecting looseversion==1.1.2 (from biopandas)
  Downloading looseversion-1.1.2-py3-none-any.whl.metadata (4.6 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00

In [2]:
import sys
sys.path.append("..")

import os
import json
from pathlib import Path

import torch
import numpy as np
import esm
from rdkit import Chem
from biopandas.pdb import PandasPdb
from transformers import PretrainedConfig, PreTrainedModel, AutoTokenizer

from src.transformer_classes import CuteSmileyBERT, CuteSmileyBERTConfig, SMILESTokenizer
from src.download_dataset import download_datasets, extract_files, delete_files

Data directory: /content/Cheminformatics_molecule_property_project/data


In [3]:
download_datasets()
extract_files()
delete_files()

Téléchargement de pdbbind_v2015.tar.gz...


pdbbind_v2015.tar.gz: 100%|██████████| 1.91G/1.91G [01:31<00:00, 22.5MB/s]


pdbbind_v2015.tar.gz téléchargé et enregistré sous /content/Cheminformatics_molecule_property_project/data/pdbbind_v2015.tar.gz
Téléchargement de sider.csv.gz...


sider.csv.gz: 100%|██████████| 33.9k/33.9k [00:00<00:00, 249kB/s]


sider.csv.gz téléchargé et enregistré sous /content/Cheminformatics_molecule_property_project/data/sider.csv.gz
Extraction de pdbbind_v2015.tar.gz...


Extracting pdbbind_v2015.tar.gz: 100%|██████████| 64861/64861 [00:58<00:00, 1106.00it/s]


pdbbind_v2015.tar.gz extrait dans /content/Cheminformatics_molecule_property_project/data/pdbbind_v2015
Extraction de sider.csv.gz...


Extracting sider.csv.gz: 184kB [00:00, 110MB/s]                     


sider.csv.gz extrait vers /content/Cheminformatics_molecule_property_project/data/sider.csv


In [4]:
# Loading the model checkpoint from Hugging Face Hub
REPO = "marcosbolanos/cutesmileybert-4.8m"

# We're defining the tokenizer locally for now
# Hugging Face needed standardized definitions, no time to implement
VOCAB_PATH = "./data/vocab.json"
with open(VOCAB_PATH, "r") as f:
    vocab = json.load(f)
inv_vocab = {v : k for k, v in vocab.items()}
tokenizer = SMILESTokenizer(vocab, inv_vocab)

# This is the model config, loaded from the Hugging Face Repo
config = CuteSmileyBERTConfig.from_pretrained(REPO)
# And this loads the model's weights
ligand_embedder = CuteSmileyBERT.from_pretrained(REPO, config=config)

# This is going to be our function to embed a single smiles string
def get_smiles_pooled_embeddings(smiles: str):
    # Tokenize the smiles string
    encoded = tokenizer(smiles, return_tensors="pt")
    input_ids = encoded["input_ids"]
    # Feed the tokens into the embedder, recover the embeddings
    with torch.no_grad():
        emb = ligand_embedder(input_ids, return_embeddings=True)

    # Here, the embeddings are tensors containing 256-long column vectors for each token
    # In practice, we'll mean pool the embeddings to get a gloabl representation as a single column
    pooled_emb = emb.mean(dim=1)
    return pooled_emb

config.json:   0%|          | 0.00/276 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/19.2M [00:00<?, ?B/s]

In [6]:
# Now we load the test/train datasets, just like last time
DATA_DIR = "./data"
PDBBIND_DIR = Path(DATA_DIR, "v2015")
INTERIM_DIR = Path(DATA_DIR, "interim")
DATASET_PATH = Path(INTERIM_DIR, "reg_preprocessed_1.npz")

data = np.load(DATASET_PATH)
# We're going to do operations based on our list of train and test IDs
# Each individual file will be loaded and embedded
train_ids = data["train_ids"]
test_ids = data["test_ids"]

In [7]:
# Here, we're creating the loop to embed all of our train and test IDs
def get_smiles_embeddings_list(pdb_ids: list[str]):
    embeddings = []
    for pdb_id in pdb_ids:
        ligand_mol2_path = Path(PDBBIND_DIR, pdb_id, pdb_id + "_ligand.mol2")
        # Make sure the file actually exists, otherwise skip
        if not os.path.exists(ligand_mol2_path):
            print(f'molecule {pdb_id} file not found')
            continue
        mol = Chem.MolFromMol2File(ligand_mol2_path, sanitize=False, removeHs=False)
        # Again, we skip it if the molecule didn't load
        if mol is None:
            print(f'molecule {pdb_id} didnt load successfully')
            continue
        # Removing explicit hydrogens to match the format our model was trained on
        mol = Chem.RemoveHs(mol, updateExplicitCount=True)
        smiles = Chem.MolToSmiles(mol,
                                  canonical=True,
                                  isomericSmiles=False,
                                  kekuleSmiles=False,
                                  allHsExplicit=False)
        smiles_emb = get_smiles_pooled_embeddings(smiles)
        embeddings.append(smiles_emb)
    return embeddings

In [8]:
# Now we need embeddings for the protein pockets
# For this, we're going to iterate through our dataset and get the sequences
# We have three-letter AA names, so we'll map them to single letter
aa_map = {
    'ALA':'A', 'ARG':'R', 'ASN':'N', 'ASP':'D',
    'CYS':'C', 'GLU':'E', 'GLN':'Q', 'GLY':'G',
    'HIS':'H', 'ILE':'I', 'LEU':'L', 'LYS':'K',
    'MET':'M', 'PHE':'F', 'PRO':'P', 'SER':'S',
    'THR':'T', 'TRP':'W', 'TYR':'Y', 'VAL':'V',
    'SEC':'U', 'PYL':'O'
}

# This function gives us the sequence string for a given PDB ID
def get_pocket_sequence(pdb_id: str) -> str:
    pocket_pdb_path = Path(PDBBIND_DIR, pdb_id, pdb_id + "_pocket.pdb")
    if not Path.exists(pocket_pdb_path):
        print(f'Warning: couldnt find pocket for complex {pdb_id}')
        return None
    ppdb = PandasPdb().read_pdb(pocket_pdb_path)
    df = ppdb.df['ATOM']
    df = df.drop_duplicates(subset='residue_number', keep='first')
    df = df.sort_values(by=['chain_id', 'residue_number'])
    seq = ''.join(aa_map.get(res, 'X') for res in df['residue_name'])
    return seq

sequence = get_pocket_sequence(test_ids[0])
print(sequence)

PYIELKLAGRWPVKVFIHNHKRYSAGERIVDIIATD


In [9]:
# Now we load ESM2 thanks to the python package
model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t30_150M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t30_150M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t30_150M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t30_150M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 640, padding_idx=1)
  (layers): ModuleList(
    (0-29): 30 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=640, out_features=640, bias=True)
        (v_proj): Linear(in_features=640, out_features=640, bias=True)
        (q_proj): Linear(in_features=640, out_features=640, bias=True)
        (out_proj): Linear(in_features=640, out_features=640, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=640, out_features=2560, bias=True)
      (fc2): Linear(in_features=2560, out_features=640, bias=True)
      (final_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=600, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((640,), eps=1e-05, elementw

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.get_device_name())
model = model.to(device)

NVIDIA A100-SXM4-80GB


In [11]:
# This helper will get pocket sequences and put them in the required format
# ESM2 takes inputs as a list of ('name', SEQUENCE) tuples
def get_esm_inputs_for_pdb_ids(pdb_ids:list[str]) -> tuple[list[tuple[str, str]], list[str]]:
    missing_ids = []
    esm_inputs = []
    for pdb_id in pdb_ids:
        seq = get_pocket_sequence(pdb_id)
        if seq == None:
            missing_ids.append(pdb_id)
            continue
        esm_inputs.append((str(pdb_id), seq))
    return esm_inputs, missing_ids

# Generate embeddings for a given list of PDB IDs
def get_esm_embedding_list(pdb_ids:list[str]):
    print("Preprocessing data for ESM2...")
    esm_inputs, missing_ids = get_esm_inputs_for_pdb_ids(pdb_ids)
    # This is the tokenizer used by ESM2
    batch_labels, batch_strs, batch_tokens = batch_converter(esm_inputs)

    batch_tokens = batch_tokens.to(device)

    # This gives us the last-layer token embeddings of each of our sequences
    # return_contacts = False means we won't get the attention patterns
    print("Computing embeddings...")
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[30], return_contacts=False)
    token_embeddings = results["representations"][30]

    # We then mean pool the token embeddings for every sequence
    # The result is a list of per-sequence embeddings, which we will fetch
    sequence_embeddings = []
    for i, (_, seq) in enumerate(esm_inputs):
        sequence_embeddings.append(token_embeddings[i, 1 : len(seq) + 1].mean(0))
    print("Done !")
    return sequence_embeddings, missing_ids

train_esm_embeddings, missing_train_ids = get_esm_embedding_list(train_ids)
test_esm_embeddings, missing_test_ids = get_esm_embedding_list(test_ids)

print(len(train_esm_embeddings))
print(len(test_esm_embeddings))
print(train_esm_embeddings[0].shape)
print(f'Missing IDs: train: {len(missing_train_ids)}, test: {len(missing_test_ids)}')

Preprocessing data for ESM2...
Computing embeddings...
Done !
Preprocessing data for ESM2...
Computing embeddings...
Done !
3509
195
torch.Size([640])
Missing IDs: train: 0, test: 0


In [17]:
def reshape_rows_to_col(esm_embeddings):
    return [array.reshape(1, -1) for array in esm_embeddings]

train_esm_embeddings_reshaped = reshape_rows_to_col(train_esm_embeddings)
test_esm_embeddings_reshaped = reshape_rows_to_col(test_esm_embeddings)

In [13]:
# Now we know which IDs actually have protein pocket data
# We can will remove those from our dataset and get SMILES embeddings for the remaining ones
# None will be missing after this step
train_smiles_embeddings = get_smiles_embeddings_list(train_ids)
test_smiles_embeddings = get_smiles_embeddings_list(test_ids)

print(len(train_smiles_embeddings))
print(len(test_smiles_embeddings))
print(train_smiles_embeddings[0].shape)



3509
195
torch.Size([1, 256])


In [37]:
def concatenate_embeddings(esm_embeddings, smiles_embeddings):
    if len(esm_embeddings) != len(smiles_embeddings):
        raise ValueError('Warning: both embedding lists arent the same size')
    concatenated_embeddings = []
    for i in range(len(esm_embeddings)):
        embedding = torch.cat((esm_embeddings[i].to('cpu'), smiles_embeddings[i]), dim=1)
        concatenated_embeddings.append(embedding)
    return concatenated_embeddings


X_train = concatenate_embeddings(train_esm_embeddings_reshaped, train_smiles_embeddings)
X_test = concatenate_embeddings(test_esm_embeddings_reshaped, test_smiles_embeddings)
print(train_embeddings[0].shape)

torch.Size([1, 896])


In [39]:
y_train = data['y_train']
y_test = data['y_test']

print(len(X_train))
print(len(y_train))
print(len(X_test))
print(len(y_test))

3509
3509
195
195


In [41]:
OUTPUT_PATH = "./data/interim/reg_embeddings.npz"
np.savez(OUTPUT_PATH, X_train, X_test, y_train, y_test)