# This is an example of reproducing training results
Prior to this step, please execute `download_weights.py` to download the weights for all pre-trained models.

In [2]:
# import packages

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # <-- add this at the top

from torch.utils.data import DataLoader
import readData
import shutil
from utils import *
import torch
from torch import nn
from model import LABind
from func_help import setALlSeed,get_std_opt
from tqdm import tqdm
import pickle as pkl
from sklearn.model_selection import KFold
import gc

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# config
DEVICE = torch.device('cuda:0')
root_path = getRootPath()
dataset = 'LigBind' # DS1:LigBind, DS2:GPSite, DS3:Unseen

nn_config = {
    # dataset 
    'train_file': f'{root_path}/{dataset}/label/train/train.fa',
    'test_file': f'{root_path}/{dataset}/label/test/test.fa',
    'valid_file': f'{root_path}/{dataset}/label/picking.fa',
    'proj_dir': f'{root_path}/{dataset}',
    'lig_dict': pkl.load(open(f'{root_path}/tools/ligand.pkl', 'rb')),
    'pdb_class':'source', # source or omegafold or esmfold
    'dssp_max_repr': np.load(f'{root_path}/tools/dssp_max_repr.npy'),
    'dssp_min_repr': np.load(f'{root_path}/tools/dssp_min_repr.npy'),
    'ankh_max_repr': np.load(f'{root_path}/tools/ankh_max_repr.npy'),
    'ankh_min_repr': np.load(f'{root_path}/tools/ankh_min_repr.npy'),
    'esm2_max_repr': np.load(f'{root_path}/tools/esm2_max_repr.npy'),
    'esm2_min_repr': np.load(f'{root_path}/tools/esm2_min_repr.npy'),
    'ion_max_repr': np.load(f'{root_path}/tools/ion_max_repr.npy'),
    'ion_min_repr': np.load(f'{root_path}/tools/ion_min_repr.npy'),
    # model parameters
    
    'rfeat_dim':2580,
    'ligand_dim':768, 
    'hidden_dim':256, 
    'heads':4, 
    'augment_eps':0.05, 
    'rbf_num':8, 
    'top_k':30, 
    'attn_drop':0.1, 
    'dropout':0.1, 
    'num_layers':4, 
    'lr':0.0004, 
    
    # training parameters 
    # You can modify it according to the actual situation. 
    # Since it involves mapping the entire protein, it will consume a large amount of GPU memory.
    'batch_size':1,
    'max_patience':10,
    'device_ids':[0]
}
pretrain_path = { # Please modify 
    'esmfold_path': '../tools/esmfold_v1', # esmfold path
    'esm2_path': '../tools/esm2', 
    'ankh_path': '../tools/ankh-large/', # ankh path
    'molformer_path': '../tools/MoLFormer-XL-both-10pct/', # molformer path
    'model_path':f'{root_path}/model/LigBind/' # based on Unseen
}

In [4]:
print(nn_config.keys())

dict_keys(['train_file', 'test_file', 'valid_file', 'proj_dir', 'lig_dict', 'pdb_class', 'dssp_max_repr', 'dssp_min_repr', 'ankh_max_repr', 'ankh_min_repr', 'esm2_max_repr', 'esm2_min_repr', 'ion_max_repr', 'ion_min_repr', 'rfeat_dim', 'ligand_dim', 'hidden_dim', 'heads', 'augment_eps', 'rbf_num', 'top_k', 'attn_drop', 'dropout', 'num_layers', 'lr', 'batch_size', 'max_patience', 'device_ids'])


### Download dataset
Download the file from https://zenodo.org/records/13938443 and place it in the root directory.


### Retrieve the features

#### ESM2

In [22]:
import torch
import numpy as np
import os
from tqdm import tqdm
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModel 
from Bio import SeqIO
# 1️⃣ Delete all tensors and models
del model
gc.collect()

# 2️⃣ Empty PyTorch cache
torch.cuda.empty_cache()

# 3️⃣ Optional: reset CUDA memory stats
torch.cuda.reset_peak_memory_stats()
MODEL_DIR = "../tools/esm2"   # <-- directory containing HF files
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
FASTA_ROOT = "fasta"             # root directory containing subdirs of FASTA files
SAVE_ROOT = "embeddings"         # root output folder
out_path = f"{root_path}/{dataset}/esm/"
fasta_path = f"{root_path}/{dataset}/fasta/"

