In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 包

In [None]:
pip install Bio

In [None]:
pip install torch torch-geometric optuna tqdm scikit-learn esm umap-learn shap xgboost

In [None]:
!pip uninstall -y numpy pandas
!pip install numpy==1.26.4 pandas==2.2.2

In [None]:
pip list

In [None]:
!python --version

Python 3.12.11


# 数据集

## 阳性样本cd-hit

In [None]:
!apt-get install cd-hit

In [None]:
!grep -c "^>" AFP.fasta

In [None]:
!cd-hit -i AFP.fasta -o AFP_clustered.fasta -c 0.8 -n 5

In [None]:
def convert_to_fasta(input_filename, output_filename):
    with open(input_filename, 'r') as file:
        sequences = file.readlines()

    with open(output_filename, 'w') as file:
        for i, sequence in enumerate(sequences):
            sequence = sequence.strip()  
            file.write(f">Sequence_{i + 1}\n{sequence}\n")

convert_to_fasta('AFP_2.txt', 'AFP_2.fasta')

In [None]:
with open("AFP.fasta", "r") as file:
    sequences = file.read().split('>')
    sequences = [seq for seq in sequences if seq.strip()]
    unique_sequences = set(sequences)

print(f"Total sequences: {len(sequences)}")
print(f"Unique sequences: {len(unique_sequences)}")

In [None]:
def filter_sequences(input_file, output_file, max_length=100):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        write_sequence = False
        for line in infile:
            if line.startswith('>'):
                if write_sequence:
                    outfile.write(sequence_header + sequence_data)
                sequence_header = line
                sequence_data = ''
                write_sequence = False  # Reset for next sequence
            else:
                sequence_data += line
                if len(sequence_data.replace('\n', '')) <= max_length:
                    write_sequence = True
                else:
                    write_sequence = False

        # Check last sequence
        if write_sequence:
            outfile.write(sequence_header + sequence_data)

input_filename = 'AFP_clustered.fasta'
output_filename = 'AFP_CD-hit.fasta'
filter_sequences(input_filename, output_filename)

In [None]:
def renumber_sequences(input_file, output_file):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        counter = 1
        for line in infile:
            if line.startswith('>'):
                outfile.write(f'>Sequence_{counter}\n')
                counter += 1
            else:
                outfile.write(line)

input_filename = 'AFP_CD-hit.fasta'  
output_filename = 'AFP_renumbered.fasta'  
renumber_sequences(input_filename, output_filename)

## 阴性样本cd-hit

In [None]:
def convert_to_fasta(input_filename, output_filename):
    with open(input_filename, 'r') as file:
        sequences = file.readlines()

    with open(output_filename, 'w') as file:
        for i, sequence in enumerate(sequences):
            sequence = sequence.strip()  
            file.write(f">Sequence_{i + 1}\n{sequence}\n")

convert_to_fasta('Non_AFP.txt', 'Non_AFP.fasta')

In [None]:
!cd-hit -i Non_AFP.fasta -o Non_AFP_clustered.fasta -c 0.8 -n 5

In [None]:
def filter_sequences(input_file, output_file, max_length=100):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        write_sequence = False
        for line in infile:
            if line.startswith('>'):
                if write_sequence:
                    outfile.write(sequence_header + sequence_data)
                sequence_header = line
                sequence_data = ''
                write_sequence = False  # Reset for next sequence
            else:
                sequence_data += line
                if len(sequence_data.replace('\n', '')) <= max_length:
                    write_sequence = True
                else:
                    write_sequence = False

        # Check last sequence
        if write_sequence:
            outfile.write(sequence_header + sequence_data)

input_filename = 'Non_AFP_clustered.fasta'
output_filename = 'Non_AFP_CD-hit.fasta'
filter_sequences(input_filename, output_filename)

In [None]:
def renumber_sequences(input_file, output_file):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        counter = 1
        for line in infile:
            if line.startswith('>'):
                outfile.write(f'>Sequence_{counter}\n')
                counter += 1
            else:
                outfile.write(line)

input_filename = 'Non_AFP_CD-hit.fasta' 
output_filename = 'Non_AFP_renumbered.fasta' 
renumber_sequences(input_filename, output_filename)

## 划分数据集

In [None]:
pip install Bio

In [None]:
from Bio import SeqIO
import pandas as pd
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import train_test_split

def load_sequences(file_path, label):
    sequences = []
    labels = []
    for record in SeqIO.parse(file_path, "fasta"):
        sequences.append(str(record.seq))
        labels.append(label)  
    return sequences, labels

def balance_and_split(sequences, labels):
    df = pd.DataFrame({
        'sequence': sequences,
        'label': labels
    })

    print(f"原始数据总量: {df.shape[0]}")
    print(f"各类样本数量：\n{df['label'].value_counts()}")

    # 进行欠采样
    rus = RandomUnderSampler(random_state=42)
    X_res, y_res = rus.fit_resample(df[['sequence']], df['label'])

    print(f"欠采样后数据总量: {X_res.shape[0]}")
    print(f"欠采样后各类样本数量：\n{pd.Series(y_res).value_counts()}")

    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X_res, y_res, test_size=0.2, random_state=42)

    print(f"训练集数量: {X_train.shape[0]}")
    print(f"测试集数量: {X_test.shape[0]}")

    train_df = pd.DataFrame(X_train, columns=['sequence'])
    train_df['label'] = y_train
    test_df = pd.DataFrame(X_test, columns=['sequence'])
    test_df['label'] = y_test

    train_df.to_csv('train_dataset.csv', index=False)
    test_df.to_csv('test_dataset.csv', index=False)

    return train_df, test_df

pos_sequences, pos_labels = load_sequences('AFP_renumbered.fasta', 1)
neg_sequences, neg_labels = load_sequences('Non_AFP_renumbered.fasta', 0)

print(f"阳性样本数量: {len(pos_sequences)}")
print(f"阴性样本数量: {len(neg_sequences)}")

all_sequences = pos_sequences + neg_sequences
all_labels = pos_labels + neg_labels

train_df, test_df = balance_and_split(all_sequences, all_labels)

## 去除非20个标准氨基酸

In [None]:
import pandas as pd

def replace_non_standard_amino_acids(seq):

    replacements = {'B': 'D', 'Z': 'E', 'X': 'A', 'J': 'L', 'U': 'C', 'O': 'K'}
    for old, new in replacements.items():
        seq = seq.replace(old, new)
    return seq

def load_and_preprocess(file_path, output_path):
    df = pd.read_csv(file_path)
    df['sequence'] = df['sequence'].apply(replace_non_standard_amino_acids)
    df.to_csv(output_path, index=False)
    return df

train_df = load_and_preprocess('train_dataset.csv', 'processed_train_dataset.csv')
test_df = load_and_preprocess('test_dataset.csv', 'processed_test_dataset.csv')

In [None]:
def renumber_sequences(input_file, output_file):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        counter = 1
        for line in infile:
            if line.startswith('>'):
                outfile.write(f'>Sequence_{counter}\n')
                counter += 1
            else:
                outfile.write(line)

input_filename = 'processed_test_dataset.csv'
output_filename = 'ColabFold_test_dataset.csv'
renumber_sequences(input_filename, output_filename)

## 分别划分成训练集阳性，训练集阴性，测试集阳性，测试集阴性

In [None]:
import pandas as pd

def to_fasta(df, file_name):
    with open(file_name, 'w') as f:
        for index, row in df.iterrows():
            f.write(f">{index}\n{row['sequence']}\n")

def process_and_save(data_path, output_prefix):
    df = pd.read_csv(data_path)

    pos = df[df['label'] == 1]
    neg = df[df['label'] == 0]

    to_fasta(pos, f'{output_prefix}_pos.fasta')
    to_fasta(neg, f'{output_prefix}_neg.fasta')

train_path = 'processed_train_dataset.csv'
test_path = 'processed_test_dataset.csv'

process_and_save(train_path, 'Colab_train')
process_and_save(test_path, 'Colab_test')

# 结构信息
*   加载pdb文件
*   使用pdb文件提取特征

## 加载pdb文件

In [None]:
import zipfile
import os


zip_folder_path = '/content/drive/MyDrive/AFP_work/result/test_neg'
extract_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_neg'  # 解压后的路径

os.makedirs(extract_folder_path, exist_ok=True)

file_count = 0

for filename in os.listdir(zip_folder_path):
    if filename.endswith('.zip'):
        zip_path = os.path.join(zip_folder_path, filename)
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            for pdb_file in zip_ref.namelist():
                if pdb_file.endswith('.pdb') and 'relaxed_rank_001' in pdb_file:
                    zip_ref.extract(pdb_file, extract_folder_path)
                    print(f'解压缩完成：{pdb_file} 从 {filename} 到 {extract_folder_path}')
                    file_count += 1

print(f'总共解压了 {file_count} 个文件')

pdb_files = [f for f in os.listdir(extract_folder_path) if f.endswith('.pdb')]
print(f'解压后的文件夹中共有 {len(pdb_files)} 个 PDB 文件')

# 删除包含 "unrelaxed" 的 PDB 文件
for pdb_file in pdb_files:
    if 'unrelaxed' in pdb_file:
        file_path = os.path.join(extract_folder_path, pdb_file)
        os.remove(file_path)
        print(f'已删除文件：{file_path}')

remaining_pdb_files = [f for f in os.listdir(extract_folder_path) if f.endswith('.pdb')]
print(f'删除 unrelaxed 文件后，文件夹中共有 {len(remaining_pdb_files)} 个 PDB 文件')


In [None]:
from Bio.PDB import PDBParser, PPBuilder

parser = PDBParser(QUIET=True)
ppb = PPBuilder()

def extract_sequence_from_pdb(pdb_path):
    structure = parser.get_structure('', pdb_path)
    sequence = ""
    for pp in ppb.build_peptides(structure):
        sequence += str(pp.get_sequence())
    return sequence

In [None]:
import os
import random
from Bio.PDB import PDBParser, PPBuilder
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq

# 四个文件夹路径
train_pos_path = '/content/drive/MyDrive/AFP_work/pdb/train_pos'
test_pos_path = '/content/drive/MyDrive/AFP_work/pdb/test_pos'
train_neg_path = '/content/drive/MyDrive/AFP_work/pdb/train_neg'
test_neg_path = '/content/drive/MyDrive/AFP_work/pdb/test_neg'

train_pdb_count = len([f for f in os.listdir(train_pos_path) if f.endswith('.pdb')])
test_pdb_count = len([f for f in os.listdir(test_pos_path) if f.endswith('.pdb')])
train_neg_pdb_count = len([f for f in os.listdir(train_neg_path) if f.endswith('.pdb')])
test_neg_pdb_count = len([f for f in os.listdir(test_neg_path) if f.endswith('.pdb')])

print(f'训练集中的 PDB 文件数量（阳性）：{train_pdb_count}')
print(f'测试集中的 PDB 文件数量（阳性）：{test_pdb_count}')
print(f'训练集中的 PDB 文件数量（阴性）：{train_neg_pdb_count}')
print(f'测试集中的 PDB 文件数量（阴性）：{test_neg_pdb_count}')

## 使用pdb文件提取特征

In [None]:
import os
import time
import json
from Bio.PDB import PDBParser
import numpy as np

train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb/train_pos'
train_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb/train_neg'
test_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_pos'
test_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_neg'
output_train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/train_pos'
output_train_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/train_neg'
output_test_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/test_pos'
output_test_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/test_neg'

output_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features'

os.makedirs(output_train_pos_folder_path, exist_ok=True)
os.makedirs(output_train_neg_folder_path, exist_ok=True)
os.makedirs(output_test_pos_folder_path, exist_ok=True)
os.makedirs(output_test_neg_folder_path, exist_ok=True)

# 提取训练集和测试集中的所有 PDB 文件
train_pos_pdb_files = [f for f in os.listdir(train_pos_folder_path) if f.endswith('.pdb')]
train_neg_pdb_files = [f for f in os.listdir(train_neg_folder_path) if f.endswith('.pdb')]
test_pos_pdb_files = [f for f in os.listdir(test_pos_folder_path) if f.endswith('.pdb')]
test_neg_pdb_files = [f for f in os.listdir(test_neg_folder_path) if f.endswith('.pdb')]

# 提取选中的 PDB 文件的结构信息
parser = PDBParser(QUIET=True)

start_time = time.time()

# 优化 PDB 文件处理函数
def process_pdb_file(pdb_path):
    structure = parser.get_structure('', pdb_path)
    residues = [residue for residue in structure.get_residues() if 'CA' in residue]
    num_residues = len(residues)

    # 提取位置特征、方向特征和旋转特征
    positions = np.array([residue['CA'].get_coord() for residue in residues], dtype=np.float64)
    edges = []
    directions = []
    rotations = []

    # 计算接触图和附加特征
    for i in range(num_residues):
        for j in range(i + 1, num_residues):
            distance = np.linalg.norm(positions[i] - positions[j])
            if distance < 10.0:  # 阈值为10Å来定义接触
                edges.append([i, j])
                direction = positions[j] - positions[i]
                norm = np.linalg.norm(direction)
                if norm != 0:
                    directions.append(direction / norm)
                    rotations.append(float(np.arctan2(direction[1], direction[0])))

    return positions, edges, directions, rotations

