In [2]:
# 导入所需的库
import datetime
import os
import sys
import warnings
warnings.filterwarnings("ignore")

import torch
import pytorch_lightning as pl
from pytorch_lightning.trainer import Trainer
import pytorch_lightning.callbacks as plc
from model_interface import MInterface
from data_interface import DInterface
from src.tools.logger import SetupCallback, BestCheckpointCallback, BackupCodeCallback, TempFileCleanupCallback
from shutil import ignore_patterns
import pytorch_lightning.loggers as plog
import pickle

# 设置环境变量和工作目录
os.environ['CURL_CA_BUNDLE'] = ''
os.chdir(sys.path[0])


In [3]:
# 定义参数
args = {
    'res_dir': './results',
    'ex_name': 'debug',
    'check_val_every_n_epoch': 1,
    'dataset': 'PTM',
    'model_name': 'MeTokenMax', # model name here
    'lr': 1e-4,
    'lr_scheduler': 'onecycle',
    'offline': 1,
    'seed': 114514,
    'batch_size': 16,
    'num_workers': 16,
    'pad': 1024,
    'min_length': 40,
    'path': './data_inference', # data path here
    'with_null_ptm': 1,
    'epoch': 2,
    'augment_eps': 0.0,
    'module_type': 0,
    'loss_type': 'uni',
    'dis': 'uniform',
    'weight_type': 0,
    'gamma': 2.0,
    'final_tau': 1e-4,
    'pretrain': 0,
    'test_only': 1,
    'inference_pos':[[4,6],[35,56],[79,114]],
    'ckpt_from_deepspeed': 0,
    'ckpt_path': "/tancheng/caozx/caozx/ProteinInvBench/results/baseline_metoken_222/checkpoints/best.ckpt",
    'gpus': [0],
    'strategy': 'auto',
    'wandb_offline': 1
}

In [4]:
import json
from re import I
from Bio import PDB
import os
import h5py
import numpy as np

def extract_pdb_info_from_folder(folder_path, output_json_path):
    parser = PDB.PDBParser(QUIET=True)
    protein_list = []

    for filename in os.listdir(folder_path):
        if filename.endswith('.pdb'):
            pdb_file_path = os.path.join(folder_path, filename)
            structure = parser.get_structure(filename.replace('.pdb', ''), pdb_file_path)
            
            seq = ''
            coords_chain_A = {'N_chain_A': [], 'C_chain_A': [], 'CA_chain_A': [], 'O_chain_A': []}
            
            for model in structure:
                for chain in model:
                    if chain.id == 'A':  # Process only chain A
                        for residue in chain:
                            if PDB.is_aa(residue):
                                # Get the sequence
                                seq += PDB.Polypeptide.three_to_one(residue.resname)
                                # Get the coordinates of N, C, CA, and O atoms
                                for atom in residue:
                                    if atom.id == 'N':
                                        coords_chain_A['N_chain_A'].append(atom.coord.tolist())
                                    elif atom.id == 'C':
                                        coords_chain_A['C_chain_A'].append(atom.coord.tolist())
                                    elif atom.id == 'CA':
                                        coords_chain_A['CA_chain_A'].append(atom.coord.tolist())
                                    elif atom.id == 'O':
                                        coords_chain_A['O_chain_A'].append(atom.coord.tolist())
            
            # Extract ID from the file name
            pdb_id = filename.replace('.pdb', '')
            
            # Create the JSON object for this PDB file
            protein_data = {
                "id": pdb_id,
                "seq": seq,
                "coords_chain_A": coords_chain_A
            }
            
            # Add to the list
            protein_list.append(protein_data)
    
    # Write the JSON list to file
    with open(output_json_path, 'w') as json_file:
        json.dump(protein_list, json_file, indent=4)

def process_coords_chain_A(group, coords_chain_A):
    for key, value in coords_chain_A.items():
        # 将非空列表转换为numpy数组并存储
        array = np.array(value, dtype=np.float32)
        group.create_dataset(key, data=array)

def apply_ptm_indices(input_json_path, output_json_path, ptm_indices):
    # Read the input JSON file
    with open(input_json_path, 'r') as json_file:
        protein_list = json.load(json_file)

    for i,protein_data in enumerate(protein_list):
        seq_length = len(protein_data["seq"])
        
        if -1 in ptm_indices[i]:
            ptm = [1] * seq_length
        else:
            ptm = [0] * seq_length
            for index in ptm_indices[i]:
                if 0 <= index < seq_length:
                    ptm[index] = 1
        
        protein_data["ptm"] = ptm

    # Write the updated JSON list to file
    with open(output_json_path, 'w') as json_file:
        json.dump(protein_list, json_file, indent=4)

