In [1]:
import polars as pl
df = pl.read_parquet('../data/BindingDB_predprocessed/BindingDB_v0.parquet')

In [2]:
dataic50 = df[["Ligand SMILES","IC50 (nM)","BindingDB Target Chain Sequence"]].drop_nulls() 

In [3]:
dataic50 

Ligand SMILES,IC50 (nM),BindingDB Target Chain Sequence
str,f64,str
"""NS(=O)(=O)c1ccc(Nc2cc(OCC3CCCC…",29000.0,"""MSGRPRTTSFAESCKPVQQPSAFGSMKVSR…"
"""NS(=O)(=O)c1ccc(Nc2cc(OC3CCCCC…",190.0,"""MSGRPRTTSFAESCKPVQQPSAFGSMKVSR…"
"""NS(=O)(=O)c1ccc(Nc2cc(NC3CCCCC…",970.0,"""MSGRPRTTSFAESCKPVQQPSAFGSMKVSR…"
"""CCN(CC)c1cc(Nc2ccc(cc2)S(N)(=O…",11000.0,"""MSGRPRTTSFAESCKPVQQPSAFGSMKVSR…"
"""N[C@H]1CC[C@@H](CC1)Nc1cc(Nc2c…",780.0,"""MSGRPRTTSFAESCKPVQQPSAFGSMKVSR…"
…,…,…
"""O[C@@H]1CCCN(C1)C(=O)c1cccc(c1…",90.0,"""MSSWIRWHGPAMARLWGFCWLVVGFWRAAF…"
"""O[C@H]1CCCN(C1)C(=O)c1cccc(c1)…",118.0,"""MSSWIRWHGPAMARLWGFCWLVVGFWRAAF…"
"""COc1nc2ccc(Br)cc2cc1[C@@H](c1c…",1600.0,"""MPVRRGHVAPQNTFLDTIIRKFEGQSRKFI…"
"""COc1ccc(cc1)N(C)c1nc(C)nc2[nH]…",2600.0,"""CVSASPSTLARLVSRSAMPAGSSTAWNTAF…"


In [6]:
ligand_smiles = dataic50['Ligand SMILES'].to_list()
target_chain_sequence = dataic50['BindingDB Target Chain Sequence'].to_list()

In [7]:
def create_batches(data, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

In [8]:
# Check if GPU is available and move model to GPU
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from transformers import BertTokenizerFast, BertModel
checkpoint = 'unikei/bert-base-smiles'
tokenizer_smiles = BertTokenizerFast.from_pretrained(checkpoint)
model_smiles = BertModel.from_pretrained(checkpoint)
model_smiles.to(device)

In [9]:
from transformers import AutoTokenizer, AutoModel
tokenizer_protein = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model_protein = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
model_protein.to(device)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 320, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
    (position_embeddings): Embedding(1026, 320, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-5): 6 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=320, out_features=320, bias=True)
            (key): Linear(in_features=320, out_features=320, bias=True)
            (value): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
    

In [10]:
import torch
def generate_embeddings_bert_based_models(batch, tokenizer, model,device):
    inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state[:, 0, :].cpu().tolist()
    return embeddings

def generate_embeddings(batch, tokenizer, model,device):
    inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state[:, 1:-1, :].mean(dim=1).cpu().numpy().tolist() 
    return embeddings

In [11]:
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm

# Create a Parquet writer schema
schema = pa.schema([
    ('ID', pa.string()),
    ('encoding', pa.list_(pa.float64()))
])

# Universal function to write embeddings to Parquet
def write_embeddings_to_parquet(data, batch_size, parquet_file, generate_embedding_func, tokenizer, model, device):
    num_newlines = len(data)
    total_batches = (num_newlines + batch_size - 1) // batch_size

    writer = None

    for batch in tqdm(create_batches(data, batch_size), total=total_batches, desc="Processing batches"):
        embeddings = generate_embedding_func(batch,tokenizer, model, device)
        df = pd.DataFrame({'ID': batch, 'encoding': embeddings})
        
        table = pa.Table.from_pandas(df, schema=schema)
        
        if writer is None:
            writer = pq.ParquetWriter(parquet_file, schema)
        
        writer.write_table(table)
        
        # Clear CUDA cache to free up memory
        torch.cuda.empty_cache()

    if writer:
        writer.close()

In [12]:
batch_size = 8
parquet_file_smiles = '../data/embeddings/ligand_embeddings.parquet'
parquet_file_protein = '../data/embeddings/protein_embeddings.parquet'

In [None]:
# Write SMILES embeddings to Parquet
write_embeddings_to_parquet(ligand_smiles, batch_size, parquet_file_smiles, generate_embeddings, tokenizer_smiles, model_smiles, device)

In [13]:
# Write protein sequence embeddings to Parquet
write_embeddings_to_parquet(target_chain_sequence, batch_size, parquet_file_protein, generate_embeddings, tokenizer_protein, model_protein, device)

Processing batches:   0%|          | 0/210975 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Processing batches: 100%|██████████| 210975/210975 [4:07:26<00:00, 14.21it/s]   