# 处理训练集中的 PDB 文件并添加标签
print(f'训练集中的 PDB 文件数量 (阳性): {len(train_pos_pdb_files)}')
print(f'训练集中的 PDB 文件数量 (阴性): {len(train_neg_pdb_files)}')

for pdb_file in train_pos_pdb_files:
    pdb_path = os.path.join(train_pos_folder_path, pdb_file)
    positions, edges, directions, rotations = process_pdb_file(pdb_path)
    features = {
        "node_features": positions.tolist(),
        "edge_features": {
            "edges": edges,
            "directions": [d.tolist() for d in directions],
            "rotations": [float(rot) for rot in rotations]
        },
        "label": 1  # 阳性样本标签
    }
    # 保存特征到文件
    output_file_path = os.path.join(output_train_pos_folder_path, f'{os.path.splitext(pdb_file)[0]}_features.json') if features['label'] == 1 else os.path.join(output_train_neg_folder_path, f'{os.path.splitext(pdb_file)[0]}_features.json')
    with open(output_file_path, 'w') as output_file:
        json.dump(features, output_file)

for pdb_file in train_neg_pdb_files:
    pdb_path = os.path.join(train_neg_folder_path, pdb_file)
    positions, edges, directions, rotations = process_pdb_file(pdb_path)
    features = {
        "node_features": positions.tolist(),
        "edge_features": {
            "edges": edges,
            "directions": [d.tolist() for d in directions],
            "rotations": [float(rot) for rot in rotations]
        },
        "label": 0  # 阴性样本标签
    }
    # 保存特征到文件
    output_file_path = os.path.join(output_train_neg_folder_path if features['label'] == 1 else output_test_neg_folder_path, f'{os.path.splitext(pdb_file)[0]}_features.json')
    with open(output_file_path, 'w') as output_file:
        json.dump(features, output_file)

# 处理测试集中的 PDB 文件并添加标签
print(f'测试集中的 PDB 文件数量 (阳性): {len(test_pos_pdb_files)}')
print(f'测试集中的 PDB 文件数量 (阴性): {len(test_neg_pdb_files)}')

for pdb_file in test_pos_pdb_files:
    pdb_path = os.path.join(test_pos_folder_path, pdb_file)
    positions, edges, directions, rotations = process_pdb_file(pdb_path)
    features = {
        "node_features": positions.tolist(),
        "edge_features": {
            "edges": edges,
            "directions": [d.tolist() for d in directions],
            "rotations": [float(rot) for rot in rotations]
        },
        "label": 1  # 阳性样本标签
    }
    # 保存特征到文件
    output_file_path = os.path.join(output_test_pos_folder_path, f'{os.path.splitext(pdb_file)[0]}_features.json') if features['label'] == 1 else os.path.join(output_test_neg_folder_path, f'{os.path.splitext(pdb_file)[0]}_features.json')
    with open(output_file_path, 'w') as output_file:
        json.dump(features, output_file)

for pdb_file in test_neg_pdb_files:
    pdb_path = os.path.join(test_neg_folder_path, pdb_file)
    positions, edges, directions, rotations = process_pdb_file(pdb_path)
    features = {
        "node_features": positions.tolist(),
        "edge_features": {
            "edges": edges,
            "directions": [d.tolist() for d in directions],
            "rotations": [float(rot) for rot in rotations]
        },
        "label": 0  # 阴性样本标签
    }
    # 保存特征到文件
    output_file_path = os.path.join(output_test_neg_folder_path, f'{os.path.splitext(pdb_file)[0]}_features.json')
    with open(output_file_path, 'w') as output_file:
        json.dump(features, output_file)

# 输出保存的 JSON 文件数量
# json_files = [f for f in os.listdir(output_folder_path) if f.endswith('.json')]
# print(f'保存的 JSON 文件数量: {len(json_files)}')

# 检查每个保存的 JSON 文件中的维度信息
# for json_file in json_files:
#     json_path = os.path.join(output_folder_path, json_file)
#     with open(json_path, 'r') as file:
#         data = json.load(file)
#         print(f'文件: {json_file}')
#         print(f'节点特征数量: {len(data["node_features"])}, 位置特征维度: {len(data["node_features"][0])}')
#         print(f'边特征数量: {len(data["edge_features"]["edges"])}, 方向特征数量: {len(data["edge_features"]["directions"])}, 旋转特征数量: {len(data["edge_features"]["rotations"])}')


output_train_neg_folder_path

# 输出保存的 JSON 文件数量
json_files = [f for f in os.listdir(output_train_neg_folder_path) if f.endswith('.json')]
print(f'output_train_neg_folder_path 保存的 JSON 文件数量: {len(json_files)}')

# 检查每个保存的 JSON 文件中的维度信息
for json_file in json_files:
    json_path = os.path.join(output_train_neg_folder_path, json_file)
    with open(json_path, 'r') as file:
        data = json.load(file)
        print(f'文件: {json_file}')
        print(f'节点特征数量: {len(data["node_features"])}, 位置特征维度: {len(data["node_features"][0])}')
        print(f'边特征数量: {len(data["edge_features"]["edges"])}, 方向特征数量: {len(data["edge_features"]["directions"])}, 旋转特征数量: {len(data["edge_features"]["rotations"])}')



end_time = time.time()
print(f"总处理时间: {end_time - start_time:.2f} 秒")


# ESM-C

In [None]:
pip install esm

In [None]:
import os
from Bio import SeqIO
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
import torch
import numpy as np
from tqdm import tqdm 

# 定义 FASTA 文件路径
#train_pos_seq_file = '/content/drive/MyDrive/AFP_work/seq/Colab_train_pos.fasta'
#train_neg_seq_file = '/content/drive/MyDrive/AFP_work/seq/Colab_train_neg.fasta'
test_pos_seq_file = '/content/drive/MyDrive/AFP_work/seq/Colab_test_pos.fasta'
#test_neg_seq_file = '/content/drive/MyDrive/AFP_work/seq/Colab_test_neg.fasta'

output_feature_path = '/content/drive/MyDrive/esmc_600_test_pos'

os.makedirs(output_feature_path, exist_ok=True)

def read_fasta_sequences(fasta_file):
    sequences = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        seq_id = record.id
        sequence = str(record.seq).replace(" ", "").replace("\n", "")
        sequences.append((seq_id, sequence))
    return sequences

def save_features(features, output_dir):
    for seq_id, feature_dict in features.items():
        logits = feature_dict['logits']
        embeddings = feature_dict['embeddings']

        # 定义文件路径
        logits_path = os.path.join(output_dir, f"{seq_id}_logits.npy")
        embeddings_path = os.path.join(output_dir, f"{seq_id}_embeddings.npy")

        # 保存 logits 和 embeddings
        np.save(logits_path, logits)
        np.save(embeddings_path, embeddings)

def extract_features_individual(client, sequences):
    features = {}
    for seq_id, seq in tqdm(sequences, desc="提取特征"):
        try:
            # 创建 ESMProtein 实例
            protein = ESMProtein(sequence=seq)

            # 编码蛋白质序列
            protein_tensor = client.encode(protein)

            # 获取 logits 和 embeddings
            logits_output = client.logits(
                protein_tensor,
                LogitsConfig(sequence=True, return_embeddings=True)
            )

            logits = logits_output.logits  # 需要根据实际情况修改
            embeddings = logits_output.embeddings  # 需要根据实际情况修改

            # 检查 logits 和 embeddings 的类型
            if isinstance(logits, torch.Tensor):
                logits = logits.cpu().numpy()
            elif isinstance(logits, np.ndarray):
                pass  # 已经是 NumPy 数组
            else:
                # 如果不是 tensor 或 ndarray，尝试其他转换
                logits = np.array(logits)

            if isinstance(embeddings, torch.Tensor):
                embeddings = embeddings.cpu().numpy()
            elif isinstance(embeddings, np.ndarray):
                pass  # 已经是 NumPy 数组
            else:
                # 如果不是 tensor 或 ndarray，尝试其他转换
                embeddings = np.array(embeddings)

            # 保存特征
            features[seq_id] = {
                'logits': logits,
                'embeddings': embeddings
            }
        except Exception as e:
            print(f"处理序列 {seq_id} 时出错: {e}")
            continue
    return features

# 初始化 ESMC 客户端
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
client = ESMC.from_pretrained("esmc_600m").to(device)
client.eval()  # 设置为评估模式

fasta_files = [
    #train_pos_seq_file,
    #train_neg_seq_file,
    test_pos_seq_file,
    #test_neg_seq_file
]

for fasta_file in fasta_files:
    print(f"正在处理文件: {fasta_file}")

    # 读取序列
    sequences = read_fasta_sequences(fasta_file)
    print(f"序列数量: {len(sequences)}")

    # 提取特征（逐条处理）
    features = extract_features_individual(client, sequences)

    # 保存特征
    save_features(features, output_feature_path)

    print(f"特征已保存: {fasta_file}\n")

# 处理序列结构信息

## 处理序列信息

In [None]:
# 2. 导入必要的库
import os
import numpy as np

# 3. 定义特征输出文件夹列表
feature_dirs = [
    '/content/drive/MyDrive/AFP_work/esmc_600_train_pos',
    '/content/drive/MyDrive/AFP_work/esmc_600_train_neg',
    '/content/drive/MyDrive/AFP_work/esmc_600_test_pos',  # 请确认此路径是否正确
    '/content/drive/MyDrive/AFP_work/esmc_600_test_neg'
    # 如果有其他文件夹，如 esmc_600_test_pos，请在此添加
]

# 4. 定义函数以查看汇总文件的维度
def inspect_combined_files(feature_dirs):
    """
    遍历每个特征文件夹，加载并打印 combined_logits.npy 和 combined_embeddings.npy 的维度。

    参数:
        feature_dirs (List[str]): 特征文件夹路径列表。
    """
    for feature_dir in feature_dirs:
        print(f"正在处理文件夹: {feature_dir}")

        # 检查文件夹是否存在
        if not os.path.isdir(feature_dir):
            print(f"文件夹 '{feature_dir}' 不存在，跳过。\n")
            continue

        # 定义 combined_logits.npy 和 combined_embeddings.npy 的路径
        logits_path = os.path.join(feature_dir, 'combined_logits.npy')
        embeddings_path = os.path.join(feature_dir, 'combined_embeddings.npy')

        # 检查 combined_logits.npy 是否存在
        if os.path.isfile(logits_path):
            try:
                combined_logits = np.load(logits_path, allow_pickle=True)
                print(f" - {os.path.basename(logits_path)} 的形状: {combined_logits.shape}")
            except Exception as e:
                print(f" - 加载 {os.path.basename(logits_path)} 时出错: {e}")
        else:
            print(f" - {os.path.basename(logits_path)} 不存在。")

        # 检查 combined_embeddings.npy 是否存在
        if os.path.isfile(embeddings_path):
            try:
                combined_embeddings = np.load(embeddings_path, allow_pickle=True)
                print(f" - {os.path.basename(embeddings_path)} 的形状: {combined_embeddings.shape}\n")
            except Exception as e:
                print(f" - 加载 {os.path.basename(embeddings_path)} 时出错: {e}\n")
        else:
            print(f" - {os.path.basename(embeddings_path)} 不存在。\n")

        # mxlin添加
        sample = combined_embeddings[0]  # 取第一个样本
        print("数据级别：",sample.shape)  # 输出 (1152,)，则是序列级别

# 5. 运行函数以查看维度
inspect_combined_files(feature_dirs)


In [None]:
import os
import glob

feature_dirs = [
    '/content/drive/MyDrive/AFP_work/esmc_600_train_pos',
    '/content/drive/MyDrive/AFP_work/esmc_600_train_neg',
    '/content/drive/MyDrive/AFP_work/esmc_600_test_pos',
    '/content/drive/MyDrive/AFP_work/esmc_600_test_neg'
]

# 统计每个特征文件夹中 logits 和 embeddings 文件的数量，并验证是否匹配。
def count_logit_embedding_files(feature_dirs):
    for feature_dir in feature_dirs:

        if not os.path.isdir(feature_dir):
            continue

        logits_files = sorted(glob.glob(os.path.join(feature_dir, '*_logits.npy')))
        embeddings_files = sorted(glob.glob(os.path.join(feature_dir, '*_embeddings.npy')))

        num_logits = len(logits_files)
        num_embeddings = len(embeddings_files)

        if num_logits == num_embeddings:
            print(f"数量一致。\n")
        else:
            print(f"数量不一致。")
            print(f"- logits 文件数量: {num_logits}")
            print(f"- embeddings 文件数量: {num_embeddings}")

            # 找出缺失或多余的文件
            logits_indices = set([os.path.basename(f).split('_')[0] for f in logits_files])
            embeddings_indices = set([os.path.basename(f).split('_')[0] for f in embeddings_files])

            missing_in_embeddings = logits_indices - embeddings_indices
            missing_in_logits = embeddings_indices - logits_indices

            if missing_in_embeddings:
                print(f"在 embeddings 文件夹中缺失: {sorted(missing_in_embeddings)}")
            if missing_in_logits:
                print(f"在 logits 文件夹中缺失: {sorted(missing_in_logits)}")
            print()