os.makedirs(out_path, exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained(pretrain_path['esm2_path'])
model     = AutoModel.from_pretrained(pretrain_path['esm2_path'])
model.to(DEVICE)

model.eval()
model.gradient_checkpointing_enable()
print("Model loaded on", DEVICE)
for file_class in os.listdir(fasta_path):
    class_path = os.path.join(fasta_path, file_class)

    for fasta_file in os.listdir(class_path):
        fasta_path_full = os.path.join(class_path, fasta_file)

        sequences = list(SeqIO.parse(fasta_path_full, "fasta"))

        print(f"\nProcessing {fasta_file} ({len(sequences)} sequences)")

        for record in tqdm(sequences):
            save_path = os.path.join(out_path, f"{record.id}.npy")

            # Skip if already processed
            if os.path.exists(save_path):
                continue

            seq = str(record.seq)

            # Tokenize for HF model
            encoded = tokenizer(
                seq,
                return_tensors="pt",
                padding=False,
                add_special_tokens=True
            )

            encoded = {k: v.to(DEVICE) for k, v in encoded.items()}

            # Run model on GPU
            with torch.no_grad():
                output = model(**encoded)
                hidden = output.last_hidden_state.cpu()      # shape: (1, L, D)
                embedding = hidden[0, 1:len(seq)+1].numpy()

            # Save
            np.save(save_path, embedding)
            torch.cuda.empty_cache()

        

# Cleanup
del model
gc.collect()
print("\n✓ Done extracting ESM embeddings!")


Loading checkpoint shards: 100%|██████████████████| 2/2 [00:00<00:00,  2.04it/s]
Some weights of EsmModel were not initialized from the model checkpoint at ../tools/esm2 and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on cuda

Processing CO3.fa (21 sequences)


100%|███████████████████████████████████████| 21/21 [00:00<00:00, 106120.94it/s]



Processing SO4.fa (31 sequences)


100%|███████████████████████████████████████| 31/31 [00:00<00:00, 132948.29it/s]



Processing MG.fa (665 sequences)


 83%|███████████████████████████████▌      | 553/665 [00:00<00:00, 26351.10it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 44.00 MiB. GPU 0 has a total capacty of 15.69 GiB of which 10.12 MiB is free. Including non-PyTorch memory, this process has 15.49 GiB memory in use. Of the allocated memory 15.14 GiB is allocated by PyTorch, and 79.01 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [40]:

from tqdm import tqdm

esm_dir = "/virtual/zengzix4/LABind_ESM/LigBind/esm3B"    # <-- CHANGE THIS
save_file = "esm2_repr_stats.npy"
save_dir = "/virtual/zengzix4/LABind_ESM/tools/"
# Use None so we can initialize on first file

all_min = None
all_max = None

for f in os.listdir(esm_dir):
    if not f.endswith(".npy"):
        continue

    fpath = os.path.join(esm_dir, f)

    try:
        arr = np.load(fpath, allow_pickle=True)
    except Exception as e:
        print("Skipping bad file:", f, "Error:", e)
        continue

    # Convert object arrays to real numpy if possible
    if isinstance(arr, np.ndarray) and arr.dtype == object:
        try:
            arr = np.vstack(arr)  # many ESM embeddings come as list of arrays
        except:
            print("Skipping malformed object array:", f)
            continue

    # Now check shape
    if arr.ndim != 2:
        print("Skipping wrong shape", f, arr.shape)
        continue

    # Compute per-dimension min/max
    fmin = arr.min(axis=0)
    fmax = arr.max(axis=0)

    if all_min is None:
        all_min = fmin
        all_max = fmax
    else:
        all_min = np.minimum(all_min, fmin)
        all_max = np.maximum(all_max, fmax)

print("FINAL ESM MIN SHAPE:", all_min.shape)
print("FINAL ESM MAX SHAPE:", all_max.shape)
np.save(os.path.join(save_dir, "esm2_min_repr.npy"), all_min)
np.save(os.path.join(save_dir, "esm2_max_repr.npy"), all_max)


Skipping bad file: 6s9fB.npy Error: Failed to interpret file '/virtual/zengzix4/LABind_ESM/LigBind/esm3B/6s9fB.npy' as a pickle
Skipping bad file: 6ptkA.npy Error: cannot reshape array of size 1138656 into shape (471,2560)
FINAL ESM MIN SHAPE: (2560,)
FINAL ESM MAX SHAPE: (2560,)


#### Ankh

In [3]:
import sys
print(sys.executable)

/virtual/zengzix4/miniconda/envs/LABind/bin/python


In [None]:
# get ankh features
from transformers import AutoTokenizer, T5EncoderModel 
from Bio import SeqIO
tokenizer = AutoTokenizer.from_pretrained(pretrain_path['ankh_path'])
model     = T5EncoderModel.from_pretrained(pretrain_path['ankh_path'])
model.to(DEVICE)
model.eval()
out_path = f"{root_path}/{dataset}/ankh/"
makeDir(out_path)
# 使用biopython读取fasta文件
fasta_path = f"{root_path}/{dataset}/fasta/"
for file_class in os.listdir(fasta_path):
    for fasta_base_file in os.listdir(f"{fasta_path}/{file_class}"):
        fasta_file = f"{fasta_path}/{file_class}/{fasta_base_file}"
        sequences = SeqIO.parse(fasta_file, "fasta")
        for record in tqdm(sequences):
            if os.path.exists(out_path+f'{record.id}.npy'):
                continue
            ids = tokenizer.batch_encode_plus([list(record.seq)], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt")
            input_ids = ids['input_ids'].to(DEVICE)
            attention_mask = ids['attention_mask'].to(DEVICE)
            with torch.no_grad():
                embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
                emb = embedding_repr.last_hidden_state[0,:len(record.seq)].cpu().numpy()
                np.save(out_path+f'{record.id}.npy',emb)
del model
gc.collect()
torch.cuda.empty_cache()

#### DSSP

In [29]:
from Bio.PDB import PDBParser
from Bio.PDB.DSSP import DSSP

mapSS = {' ':[0,0,0,0,0,0,0,0,0],
        '-':[1,0,0,0,0,0,0,0,0],
        'H':[0,1,0,0,0,0,0,0,0],
        'B':[0,0,1,0,0,0,0,0,0],
        'E':[0,0,0,1,0,0,0,0,0],
        'G':[0,0,0,0,1,0,0,0,0],
        'I':[0,0,0,0,0,1,0,0,0],
        'P':[0,0,0,0,0,0,1,0,0],
        'T':[0,0,0,0,0,0,0,1,0],
        'S':[0,0,0,0,0,0,0,0,1]}
p = PDBParser(QUIET=True)
pdb_path = f"{root_path}/{dataset}/pdb/"
dssp_path = "../tools/mkdssp"
pdb_class = nn_config['pdb_class']
makeDir(f"{root_path}/{dataset}/{pdb_class}_dssp/")
test_files = os.listdir(pdb_path)
pdb_files = []

for dirpath, dirnames, filenames in os.walk(pdb_path):
    for f in filenames:
        if f.endswith('.pdb'):
            pdb_files.append(os.path.join(dirpath, f)) 
            
for pdb_file_name in tqdm(pdb_files, desc='DSSP running',ncols=80,unit='proteins'):
    pdb_file = pdb_file_name
    save_file = pdb_file.replace('.pdb','.npy').replace('pdb',f'{pdb_class}_dssp')
    if os.path.exists(save_file):
        continue
    structure = p.get_structure("tmp", pdb_file)
    model = structure[0]
    try:
        dssp = DSSP(model, pdb_file, dssp=dssp_path)
        keys = list(dssp.keys())
    except:
        keys = []
    res_np = []
    for chain in model:
        for residue in chain:
            res_key = (chain.id,(' ', residue.id[1], residue.id[2]))
            if res_key in keys:
                tuple_dssp = dssp[res_key]
                res_np.append(mapSS[tuple_dssp[2]] + list(tuple_dssp[3:]))
            else:
                res_np.append(np.zeros(20))
    os.makedirs(os.path.dirname(save_file), exist_ok=True)
    np.save(save_file, np.array(res_np))

DSSP running: 100%|███████████████| 11121/11121 [1:22:02<00:00,  2.26proteins/s]


#### Position

In [None]:
from Bio.PDB.ResidueDepth import get_surface
from scipy.spatial import cKDTree

pdb_path = f"{root_path}/{dataset}/pdb/"
msms_path = "../tools/msms"
pdb_class = nn_config['pdb_class']
makeDir(f"{root_path}/{dataset}/{pdb_class}_pos/")
pdb_files = []

for dirpath, dirnames, filenames in os.walk(pdb_path):
    for f in filenames:
        if f.endswith('.pdb'):
            pdb_files.append(os.path.join(dirpath, f)) 

for pdb_file in tqdm(pdb_files,desc='MSMS running',ncols=80,unit='proteins'):
    save_file = pdb_file.replace('.pdb','.npy').replace('pdb',f'{pdb_class}_pos')
    if os.path.exists(save_file):
        continue
    parser = PDBParser(QUIET=True)
    X = []
    chain_atom = ['N', 'CA', 'C', 'O']
    model = parser.get_structure('model', pdb_file)[0]
    chain = next(model.get_chains())
    try:
        surf = get_surface(chain,MSMS=msms_path)
        surf_tree = cKDTree(surf)
    except:
        surf = np.empty(0)
    for residue in chain:
        line = []
        atoms_coord = np.array([atom.get_coord() for atom in residue])
        if surf.size == 0:
            dist, _ = surf_tree.query(atoms_coord)
            closest_atom = np.argmin(dist)
            closest_pos = atoms_coord[closest_atom]
        else:
            closest_pos = atoms_coord[-1]
        atoms = list(residue.get_atoms())
        try:
            ca_pos = residue['CA'].get_coord()
        except KeyError:
            print(residue)
            continue

        pos_s = 0
        un_s = 0
        for atom in atoms:
            if atom.name in chain_atom:
                line.append(atom.get_coord())
            else:
                pos_s += calMass(atom,True)
                un_s += calMass(atom,False)
        # 此处line应该等于4
        if len(line) != 4:
            line = line + [list(ca_pos)]*(4-len(line))
        if un_s == 0:
            R_pos = ca_pos
        else:
            R_pos = pos_s / un_s
        line.append(R_pos)  
        line.append(closest_pos) # 加入最近点的残基信息
        X.append(line) 
    
    os.makedirs(os.path.dirname(save_file), exist_ok=True)

    np.save(save_file, X)

### Train

In [8]:
def valid(model, valid_list, fold_idx):
    model.to(DEVICE)
    model.eval()
    valid_data = readData.readData(
        name_list=valid_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=nn_config['train_file'], mode='train') # If 5-fold cross-validation is not used, it needs to be changed to valid_file.
    valid_loader = DataLoader(valid_data, batch_size=nn_config['batch_size'],shuffle=True, collate_fn=valid_data.collate_fn, num_workers=5)
    all_y_score = []
    all_y_true = []
    with torch.no_grad():
        for rfeat, ligand, xyz,  mask, y_true in valid_loader:
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            logits = model(rfeat, ligand, xyz,  mask).sigmoid() # [N]
            logits = torch.masked_select(logits, mask==1)
            y_true = torch.masked_select(y_true, mask==1)
            all_y_score.extend(logits.cpu().detach().numpy())
            all_y_true.extend(y_true.cpu().detach().numpy())
        # 通过aupr数值进行早停
        aupr_value = average_precision_score(all_y_true, all_y_score)
    return aupr_value

def train(train_list,valid_list=None,model=None,epochs=50,fold_idx=None):
    model.to(DEVICE)
    train_data = readData.readData(
        name_list=train_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=nn_config['train_file'], mode='train')
    train_loader = DataLoader(train_data, batch_size=nn_config['batch_size'],shuffle=True, collate_fn=train_data.collate_fn, num_workers=5)
    loss_fn = nn.BCELoss(reduction='none')
    optimizer = get_std_opt(len(train_list),nn_config['batch_size'], model.parameters(), nn_config['hidden_dim'], nn_config['lr'])
    v_max_aupr = 0
    patience = 0
    t_mccs = []
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    train_losses = []
    for epoch in range(epochs):
        all_loss = 0
        all_cnt = 0
        model.train()
        for rfeat, ligand, xyz,  mask, y_true in tqdm(train_loader):
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            optimizer.zero_grad()
            logits = model(rfeat, ligand, xyz, mask).sigmoid() # [N]
            # 计算所有离子的loss
            loss = loss_fn(logits, y_true) * mask
            loss = loss.sum() / mask.sum()
            all_loss += loss.item()
            all_cnt += 1
            loss.backward()
            # NEW
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        train_losses.append(all_loss / all_cnt)
        # 根据验证集的aupr进行早停
        if valid_list is not None:
            v_aupr = valid(model,valid_list, fold_idx)
            t_mccs.append(v_aupr)
            print(f'Epoch {epoch} Loss: {all_loss / all_cnt}', f'Epoch Valid {epoch} AUPR: {v_aupr}')
            if v_aupr > v_max_aupr:
                v_max_aupr = v_aupr
                patience = 0
                torch.save(model.state_dict(), f'{root_path}/Output/{dataset}_5fold/fold{fold_idx}.ckpt')
            else:
                patience += 1
            if patience >= nn_config['max_patience']:
                break


In [None]:
setALlSeed(11)
# 5-fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state = 42)
data_list = readDataList(f'{root_path}/{dataset}/label/train/train.fa',skew=1)
makeDir(f'{root_path}/Output/{dataset}_5fold/')
fold_idx = 0
torch.cuda.empty_cache()
gc.collect()
for train_idx, valid_idx in kf.split(data_list):
    train_list = [data_list[i] for i in train_idx]
    valid_list = [data_list[j] for j in valid_idx]
    model = LABind(
    rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
    rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers'])
    train(train_list,valid_list,model,epochs=70,fold_idx=fold_idx)
    fold_idx += 1

In [None]:
from sklearn.model_selection import train_test_split

all_data = readDataList(f'{root_path}/{dataset}/label/train/train.fa', skew=1) 
torch.cuda.empty_cache()
gc.collect()

train_list, valid_list = train_test_split( ## no given validation in LigBind
    all_data,
    test_size=0.1,       # 10% validation
    random_state=42,
    shuffle=True
)
makeDir(f'{root_path}/Output/{dataset}/')
model = LABind(
rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers'])
train(train_list, valid_list, model, epochs=50, fold_idx=0)

100%|███████████████████████████████████████| 7452/7452 [06:18<00:00, 19.69it/s]


Epoch 0 Loss: 0.10398311041835676 Epoch Valid 0 AUPR: 0.35387015061608573


100%|███████████████████████████████████████| 7452/7452 [06:19<00:00, 19.65it/s]


Epoch 1 Loss: 0.0786849736162089 Epoch Valid 1 AUPR: 0.40323699306164873


100%|███████████████████████████████████████| 7452/7452 [06:19<00:00, 19.63it/s]


Epoch 2 Loss: 0.07507787457394156 Epoch Valid 2 AUPR: 0.4754686796474941


  4%|█▌                                      | 286/7452 [00:14<06:04, 19.67it/s]

In [6]:
# Determine the best threshold for MCC based on the validation set.
model_path = f'{root_path}/Output/{dataset}_5fold/' # if 5-fold cross-validation, {dataset}_5fold
print(model_path)
print(nn_config['pdb_class'])

models = []
for fold in range(5): # if 5-fold cross-validation, set to 5
    state_dict = torch.load(model_path + 'fold%s.ckpt'%fold,'cuda:0')
    model = LABind(
        rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
        rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers']).to(DEVICE)
    model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    model.load_state_dict(state_dict)
    model.eval()
    models.append(model)
    
valid_list = readDataList(f'{root_path}/{dataset}/label/picking.fa',skew=1)
valid_data = readData.readData(
    name_list=valid_list, 
    proj_dir=nn_config['proj_dir'], 
    lig_dict=nn_config['lig_dict'],
    true_file=f'{root_path}/{dataset}/label/picking.fa', mode='train')
# 打印长度
valid_loader = DataLoader(valid_data, batch_size=nn_config['batch_size'], collate_fn=valid_data.collate_fn)
print(f'valid data length: {len(valid_data)}')
all_y_score = []
all_y_true = []
with torch.no_grad():
    for rfeat, ligand, xyz,  mask, y_true in valid_loader:
        tensors = [rfeat, ligand, xyz,  mask, y_true]
        tensors = [tensor.to(DEVICE) for tensor in tensors]
        rfeat, ligand, xyz, mask, y_true = tensors
        
        logits = [model(rfeat, ligand, xyz, mask).sigmoid() for model in models]
        logits = torch.stack(logits,0).mean(0)
        
        logits = torch.masked_select(logits, mask==1)
        y_true = torch.masked_select(y_true, mask==1)
        all_y_score.extend(logits.cpu().detach().numpy())
        all_y_true.extend(y_true.cpu().detach().numpy())

best_threshold,best_mcc,best_pred = getBestThreshold(all_y_true, all_y_score)
appendText(f'{model_path}/Best_Threshold.txt',f'{best_threshold} {best_mcc}\n')

/virtual/zengzix4/LABind_ESM/Output/LigBind_5fold/
source


RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.in_mlp.fc1.weight", "module.in_mlp.fc1.bias", "module.in_mlp.fc2.weight", "module.in_mlp.fc2.bias", "module.in_mlp.ln1.weight", "module.in_mlp.ln1.bias", "module.in_mlp.ln2.weight", "module.in_mlp.ln2.bias", "module.lig_mlp.fc1.0.weight", "module.lig_mlp.fc1.0.bias", "module.lig_mlp.fc1.1.weight", "module.lig_mlp.fc1.1.bias", "module.edge_feature.edge_emb.weight", "module.edge_feature.edge_emb.bias", "module.edge_feature.norm_edge.weight", "module.edge_feature.norm_edge.bias", "module.edge_feature.node_emb.weight", "module.edge_feature.node_emb.bias", "module.edge_feature.norm_node.weight", "module.edge_feature.norm_node.bias", "module.f_mlp.0.weight", "module.f_mlp.0.bias", "module.f_mlp.2.weight", "module.f_mlp.2.bias", "module.f_mlp.4.weight", "module.f_mlp.4.bias", "module.conv_layers.0.norm.0.weight", "module.conv_layers.0.norm.0.bias", "module.conv_layers.0.norm.1.weight", "module.conv_layers.0.norm.1.bias", "module.conv_layers.0.norm.2.weight", "module.conv_layers.0.norm.2.bias", "module.conv_layers.0.CrossAttn.query.weight", "module.conv_layers.0.CrossAttn.query.bias", "module.conv_layers.0.CrossAttn.key.weight", "module.conv_layers.0.CrossAttn.key.bias", "module.conv_layers.0.CrossAttn.value.weight", "module.conv_layers.0.CrossAttn.value.bias", "module.conv_layers.0.CrossAttn.out.weight", "module.conv_layers.0.CrossAttn.out.bias", "module.conv_layers.0.NeighAttn.W_Q.weight", "module.conv_layers.0.NeighAttn.W_K.weight", "module.conv_layers.0.NeighAttn.W_V.weight", "module.conv_layers.0.NeighAttn.W_O.weight", "module.conv_layers.0.dense.linear1.weight", "module.conv_layers.0.dense.linear1.bias", "module.conv_layers.0.dense.linear2.weight", "module.conv_layers.0.dense.linear2.bias", "module.conv_layers.0.edge_update.norm.weight", "module.conv_layers.0.edge_update.norm.bias", "module.conv_layers.0.edge_update.EdgeMLP.0.weight", "module.conv_layers.0.edge_update.EdgeMLP.0.bias", "module.conv_layers.0.edge_update.EdgeMLP.2.weight", "module.conv_layers.0.edge_update.EdgeMLP.2.bias", "module.conv_layers.0.context.ContextMLP.0.weight", "module.conv_layers.0.context.ContextMLP.0.bias", "module.conv_layers.0.context.ContextMLP.2.weight", "module.conv_layers.0.context.ContextMLP.2.bias", "module.conv_layers.0.context.norm.weight", "module.conv_layers.0.context.norm.bias", "module.conv_layers.1.norm.0.weight", "module.conv_layers.1.norm.0.bias", "module.conv_layers.1.norm.1.weight", "module.conv_layers.1.norm.1.bias", "module.conv_layers.1.norm.2.weight", "module.conv_layers.1.norm.2.bias", "module.conv_layers.1.CrossAttn.query.weight", "module.conv_layers.1.CrossAttn.query.bias", "module.conv_layers.1.CrossAttn.key.weight", "module.conv_layers.1.CrossAttn.key.bias", "module.conv_layers.1.CrossAttn.value.weight", "module.conv_layers.1.CrossAttn.value.bias", "module.conv_layers.1.CrossAttn.out.weight", "module.conv_layers.1.CrossAttn.out.bias", "module.conv_layers.1.NeighAttn.W_Q.weight", "module.conv_layers.1.NeighAttn.W_K.weight", "module.conv_layers.1.NeighAttn.W_V.weight", "module.conv_layers.1.NeighAttn.W_O.weight", "module.conv_layers.1.dense.linear1.weight", "module.conv_layers.1.dense.linear1.bias", "module.conv_layers.1.dense.linear2.weight", "module.conv_layers.1.dense.linear2.bias", "module.conv_layers.1.edge_update.norm.weight", "module.conv_layers.1.edge_update.norm.bias", "module.conv_layers.1.edge_update.EdgeMLP.0.weight", "module.conv_layers.1.edge_update.EdgeMLP.0.bias", "module.conv_layers.1.edge_update.EdgeMLP.2.weight", "module.conv_layers.1.edge_update.EdgeMLP.2.bias", "module.conv_layers.1.context.ContextMLP.0.weight", "module.conv_layers.1.context.ContextMLP.0.bias", "module.conv_layers.1.context.ContextMLP.2.weight", "module.conv_layers.1.context.ContextMLP.2.bias", "module.conv_layers.1.context.norm.weight", "module.conv_layers.1.context.norm.bias", "module.conv_layers.2.norm.0.weight", "module.conv_layers.2.norm.0.bias", "module.conv_layers.2.norm.1.weight", "module.conv_layers.2.norm.1.bias", "module.conv_layers.2.norm.2.weight", "module.conv_layers.2.norm.2.bias", "module.conv_layers.2.CrossAttn.query.weight", "module.conv_layers.2.CrossAttn.query.bias", "module.conv_layers.2.CrossAttn.key.weight", "module.conv_layers.2.CrossAttn.key.bias", "module.conv_layers.2.CrossAttn.value.weight", "module.conv_layers.2.CrossAttn.value.bias", "module.conv_layers.2.CrossAttn.out.weight", "module.conv_layers.2.CrossAttn.out.bias", "module.conv_layers.2.NeighAttn.W_Q.weight", "module.conv_layers.2.NeighAttn.W_K.weight", "module.conv_layers.2.NeighAttn.W_V.weight", "module.conv_layers.2.NeighAttn.W_O.weight", "module.conv_layers.2.dense.linear1.weight", "module.conv_layers.2.dense.linear1.bias", "module.conv_layers.2.dense.linear2.weight", "module.conv_layers.2.dense.linear2.bias", "module.conv_layers.2.edge_update.norm.weight", "module.conv_layers.2.edge_update.norm.bias", "module.conv_layers.2.edge_update.EdgeMLP.0.weight", "module.conv_layers.2.edge_update.EdgeMLP.0.bias", "module.conv_layers.2.edge_update.EdgeMLP.2.weight", "module.conv_layers.2.edge_update.EdgeMLP.2.bias", "module.conv_layers.2.context.ContextMLP.0.weight", "module.conv_layers.2.context.ContextMLP.0.bias", "module.conv_layers.2.context.ContextMLP.2.weight", "module.conv_layers.2.context.ContextMLP.2.bias", "module.conv_layers.2.context.norm.weight", "module.conv_layers.2.context.norm.bias", "module.conv_layers.3.norm.0.weight", "module.conv_layers.3.norm.0.bias", "module.conv_layers.3.norm.1.weight", "module.conv_layers.3.norm.1.bias", "module.conv_layers.3.norm.2.weight", "module.conv_layers.3.norm.2.bias", "module.conv_layers.3.CrossAttn.query.weight", "module.conv_layers.3.CrossAttn.query.bias", "module.conv_layers.3.CrossAttn.key.weight", "module.conv_layers.3.CrossAttn.key.bias", "module.conv_layers.3.CrossAttn.value.weight", "module.conv_layers.3.CrossAttn.value.bias", "module.conv_layers.3.CrossAttn.out.weight", "module.conv_layers.3.CrossAttn.out.bias", "module.conv_layers.3.NeighAttn.W_Q.weight", "module.conv_layers.3.NeighAttn.W_K.weight", "module.conv_layers.3.NeighAttn.W_V.weight", "module.conv_layers.3.NeighAttn.W_O.weight", "module.conv_layers.3.dense.linear1.weight", "module.conv_layers.3.dense.linear1.bias", "module.conv_layers.3.dense.linear2.weight", "module.conv_layers.3.dense.linear2.bias", "module.conv_layers.3.edge_update.norm.weight", "module.conv_layers.3.edge_update.norm.bias", "module.conv_layers.3.edge_update.EdgeMLP.0.weight", "module.conv_layers.3.edge_update.EdgeMLP.0.bias", "module.conv_layers.3.edge_update.EdgeMLP.2.weight", "module.conv_layers.3.edge_update.EdgeMLP.2.bias", "module.conv_layers.3.context.ContextMLP.0.weight", "module.conv_layers.3.context.ContextMLP.0.bias", "module.conv_layers.3.context.ContextMLP.2.weight", "module.conv_layers.3.context.ContextMLP.2.bias", "module.conv_layers.3.context.norm.weight", "module.conv_layers.3.context.norm.bias", "module.out_mlp.0.weight", "module.out_mlp.0.bias", "module.out_mlp.2.weight", "module.out_mlp.2.bias". 
	Unexpected key(s) in state_dict: "in_mlp.fc1.weight", "in_mlp.fc1.bias", "in_mlp.fc2.weight", "in_mlp.fc2.bias", "in_mlp.ln1.weight", "in_mlp.ln1.bias", "in_mlp.ln2.weight", "in_mlp.ln2.bias", "lig_mlp.fc1.0.weight", "lig_mlp.fc1.0.bias", "lig_mlp.fc1.1.weight", "lig_mlp.fc1.1.bias", "edge_feature.edge_emb.weight", "edge_feature.edge_emb.bias", "edge_feature.norm_edge.weight", "edge_feature.norm_edge.bias", "edge_feature.node_emb.weight", "edge_feature.node_emb.bias", "edge_feature.norm_node.weight", "edge_feature.norm_node.bias", "f_mlp.0.weight", "f_mlp.0.bias", "f_mlp.2.weight", "f_mlp.2.bias", "f_mlp.4.weight", "f_mlp.4.bias", "conv_layers.0.norm.0.weight", "conv_layers.0.norm.0.bias", "conv_layers.0.norm.1.weight", "conv_layers.0.norm.1.bias", "conv_layers.0.norm.2.weight", "conv_layers.0.norm.2.bias", "conv_layers.0.CrossAttn.query.weight", "conv_layers.0.CrossAttn.query.bias", "conv_layers.0.CrossAttn.key.weight", "conv_layers.0.CrossAttn.key.bias", "conv_layers.0.CrossAttn.value.weight", "conv_layers.0.CrossAttn.value.bias", "conv_layers.0.CrossAttn.out.weight", "conv_layers.0.CrossAttn.out.bias", "conv_layers.0.NeighAttn.W_Q.weight", "conv_layers.0.NeighAttn.W_K.weight", "conv_layers.0.NeighAttn.W_V.weight", "conv_layers.0.NeighAttn.W_O.weight", "conv_layers.0.dense.linear1.weight", "conv_layers.0.dense.linear1.bias", "conv_layers.0.dense.linear2.weight", "conv_layers.0.dense.linear2.bias", "conv_layers.0.edge_update.norm.weight", "conv_layers.0.edge_update.norm.bias", "conv_layers.0.edge_update.EdgeMLP.0.weight", "conv_layers.0.edge_update.EdgeMLP.0.bias", "conv_layers.0.edge_update.EdgeMLP.2.weight", "conv_layers.0.edge_update.EdgeMLP.2.bias", "conv_layers.0.context.ContextMLP.0.weight", "conv_layers.0.context.ContextMLP.0.bias", "conv_layers.0.context.ContextMLP.2.weight", "conv_layers.0.context.ContextMLP.2.bias", "conv_layers.0.context.norm.weight", "conv_layers.0.context.norm.bias", "conv_layers.1.norm.0.weight", "conv_layers.1.norm.0.bias", "conv_layers.1.norm.1.weight", "conv_layers.1.norm.1.bias", "conv_layers.1.norm.2.weight", "conv_layers.1.norm.2.bias", "conv_layers.1.CrossAttn.query.weight", "conv_layers.1.CrossAttn.query.bias", "conv_layers.1.CrossAttn.key.weight", "conv_layers.1.CrossAttn.key.bias", "conv_layers.1.CrossAttn.value.weight", "conv_layers.1.CrossAttn.value.bias", "conv_layers.1.CrossAttn.out.weight", "conv_layers.1.CrossAttn.out.bias", "conv_layers.1.NeighAttn.W_Q.weight", "conv_layers.1.NeighAttn.W_K.weight", "conv_layers.1.NeighAttn.W_V.weight", "conv_layers.1.NeighAttn.W_O.weight", "conv_layers.1.dense.linear1.weight", "conv_layers.1.dense.linear1.bias", "conv_layers.1.dense.linear2.weight", "conv_layers.1.dense.linear2.bias", "conv_layers.1.edge_update.norm.weight", "conv_layers.1.edge_update.norm.bias", "conv_layers.1.edge_update.EdgeMLP.0.weight", "conv_layers.1.edge_update.EdgeMLP.0.bias", "conv_layers.1.edge_update.EdgeMLP.2.weight", "conv_layers.1.edge_update.EdgeMLP.2.bias", "conv_layers.1.context.ContextMLP.0.weight", "conv_layers.1.context.ContextMLP.0.bias", "conv_layers.1.context.ContextMLP.2.weight", "conv_layers.1.context.ContextMLP.2.bias", "conv_layers.1.context.norm.weight", "conv_layers.1.context.norm.bias", "conv_layers.2.norm.0.weight", "conv_layers.2.norm.0.bias", "conv_layers.2.norm.1.weight", "conv_layers.2.norm.1.bias", "conv_layers.2.norm.2.weight", "conv_layers.2.norm.2.bias", "conv_layers.2.CrossAttn.query.weight", "conv_layers.2.CrossAttn.query.bias", "conv_layers.2.CrossAttn.key.weight", "conv_layers.2.CrossAttn.key.bias", "conv_layers.2.CrossAttn.value.weight", "conv_layers.2.CrossAttn.value.bias", "conv_layers.2.CrossAttn.out.weight", "conv_layers.2.CrossAttn.out.bias", "conv_layers.2.NeighAttn.W_Q.weight", "conv_layers.2.NeighAttn.W_K.weight", "conv_layers.2.NeighAttn.W_V.weight", "conv_layers.2.NeighAttn.W_O.weight", "conv_layers.2.dense.linear1.weight", "conv_layers.2.dense.linear1.bias", "conv_layers.2.dense.linear2.weight", "conv_layers.2.dense.linear2.bias", "conv_layers.2.edge_update.norm.weight", "conv_layers.2.edge_update.norm.bias", "conv_layers.2.edge_update.EdgeMLP.0.weight", "conv_layers.2.edge_update.EdgeMLP.0.bias", "conv_layers.2.edge_update.EdgeMLP.2.weight", "conv_layers.2.edge_update.EdgeMLP.2.bias", "conv_layers.2.context.ContextMLP.0.weight", "conv_layers.2.context.ContextMLP.0.bias", "conv_layers.2.context.ContextMLP.2.weight", "conv_layers.2.context.ContextMLP.2.bias", "conv_layers.2.context.norm.weight", "conv_layers.2.context.norm.bias", "conv_layers.3.norm.0.weight", "conv_layers.3.norm.0.bias", "conv_layers.3.norm.1.weight", "conv_layers.3.norm.1.bias", "conv_layers.3.norm.2.weight", "conv_layers.3.norm.2.bias", "conv_layers.3.CrossAttn.query.weight", "conv_layers.3.CrossAttn.query.bias", "conv_layers.3.CrossAttn.key.weight", "conv_layers.3.CrossAttn.key.bias", "conv_layers.3.CrossAttn.value.weight", "conv_layers.3.CrossAttn.value.bias", "conv_layers.3.CrossAttn.out.weight", "conv_layers.3.CrossAttn.out.bias", "conv_layers.3.NeighAttn.W_Q.weight", "conv_layers.3.NeighAttn.W_K.weight", "conv_layers.3.NeighAttn.W_V.weight", "conv_layers.3.NeighAttn.W_O.weight", "conv_layers.3.dense.linear1.weight", "conv_layers.3.dense.linear1.bias", "conv_layers.3.dense.linear2.weight", "conv_layers.3.dense.linear2.bias", "conv_layers.3.edge_update.norm.weight", "conv_layers.3.edge_update.norm.bias", "conv_layers.3.edge_update.EdgeMLP.0.weight", "conv_layers.3.edge_update.EdgeMLP.0.bias", "conv_layers.3.edge_update.EdgeMLP.2.weight", "conv_layers.3.edge_update.EdgeMLP.2.bias", "conv_layers.3.context.ContextMLP.0.weight", "conv_layers.3.context.ContextMLP.0.bias", "conv_layers.3.context.ContextMLP.2.weight", "conv_layers.3.context.ContextMLP.2.bias", "conv_layers.3.context.norm.weight", "conv_layers.3.context.norm.bias", "out_mlp.0.weight", "out_mlp.0.bias", "out_mlp.2.weight", "out_mlp.2.bias". 

### Test

In [7]:
import pandas as pd
from collections import OrderedDict

new_state_dict = OrderedDict()
model_path = f'{root_path}/Output/{dataset}_5fold1/' # if 5-fold cross-validation, {dataset}_5fold
print(model_path)
print(nn_config['pdb_class'])

models = []
for fold in range(5): # if 5-fold cross-validation, set to 5
    state_dict = torch.load(model_path + 'fold%s.ckpt'%fold,'cuda:0')
    model = LABind(
        rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
        rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers']).to(DEVICE)
    model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    
    for k, v in state_dict.items():
        # if checkpoint lacks "module.", add it
        if not k.startswith("module."):
            new_state_dict["module." + k] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.eval()
    models.append(model)

df = pd.DataFrame(columns=['ligand','Rec','SPE','Acc','Pre','F1','MCC','AUC','AUPR'])
for ionic in os.listdir(f'{root_path}/{dataset}/label/test/'):
    ionic = ionic.split('.')[0]
    test_list = readDataList(f'{root_path}/{dataset}/label/test/{ionic}.fa',skew=1)
    test_data = readData.readData(
        name_list=test_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=f'{root_path}/{dataset}/label/test/{ionic}.fa', mode='test')
    # 打印长度
    test_loader = DataLoader(test_data, batch_size=nn_config['batch_size'], collate_fn=test_data.collate_fn)
    print(f'{ionic} test data length: {len(test_data)}')
    all_y_score = []
    all_y_true = []
    with torch.no_grad():
        for rfeat, ligand, xyz,  mask, y_true in test_loader:
            if rfeat == None:
                continue
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            
            logits = [model(rfeat, ligand, xyz, mask).sigmoid() for model in models]
            logits = torch.stack(logits,0).mean(0)
            
            logits = torch.masked_select(logits, mask==1)
            y_true = torch.masked_select(y_true, mask==1)
            all_y_score.extend(logits.cpu().detach().numpy())
            all_y_true.extend(y_true.cpu().detach().numpy())
    data_dict = calEval(all_y_true, all_y_score) # please set best threshold.
    data_dict['ligand'] = ionic
    df = pd.concat([df,pd.DataFrame(data_dict,index=[0])])
df.to_csv(f'{model_path}test.csv',index=False)
df

/virtual/zengzix4/LABind_ESM/Output/LigBind_5fold1/
source
CO3 test data length: 21
SO4 test data length: 31
Corrupt or mismatched file: 6ptkA


AttributeError: 'NoneType' object has no attribute 'shape'

In [None]:
training done, no validation though
also 