# To Reproduce the results for CSC413 Final Project:
Prior to this step, please execute `download_weights.py` to download the weights for all pre-trained models.

NOTE: THIS NOTEBOOK WAS ADAPTED FROM LABIND, AND ADAPTED FOR THIS PROJECT. MUCH OF THIS IS NOT ORIGINAL, BUT REPURPOSED AND TAKEN DIRECTLY FROM THE ORIGINAL.  

In [None]:
# import packages

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # prevents dssp from outputing warnings

from torch.utils.data import DataLoader
import readData
import shutil
from utils import *
import torch
from torch import nn
from model import LABind
from func_help import setALlSeed,get_std_opt
from tqdm import tqdm
import pickle as pkl
from sklearn.model_selection import KFold
import gc

In [None]:
# config
DEVICE = torch.device('cuda:0')
root_path = getRootPath()
dataset = 'LigBind' # DS1:LigBind, DS2:GPSite, DS3:Unseen

nn_config = {
    # dataset 
    'train_file': f'{root_path}/{dataset}/label/train/train.fa',
    'test_file': f'{root_path}/{dataset}/label/test/test.fa',
    'valid_file': f'{root_path}/{dataset}/label/picking.fa',
    'proj_dir': f'{root_path}/{dataset}',
    'lig_dict': pkl.load(open(f'{root_path}/tools/ligand.pkl', 'rb')),
    'pdb_class':'source', # source or omegafold or esmfold
    'dssp_max_repr': np.load(f'{root_path}/tools/dssp_max_repr.npy'),
    'dssp_min_repr': np.load(f'{root_path}/tools/dssp_min_repr.npy'),
    'ankh_max_repr': np.load(f'{root_path}/tools/ankh_max_repr.npy'),
    'ankh_min_repr': np.load(f'{root_path}/tools/ankh_min_repr.npy'),
    'esm2_max_repr': np.load(f'{root_path}/tools/esm2_max_repr.npy'),
    'esm2_min_repr': np.load(f'{root_path}/tools/esm2_min_repr.npy'),
    'ion_max_repr': np.load(f'{root_path}/tools/ion_max_repr.npy'),
    'ion_min_repr': np.load(f'{root_path}/tools/ion_min_repr.npy'),
    # model parameters
    
    'rfeat_dim':2580,
    'ligand_dim':768, 
    'hidden_dim':256, 
    'heads':4, 
    'augment_eps':0.05, 
    'rbf_num':8, 
    'top_k':30, 
    'attn_drop':0.1, 
    'dropout':0.1, 
    'num_layers':4, 
    'lr':0.00002, # a lower learning rate may perform even better
    
    # training parameters 
    # You can modify it according to the actual situation. 
    # Since it involves mapping the entire protein, it will consume a large amount of GPU memory.
    'batch_size':1, # due to insufficent cuda memory
    'max_patience':10,
    'device_ids':[0]
}
pretrain_path = { # Please modify 
    'esmfold_path': '../tools/esmfold_v1', # esmfold path
    'esm2_path': '../tools/esm2', 
    'ankh_path': '../tools/ankh-large/', # ankh path
    'molformer_path': '../tools/MoLFormer-XL-both-10pct/', # molformer path
    'model_path':f'{root_path}/model/LigBind/' 
}

In [3]:
print(nn_config.keys())

dict_keys(['train_file', 'test_file', 'valid_file', 'proj_dir', 'lig_dict', 'pdb_class', 'dssp_max_repr', 'dssp_min_repr', 'ankh_max_repr', 'ankh_min_repr', 'esm2_max_repr', 'esm2_min_repr', 'ion_max_repr', 'ion_min_repr', 'rfeat_dim', 'ligand_dim', 'hidden_dim', 'heads', 'augment_eps', 'rbf_num', 'top_k', 'attn_drop', 'dropout', 'num_layers', 'lr', 'batch_size', 'max_patience', 'device_ids'])


### Download dataset
Download the file from https://zenodo.org/records/13938443 and place it in the root directory.


### Retrieve the features

#### ESM2 
(this version needs more vram, please look at esm_embed.ipynb for the actual version used in the project)

In [None]:
import numpy as np
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModel 
from Bio import SeqIO


MODEL_DIR = "../tools/esm2"   # <-- directory containing HF files
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
FASTA_ROOT = "fasta"             # root directory containing subdirs of FASTA files
SAVE_ROOT = "embeddings"         # root output folder
out_path = f"{root_path}/{dataset}/esm/"
fasta_path = f"{root_path}/{dataset}/fasta/"