count_logit_embedding_files(feature_dirs)


In [None]:
# 遍历每个特征文件夹，加载并打印 combined_logits.npy 和 combined_embeddings.npy 的维度
def inspect_combined_files(feature_dirs):
    """
    遍历每个特征文件夹，加载并打印 combined_logits.npy 和 combined_embeddings.npy 的维度。

    参数:
        feature_dirs (List[str]): 特征文件夹路径列表。
    """
    for feature_dir in feature_dirs:
        print(f"正在检查文件夹: {feature_dir}")

        # 定义 combined_logits.npy 和 combined_embeddings.npy 的路径
        combined_logits_path = os.path.join(feature_dir, 'combined_logits.npy')
        combined_embeddings_path = os.path.join(feature_dir, 'combined_embeddings.npy')

        # 检查并加载 combined_logits.npy
        if os.path.isfile(combined_logits_path):
            try:
                combined_logits = np.load(combined_logits_path, allow_pickle=True)
                print(f"  {os.path.basename(combined_logits_path)} 的形状: {combined_logits.shape}")
                print(f"  {os.path.basename(combined_logits_path)} 的前5行数据：\n{combined_logits[:5]}\n")
            except Exception as e:
                print(f"  加载 {os.path.basename(combined_logits_path)} 时出错: {e}")
        else:
            print(f" {os.path.basename(combined_logits_path)} 不存在。")

        # 检查并加载 combined_embeddings.npy
        if os.path.isfile(combined_embeddings_path):
            try:
                combined_embeddings = np.load(combined_embeddings_path, allow_pickle=True)
                print(f" {feature_dir} {os.path.basename(combined_embeddings_path)} 的形状: {combined_embeddings.shape}")
                print(f" {feature_dir} {os.path.basename(combined_embeddings_path)} 的前5行数据：\n{combined_embeddings[:5]}\n")
            except Exception as e:
                print(f"加载 {os.path.basename(combined_embeddings_path)} 时出错: {e}\n")
        else:
            print(f"{os.path.basename(combined_embeddings_path)} 不存在。\n")
            
inspect_combined_files(feature_dirs)

正在检查文件夹: /content/drive/MyDrive/AFP_work/esmc_600_train_pos
  📊 combined_logits.npy 的形状: (1200, 1)
  📄 combined_logits.npy 的前5行数据：
