In [18]:
# extract_smiles_from_sdf.py
from rdkit import Chem
import pandas as pd
from pathlib import Path

sdf_dir = Path("/data1/home/jw1017/GLAS_Dock/Data/ready_data/ligand_sdf_files")  # folder containing .sdf
records = []
for sdf_path in sdf_dir.glob("*.sdf"):
    suppl = Chem.SDMolSupplier(str(sdf_path), sanitize=True)
    mol = next((m for m in suppl if m is not None), None)
    if mol:
        smiles = Chem.MolToSmiles(mol)
        ligand_id = sdf_path.stem
        records.append({"ligand_id": ligand_id, "smiles": smiles})

df = pd.DataFrame(records)
df.to_csv("ligand_smiles.csv", index=False)
print(f"Extracted {len(df)} ligands.")


Extracted 5623 ligands.


In [21]:
# chemberta_embed.py
import torch
from transformers import AutoTokenizer, AutoModel
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained("/data1/home/jw1017/GLAS_Dock/Data/embedding_models/chemberta_77M_MLM", local_files_only=True)
model = AutoModel.from_pretrained("/data1/home/jw1017/GLAS_Dock/Data/embedding_models/chemberta_77M_MLM", local_files_only=True).to(device).eval()

df = pd.read_csv("ligand_smiles.csv")
embeddings, ids = [], []

with torch.no_grad():
    for i, row in tqdm(df.iterrows(), total=len(df)):
        smi = row["smiles"]
        enc = tok(smi, return_tensors="pt", truncation=True, padding=True).to(device)
        out = model(**enc).last_hidden_state.mean(1)  # mean-pool tokens
        embeddings.append(out.cpu())
        ids.append('_'.join(row["ligand_id"].split('_')[:-1]))

emb = torch.cat(embeddings, dim=0)
torch.save({"ligand_ids": ids, "embeddings": emb}, "ligand_chemberta.pt")
print(f"Saved {len(ids)} embeddings, dim={emb.shape[1]}")

Some weights of RobertaModel were not initialized from the model checkpoint at /data1/home/jw1017/GLAS_Dock/Data/embedding_models/chemberta_77M_MLM and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 5623/5623 [00:15<00:00, 357.82it/s]


Saved 5623 embeddings, dim=384


In [22]:
graph = torch.load("ligand_chemberta.pt")

In [None]:
graph

{'ligand_ids': ['5uio_LG3_B_201',
  '6flu_DTK_B_401',
  '4jhh_H35_A_201',
  '7V17_OQO',
  '7KLX_WOV',
  '4nir_6DH_A_308',
  '5nh9_XLS_D_504',
  '8GEZ_ZMB',
  '8AO6_ANP',
  '6y4f_GLU_A_202',
  '8QBC_GEE',
  '7RGX_D9J',
  '4mig_G3F_B_802',
  '7SLZ_L6U',
  '5eo7_SFU_B_403',
  '5tch_MLI_C_301',
  '7WN5_JGL',
  '7A9E_R4W',
  '7HG3_X0S',
  '4ma8_Z80_C_301',
  '4p8c_Y22_B_502',
  '2pvy_ACP_A_802',
  '4a8p_PAO_C_1340',
  '8FO5_Y4Q',
  '2uyv_TLA_A_1277',
  '8CSD_C5P',
  '8BSA_DLY',
  '7F10_0BJ',
  '2zut_A2G_B_4002',
  '5c8w_PCG_D_302',
  '2dv0_ZST_A_600',
  '1LPZ_CMB',
  '7F1G_0QQ',
  '7BCH_UR4',
  '1z6k_OAA_A_274',
  '7N4Q_YD7',
  '2cjz_PTR_A_1537',
  '5je1_TOF_B_302',
  '7V11_OTL',
  '6o04_DNA_D_603',
  '6rgg_K3Q_B_401',
  '8V4G_CDP',
  '7UR3_N6M',
  '1T46_STI',
  '2pin_LEG_A_501',
  '4al9_GLA_C_1123',
  '4x36_CHT_A_410',
  '8BIN_QUF',
  '7SCP_8S6',
  '5eo7_SFU_C_403',
  '7V12_ORF',
  '2r4j_13P_A_1968',
  '4nkw_PLO_A_601',
  '5SGV_KI8',
  '7FDT_4I1',
  '7O0N_CDP',
  '1lzz_IHG_A_830',
  '5bkd_

: 

In [7]:
graph['embeddings'].shape

torch.Size([5623, 384])