def dict_to_hdf5(group, item):
    for key, value in item.items():
        if key == 'coords_chain_A':
            coords_group = group.create_group(key)
            process_coords_chain_A(coords_group, value)
        elif isinstance(value, list):
            # 处理非空列表
            value = np.array(value)
            group.create_dataset(key, data=value)
        elif isinstance(value, (int, float)):
            value = np.array([value])
            group.create_dataset(key, data=value)
        elif isinstance(value, str):
            dt = h5py.special_dtype(vlen=str)
            value = np.array([value], dtype=dt)
            group.create_dataset(key, data=value)
        else:
            raise ValueError(f"不支持的数据类型: {type(value)}")

def json_to_hdf5(json_filepath, hdf5_filepath):
    with open(json_filepath, 'r') as json_file:
        data = json.load(json_file)
    
    with h5py.File(hdf5_filepath, 'w') as hdf5_file:
        for i, item in enumerate(data):
            group = hdf5_file.create_group(str(i))
            dict_to_hdf5(group, item)
            
# Example usage
pdb_file_path = args["path"]  # Replace with your PDB file path
output_json_path = args["path"]+'/predict.json'
output_path=args["path"]+"/predict.hdf5"
predict_indices=args["inference_pos"]
extract_pdb_info_from_folder(pdb_file_path, output_json_path)
apply_ptm_indices(output_json_path,output_json_path,predict_indices)
json_to_hdf5(output_json_path, output_path)
print("Data Prepared")

Data Prepared


In [5]:
# 设置回调函数
def load_callbacks(args):
    callbacks = []
    logdir = str(os.path.join(args['res_dir'], args['ex_name']))
    ckptdir = os.path.join(logdir, "checkpoints")
    callbacks.append(BackupCodeCallback(os.path.dirname(args['res_dir']),logdir, ignore_patterns=ignore_patterns('results*', 'pdb*', 'metadata*', 'vq_dataset*', 'bin*', 'data*', '__pycache__', 'info', 'lib', 'requirements', 'debug', 'wandb')))
    
    metric = "val_f1"
    early_stop_val = "max"
    sv_filename = 'best-{epoch:02d}-{val_f1:.3f}'
    callbacks.append(BestCheckpointCallback(
        monitor=metric,
        filename=sv_filename,
        save_top_k=15,
        mode='max',
        save_last=True,
        dirpath=ckptdir,
        verbose=True,
        every_n_epochs=args['check_val_every_n_epoch'],
    ))

    now = datetime.datetime.now().strftime("%m-%dT%H-%M-%S")
    cfgdir = os.path.join(logdir, "configs")
    callbacks.append(
        SetupCallback(
                now=now,
                logdir=logdir,
                ckptdir=ckptdir,
                cfgdir=cfgdir,
                config=args,
                argv_content=sys.argv + ["gpus: {}".format(torch.cuda.device_count())],)
    )

    callbacks.append(plc.EarlyStopping(monitor=metric, mode=early_stop_val, patience=20 if args['pretrain'] else 5))
    callbacks.append(TempFileCleanupCallback())
    return callbacks, ckptdir

# 设置种子
pl.seed_everything(args['seed'])

# 创建数据模块
data_module = DInterface(**args)
data_module.setup(stage="predict")
gpu_count = torch.cuda.device_count()

# 设置日志记录器
logger = plog.WandbLogger(project='PTM-MeToken', dir='./wandb/', name=args['ex_name'], offline=args['wandb_offline'], config=args)

# 加载回调函数
callbacks, ckptdir = load_callbacks(args)


Global seed set to 114514
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [6]:
trainer_config = {
    'gpus': args['gpus'] if args['ex_name'] != 'debug' else [0],
    'max_epochs': args['epoch'],
    'strategy': args['strategy'],
    'accelerator': 'gpu',
    'callbacks': callbacks,
    'logger': logger,
    'gradient_clip_val': 1.0,
    'resume_from_checkpoint': args['ckpt_path'] if args['test_only'] else None,
}

trainer = Trainer(**trainer_config)


# 训练和测试模型
model = MInterface.load_from_checkpoint(trainer_config["resume_from_checkpoint"], strict=False)
model.hparams["predict_indices"]=args["inference_pos"] # the line that comes with problem
result=trainer.predict(model,data_module)[0]
print(result)
with open("/tancheng/caozx/caozx/ProteinInvBench/prediction_sample.pkl","wb") as f:
    pickle.dump(result,f)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]
[{4: 24, 6: 1}, {35: 24, 56: 1}, {79: 24, 114: 1}]