[[ForwardTrackData(sequence=tensor([[[-21.2500, -21.1250, -21.2500,  ..., -21.2500, -21.2500, -21.2500],
           [-26.6250, -26.6250, -26.6250,  ..., -26.6250, -26.6250, -26.6250],
           [-28.3750, -28.2500, -28.3750,  ..., -28.3750, -28.3750, -28.2500],
           ...,
           [-24.7500, -24.6250, -24.7500,  ..., -24.7500, -24.7500, -24.7500],
           [-21.1250, -21.0000, -21.1250,  ..., -21.1250, -21.1250, -21.1250],
           [-20.0000, -19.8750, -20.0000,  ..., -19.8750, -20.0000, -20.0000]]],
         device='cuda:0', dtype=torch.bfloat16), structure=None, secondary_structure=None, sasa=None, function=None)]
 [ForwardTrackData(sequence=tensor([[[-22.7500, -22.6250, -22.7500,  ..., -22.7500, -22.7500, -22.7500],
           [-26.3750, -26.2500, -26.3750,  ..., -26.2500, -26.3750, -26.3750],
           [-24.2500, -24.2500, -24.3750,  ..., -

## 处理结构信息

In [None]:
import os
import glob
import pandas as pd

# 定义目标文件夹路径
train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb/train_pos'
train_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb/train_neg'
test_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_pos'
test_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_neg'

# 将文件夹路径存储在一个字典中，便于遍历
folder_paths = {
    'train_pos': train_pos_folder_path,
    'train_neg': train_neg_folder_path,
    'test_pos': test_pos_folder_path,
    'test_neg': test_neg_folder_path
}

# 初始化一个空列表，用于存储统计结果
stats = []

# 遍历每个文件夹，统计个数
for folder_name, folder_path in folder_paths.items():
    print(f"正在处理文件夹: {folder_path}")

    # 检查文件夹是否存在
    if not os.path.isdir(folder_path):
        print(f"❌ 文件夹 '{folder_path}' 不存在。请检查路径是否正确。\n")
        stats.append({
            'folder': folder_name,
            'path': folder_path,
            'pdb_file_count': '文件夹不存在'
        })
        continue

    # 使用 glob 查找所有 .pdb 文件（不区分大小写）
    pdb_files = glob.glob(os.path.join(folder_path, '*.pdb')) + glob.glob(os.path.join(folder_path, '*.PDB'))

    # 统计 .pdb 文件的数量
    pdb_count = len(pdb_files)

    print(f"✅ 文件夹 '{folder_name}' 中共有 {pdb_count} 个 .pdb 文件。\n")

    # 将统计结果添加到列表中
    stats.append({
        'folder': folder_name,
        'path': folder_path,
        'pdb_file_count': pdb_count
    })

# 创建 DataFrame
df_stats = pd.DataFrame(stats)

# 显示统计结果
print("📊 各文件夹中 .pdb 文件的数量统计：")
print(df_stats)

# 保存统计结果为 CSV 文件
csv_output_path = '/content/drive/MyDrive/AFP_work/pdb/pdb_file_counts.csv'
df_stats.to_csv(csv_output_path, index=False, encoding='utf-8-sig')

print(f"\n✅ 统计结果已保存到 '{csv_output_path}'。")

# 读取并显示保存的 CSV 文件内容
df_loaded_stats = pd.read_csv(csv_output_path)
print("\n📄 统计结果 CSV 文件内容：")
print(df_loaded_stats)


正在处理文件夹: /content/drive/MyDrive/AFP_work/pdb/train_pos
✅ 文件夹 'train_pos' 中共有 1140 个 .pdb 文件。

正在处理文件夹: /content/drive/MyDrive/AFP_work/pdb/train_neg
✅ 文件夹 'train_neg' 中共有 1200 个 .pdb 文件。

正在处理文件夹: /content/drive/MyDrive/AFP_work/pdb/test_pos
✅ 文件夹 'test_pos' 中共有 367 个 .pdb 文件。

正在处理文件夹: /content/drive/MyDrive/AFP_work/pdb/test_neg
✅ 文件夹 'test_neg' 中共有 308 个 .pdb 文件。

📊 各文件夹中 .pdb 文件的数量统计：
      folder                                           path  pdb_file_count
0  train_pos  /content/drive/MyDrive/AFP_work/pdb/train_pos            1140
1  train_neg  /content/drive/MyDrive/AFP_work/pdb/train_neg            1200
2   test_pos   /content/drive/MyDrive/AFP_work/pdb/test_pos             367
3   test_neg   /content/drive/MyDrive/AFP_work/pdb/test_neg             308

✅ 统计结果已保存到 '/content/drive/MyDrive/AFP_work/pdb/pdb_file_counts.csv'。

📄 统计结果 CSV 文件内容：
      folder                                           path  pdb_file_count
0  train_pos  /content/drive/MyDrive/AFP_work/pdb/train_pos    

In [None]:
import os
import glob
import pandas as pd

# 定义目标文件夹路径
train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_pos'

# 定义输出 CSV 文件的保存路径
output_csv_path = '/content/drive/MyDrive/AFP_work/pdb/test_pos_pdb_filenames.csv'

# 检查文件夹是否存在
if not os.path.isdir(train_pos_folder_path):
    print(f"❌ 文件夹 '{train_pos_folder_path}' 不存在。请检查路径是否正确。")
else:
    # 使用 glob 查找所有 .pdb 文件（不区分大小写）
    pdb_files_lower = glob.glob(os.path.join(train_pos_folder_path, '*.pdb'))
    pdb_files_upper = glob.glob(os.path.join(train_pos_folder_path, '*.PDB'))
    pdb_files = pdb_files_lower + pdb_files_upper

    # 提取文件名
    pdb_filenames = [os.path.basename(f) for f in pdb_files]

    # 检查是否找到任何 .pdb 文件
    if not pdb_filenames:
        print(f"❌ 在文件夹 '{train_pos_folder_path}' 中未找到任何 .pdb 文件。")
    else:
        # 创建一个 DataFrame
        df = pd.DataFrame({'filename': pdb_filenames})

        # 保存为 CSV 文件
        try:
            df.to_csv(output_csv_path, index=False, encoding='utf-8-sig')
            print(f"✅ 所有 .pdb 文件名已成功保存到 '{output_csv_path}'。")
            print(f"总共保存了 {len(pdb_filenames)} 个文件名。")
        except Exception as e:
            print(f"❌ 保存 CSV 文件时出错: {e}")


✅ 所有 .pdb 文件名已成功保存到 '/content/drive/MyDrive/pdb/test_pos_pdb_filenames.csv'。
总共保存了 308 个文件名。


In [None]:
import os
import glob
import re
import pandas as pd

def save_extracted_numbers_to_csv(folder_path, output_csv_path):
    """
    从指定文件夹中提取 .pdb 文件名中 '_relaxed_rank_001' 之前的数字，并保存到 CSV 文件中。

    参数:
        folder_path (str): 目标文件夹路径。
        output_csv_path (str): 输出 CSV 文件的路径。
    """
    # 检查文件夹是否存在
    if not os.path.isdir(folder_path):
        print(f"❌ 文件夹 '{folder_path}' 不存在。请检查路径是否正确。")
        return

    # 使用 glob 查找所有 .pdb 文件（不区分大小写）
    pdb_files_lower = glob.glob(os.path.join(folder_path, '*.pdb'))
    pdb_files_upper = glob.glob(os.path.join(folder_path, '*.PDB'))
    pdb_files = pdb_files_lower + pdb_files_upper

    # 提取文件名
    pdb_filenames = [os.path.basename(f) for f in pdb_files]

    # 定义正则表达式模式
    # 文件名格式：数字_relaxed_rank_001_其他信息.pdb，例如：2033_relaxed_rank_001_alphafold2_ptm_model_4_seed_000.pdb
    pattern = re.compile(r'^(\d+)_relaxed_rank_001.*\.pdb$', re.IGNORECASE)

    # 初始化列表存储提取的数字
    extracted_numbers = []

    # 遍历文件名并提取数字
    for filename in pdb_filenames:
        match = pattern.match(filename)
        if match:
            number = match.group(1)
            extracted_numbers.append(int(number))  # 转换为整数类型
        else:
            print(f"⚠️ 文件名不符合预期模式，无法提取数字: {filename}")

    # 检查是否提取到了任何数字
    if not extracted_numbers:
        print(f"❌ 在文件夹 '{folder_path}' 中未找到符合模式的 .pdb 文件。")
    else:
        # 创建一个 DataFrame
        df = pd.DataFrame({'extracted_number': extracted_numbers})

        # 保存为 CSV 文件
        try:
            df.to_csv(output_csv_path, index=False, encoding='utf-8-sig')
            print(f"✅ 提取的数字已成功保存到 '{output_csv_path}'。")
            print(f"总共保存了 {len(extracted_numbers)} 个数字。")
        except Exception as e:
            print(f"❌ 保存 CSV 文件时出错: {e}")

# 定义目标文件夹路径
train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_pos'

# 定义输出 CSV 文件的保存路径
output_csv_path = '/content/drive/MyDrive/AFP_work/pdb/test_pos_pdb_filenames_extracted.csv'

# 调用函数
save_extracted_numbers_to_csv(train_pos_folder_path, output_csv_path)


✅ 提取的数字已成功保存到 '/content/drive/MyDrive/pdb/test_pos_pdb_filenames_extracted.csv'。
总共保存了 308 个数字。


In [None]:
import os
import time
import json
from Bio.PDB import PDBParser
import numpy as np

# 训练集和测试集的 PDB 文件夹路径
train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb/train_pos'
train_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb/train_neg'
test_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_pos'
test_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb/test_neg'

# 输出文件夹路径
output_train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/train_pos'
output_train_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/train_neg'
output_test_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/test_pos'
output_test_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/test_neg'

# 创建输出文件夹（如果不存在）
os.makedirs(output_train_pos_folder_path, exist_ok=True)
os.makedirs(output_train_neg_folder_path, exist_ok=True)
os.makedirs(output_test_pos_folder_path, exist_ok=True)
os.makedirs(output_test_neg_folder_path, exist_ok=True)

# 提取训练集和测试集中的所有 PDB 文件
train_pos_pdb_files = [f for f in os.listdir(train_pos_folder_path) if f.endswith('.pdb') or f.endswith('.PDB')]
train_neg_pdb_files = [f for f in os.listdir(train_neg_folder_path) if f.endswith('.pdb') or f.endswith('.PDB')]
test_pos_pdb_files = [f for f in os.listdir(test_pos_folder_path) if f.endswith('.pdb') or f.endswith('.PDB')]
test_neg_pdb_files = [f for f in os.listdir(test_neg_folder_path) if f.endswith('.pdb') or f.endswith('.PDB')]

# 初始化 PDBParser
parser = PDBParser(QUIET=True)

start_time = time.time()

# 优化 PDB 文件处理函数
def process_pdb_file(pdb_path):
    try:
        structure = parser.get_structure('', pdb_path)
        residues = [residue for residue in structure.get_residues() if 'CA' in residue]
        num_residues = len(residues)

        if num_residues == 0:
            print(f"⚠️ 文件 '{pdb_path}' 中没有找到 CA 原子。")
            return None, None, None, None

        # 提取位置特征、方向特征和旋转特征
        positions = np.array([residue['CA'].get_coord() for residue in residues], dtype=np.float64)
        edges = []
        directions = []
        rotations = []

        # 计算接触图和附加特征
        for i in range(num_residues):
            for j in range(i + 1, num_residues):
                distance = np.linalg.norm(positions[i] - positions[j])
                if distance < 10.0:  # 阈值为10Å来定义接触
                    edges.append([i, j])
                    direction = positions[j] - positions[i]
                    norm = np.linalg.norm(direction)
                    if norm != 0:
                        directions.append(direction / norm)
                        rotations.append(float(np.arctan2(direction[1], direction[0])))

        return positions, edges, directions, rotations
    except Exception as e:
        print(f"❌ 处理文件 '{pdb_path}' 时出错: {e}")
        return None, None, None, None

# 处理 PDB 文件并保存特征
def process_and_save(pdb_files, folder_path, output_folder_path, label):
    print(f'处理文件夹: {folder_path}')
    print(f'文件数量: {len(pdb_files)}')

    processed_count = 0
    for pdb_file in pdb_files:
        pdb_path = os.path.join(folder_path, pdb_file)
        positions, edges, directions, rotations = process_pdb_file(pdb_path)

        if positions is None:
            continue  # 跳过处理出错的文件

        features = {
            "node_features": positions.tolist(),
            "edge_features": {
                "edges": edges,
                "directions": [d.tolist() for d in directions],
                "rotations": [float(rot) for rot in rotations]
            },
            "label": label
        }

        # 定义输出文件路径
        output_file_path = os.path.join(output_folder_path, f'{os.path.splitext(pdb_file)[0]}_features.json')

        # 保存特征到文件
        try:
            with open(output_file_path, 'w') as output_file:
                json.dump(features, output_file)
            processed_count += 1
        except Exception as e:
            print(f"❌ 保存文件 '{output_file_path}' 时出错: {e}")
            continue

        # 每处理100个文件，打印一次进度
        if processed_count % 100 == 0:
            print(f'✅ 已处理 {processed_count} 个文件。')

    print(f'✅ 完成处理 {processed_count} 个文件。')





# 处理训练集中的阳性 PDB 文件
process_and_save(train_pos_pdb_files, train_pos_folder_path, output_train_pos_folder_path, label=1)
# 处理训练集中的阴性 PDB 文件
process_and_save(train_neg_pdb_files, train_neg_folder_path, output_train_neg_folder_path, label=0)
# 处理测试集中的阳性 PDB 文件
process_and_save(test_pos_pdb_files, test_pos_folder_path, output_test_pos_folder_path, label=1)
# 处理测试集中的阴性 PDB 文件
process_and_save(test_neg_pdb_files, test_neg_folder_path, output_test_neg_folder_path, label=0)




end_time = time.time()
print(f"总处理时间: {end_time - start_time:.2f} 秒")



# 输出保存的 JSON 文件数量
def count_json_files(output_folder_path):
    json_files = [f for f in os.listdir(output_folder_path) if f.endswith('.json')]
    print(f'文件夹 "{output_folder_path}" 中保存的 JSON 文件数量: {len(json_files)}')
    return json_files

print("\n保存的 JSON 文件数量:")
count_json_files(output_train_pos_folder_path)
count_json_files(output_train_neg_folder_path)
count_json_files(output_test_pos_folder_path)
count_json_files(output_test_neg_folder_path)

# 检查每个保存的 JSON 文件中的维度信息
def check_json_dimensions(output_folder_path):
    json_files = [f for f in os.listdir(output_folder_path) if f.endswith('.json')]
    for json_file in json_files[:5]:  # 仅检查前5个文件
        json_path = os.path.join(output_folder_path, json_file)
        with open(json_path, 'r') as file:
            data = json.load(file)
            node_features = data.get("node_features", [])
            edge_features = data.get("edge_features", {})
            edges = edge_features.get("edges", [])
            directions = edge_features.get("directions", [])
            rotations = edge_features.get("rotations", [])
            print(f'文件: {json_file}')
            print(f'节点特征数量: {len(node_features)}, 位置特征维度: {len(node_features[0]) if node_features else 0}')
            print(f'边特征数量: {len(edges)}, 方向特征数量: {len(directions)}, 旋转特征数量: {len(rotations)}\n')

print("\n检查部分 JSON 文件的维度信息:")
check_json_dimensions(output_train_pos_folder_path)
check_json_dimensions(output_train_neg_folder_path)
check_json_dimensions(output_test_pos_folder_path)
check_json_dimensions(output_test_neg_folder_path)


处理文件夹: /content/drive/MyDrive/pdb/train_pos
文件数量: 1200
✅ 已处理 100 个文件。
✅ 已处理 200 个文件。
✅ 已处理 300 个文件。
✅ 已处理 400 个文件。
✅ 已处理 500 个文件。
✅ 已处理 600 个文件。
✅ 已处理 700 个文件。
✅ 已处理 800 个文件。
✅ 已处理 900 个文件。
✅ 已处理 1000 个文件。
✅ 已处理 1100 个文件。
✅ 已处理 1200 个文件。
✅ 完成处理 1200 个文件。
处理文件夹: /content/drive/MyDrive/pdb/train_neg
文件数量: 1200
✅ 已处理 100 个文件。
✅ 已处理 200 个文件。
✅ 已处理 300 个文件。
✅ 已处理 400 个文件。
✅ 已处理 500 个文件。
✅ 已处理 600 个文件。
✅ 已处理 700 个文件。
✅ 已处理 800 个文件。
✅ 已处理 900 个文件。
✅ 已处理 1000 个文件。
✅ 已处理 1100 个文件。
✅ 已处理 1200 个文件。
✅ 完成处理 1200 个文件。
处理文件夹: /content/drive/MyDrive/pdb/test_pos
文件数量: 308
✅ 已处理 100 个文件。
✅ 已处理 200 个文件。
✅ 已处理 300 个文件。
✅ 完成处理 308 个文件。
处理文件夹: /content/drive/MyDrive/pdb/test_neg
文件数量: 308
✅ 已处理 100 个文件。
✅ 已处理 200 个文件。
✅ 已处理 300 个文件。
✅ 完成处理 308 个文件。
总处理时间: 947.43 秒

保存的 JSON 文件数量:
文件夹 "/content/drive/MyDrive/pdb_features/train_pos" 中保存的 JSON 文件数量: 1200
文件夹 "/content/drive/MyDrive/pdb_features/train_neg" 中保存的 JSON 文件数量: 1200
文件夹 "/content/drive/MyDrive/pdb_features/test_pos" 中保存的 JSON 文件数量: 308
文件夹 "/content

In [None]:
import os
import glob

# 定义输出文件夹路径
output_train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/train_pos'
output_train_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/train_neg'
output_test_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/test_pos'
output_test_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/test_neg'

# 将文件夹路径存储在一个字典中，便于遍历
folders = {
    'train_pos': output_train_pos_folder_path,
    'train_neg': output_train_neg_folder_path,
    'test_pos': output_test_pos_folder_path,
    'test_neg': output_test_neg_folder_path
}

# 遍历每个文件夹并统计 JSON 文件数量
for folder_name, folder_path in folders.items():
    if os.path.isdir(folder_path):
        # 使用 glob 查找所有 .json 文件（不区分大小写）
        json_files = glob.glob(os.path.join(folder_path, '*.json')) + glob.glob(os.path.join(folder_path, '*.JSON'))
        count = len(json_files)
        print(f"文件夹 '{folder_name}' 中的 JSON 文件数量: {count}")
    else:
        print(f"❌ 文件夹 '{folder_name}' 不存在。请检查路径是否正确。")


文件夹 'train_pos' 中的 JSON 文件数量: 1200
文件夹 'train_neg' 中的 JSON 文件数量: 1200
文件夹 'test_pos' 中的 JSON 文件数量: 308
文件夹 'test_neg' 中的 JSON 文件数量: 308


In [None]:
import os
import glob
import json
import pandas as pd
from tqdm import tqdm
import concurrent.futures
import logging
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

# 配置日志
logging.basicConfig(filename='/content/drive/MyDrive/AFP_work/pdb_features/processing.log',
                    filemode='a',
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    level=logging.INFO)

# 定义输出文件夹路径
output_train_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/train_pos'
output_train_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/train_neg'
output_test_pos_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/test_pos'
output_test_neg_folder_path = '/content/drive/MyDrive/AFP_work/pdb_features/test_neg'

output_folders = {
    'train_pos': output_train_pos_folder_path,
    'train_neg': output_train_neg_folder_path,
    'test_pos': output_test_pos_folder_path,
    'test_neg': output_test_neg_folder_path
}

# 创建输出文件夹（如果不存在）
for folder in output_folders.values():
    os.makedirs(folder, exist_ok=True)

# 定义函数加载单个 JSON 文件
def load_single_json(json_file):
    try:
        with open(json_file, 'r') as f:
            sample = json.load(f)
        logging.info(f"成功加载文件: {json_file}")
        return sample
    except Exception as e:
        logging.error(f"加载文件 '{json_file}' 时出错: {e}")
        return None

# 定义函数并行加载 JSON 文件
def load_json_files_parallel(folder_path, max_workers=8):
    """
    并行加载指定文件夹中的所有 JSON 文件。

    参数:
        folder_path (str): JSON 文件所在的文件夹路径。
        max_workers (int): 并行工作的最大线程数。

    返回:
        list: 包含所有成功加载的样本数据的列表。
    """
    json_files = glob.glob(os.path.join(folder_path, '*.json'))
    data = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(load_single_json, f): f for f in json_files}
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc=f'Loading {os.path.basename(folder_path)}'):
            result = future.result()
            if result is not None:
                data.append(result)
    return data

# 初始化训练集和测试集
train_data = []
test_data = []

# 加载训练集数据
train_pos_data = load_json_files_parallel(output_folders['train_pos'])
train_neg_data = load_json_files_parallel(output_folders['train_neg'])
train_data = train_pos_data + train_neg_data

# 加载测试集数据
test_pos_data = load_json_files_parallel(output_folders['test_pos'])
test_neg_data = load_json_files_parallel(output_folders['test_neg'])
test_data = test_pos_data + test_neg_data

print(f"✅ 训练集总样本数: {len(train_data)}")
print(f"✅ 测试集总样本数: {len(test_data)}")

# 定义输出汇总文件的路径
aggregated_output_folder = '/content/drive/MyDrive/AFP_work/pdb_features/aggregated'
os.makedirs(aggregated_output_folder, exist_ok=True)

train_output_path = os.path.join(aggregated_output_folder, 'train_dataset.json')
test_output_path = os.path.join(aggregated_output_folder, 'test_dataset.json')

# 保存训练集
try:
    with open(train_output_path, 'w') as f:
        json.dump(train_data, f)
    print(f"✅ 训练集已保存到 '{train_output_path}'。")
except Exception as e:
    print(f"❌ 保存训练集时出错: {e}")

# 保存测试集
try:
    with open(test_output_path, 'w') as f:
        json.dump(test_data, f)
    print(f"✅ 测试集已保存到 '{test_output_path}'。")
except Exception as e:
    print(f"❌ 保存测试集时出错: {e}")

# 验证汇总结果
def load_aggregated_data(file_path):
    """
    从指定的 JSON 文件中加载汇总数据。

    参数:
        file_path (str): 汇总数据的 JSON 文件路径。

    返回:
        list: 包含所有样本数据的列表。
    """
    if not os.path.isfile(file_path):
        print(f"❌ 文件 '{file_path}' 不存在。")
        return []
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
        print(f"✅ 成功加载 '{file_path}'，样本数: {len(data)}")
        return data
    except Exception as e:
        print(f"❌ 加载文件 '{file_path}' 时出错: {e}")
        return []

# 加载并查看训练集
train_dataset = load_aggregated_data(train_output_path)
if train_dataset:
    print(f"训练集第一个样本内容:")
    print(json.dumps(train_dataset[0], indent=2))

# 加载并查看测试集
test_dataset = load_aggregated_data(test_output_path)
if test_dataset:
    print(f"测试集第一个样本内容:")
    print(json.dumps(test_dataset[0], indent=2))

# 转换为 Pandas DataFrame（可选）
def json_to_dataframe(data):
    """
    将 JSON 数据转换为 Pandas DataFrame。

    参数:
        data (list): 包含所有样本数据的列表。

    返回:
        pd.DataFrame: 包含标签和特征的 DataFrame。
    """
    records = []
    for sample in data:
        record = {}
        record['label'] = sample['label']
        # 示例：计算节点特征的均值和标准差作为简单特征
        node_features = np.array(sample['node_features'])
        record['node_mean_x'] = node_features[:, 0].mean()
        record['node_mean_y'] = node_features[:, 1].mean()
        record['node_mean_z'] = node_features[:, 2].mean()
        record['node_std_x'] = node_features[:, 0].std()
        record['node_std_y'] = node_features[:, 1].std()
        record['node_std_z'] = node_features[:, 2].std()
        # 可以根据需要添加更多特征
        records.append(record)
    df = pd.DataFrame(records)
    return df

# 转换训练集和测试集为 DataFrame
train_df = json_to_dataframe(train_dataset)
test_df = json_to_dataframe(test_dataset)

print("训练集 DataFrame 预览:")
print(train_df.head())

print("\n测试集 DataFrame 预览:")
print(test_df.head())

# 保存为 CSV 文件（可选）
train_csv_path = os.path.join(aggregated_output_folder, 'train_dataset_dataframe.csv')
test_csv_path = os.path.join(aggregated_output_folder, 'test_dataset_dataframe.csv')

train_df.to_csv(train_csv_path, index=False, encoding='utf-8-sig')
test_df.to_csv(test_csv_path, index=False, encoding='utf-8-sig')

print(f"✅ 训练集 DataFrame 已保存到 '{train_csv_path}'。")
print(f"✅ 测试集 DataFrame 已保存到 '{test_csv_path}'。")

# 定义 PyTorch Geometric 数据集类（可选）
class PDBDataset(Dataset):
    def __init__(self, data_list):
        super(PDBDataset, self).__init__()
        self.data_list = data_list

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        sample = self.data_list[idx]
        node_features = torch.tensor(sample['node_features'], dtype=torch.float)
        edge_index = torch.tensor(sample['edge_features']['edges'], dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(sample['edge_features']['directions'], dtype=torch.float)
        rotations = torch.tensor(sample['edge_features']['rotations'], dtype=torch.float).unsqueeze(1)
        edge_features = torch.cat([edge_attr, rotations], dim=1)  # 合并方向和旋转特征
        label = torch.tensor(sample['label'], dtype=torch.long)

        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features, y=label)
        return data

# 创建 PyTorch Geometric 数据集（可选）
train_pyg_dataset = PDBDataset(train_dataset)
test_pyg_dataset = PDBDataset(test_dataset)

# 创建 DataLoader（可选）
train_loader = DataLoader(train_pyg_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_pyg_dataset, batch_size=32, shuffle=False)

print(f"✅ PyTorch Geometric 训练集数据量: {len(train_pyg_dataset)}")
print(f"✅ PyTorch Geometric 测试集数据量: {len(test_pyg_dataset)}")


# 结合

In [None]:
import os
import numpy as np
import pandas as pd

import json
from tqdm import tqdm
import random
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader

from torch_geometric.nn import (
    GATConv, SAGEConv, GINConv, Set2Set, global_mean_pool
)
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    classification_report, accuracy_score, precision_recall_fscore_support,
    matthews_corrcoef, roc_auc_score, confusion_matrix
)
import copy
import optuna
import matplotlib.pyplot as plt
import shap
from xgboost import XGBClassifier
from sklearn.manifold import TSNE
import seaborn as sns
from torch_geometric.loader import DataLoader


esmc_folders = {
    'train_pos': '/content/drive/MyDrive/AFP_work/esmc_600_train_pos',
    'train_neg': '/content/drive/MyDrive/AFP_work/esmc_600_train_neg',
    'test_pos': '/content/drive/MyDrive/AFP_work/esmc_600_test_pos',
    'test_neg': '/content/drive/MyDrive/AFP_work/esmc_600_test_neg'
}

struct_folders = {
    'train': '/content/drive/MyDrive/AFP_work/pdb_features/aggregated/train_dataset.json',
    'test': '/content/drive/MyDrive/AFP_work/pdb_features/aggregated/test_dataset.json'
}

aggregated_output_folder = '/content/drive/MyDrive/AFP_work/esmc_struct_aggregated'
os.makedirs(aggregated_output_folder, exist_ok=True)

#==============================数据加载和预处理==============================
#***************************加载 ESM-C 特征***************************
def load_esmc_features(esmc_folder):
    logits_path = os.path.join(esmc_folder, 'combined_logits.npy')
    embeddings_path = os.path.join(esmc_folder, 'combined_embeddings.npy')
    logits = np.load(logits_path, allow_pickle=True)
    embeddings = np.load(embeddings_path, allow_pickle=True)

    # 调试：检查 logits 的结构
    print(f"Logits[0] 类型: {type(logits[0])}, 值: {logits[0]}")  #  类型 <class 'numpy.ndarray'>
    # 打印每个样本的 logits 和 embeddings
    print("Logits sample:", logits[0])  # 打印第一个样本的 logits
    print("Embeddings sample:", embeddings[0])  # 打印第一个样本的 embeddings

    # 保存特征
    # if save_path:
    #     np.savetxt(os.path.join(save_path, 'logits.csv'), logits, delimiter=",")
    #     np.savetxt(os.path.join(save_path, 'embeddings.csv'), embeddings, delimiter=",")

    # 从 ForwardTrackData 中提取 sequence 张量并池化
    logits_values = []
    for l in logits:
        # 假设 l 是一个包含 ForwardTrackData 的数组，取第一个元素
        forward_data = l[0] if isinstance(l, np.ndarray) else l
        sequence_tensor = forward_data.sequence  # 获取张量
        # 将张量移到 CPU 并转换为 float32
        sequence_tensor = sequence_tensor.to(device='cpu', dtype=torch.float32)
        # 对所有维度取均值，确保标量
        pooled_value = sequence_tensor.mean(dim=[0, 1, 2]).item()  # 池化为标量
        logits_values.append(pooled_value)

    logits_values = np.array(logits_values, dtype=np.float32).reshape(-1, 1)
    embeddings = embeddings.astype(np.float32)
    label = 1 if 'pos' in esmc_folder else 0
    labels = np.full((logits_values.shape[0],), label)
    return logits_values, embeddings, labels
# 加载训练集和测试集的 ESM-C 特征
train_pos_logits, train_pos_embeddings, train_pos_labels = load_esmc_features(esmc_folders['train_pos'])
train_neg_logits, train_neg_embeddings, train_neg_labels = load_esmc_features(esmc_folders['train_neg'])
test_pos_logits, test_pos_embeddings, test_pos_labels = load_esmc_features(esmc_folders['test_pos'])
test_neg_logits, test_neg_embeddings, test_neg_labels = load_esmc_features(esmc_folders['test_neg'])
# 合并训练集和测试集特征
train_logits = np.vstack((train_pos_logits, train_neg_logits))
train_embeddings = np.vstack((train_pos_embeddings, train_neg_embeddings))
train_labels = np.hstack((train_pos_labels, train_neg_labels))

test_logits = np.vstack((test_pos_logits, test_neg_logits))
test_embeddings = np.vstack((test_pos_embeddings, test_neg_embeddings))
test_labels = np.hstack((test_pos_labels, test_neg_labels))


#***************************2、加载结构特征***************************
def load_struct_features(json_path, sample_limit=5):
    with open(json_path, 'r') as f:
        json_data = json.load(f)
    data_list = []
    for idx, sample in enumerate(tqdm(json_data, desc=f'加载结构特征 from {json_path}')):
        required_keys = ['node_features', 'edge_features', 'label']
        if not all(key in sample for key in required_keys):
            print(f"[ERROR] 样本缺少必要的键: {sample}")
            continue
        node_features = sample['node_features']
        edge_features = sample['edge_features']
        label = sample['label']
        edges = edge_features.get('edges', [])
        directions = edge_features.get('directions', [])
        rotations = edge_features.get('rotations', [])
        num_edges = len(edges)
        if not (len(directions) == num_edges and len(rotations) == num_edges):
            print(f"[ERROR] 边的数量与方向或旋转数量不匹配: {sample}")
            continue
        if num_edges > 0:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            directions = torch.tensor(directions, dtype=torch.float)
            rotations = torch.tensor(rotations, dtype=torch.float).unsqueeze(1)
            edge_attr = torch.cat([directions, rotations], dim=1)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_attr = torch.empty((0, 4), dtype=torch.float)
        node_features = torch.tensor(node_features, dtype=torch.float)
        label = torch.tensor(label, dtype=torch.long)
        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=label)
        data_list.append(data)
        if idx < sample_limit:
            num_nodes = node_features.shape[0]
            node_feature_dim = node_features.shape[1]
            print(f"样本 {idx+1}: 节点数量: {num_nodes}, 节点特征维度: {node_feature_dim}, 边数量: {num_edges}")
            if num_edges > 0:
                print(f"  边特征维度: {edge_attr.shape[1]}")
            print("-" * 50)
    unique_node_feature_dims = set([data.x.shape[1] for data in data_list])
    unique_edge_feature_dims = set([data.edge_attr.shape[1] for data in data_list if data.edge_attr.shape[0] > 0])
    print(f"\n所有样本中唯一的节点特征维度: {unique_node_feature_dims}")  # 3
    print(f"所有样本中唯一的边特征维度: {unique_edge_feature_dims}")  # 4
    return data_list

train_struct_data = load_struct_features(struct_folders['train'])
test_struct_data = load_struct_features(struct_folders['test'])

for i in range(min(3, len(train_struct_data))):
    data = train_struct_data[i]
    print(f"样本 {i+1} - 节点特征: {data.x.shape}, 边特征: {data.edge_attr.shape}")

##*************************** 特征标准化 ***************************
def normalize_features(train_data_list, test_data_list=None):
    node_scaler = StandardScaler()
    edge_scaler = StandardScaler()
    all_node_features = np.concatenate([data.x.numpy() for data in train_data_list], axis=0)
    all_edge_features = np.concatenate([data.edge_attr.numpy() for data in train_data_list if data.edge_attr.shape[0] > 0], axis=0)
    node_scaler.fit(all_node_features)
    if all_edge_features.size > 0:
        edge_scaler.fit(all_edge_features)
    for data in train_data_list:
        data.x = torch.tensor(node_scaler.transform(data.x.numpy()), dtype=torch.float)
        if data.edge_attr.shape[0] > 0:
            data.edge_attr = torch.tensor(edge_scaler.transform(data.edge_attr.numpy()), dtype=torch.float)
    if test_data_list:
        for data in test_data_list:
            data.x = torch.tensor(node_scaler.transform(data.x.numpy()), dtype=torch.float)
            if data.edge_attr.shape[0] > 0:
                data.edge_attr = torch.tensor(edge_scaler.transform(data.edge_attr.numpy()), dtype=torch.float)
    return train_data_list, test_data_list

train_struct_data, test_struct_data = normalize_features(train_struct_data, test_struct_data)

# #*************************** 整合 ESM-C 特征 ***************************
def integrate_features(data_list, embeddings, logits):
    if len(data_list) != len(embeddings) or len(data_list) != len(logits):
        raise ValueError(f"data_list, embeddings 和 logits 长度不匹配: {len(data_list)} vs {len(embeddings)} vs {len(logits)}")
    for i, data in enumerate(tqdm(data_list, desc='整合 ESM-C embeddings 和 logits')):
        embedding = torch.tensor(embeddings[i], dtype=torch.float)  # [1152]
        logit = torch.tensor(logits[i], dtype=torch.float).squeeze()  # [1] -> 标量
        combined_feature = torch.cat([embedding, logit.unsqueeze(0)], dim=0)  # [1152 + 1 = 1153]
        num_nodes = data.x.shape[0]
        combined_expanded = combined_feature.unsqueeze(0).repeat(num_nodes, 1)  # [num_nodes, 1153]
        data.x = torch.cat([data.x, combined_expanded], dim=1)  # [num_nodes, 3 + 1153 = 1156]
    return data_list

train_struct_data = integrate_features(train_struct_data, train_embeddings, train_logits)
test_struct_data = integrate_features(test_struct_data, test_embeddings, test_logits)

print(f"训练集第一个样本的节点特征维度（整合后）: {train_struct_data[0].x.shape[1]}")  # 1156
print(f"测试集第一个样本的节点特征维度（整合后）: {test_struct_data[0].x.shape[1]}") # 1156


class ProteinDataset(Dataset):
    def __init__(self, data_list):
        super(ProteinDataset, self).__init__()
        self.data_list = data_list
    def len(self):
        return len(self.data_list)
    def get(self, idx):
        return self.data_list[idx]

train_dataset = ProteinDataset(train_struct_data)
test_dataset = ProteinDataset(test_struct_data)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)


In [None]:
class DeepGATModel(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, hidden_dim, out_dim, num_heads=4, dropout=0.3, num_layers=3):
        super(DeepGATModel, self).__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = nn.Dropout(p=dropout)

        # 边特征预处理层
        self.edge_preprocess = nn.Sequential(
            nn.Linear(edge_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 堆叠 GAT 层
        for layer in range(num_layers):
            in_dim = node_feature_dim if layer == 0 else hidden_dim * num_heads
            self.convs.append(GATConv(
                in_channels=in_dim,
                out_channels=hidden_dim,
                heads=num_heads,
                dropout=dropout,
                edge_dim=hidden_dim,  # 调整为预处理后的边特征维度
                add_self_loops=True  # 添加自环，增强稳定性。能提升稳定性（每个节点至少保留自身信息）
            ))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim * num_heads))

        # 替换 Set2Set 为更简单的池化
        self.readout = global_mean_pool
        self.fc1 = nn.Linear(hidden_dim * num_heads, 256)
        self.fc2 = nn.Linear(256, out_dim)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        # 预处理边特征
        if edge_attr is not None:
            edge_attr = self.edge_preprocess(edge_attr)

        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index, edge_attr)
            x = bn(x)
            x = F.elu(x)
            x = self.dropout(x)

        x = self.readout(x, batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def get_last_layer_features(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        if edge_attr is not None:
            edge_attr = self.edge_preprocess(edge_attr)

        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index, edge_attr)
            x = bn(x)
            x = F.elu(x)
            x = self.dropout(x)

        x = self.readout(x, batch)
        x = self.fc1(x)
        return x

class GraphSAGEModel(nn.Module):
    def __init__(self, node_feature_dim, hidden_dim, out_dim, num_layers=3, dropout=0.5):
        super(GraphSAGEModel, self).__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = nn.Dropout(p=dropout)
        self.convs.append(SAGEConv(node_feature_dim, hidden_dim))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        self.convs.append(SAGEConv(hidden_dim, hidden_dim))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        self.readout = global_mean_pool
        self.fc1 = nn.Linear(hidden_dim, 256)
        self.fc2 = nn.Linear(256, out_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
        x = self.readout(x, batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def get_last_layer_features(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
        x = self.readout(x, batch)
        x = self.fc1(x)
        return x

class GINModel(nn.Module):
    def __init__(self, node_feature_dim, hidden_dim, out_dim, num_layers=3, dropout=0.5):
        super(GINModel, self).__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = nn.Dropout(p=dropout)
        for layer in range(num_layers):
            if layer == 0:
                nn_lin = nn.Sequential(nn.Linear(node_feature_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
            else:
                nn_lin = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
            self.convs.append(GINConv(nn_lin))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        self.readout = global_mean_pool
        self.fc1 = nn.Linear(hidden_dim, 256)
        self.fc2 = nn.Linear(256, out_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
        x = self.readout(x, batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def get_last_layer_features(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
        x = self.readout(x, batch)
        x = self.fc1(x)
        return x

# 交叉注意力融合模块
class CrossAttentionFusion(nn.Module):
    def __init__(self, feature_dim, num_heads=4, dropout=0.1):
        super(CrossAttentionFusion, self).__init__()
        # 交叉注意力机制
        self.attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(feature_dim)
        self.fc = nn.Linear(feature_dim, 2)  # 最终分类层，输出 2 类

    def forward(self, features_list):
        # features_list: [model1_features, model2_features, model3_features], 每个形状为 [batch_size, feature_dim]
        # 堆叠特征为 [num_models, batch_size, feature_dim]
        feats = torch.stack(features_list, dim=0)   # features_list = [feat_gat, feat_sage, feat_gin]
        # 应用交叉注意力
        attn_output, _ = self.attention(feats, feats, feats)
        # 融合特征：取平均值
        fused_feats = attn_output.mean(dim=0)  # [batch_size, feature_dim]
        fused_feats = self.norm(fused_feats)
        # 分类
        out = self.fc(fused_feats)
        return out

def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=10, model_save_path='best_model.pth'):
    best_test_acc = 0
    best_model_wts = copy.deepcopy(model.state_dict())
    epochs_no_improve = 0
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        for data in tqdm(train_loader, desc=f'训练 Epoch {epoch}/{num_epochs}'):
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.num_graphs
        avg_loss = total_loss / len(train_loader.dataset)
        scheduler.step()
        train_acc, _, _ = test(model, train_loader, device)
        test_acc, test_trues, test_preds = test(model, test_loader, device)
        print(f"Epoch: {epoch:02d}, Loss: {avg_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
            torch.save(model.state_dict(), model_save_path)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"早停：在第 {epoch} 轮训练后，无提升，停止训练。")
                break
    model.load_state_dict(best_model_wts)
    return best_test_acc, best_model_wts

def test(model, loader, device):
    model.eval()
    correct = 0
    preds, trues = [], []
    with torch.no_grad():
        for data in tqdm(loader, desc='评估'):
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            preds.extend(pred.cpu().numpy())
            trues.extend(data.y.cpu().numpy())
            correct += (pred == data.y).sum().item()
    accuracy = correct / len(loader.dataset)
    return accuracy, trues, preds

from sklearn.metrics import matthews_corrcoef, roc_auc_score, confusion_matrix

def detailed_test(model, loader, device, models=None):
    model.eval()
    preds, trues, probs = [], [], []
    with torch.no_grad():
        for data in tqdm(loader, desc='详细评估'):
            data = data.to(device)
            if isinstance(model, CrossAttentionFusion):  # 检查是否为融合模型
                if models is None:
                    raise ValueError("models dictionary required for CrossAttentionFusion evaluation")
                # 提取特征列表
                feat_gat = models['DeepGATModel'].get_last_layer_features(data)
                feat_sage = models['GraphSAGEModel'].get_last_layer_features(data)
                feat_gin = models['GINModel'].get_last_layer_features(data)
                features_list = [feat_gat, feat_sage, feat_gin]
                out = model(features_list)
            else:
                out = model(data)  # 普通模型直接处理 DataBatch
            prob = F.softmax(out, dim=1)[:, 1].cpu().numpy()
            pred = out.argmax(dim=1).cpu().numpy()
            true = data.y.cpu().numpy()
            preds.extend(pred)
            trues.extend(true)
            probs.extend(prob)
    acc = accuracy_score(trues, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(trues, preds, average='binary')
    mcc = matthews_corrcoef(trues, preds)
    auc = roc_auc_score(trues, probs)
    tn, fp, fn, tp = confusion_matrix(trues, preds).ravel()
    sn = tp / (tp + fn) if (tp + fn) > 0 else 0
    sp = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics = {'acc': acc, 'mcc': mcc, 'auc': auc, 'sn': sn, 'sp': sp, 'precision': precision, 'recall': recall, 'f1': f1}
    return metrics

def optimize_model(model_class, train_loader, test_loader, device, model_params, n_trials=5):
    def objective(trial):
        # 定义超参数搜索空间
        hidden_dim = trial.suggest_int('hidden_dim', 64, 512)
        num_layers = trial.suggest_int('num_layers', 2, 6)
        dropout = trial.suggest_float('dropout', 0.1, 0.5)
        lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)

        # 根据模型类初始化模型
        if model_class == DeepGATModel:
            num_heads = trial.suggest_int('num_heads', 2, 16)
            model = DeepGATModel(
                node_feature_dim=model_params['node_feature_dim'],
                edge_feature_dim=model_params['edge_feature_dim'],
                hidden_dim=hidden_dim,
                out_dim=model_params['out_dim'],
                num_heads=num_heads,
                dropout=dropout,
                num_layers=num_layers
            ).to(device)
        elif model_class == GraphSAGEModel:
            model = GraphSAGEModel(
                node_feature_dim=model_params['node_feature_dim'],
                hidden_dim=hidden_dim,
                out_dim=model_params['out_dim'],
                num_layers=num_layers,
                dropout=dropout
            ).to(device)
        elif model_class == GINModel:
            model = GINModel(
                node_feature_dim=model_params['node_feature_dim'],
                hidden_dim=hidden_dim,
                out_dim=model_params['out_dim'],
                num_layers=num_layers,
                dropout=dropout
            ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

        best_acc, _ = train_model(
            model, train_loader, test_loader, criterion, optimizer, scheduler,
            device, num_epochs=50, patience=10, model_save_path=f"best_{model_class.__name__}.pth"
        )
        return best_acc

    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=5)
    print(f"总试验次数: {len(study.trials)}")
    print(f"{model_class.__name__} 最佳超参数: {study.best_params}")
    for trial in study.trials:
      print(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
    return study.best_params

def explain_features(model, test_loader, device, output_folder):
    # ESM-C 特征的 SHAP 分析
    print("正在进行 ESM-C 特征的 SHAP 分析...")
    esmc_features = np.hstack([train_embeddings, train_logits])
    labels = train_labels
    proxy_model = XGBClassifier()
    proxy_model.fit(esmc_features, labels)
    explainer = shap.Explainer(proxy_model)
    shap_values = explainer(esmc_features)
    shap.summary_plot(shap_values, esmc_features, plot_type="bar", show=False)
    plt.title("ESM-C 特征重要性 (SHAP)")
    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, "shap_esmc_features.png"))
    plt.close()
    print("SHAP 分析完成，结果已保存至 shap_esmc_features.png")

    #GNNExplainer 分析（以 DeepGATModel 为例）
    print("正在进行 GNNExplainer 分析...")
    trained_model = models['DeepGATModel']
    explainer = GNNExplainer(trained_model, epochs=200, lr=0.01)
    for sample_idx in range(min(5, len(test_struct_data))):
        data = test_struct_data[sample_idx].to(device)
        node_idx = 0  # 分析第一个节点
        node_feat_mask, edge_mask = explainer.explain_node(node_idx, data.x, data.edge_index, data.edge_attr)
        print(f"样本 {sample_idx+1} | 节点 0 特征重要性（前5个）: {node_feat_mask[:5]} | 边重要性（前5个）: {edge_mask[:5]}")
    print("GNNExplainer 分析完成")

    # t-SNE 可视化
    print("正在进行 t-SNE 可视化...")
    def get_last_layer_features(model, loader, device):
        model.eval()
        features = []
        labels = []
        with torch.no_grad():
            for data in loader:
                data = data.to(device)
                feat = model.get_last_layer_features(data)
                features.append(feat.cpu().numpy())
                labels.append(data.y.cpu().numpy())
        return np.vstack(features), np.hstack(labels)

    features, labels = get_last_layer_features(trained_model, test_loader, device)
    tsne = TSNE(n_components=2, random_state=42)
    features_2d = tsne.fit_transform(features)
    plt.figure(figsize=(8, 6))
    plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='coolwarm', alpha=0.6)
    plt.title("t-SNE of Last Layer Features (DeepGATModel)")
    plt.colorbar(label='Class')
    plt.savefig(os.path.join(output_folder, "tsne_last_layer.png"))
    plt.close()
    print("t-SNE 可视化完成，结果已保存至 tsne_last_layer.png")

    return results, models

def train_and_evaluate_models(train_loader, test_loader, device, output_folder):
    model_params = {"node_feature_dim": 1156, "edge_feature_dim": 4, "out_dim": 2}
    best_params = {}

    # 优化并训练三个基础模型
    for model_class in [DeepGATModel, GraphSAGEModel, GINModel]:
        print(f"优化 {model_class.__name__}...")
        best_params[model_class.__name__] = optimize_model(model_class, train_loader, test_loader, device, model_params, n_trials=10)

    # 初始化模型
    models = {
        "DeepGATModel": DeepGATModel(
            node_feature_dim=1156, edge_feature_dim=4, out_dim=2,
            hidden_dim=best_params["DeepGATModel"]["hidden_dim"],
            num_layers=best_params["DeepGATModel"]["num_layers"],
            dropout=best_params["DeepGATModel"]["dropout"],
            num_heads=best_params["DeepGATModel"]["num_heads"]
        ).to(device),
        "GraphSAGEModel": GraphSAGEModel(
            node_feature_dim=1156, out_dim=2,
            hidden_dim=best_params["GraphSAGEModel"]["hidden_dim"],
            num_layers=best_params["GraphSAGEModel"]["num_layers"],
            dropout=best_params["GraphSAGEModel"]["dropout"]
        ).to(device),
        "GINModel": GINModel(
            node_feature_dim=1156, out_dim=2,
            hidden_dim=best_params["GINModel"]["hidden_dim"],
            num_layers=best_params["GINModel"]["num_layers"],
            dropout=best_params["GINModel"]["dropout"]
        ).to(device)
    }

    # 训练并评估每个模型
    results = {}
    for name, model in models.items():
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=best_params[name]["lr"], weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        save_path = os.path.join(output_folder, f"best_{name}.pth")
        best_acc, _ = train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=10, model_save_path=save_path)
        metrics = detailed_test(model, test_loader, device)
        results[name] = metrics
        print(f"{name} - Acc: {metrics['acc']:.4f}, MCC: {metrics['mcc']:.4f}, AUC: {metrics['auc']:.4f}")

    # 性能对比
    print("\n### 三个模型性能对比 ###")
    for name, metrics in results.items():
        print(f"{name}: Acc: {metrics['acc']:.4f}, MCC: {metrics['mcc']:.4f}, AUC: {metrics['auc']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, F1: {metrics['f1']:.4f}")

    # 加载最佳模型
    for name, model in models.items():
        model.load_state_dict(torch.load(os.path.join(output_folder, f"best_{name}.pth")))
        model.eval()

    return results, models

# 交叉注意力融合训练
# 在交叉注意力融合训练函数中修改
def train_cross_attention_fusion(models, train_loader, test_loader, device, output_folder, num_epochs=50, patience=10):
    fusion_module = CrossAttentionFusion(feature_dim=256, num_heads=4, dropout=0.1).to(device)
    optimizer_fusion = torch.optim.Adam(fusion_module.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler_fusion = torch.optim.lr_scheduler.StepLR(optimizer_fusion, step_size=10, gamma=0.1)

    print("\n### 训练交叉注意力融合模块 ###")
    best_fusion_acc = 0
    best_fusion_wts = copy.deepcopy(fusion_module.state_dict())
    epochs_no_improve = 0

    for epoch in range(1, num_epochs + 1):
        fusion_module.train()
        total_loss = 0
        for data in tqdm(train_loader, desc=f'融合训练 Epoch {epoch}/{num_epochs}'):
            data = data.to(device)
            with torch.no_grad():
                # 确保获取的是纯张量
                feat_gat = models['DeepGATModel'].get_last_layer_features(data)
                feat_sage = models['GraphSAGEModel'].get_last_layer_features(data)
                feat_gin = models['GINModel'].get_last_layer_features(data)
                # 调试：打印特征形状和类型
                print(f"feat_gat shape: {feat_gat.shape}, type: {type(feat_gat)}")
                print(f"feat_sage shape: {feat_sage.shape}, type: {type(feat_sage)}")
                print(f"feat_gin shape: {feat_gin.shape}, type: {type(feat_gin)}")
                # 如果返回的是 DataBatch，提取特征张量
                if isinstance(feat_gat, torch.Tensor) and feat_gat.dim() == 2:  # 确保是 [batch_size, feature_dim]
                    features_list = [feat_gat, feat_sage, feat_gin]
                else:
                    raise ValueError("Expected pure tensors from get_last_layer_features, got unexpected type or shape")
            out = fusion_module(features_list)
            loss = criterion(out, data.y)
            optimizer_fusion.zero_grad()
            loss.backward()
            optimizer_fusion.step()
            total_loss += loss.item() * data.num_graphs
        avg_loss = total_loss / len(train_loader.dataset)
        scheduler_fusion.step()

        fusion_module.eval()
        preds, trues = [], []
        with torch.no_grad():
            for data in test_loader:
                data = data.to(device)
                feat_gat = models['DeepGATModel'].get_last_layer_features(data)
                feat_sage = models['GraphSAGEModel'].get_last_layer_features(data)
                feat_gin = models['GINModel'].get_last_layer_features(data)
                # 同样的检查和处理
                if isinstance(feat_gat, torch.Tensor) and feat_gat.dim() == 2:
                    features_list = [feat_gat, feat_sage, feat_gin]
                else:
                    raise ValueError("Expected pure tensors from get_last_layer_features in test loop")
                out = fusion_module(features_list)
                pred = out.argmax(dim=1)
                preds.extend(pred.cpu().numpy())
                trues.extend(data.y.cpu().numpy())
        test_acc = accuracy_score(trues, preds)
        print(f"Epoch: {epoch:02d}, Loss: {avg_loss:.4f}, Fusion Test Acc: {test_acc:.4f}")

        if test_acc > best_fusion_acc:
            best_fusion_acc = test_acc
            best_fusion_wts = copy.deepcopy(fusion_module.state_dict())
            epochs_no_improve = 0
            torch.save(fusion_module.state_dict(), os.path.join(output_folder, "best_cross_attention_fusion.pth"))
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"早停：在第 {epoch} 轮训练后，无提升，停止训练。")
                break

    fusion_module.load_state_dict(best_fusion_wts)
    fusion_metrics = detailed_test(fusion_module, test_loader, device, models=models)  # 传递 models 参数
    return fusion_module, fusion_metrics

In [None]:
# ### 运行实验
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")

    results, models = train_and_evaluate_models(train_loader, test_loader, device, aggregated_output_folder)

In [None]:
fusion_module, fusion_metrics = train_cross_attention_fusion(models, train_loader, test_loader, device, aggregated_output_folder)
results["CrossAttentionFusion"] = fusion_metrics
print(f"CrossAttentionFusion - Acc: {fusion_metrics['acc']:.4f}, MCC: {fusion_metrics['mcc']:.4f}, AUC: {fusion_metrics['auc']:.4f}, Precision: {fusion_metrics['precision']:.4f}, Recall: {fusion_metrics['recall']:.4f}, F1: {fusion_metrics['f1']:.4f}")

# 单调GAT

In [None]:
import os
import numpy as np
import pandas as pd
import json
from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.nn.inits import reset
from torch_geometric.explain import GNNExplainer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    classification_report, accuracy_score, precision_recall_fscore_support,
    matthews_corrcoef, roc_auc_score, confusion_matrix
)
import copy
import optuna
import matplotlib.pyplot as plt
import shap
from xgboost import XGBClassifier
from sklearn.manifold import TSNE
import seaborn as sns

# 设置设备（GPU/CPU）
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 路径配置
esmc_folders = {
    'train_pos': '/content/drive/MyDrive/AFP_work/esmc_600_train_pos',
    'train_neg': '/content/drive/MyDrive/AFP_work/esmc_600_train_neg',
    'test_pos': '/content/drive/MyDrive/AFP_work/esmc_600_test_pos',
    'test_neg': '/content/drive/MyDrive/AFP_work/esmc_600_test_neg'
}

struct_folders = {
    'train': '/content/drive/MyDrive/AFP_work/pdb_features/aggregated/train_dataset.json',
    'test': '/content/drive/MyDrive/AFP_work/pdb_features/aggregated/test_dataset.json'
}

output_folder = '/content/drive/MyDrive/AFP_work/deepgat_explain_results'
os.makedirs(output_folder, exist_ok=True)

# ### 数据加载和预处理

# **加载 ESM-C 特征**
def load_esmc_features(esmc_folder, save_path=None):
    logits_path = os.path.join(esmc_folder, 'combined_logits.npy')
    embeddings_path = os.path.join(esmc_folder, 'combined_embeddings.npy')
    logits = np.load(logits_path, allow_pickle=True)
    embeddings = np.load(embeddings_path, allow_pickle=True)

    # 调试：检查 logits 和 embeddings 的结构
    print(f"Logits[0] 类型: {type(logits[0])}, 值: {logits[0]}")
    print(f"Embeddings[0] 形状: {embeddings[0].shape}, 样本: {embeddings[0][:5]}")

    # 从 ForwardTrackData 中提取 sequence 张量并池化
    logits_values = []
    for l in logits:
        forward_data = l[0] if isinstance(l, np.ndarray) else l
        sequence_tensor = forward_data.sequence  # 获取张量
        sequence_tensor = sequence_tensor.to(device='cpu', dtype=torch.float32)
        pooled_value = sequence_tensor.mean(dim=[0, 1, 2]).item()  # 池化为标量
        logits_values.append(pooled_value)

    logits_values = np.array(logits_values, dtype=np.float32).reshape(-1, 1)
    embeddings = embeddings.astype(np.float32)
    label = 1 if 'pos' in esmc_folder else 0
    labels = np.full((logits_values.shape[0],), label)
    return logits_values, embeddings, labels

# 加载训练集和测试集的 ESM-C 特征
train_pos_logits, train_pos_embeddings, train_pos_labels = load_esmc_features(esmc_folders['train_pos'])
train_neg_logits, train_neg_embeddings, train_neg_labels = load_esmc_features(esmc_folders['train_neg'])
test_pos_logits, test_pos_embeddings, test_pos_labels = load_esmc_features(esmc_folders['test_pos'])
test_neg_logits, test_neg_embeddings, test_neg_labels = load_esmc_features(esmc_folders['test_neg'])

# 合并训练集和测试集特征
train_logits = np.vstack((train_pos_logits, train_neg_logits))
train_embeddings = np.vstack((train_pos_embeddings, train_neg_embeddings))
train_labels = np.hstack((train_pos_labels, train_neg_labels))

test_logits = np.vstack((test_pos_logits, test_neg_logits))
test_embeddings = np.vstack((test_pos_embeddings, test_neg_embeddings))
test_labels = np.hstack((test_pos_labels, test_neg_labels))

print(f"训练集 logits 形状: {train_logits.shape}")
print(f"训练集 embeddings 形状: {train_embeddings.shape}")
print(f"训练集 labels 形状: {train_labels.shape}")
print(f"测试集 logits 形状: {test_logits.shape}")
print(f"测试集 embeddings 形状: {test_embeddings.shape}")
print(f"测试集 labels 形状: {test_labels.shape}")

# **加载结构特征**
def load_struct_features(json_path, sample_limit=5):
    with open(json_path, 'r') as f:
        json_data = json.load(f)
    data_list = []
    for idx, sample in enumerate(tqdm(json_data, desc=f'加载结构特征 from {json_path}')):
        required_keys = ['node_features', 'edge_features', 'label']
        if not all(key in sample for key in required_keys):
            print(f" 样本缺少必要的键: {sample}")
            continue
        node_features = sample['node_features']
        edge_features = sample['edge_features']
        label = sample['label']
        edges = edge_features.get('edges', [])
        directions = edge_features.get('directions', [])
        rotations = edge_features.get('rotations', [])
        num_edges = len(edges)
        if not (len(directions) == num_edges and len(rotations) == num_edges):
            print(f" 边的数量与方向或旋转数量不匹配: {sample}")
            continue
        if num_edges > 0:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            directions = torch.tensor(directions, dtype=torch.float)
            rotations = torch.tensor(rotations, dtype=torch.float).unsqueeze(1)
            edge_attr = torch.cat([directions, rotations], dim=1)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_attr = torch.empty((0, 4), dtype=torch.float)
        node_features = torch.tensor(node_features, dtype=torch.float)
        label = torch.tensor(label, dtype=torch.long)
        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=label)
        data_list.append(data)
        if idx < sample_limit:
            num_nodes = node_features.shape[0]
            node_feature_dim = node_features.shape[1]
            print(f"样本 {idx+1}: 节点数量: {num_nodes}, 节点特征维度: {node_feature_dim}, 边数量: {num_edges}")
            if num_edges > 0:
                print(f"  边特征维度: {edge_attr.shape[1]}")
            print("-" * 50)
    print(f"\n所有样本中唯一的节点特征维度: {set([data.x.shape[1] for data in data_list])}")
    print(f"所有样本中唯一的边特征维度: {set([data.edge_attr.shape[1] for data in data_list if data.edge_attr.shape[0] > 0])}")
    return data_list

train_struct_data = load_struct_features(struct_folders['train'])
test_struct_data = load_struct_features(struct_folders['test'])

# **特征标准化**
def normalize_features(train_data_list, test_data_list=None):
    node_scaler = StandardScaler()
    edge_scaler = StandardScaler()
    all_node_features = np.concatenate([data.x.numpy() for data in train_data_list], axis=0)
    all_edge_features = np.concatenate([data.edge_attr.numpy() for data in train_data_list if data.edge_attr.shape[0] > 0], axis=0)
    node_scaler.fit(all_node_features)
    if all_edge_features.size > 0:
        edge_scaler.fit(all_edge_features)
    for data in train_data_list:
        data.x = torch.tensor(node_scaler.transform(data.x.numpy()), dtype=torch.float)
        if data.edge_attr.shape[0] > 0:
            data.edge_attr = torch.tensor(edge_scaler.transform(data.edge_attr.numpy()), dtype=torch.float)
    if test_data_list:
        for data in test_data_list:
            data.x = torch.tensor(node_scaler.transform(data.x.numpy()), dtype=torch.float)
            if data.edge_attr.shape[0] > 0:
                data.edge_attr = torch.tensor(edge_scaler.transform(data.edge_attr.numpy()), dtype=torch.float)
    return train_data_list, test_data_list

train_struct_data, test_struct_data = normalize_features(train_struct_data, test_struct_data)

# **整合 ESM-C 特征**
def integrate_features(data_list, embeddings, logits):
    if len(data_list) != len(embeddings) or len(data_list) != len(logits):
        raise ValueError(f"data_list, embeddings 和 logits 长度不匹配: {len(data_list)} vs {len(embeddings)} vs {len(logits)}")
    for i, data in enumerate(tqdm(data_list, desc='整合 ESM-C embeddings 和 logits')):
        embedding = torch.tensor(embeddings[i], dtype=torch.float)  # [1152]
        logit = torch.tensor(logits[i], dtype=torch.float).squeeze()  # [1] -> 标量
        combined_feature = torch.cat([embedding, logit.unsqueeze(0)], dim=0)  # [1153]
        num_nodes = data.x.shape[0]
        combined_expanded = combined_feature.unsqueeze(0).repeat(num_nodes, 1)  # [num_nodes, 1153]
        data.x = torch.cat([data.x, combined_expanded], dim=1)  # [num_nodes, 1156]
    return data_list

train_struct_data = integrate_features(train_struct_data, train_embeddings, train_logits)
test_struct_data = integrate_features(test_struct_data, test_embeddings, test_logits)

print(f"训练集第一个样本的节点特征维度（整合后）: {train_struct_data[0].x.shape[1]}")
print(f"测试集第一个样本的节点特征维度（整合后）: {test_struct_data[0].x.shape[1]}")

# **创建数据集和数据加载器**
class ProteinDataset(Dataset):
    def __init__(self, data_list):
        super(ProteinDataset, self).__init__()
        self.data_list = data_list
    def len(self):
        return len(self.data_list)
    def get(self, idx):
        return self.data_list[idx]

train_dataset = ProteinDataset(train_struct_data)
test_dataset = ProteinDataset(test_struct_data)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# ### 调试 DeepGATModel
class DeepGATModel(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, hidden_dim, out_dim, num_heads=4, dropout=0.3, num_layers=3):
        super(DeepGATModel, self).__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = nn.Dropout(p=dropout)

        # 边特征预处理层
        self.edge_preprocess = nn.Sequential(
            nn.Linear(edge_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 初始化GAT层
        for layer in range(num_layers):
            in_dim = node_feature_dim if layer == 0 else hidden_dim * num_heads
            self.convs.append(GATConv(
                in_channels=in_dim,
                out_channels=hidden_dim,
                heads=num_heads,
                dropout=dropout,
                edge_dim=hidden_dim,  # 预处理后的边特征维度
                add_self_loops=True  # 添加自环，增强稳定性
            ))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim * num_heads))

        # 池化和全连接层
        self.readout = global_mean_pool
        self.fc1 = nn.Linear(hidden_dim * num_heads, 256)
        self.fc2 = nn.Linear(256, out_dim)

    def forward(self, data, print_shapes=False):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # 仅在第一次批次打印形状（每轮试验一次）
        if print_shapes:
            print(f"输入 - 节点特征形状: {x.shape}, 边特征形状: {edge_attr.shape if edge_attr is not None else '无'}, 边索引形状: {edge_index.shape}")

        # 预处理边特征
        if edge_attr is not None and edge_attr.shape[0] > 0:
            edge_attr = self.edge_preprocess(edge_attr)
            if print_shapes:
                print(f"预处理后边特征形状: {edge_attr.shape}")
        else:
            edge_attr = None
            if print_shapes:
                print("无边特征，使用默认边处理")

        # 逐层处理
        for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
            x = conv(x, edge_index, edge_attr)
            if print_shapes:
                print(f"GATConv层 {i+1} 输出形状: {x.shape}")
            x = bn(x)
            x = F.elu(x)
            x = self.dropout(x)

        # 池化
        x = self.readout(x, batch)
        if print_shapes:
            print(f"池化后特征形状: {x.shape}")

        # 全连接层
        x = self.fc1(x)
        if print_shapes:
            print(f"FC1输出形状: {x.shape}")
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        if print_shapes:
            print(f"FC2输出形状: {x.shape}")  # 应为 [batch_size, 2]

        return x

    def get_last_layer_features(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        if edge_attr is not None and edge_attr.shape[0] > 0:
            edge_attr = self.edge_preprocess(edge_attr)
        else:
            edge_attr = None

        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index, edge_attr)
            x = bn(x)
            x = F.elu(x)
            x = self.dropout(x)

        x = self.readout(x, batch)
        x = self.fc1(x)
        return x

# ### 训练和调试 DeepGATModel
def train_deepgat(model, train_loader, test_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=10, model_save_path='best_deepgat.pth'):
    best_test_acc = 0
    best_model_wts = None  # 初始值设置为None
    epochs_no_improve = 0

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        # 仅在第一个批次打印形状（每轮试验一次）
        print_shapes = (epoch == 1)  # 仅第一轮打印
        for data in tqdm(train_loader, desc=f'训练 DeepGAT Epoch {epoch}/{num_epochs}'):
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data, print_shapes=print_shapes)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.num_graphs
            if print_shapes:
                print(f"批次损失: {loss.item():.4f}")
                print_shapes = False  # 仅打印一次

        avg_loss = total_loss / len(train_loader.dataset)
        scheduler.step()
        train_acc, train_trues, train_preds = test(model, train_loader, device)
        test_acc, test_trues, test_preds = test(model, test_loader, device)
        print(f"Epoch: {epoch:02d}, Loss: {avg_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")

        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_model_wts = copy.deepcopy(model.state_dict())  # 更新最佳权重
            epochs_no_improve = 0
            torch.save(best_model_wts, model_save_path)  # 保存最佳模型
            print(f"保存最佳模型，测试准确率: {test_acc:.4f}")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"早停：在第 {epoch} 轮训练后，无提升，停止训练。")
                break

    # 加载最佳权重
    if best_model_wts is not None:
        model.load_state_dict(best_model_wts)
    else:
        print("警告：未找到最佳权重，使用当前模型状态。")
    return best_test_acc, best_model_wts

def test(model, loader, device):
    model.eval()
    correct = 0
    preds, trues = [], []
    with torch.no_grad():
        for data in tqdm(loader, desc='评估 DeepGAT'):
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            preds.extend(pred.cpu().numpy())
            trues.extend(data.y.cpu().numpy())
            correct += (pred == data.y).sum().item()
    accuracy = correct / len(loader.dataset)
    return accuracy, trues, preds

def detailed_test(model, loader, device):
    model.eval()
    preds, trues, probs = [], [], []
    with torch.no_grad():
        for data in tqdm(loader, desc='详细评估 DeepGAT'):
            data = data.to(device)
            out = model(data)
            prob = F.softmax(out, dim=1)[:, 1].cpu().numpy()
            pred = out.argmax(dim=1).cpu().numpy()
            true = data.y.cpu().numpy()
            preds.extend(pred)
            trues.extend(true)
            probs.extend(prob)
    acc = accuracy_score(trues, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(trues, preds, average='binary')
    mcc = matthews_corrcoef(trues, preds)
    auc = roc_auc_score(trues, probs)
    tn, fp, fn, tp = confusion_matrix(trues, preds).ravel()
    sn = tp / (tp + fn) if (tp + fn) > 0 else 0
    sp = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics = {'acc': acc, 'mcc': mcc, 'auc': auc, 'sn': sn, 'sp': sp, 'precision': precision, 'recall': recall, 'f1': f1}
    return metrics

# ### 超参数优化 (Optuna)

def optimize_deepgat(train_loader, test_loader, device, n_trials=10):
    def objective(trial):
        # 定义超参数搜索空间
        hidden_dim = trial.suggest_int('hidden_dim', 64, 512)
        num_layers = trial.suggest_int('num_layers', 2, 6)
        num_heads = trial.suggest_int('num_heads', 2, 16)
        dropout = trial.suggest_float('dropout', 0.1, 0.5)
        lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)

        # 初始化 DeepGATModel
        model = DeepGATModel(
            node_feature_dim=1156,  # 整合后的节点特征维度
            edge_feature_dim=4,     # 边特征维度
            hidden_dim=hidden_dim,
            out_dim=2,             # 二分类
            num_heads=num_heads,
            dropout=dropout,
            num_layers=num_layers
        ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

        # 训练并返回最佳准确率和权重
        best_acc, best_wts = train_deepgat(
            model, train_loader, test_loader, criterion, optimizer, scheduler,
            device, num_epochs=50, patience=10, model_save_path='best_deepgat_optimized.pth'
        )
        return best_acc

    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=n_trials)
    print(f"总试验次数: {len(study.trials)}")
    print(f"DeepGATModel 最佳超参数: {study.best_params}")
    for trial in study.trials:
        print(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
    return study.best_params

# ### 可解释性分析函数

def explain_deepgat(model, train_loader, test_loader, device, output_folder):
    # 1. SHAP分析（ESM-C特征）
    print("正在进行 ESM-C 特征的 SHAP 分析...")
    esmc_features = np.hstack([train_embeddings, train_logits])  # [num_samples, 1153]
    labels = train_labels
    proxy_model = XGBClassifier()
    proxy_model.fit(esmc_features, labels)
    explainer = shap.Explainer(proxy_model)
    shap_values = explainer(esmc_features)
    # 绘制SHAP特征重要性
    plt.figure(figsize=(10, 6))
    shap.summary_plot(shap_values, esmc_features, plot_type="bar", show=False)
    plt.title("ESM-C 特征重要性 (SHAP)")
    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, "shap_esmc_features.png"))
    plt.close()
    print("SHAP 分析完成，结果已保存至 shap_esmc_features.png")

    # 2. GNNExplainer分析（节点和边的重要性）
    print("正在进行 GNNExplainer 分析...")
    model.eval()
    reset(model)  # 重置模型参数以确保解释一致性
    explainer = GNNExplainer(model, epochs=200, lr=0.01)
    for sample_idx in range(min(5, len(test_struct_data))):  # 分析前5个测试样本
        data = test_struct_data[sample_idx].to(device)
        node_idx = 0  # 分析第一个节点
        node_feat_mask, edge_mask = explainer.explain_node(node_idx, data.x, data.edge_index, data.edge_attr)
        print(f"样本 {sample_idx+1} | 节点 {node_idx} 特征重要性（前5个）: {node_feat_mask[:5]} | 边重要性（前5个）: {edge_mask[:5] if edge_mask is not None else '无'}")
    print("GNNExplainer 分析完成")

    # 3. t-SNE可视化（最后层特征分布）
    print("正在进行 t-SNE 可视化...")
    def get_last_layer_features(model, loader, device):
        model.eval()
        features = []
        labels = []
        with torch.no_grad():
            for data in loader:
                data = data.to(device)
                feat = model.get_last_layer_features(data)
                features.append(feat.cpu().numpy())
                labels.append(data.y.cpu().numpy())
        return np.vstack(features), np.hstack(labels)

    features, labels = get_last_layer_features(model, test_loader, device)
    tsne = TSNE(n_components=2, random_state=42)
    features_2d = tsne.fit_transform(features)
    plt.figure(figsize=(8, 6))
    plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='coolwarm', alpha=0.6)
    plt.title("DeepGATModel 最后层特征 t-SNE 可视化")
    plt.colorbar(label='Class (0=Neg, 1=Pos)')
    plt.savefig(os.path.join(output_folder, "tsne_deepgat_features.png"))
    plt.close()
    print("t-SNE 可视化完成，结果已保存至 tsne_deepgat_features.png")

