In [21]:
import pandas as pd
import prody
import os
import sys
import MDAnalysis as mda
from functools import partial 
from multiprocessing import Pool
import numpy as np
import torch
import multiprocessing
from rdkit import Chem
from rdkit import RDLogger
from torch.utils.data import Dataset
from torch_geometric.data import HeteroData
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
RDLogger.DisableLog("rdApp.*")
# dir of current
from utils.fns import load_graph, save_graph
from dataset.protein_feature import get_protein_feature_mda
from dataset.loop_feature import get_loop_feature_strict, get_loop_feature_strict_rot

def get_filenames_no_ext(folder_path):
    """
    获取指定文件夹下所有文件的去后缀名称列表。
    Args:
        folder_path (str): 文件夹路径
    Returns:
        list: 文件名（不含后缀）的列表
    """
    filenames = []
    for fname in os.listdir(folder_path):
        if os.path.isfile(os.path.join(folder_path, fname)):
            name, _ = os.path.splitext(fname)
            filenames.append(name)
    return filenames


def get_loop_graph_v1(pdb_id, path):
    torch.set_num_threads(1)
    pocket_pdb = f'{path}/{pdb_id}_pocket_12A.pdb'
    pdbid, chain, res_num_src, res_num_dst = pdb_id.split('_')
    loop_len = int(res_num_dst) - int(res_num_src) + 1
    # get protein mol
    pocket_mol = mda.Universe(pocket_pdb)
    non_loop_mol = pocket_mol.select_atoms(f'chainid {chain} and not (resid {res_num_src}:{res_num_dst})')
    loop_mol = pocket_mol.select_atoms(f'chainid {chain} and (resid {res_num_src}:{res_num_dst})')
    # assert loop_len == len(loop_mol.residues), f'{pdb_id} loop length error'
    loop_res = [f'{res.atoms.chainIDs[0]}-{res.resid}{res.icode}' for res in loop_mol.residues]
    pocket_atom_mol = pdb2rdmol(pocket_pdb)
    # generate graph
    p_xyz, p_xyz_full, p_seq, p_node_s, p_node_v, p_edge_index, p_edge_s, p_edge_v, p_full_edge_s, p_node_name, p_node_type = get_protein_feature_mda(non_loop_mol)
    loop_xyz, pa_node_feature, pa_edge_index, pa_edge_feature, atom2nonloopres, nonloop_mask, loop_edge_index, loop_edge_feature, loop_cov_edge_mask, loop_idx_2_mol_idx, loop_bb_atom_mask, loop_parent_mask_forward, loop_parent_mask_reverse, seq_order_forward, seq_parent_forward, seq_order_reverse, seq_parent_reverse, forward_ok, reverse_ok = get_loop_feature_strict(pocket_atom_mol, nonloop_res=p_node_name, loop_res=loop_res)
    assert len(p_node_s) == len(loop_parent_mask_forward), f'{pdb_id} loop parent mask error'
    assert len(p_node_s) == len(loop_parent_mask_reverse), f'{pdb_id} loop parent mask error'
    # to data
    data = HeteroData()
    # protein residue
    data.forward_ok = forward_ok
    data.reverse_ok = reverse_ok
    data.atom2nonloopres = torch.tensor(atom2nonloopres, dtype=torch.long) # 
    data.loop_parent_mask_forward = torch.tensor(loop_parent_mask_forward, dtype=torch.bool) # 
    data.loop_parent_mask_reverse = torch.tensor(loop_parent_mask_reverse, dtype=torch.bool) # 
    data.nonloop_mask = torch.tensor(nonloop_mask, dtype=torch.bool) # 
    data.loop_bb_atom_mask = torch.tensor(loop_bb_atom_mask, dtype=torch.bool) # 
    data.loop_cov_edge_mask = loop_cov_edge_mask  # 
    data.loop_idx_2_mol_idx = loop_idx_2_mol_idx # 
    data.mol = pocket_atom_mol
    # data['protein'].node_name = p_node_type
    data['protein'].node_s = p_node_s.to(torch.float32) 
    data['protein'].node_v = p_node_v.to(torch.float32)
    data['protein'].xyz = p_xyz.to(torch.float32) 
    data['protein'].xyz_full = p_xyz_full.to(torch.float32) 
    data['protein'].seq = p_seq.to(torch.int32)
    data['protein', 'p2p', 'protein'].edge_index = p_edge_index.to(torch.long)
    data['protein', 'p2p', 'protein'].edge_s = p_edge_s.to(torch.float32) 
    data['protein', 'p2p', 'protein'].full_edge_s = p_full_edge_s.to(torch.float32) 
    data['protein', 'p2p', 'protein'].edge_v = p_edge_v.to(torch.float32) 
    # protein atom
    data['protein_atom'].node_s = pa_node_feature.to(torch.float32) 
    data['protein_atom', 'pa2pa', 'protein_atom'].edge_index = pa_edge_index.to(torch.long)
    data['protein_atom', 'pa2pa', 'protein_atom'].edge_s = pa_edge_feature.to(torch.float32) 
    # # loop
    # data['loop'].node_name = loop_name
    data['loop'].node_s = torch.zeros((data.loop_bb_atom_mask.sum(), 1))
    data['loop'].xyz = loop_xyz.to(torch.float32)
    data['loop'].seq_parent_forward = seq_parent_forward.to(torch.long)
    data['loop'].seq_order_forward = seq_order_forward.to(torch.long)
    data['loop'].seq_parent_reverse = seq_parent_reverse.to(torch.long)
    data['loop'].seq_order_reverse = seq_order_reverse.to(torch.long)
    data['loop', 'l2l', 'loop'].edge_index = loop_edge_index.to(torch.long)
    data['loop', 'l2l', 'loop'].full_edge_s = loop_edge_feature.to(torch.float32)
    # # protein-loop
    data['protein', 'p2l', 'loop'].edge_index = torch.stack(
        get_repeat_node(p_xyz.shape[0], loop_xyz.shape[0]), dim=0)
    return data

def batch_get_loop_graph_v1(pdb_id_list, path):
    """
    批量读取pdb文件，生成HeteroData列表。
    
    Args:
        pdb_id_list (list): pdb_id的字符串列表
        path (str): pdb文件的文件夹路径
    Returns:
        list: HeteroData对象列表
    """
    hetero_data_list = []
    for pdb_id in pdb_id_list:
        try:
            data = get_loop_graph_v1(pdb_id, path)
            data.pdb_id = pdb_id
            hetero_data_list.append(data)
        except Exception as e:
            print(f"处理{pdb_id}时出错：{e}")
    return hetero_data_list


folp='./datas/cas15/graph'
print(get_filenames_no_ext(folp)[0])
get_loop_graph_v1(get_filenames_no_ext(folp)[0], folp)

#print(batch_get_loop_graph_v1(get_filenames_no_ext(folp), folp))











1ABR_B_114_121


FileNotFoundError: [Errno 2] No such file or directory: './datas/cas15/graph/1ABR_B_114_121_pocket_12A.pdb'