os.makedirs(out_path, exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained(pretrain_path['esm2_path'])
model     = AutoModel.from_pretrained(pretrain_path['esm2_path'])
model.to(DEVICE)

model.eval()
model.gradient_checkpointing_enable()
print("Model loaded on", DEVICE)
for file_class in os.listdir(fasta_path):
    class_path = os.path.join(fasta_path, file_class)

    for fasta_file in os.listdir(class_path):
        fasta_path_full = os.path.join(class_path, fasta_file)

        sequences = list(SeqIO.parse(fasta_path_full, "fasta"))

        print(f"\nProcessing {fasta_file} ({len(sequences)} sequences)")

        for record in tqdm(sequences):
            save_path = os.path.join(out_path, f"{record.id}.npy")

            # Skip if already processed
            if os.path.exists(save_path):
                continue

            seq = str(record.seq)

            # Tokenize for HF model
            encoded = tokenizer(
                seq,
                return_tensors="pt",
                padding=False,
                add_special_tokens=True
            )

            encoded = {k: v.to(DEVICE) for k, v in encoded.items()}

            # Run model on GPU
            with torch.no_grad():
                output = model(**encoded)
                hidden = output.last_hidden_state.cpu()      # shape: (1, L, D)
                embedding = hidden[0, 1:len(seq)+1].numpy()

            # Save
            np.save(save_path, embedding)
            torch.cuda.empty_cache()

        

# Cleanup
del model
gc.collect()
print("\n✓ Done extracting ESM embeddings!")


The following block generates the esm2_min_repr and esm2_max_repr files. 

In [None]:

from tqdm import tqdm

esm_dir = "/virtual/zengzix4/LABind_ESM/LigBind/esm3B"    
save_file = "esm2_repr_stats.npy"
save_dir = "/virtual/zengzix4/LABind_ESM/tools/"
# Use None so we can initialize on first file

all_min = None
all_max = None

for f in os.listdir(esm_dir):
    if not f.endswith(".npy"):
        continue

    fpath = os.path.join(esm_dir, f)

    try:
        arr = np.load(fpath, allow_pickle=True)
    except Exception as e:
        print("Skipping bad file:", f, "Error:", e)
        continue

    # Convert object arrays to real numpy if possible
    if isinstance(arr, np.ndarray) and arr.dtype == object:
        try:
            arr = np.vstack(arr)  # many ESM embeddings come as list of arrays
        except:
            print("Skipping malformed object array:", f)
            continue

    # Now check shape
    if arr.ndim != 2:
        print("Skipping wrong shape", f, arr.shape)
        continue

    # Compute per-dimension min/max
    fmin = arr.min(axis=0)
    fmax = arr.max(axis=0)

    if all_min is None:
        all_min = fmin
        all_max = fmax
    else:
        all_min = np.minimum(all_min, fmin)
        all_max = np.maximum(all_max, fmax)

np.save(os.path.join(save_dir, "esm2_min_repr.npy"), all_min)
np.save(os.path.join(save_dir, "esm2_max_repr.npy"), all_max)


Skipping bad file: 6s9fB.npy Error: Failed to interpret file '/virtual/zengzix4/LABind_ESM/LigBind/esm3B/6s9fB.npy' as a pickle
Skipping bad file: 6ptkA.npy Error: cannot reshape array of size 1138656 into shape (471,2560)
FINAL ESM MIN SHAPE: (2560,)
FINAL ESM MAX SHAPE: (2560,)


Note, for our code, there were two fasta sequences that generated invalid embeddings, and as such were manually removed from the test.fa and corresponding class.fa files. 

#### Ankh (unused, kept for reference)

In [3]:
import sys
print(sys.executable)

/virtual/zengzix4/miniconda/envs/LABind/bin/python


In [None]:
# get ankh features
from transformers import AutoTokenizer, T5EncoderModel 
from Bio import SeqIO
tokenizer = AutoTokenizer.from_pretrained(pretrain_path['ankh_path'])
model     = T5EncoderModel.from_pretrained(pretrain_path['ankh_path'])
model.to(DEVICE)
model.eval()
out_path = f"{root_path}/{dataset}/ankh/"
makeDir(out_path)
# 使用biopython读取fasta文件
fasta_path = f"{root_path}/{dataset}/fasta/"
for file_class in os.listdir(fasta_path):
    for fasta_base_file in os.listdir(f"{fasta_path}/{file_class}"):
        fasta_file = f"{fasta_path}/{file_class}/{fasta_base_file}"
        sequences = SeqIO.parse(fasta_file, "fasta")
        for record in tqdm(sequences):
            if os.path.exists(out_path+f'{record.id}.npy'):
                continue
            ids = tokenizer.batch_encode_plus([list(record.seq)], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt")
            input_ids = ids['input_ids'].to(DEVICE)
            attention_mask = ids['attention_mask'].to(DEVICE)
            with torch.no_grad():
                embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
                emb = embedding_repr.last_hidden_state[0,:len(record.seq)].cpu().numpy()
                np.save(out_path+f'{record.id}.npy',emb)
del model
gc.collect()
torch.cuda.empty_cache()

#### DSSP

This was changed from the given notebook to parse through subfolders.

In [None]:
from Bio.PDB import PDBParser
from Bio.PDB.DSSP import DSSP

mapSS = {' ':[0,0,0,0,0,0,0,0,0],
        '-':[1,0,0,0,0,0,0,0,0],
        'H':[0,1,0,0,0,0,0,0,0],
        'B':[0,0,1,0,0,0,0,0,0],
        'E':[0,0,0,1,0,0,0,0,0],
        'G':[0,0,0,0,1,0,0,0,0],
        'I':[0,0,0,0,0,1,0,0,0],
        'P':[0,0,0,0,0,0,1,0,0],
        'T':[0,0,0,0,0,0,0,1,0],
        'S':[0,0,0,0,0,0,0,0,1]}
p = PDBParser(QUIET=True)
pdb_path = f"{root_path}/{dataset}/pdb/"
dssp_path = "../tools/mkdssp"
pdb_class = nn_config['pdb_class']
makeDir(f"{root_path}/{dataset}/{pdb_class}_dssp/")
test_files = os.listdir(pdb_path)
pdb_files = []

for dirpath, dirnames, filenames in os.walk(pdb_path):
    for f in filenames:
        if f.endswith('.pdb'):
            pdb_files.append(os.path.join(dirpath, f)) 
            
for pdb_file_name in tqdm(pdb_files, desc='DSSP running',ncols=80,unit='proteins'):
    rel = os.path.relpath(pdb_file_name, pdb_path)
    first_fol = rel.split(os.sep)[0]

    save_dir = os.path.join(pdb_path, f"{pdb_class}_dssp", first_fol)
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.basename(pdb_file_name).replace('.pdb','.npy')
    save_file = os.path.join(save_dir, filename)
    if os.path.exists(save_file):
        continue
    structure = p.get_structure("tmp", pdb_file_name)
    model = structure[0]
    try:
        dssp = DSSP(model, pdb_file_name, dssp=dssp_path)
        keys = list(dssp.keys())
    except:
        keys = []
    res_np = []
    for chain in model:
        for residue in chain:
            res_key = (chain.id,(' ', residue.id[1], residue.id[2]))
            if res_key in keys:
                tuple_dssp = dssp[res_key]
                res_np.append(mapSS[tuple_dssp[2]] + list(tuple_dssp[3:]))
            else:
                res_np.append(np.zeros(20))
    os.makedirs(os.path.dirname(save_file), exist_ok=True)
    np.save(save_file, np.array(res_np))

DSSP running: 100%|███████████████| 11121/11121 [1:22:02<00:00,  2.26proteins/s]


#### Position
Same as DSSP. Mostly the same as the original notebook, but changed for DS1. 

In [None]:
from Bio.PDB.ResidueDepth import get_surface
from scipy.spatial import cKDTree

pdb_path = f"{root_path}/{dataset}/pdb/"
msms_path = "../tools/msms"
pdb_class = nn_config['pdb_class']
makeDir(f"{root_path}/{dataset}/{pdb_class}_pos/")
pdb_files = []

for dirpath, dirnames, filenames in os.walk(pdb_path):
    for f in filenames:
        if f.endswith('.pdb'):
            pdb_files.append(os.path.join(dirpath, f)) 

for pdb_file in tqdm(pdb_files,desc='MSMS running',ncols=80,unit='proteins'):

    rel = os.path.relpath(pdb_file_name, pdb_path)
    first_fol = rel.split(os.sep)[0]

    save_dir = os.path.join(pdb_path, f"{pdb_class}_pos", first_fol)
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.basename(pdb_file_name).replace('.pdb','.npy')
    save_file = os.path.join(save_dir, filename)
    if os.path.exists(save_file):
        continue
    parser = PDBParser(QUIET=True)
    X = []
    chain_atom = ['N', 'CA', 'C', 'O']
    model = parser.get_structure('model', pdb_file)[0]
    chain = next(model.get_chains())
    try:
        surf = get_surface(chain,MSMS=msms_path)
        surf_tree = cKDTree(surf)
    except:
        surf = np.empty(0)
    for residue in chain:
        line = []
        atoms_coord = np.array([atom.get_coord() for atom in residue])
        if surf.size == 0:
            dist, _ = surf_tree.query(atoms_coord)
            closest_atom = np.argmin(dist)
            closest_pos = atoms_coord[closest_atom]
        else:
            closest_pos = atoms_coord[-1]
        atoms = list(residue.get_atoms())
        try:
            ca_pos = residue['CA'].get_coord()
        except KeyError:
            print(residue)
            continue

        pos_s = 0
        un_s = 0
        for atom in atoms:
            if atom.name in chain_atom:
                line.append(atom.get_coord())
            else:
                pos_s += calMass(atom,True)
                un_s += calMass(atom,False)
        # 此处line应该等于4
        if len(line) != 4:
            line = line + [list(ca_pos)]*(4-len(line))
        if un_s == 0:
            R_pos = ca_pos
        else:
            R_pos = pos_s / un_s
        line.append(R_pos)  
        line.append(closest_pos) # 加入最近点的残基信息
        X.append(line) 
    
    os.makedirs(os.path.dirname(save_file), exist_ok=True)

    np.save(save_file, X)

### Train

In [4]:
def valid(model, valid_list, fold_idx):
    model.to(DEVICE)
    model.eval()
    valid_data = readData.readData(
        name_list=valid_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=nn_config['train_file'], mode='train') # If 5-fold cross-validation is not used, it needs to be changed to valid_file.
    valid_loader = DataLoader(valid_data, batch_size=nn_config['batch_size'],shuffle=True, collate_fn=valid_data.collate_fn, num_workers=5)
    all_y_score = []
    all_y_true = []
    with torch.no_grad():
        for rfeat, ligand, xyz,  mask, y_true in valid_loader:
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            logits = model(rfeat, ligand, xyz,  mask).sigmoid() # [N]
            logits = torch.masked_select(logits, mask==1)
            y_true = torch.masked_select(y_true, mask==1)
            all_y_score.extend(logits.cpu().detach().numpy())
            all_y_true.extend(y_true.cpu().detach().numpy())
        # 通过aupr数值进行早停
        aupr_value = average_precision_score(all_y_true, all_y_score)
    return aupr_value

def train(train_list,valid_list=None,model=None,epochs=50,fold_idx=None):
    model.to(DEVICE)
    train_data = readData.readData(
        name_list=train_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=nn_config['train_file'], mode='train')
    train_loader = DataLoader(train_data, batch_size=nn_config['batch_size'],shuffle=True, collate_fn=train_data.collate_fn, num_workers=5)
    loss_fn = nn.BCELoss(reduction='none')
    optimizer = get_std_opt(len(train_list),nn_config['batch_size'], model.parameters(), nn_config['hidden_dim'], nn_config['lr'])
    v_max_aupr = 0
    patience = 0
    t_mccs = []
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    train_losses = []
    for epoch in range(epochs):
        all_loss = 0
        all_cnt = 0
        model.train()
        for rfeat, ligand, xyz,  mask, y_true in tqdm(train_loader):
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            optimizer.zero_grad()
            logits = model(rfeat, ligand, xyz, mask).sigmoid() # [N]
            # 计算所有离子的loss
            loss = loss_fn(logits, y_true) * mask
            loss = loss.sum() / mask.sum()
            all_loss += loss.item()
            all_cnt += 1
            loss.backward()
            # NEW
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        train_losses.append(all_loss / all_cnt)
        # 根据验证集的aupr进行早停
        if valid_list is not None:
            v_aupr = valid(model,valid_list, fold_idx)
            t_mccs.append(v_aupr)
            print(f'Epoch {epoch} Loss: {all_loss / all_cnt}', f'Epoch Valid {epoch} AUPR: {v_aupr}')
            if v_aupr > v_max_aupr:
                v_max_aupr = v_aupr
                patience = 0
                torch.save(model.state_dict(), f'{root_path}/Output/{dataset}_5fold/fold{fold_idx}.ckpt')
            else:
                patience += 1
            if patience >= nn_config['max_patience']:
                break


The following block was used for the first model, with a learning rate of 0.0004

In [None]:
setALlSeed(11)
# 5-fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state = 42)
data_list = readDataList(f'{root_path}/{dataset}/label/train/train.fa',skew=1)
makeDir(f'{root_path}/Output/{dataset}_5fold/')
fold_idx = 0
torch.cuda.empty_cache()
gc.collect()
for train_idx, valid_idx in kf.split(data_list):
    train_list = [data_list[i] for i in train_idx]
    valid_list = [data_list[j] for j in valid_idx]
    model = LABind(
    rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
    rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers'])
    train(train_list,valid_list,model,epochs=70,fold_idx=fold_idx)
    fold_idx += 1

This block was used for the second model, and the learning rate was decreased to 0.00002

In [None]:
from sklearn.model_selection import train_test_split

all_data = readDataList(f'{root_path}/{dataset}/label/train/train.fa', skew=1) 
torch.cuda.empty_cache()
gc.collect()

train_list, valid_list = train_test_split( ## no given validation in LigBind
    all_data,
    test_size=0.1,       # 10% validation
    random_state=42,
    shuffle=True
)
makeDir(f'{root_path}/Output/{dataset}/')
model = LABind(
rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers'])
train(train_list, valid_list, model, epochs=50, fold_idx=0)

In [None]:
# Determine the best threshold for MCC based on the validation set.

from collections import OrderedDict

new_state_dict = OrderedDict()

model_path = f'{root_path}/Output/{dataset}/' # if 5-fold cross-validation, {dataset}_5fold
print(model_path)
print(nn_config['pdb_class'])

models = []
for fold in range(5): # if 5-fold cross-validation, set to 5
    state_dict = torch.load(model_path + 'fold%s.ckpt'%fold,'cuda:0')
    model = LABind(
        rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
        rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers']).to(DEVICE)
    model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    model.load_state_dict(state_dict)
    model.eval()
    models.append(model)
    
valid_data = readData.readData(
    name_list=valid_list, 
    proj_dir=nn_config['proj_dir'], 
    lig_dict=nn_config['lig_dict'],
    true_file=f'{root_path}/{dataset}/label/train/train.fa', mode='train')
# 打印长度
valid_loader = DataLoader(valid_data, batch_size=nn_config['batch_size'], collate_fn=valid_data.collate_fn)
print(f'valid data length: {len(valid_data)}')
all_y_score = []
all_y_true = []
with torch.no_grad():
    for rfeat, ligand, xyz,  mask, y_true in valid_loader:
        tensors = [rfeat, ligand, xyz,  mask, y_true]
        tensors = [tensor.to(DEVICE) for tensor in tensors]
        rfeat, ligand, xyz, mask, y_true = tensors
        
        logits = [model(rfeat, ligand, xyz, mask).sigmoid() for model in models]
        logits = torch.stack(logits,0).mean(0)
        
        logits = torch.masked_select(logits, mask==1)
        y_true = torch.masked_select(y_true, mask==1)
        all_y_score.extend(logits.cpu().detach().numpy())
        all_y_true.extend(y_true.cpu().detach().numpy())

best_threshold,best_mcc,best_pred = getBestThreshold(all_y_true, all_y_score)
appendText(f'{model_path}/Best_Threshold.txt',f'{best_threshold} {best_mcc}\n')

### Test

Mostly unchanged from the original code. The checkpoints need to start with module., so a for loop was added to deal with that. 

In [None]:
import pandas as pd
from collections import OrderedDict

new_state_dict = OrderedDict()
model_path = f'{root_path}/Output/{dataset}/' # if 5-fold cross-validation, {dataset}_5fold
print(model_path)
print(nn_config['pdb_class'])

models = []
for fold in range(1): # if 5-fold cross-validation, set to 5
    state_dict = torch.load(model_path + 'fold%s.ckpt'%fold,'cuda:0')
    model = LABind(
        rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
        rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers']).to(DEVICE)
    model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    
    for k, v in state_dict.items():
        # if checkpoint lacks "module.", add it
        if not k.startswith("module."):
            new_state_dict["module." + k] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.eval()
    models.append(model)

df = pd.DataFrame(columns=['ligand','Rec','SPE','Acc','Pre','F1','MCC','AUC','AUPR'])
for ionic in os.listdir(f'{root_path}/{dataset}/label/test/'):
    ionic = ionic.split('.')[0]
    test_list = readDataList(f'{root_path}/{dataset}/label/test/{ionic}.fa',skew=1)
    test_data = readData.readData(
        name_list=test_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=f'{root_path}/{dataset}/label/test/{ionic}.fa', mode='test')
    # 打印长度
    test_loader = DataLoader(test_data, batch_size=nn_config['batch_size'], collate_fn=test_data.collate_fn)
    print(f'{ionic} test data length: {len(test_data)}')
    all_y_score = []
    all_y_true = []
    with torch.no_grad():
        for rfeat, ligand, xyz,  mask, y_true in test_loader:
            if rfeat == None:
                continue
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            
            logits = [model(rfeat, ligand, xyz, mask).sigmoid() for model in models]
            logits = torch.stack(logits,0).mean(0)
            
            logits = torch.masked_select(logits, mask==1)
            y_true = torch.masked_select(y_true, mask==1)
            all_y_score.extend(logits.cpu().detach().numpy())
            all_y_true.extend(y_true.cpu().detach().numpy())
    data_dict = calEval(all_y_true, all_y_score) # please set best threshold.
    data_dict['ligand'] = ionic
    df = pd.concat([df,pd.DataFrame(data_dict,index=[0])])

csv_file = os.path.join(model_path, 'test.csv')
df.to_csv(csv_file, index=False)
print(f'Saved CSV to {csv_file}')
df

/virtual/zengzix4/LABind_ESM/Output/LigBind_5fold1/
source
[('5jiwA', 'CO3'), ('5gkdA', 'CO3'), ('5l6qA', 'CO3'), ('5mmyB', 'CO3'), ('5n6yA', 'CO3'), ('5nnmA', 'CO3'), ('6cj0A', 'CO3'), ('6gudA', 'CO3'), ('5z9oA', 'CO3'), ('6mbaB', 'CO3'), ('6a1iA', 'CO3'), ('6rn5A', 'CO3'), ('6nw9A', 'CO3'), ('6imeB', 'CO3'), ('6j9yA', 'CO3'), ('6u1kA', 'CO3'), ('6laaA', 'CO3'), ('7c67A', 'CO3'), ('6wb6A', 'CO3'), ('7k5kA', 'CO3'), ('6w4qA', 'CO3')]
CO3 test data length: 21
[('5h41A', 'SO4'), ('5um2A', 'SO4'), ('5b7dB', 'SO4'), ('5lwlA', 'SO4'), ('5vr6B', 'SO4'), ('5lhvA', 'SO4'), ('5h7eA', 'SO4'), ('5oe9C', 'SO4'), ('5xu6C', 'SO4'), ('6apxA', 'SO4'), ('5o4qK', 'SO4'), ('5ysqB', 'SO4'), ('6bt2B', 'SO4'), ('6f9xA', 'SO4'), ('6ggyA', 'SO4'), ('5yeqA', 'SO4'), ('6b1wA', 'SO4'), ('6gd6A', 'SO4'), ('6dvhA', 'SO4'), ('6igqA', 'SO4'), ('6erbB', 'SO4'), ('5zxlD', 'SO4'), ('6hzzB', 'SO4'), ('6j0lA', 'SO4'), ('6d9qA', 'SO4'), ('6inzA', 'SO4'), ('6nlrA', 'SO4'), ('6s0rB', 'SO4'), ('6reoA', 'SO4')]
SO4 test data 

Unnamed: 0,ligand,Rec,SPE,Acc,Pre,F1,MCC,AUC,AUPR
0,CO3,0.131313,0.994292,0.984454,0.209677,0.161491,0.158379,0.799849,0.124177
0,SO4,0.271889,0.989414,0.973907,0.361963,0.310526,0.30067,0.870803,0.230318
0,MG,0.284739,0.99584,0.98976,0.371204,0.322273,0.320039,0.860939,0.225905
0,CU,0.73444,0.993455,0.989692,0.623239,0.674286,0.671413,0.972044,0.692981
0,HEM,0.671489,0.946945,0.928324,0.478522,0.558816,0.529869,0.934397,0.575567
0,ATP,0.645908,0.968019,0.959559,0.352627,0.456198,0.458792,0.946948,0.466874
0,FE2,0.788618,0.995654,0.993053,0.697842,0.740458,0.738366,0.995163,0.833888
0,K,0.047619,0.99671,0.97796,0.225806,0.078652,0.095772,0.770969,0.126766
0,,0.080402,0.993394,0.969542,0.246154,0.121212,0.127973,0.738352,0.11522
0,MN,0.66838,0.993785,0.989849,0.568306,0.614294,0.611243,0.955471,0.622761