if __name__ == "__main__":
    deepgat_model = DeepGATModel(
        node_feature_dim=1156,
        edge_feature_dim=4,
        hidden_dim=256,   
        out_dim=2,        
        num_heads=4,            
        dropout=0.3,            
        num_layers=3           
    ).to(device)

    print("开始优化 DeepGATModel 超参数...")
    best_params = optimize_deepgat(train_loader, test_loader, device, n_trials=10)

    # 使用最佳超参数初始化并训练模型
    deepgat_model = DeepGATModel(
        node_feature_dim=1156,
        edge_feature_dim=4,
        hidden_dim=best_params['hidden_dim'],
        out_dim=2,
        num_heads=best_params['num_heads'],
        dropout=best_params['dropout'],
        num_layers=best_params['num_layers']
    ).to(device)

    # 训练模型
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(deepgat_model.parameters(), lr=best_params['lr'], weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    best_acc, best_wts = train_deepgat(
        deepgat_model, train_loader, test_loader, criterion, optimizer, scheduler,
        device, num_epochs=50, patience=10, model_save_path=os.path.join(output_folder, 'best_deepgat_final.pth')
    )

    # 详细评估
    metrics = detailed_test(deepgat_model, test_loader, device)
    print("\n### DeepGATModel 详细性能 ###")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")

    # 保存模型和结果
    if best_wts is not None:
        torch.save(best_wts, os.path.join(output_folder, 'best_deepgat_final.pth'))
    with open(os.path.join(output_folder, 'deepgat_metrics.json'), 'w') as f:
        json.dump(metrics, f)

    # 调试：打印部分预测结果
    deepgat_model.eval()
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            out = deepgat_model(data)
            pred = out.argmax(dim=1)
            true = data.y
            print(f"预测样本 - 真实标签: {true[:5].cpu().numpy()}, 预测标签: {pred[:5].cpu().numpy()}")
            break  # 仅打印第一个批次

    # 可解释性分析
    explain_deepgat(deepgat_model, train_loader, test_loader, device, output_folder)