In [1]:
import os
import pandas as pd
from sklearn.model_selection import KFold
from Bio import pairwise2

In [2]:
# 假设数据集格式：每行包含蛋白质ID、变异信息、ΔΔG值
def load_dataset(file_path):
    """加载并预处理数据集"""
    df = pd.read_csv(file_path)
    # 确保包含蛋白质ID列（如'protein_id'）和序列列（如'sequence'）
    if 'pdb_id' not in df.columns or 'wt_seq' not in df.columns:
        raise ValueError("数据集需包含'protein_id'和'sequence'列")
    ddg_counts = df.groupby('mut_seq')['ddg'].transform('count')
    df = df[ddg_counts == 1]
    return df

def calculate_sequence_identity(seq1, seq2):
    """计算两条序列的identity相似度"""
    # 简化版：使用Bio库的全局比对
    alignment = pairwise2.align.globalxx(seq1, seq2, score_only=True)
    identity = alignment / max(len(seq1), len(seq2)) * 100
    return identity

def group_proteins_by_similarity(proteins, threshold=25):
    """根据序列相似度将蛋白质分组，确保同一组内相似度>threshold"""
    groups = []
    remaining = proteins.copy()
    
    while remaining:
        # 取出第一个蛋白质作为基准
        ref_protein = remaining.pop(0)
        ref_seq = ref_protein['wt_seq']
        group = [ref_protein]
        
        # 查找相似蛋白质
        for protein in remaining.copy():
            seq_id = calculate_sequence_identity(ref_seq, protein['wt_seq'])
            if seq_id > threshold:
                group.append(protein)
                remaining.remove(protein)
        
        groups.append(group)
    return groups

def split_proteins_into_folds(protein_groups, n_splits=5):
    """将蛋白质组分配到k折中，确保组内蛋白质在同一折"""
    # 先将每个蛋白质组视为一个"块"
    blocks = protein_groups
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=123)  # 与论文一致的随机种子
    
    # 生成折分配
    fold_assignments = {}
    for fold_idx, (train_idx, val_idx) in enumerate(kf.split(blocks)):
        # 训练折：合并多个蛋白质组
        train_proteins = []
        for idx in train_idx:
            train_proteins.extend(blocks[idx])
        
        # 验证折
        val_proteins = []
        for idx in val_idx:
            val_proteins.extend(blocks[idx])
        
        fold_assignments[fold_idx] = {
            'train': train_proteins,
            'val': val_proteins
        }
    return fold_assignments

# def apply_thermodynamic_reversibility(df):
#     """应用热力学可逆性：添加反向变异（A→B → B→A，ΔΔG取反）"""
#     reverses = []
#     for _, row in df.iterrows():
#         # 假设变异格式为"WT:POS:MUT"，如"A:123:G"
#         wt, pos, mut = row['mutation'].split(':')
#         reverse_mutation = f"{mut}:{pos}:{wt}"
#         reverse_ddg = -row['ddg']
#         reverses.append({
#             'protein_id': row['protein_id'],
#             'sequence': row['sequence'],
#             'mutation': reverse_mutation,
#             'ddg': reverse_ddg
#         })
#     
#     # 合并原始数据和反向变异
#     reversed_df = pd.DataFrame(reverses)
#     return pd.concat([df, reversed_df], ignore_index=True)

def prepare_fold_data(fold_assignments, original_df, output_dir=None):
    """为每个折准备训练/验证数据"""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    all_fold_data = []
    for fold_idx, data in fold_assignments.items():
        train_proteins = {p['pdb_id'] for p in data['train']}
        val_proteins = {p['pdb_id'] for p in data['val']}
        
        # 筛选训练/验证数据
        train_df = original_df[original_df['pdb_id'].isin(train_proteins)].copy()
        val_df = original_df[original_df['pdb_id'].isin(val_proteins)].copy()
        
        # 计算pdb_id的数量
        train_pdbs = len(train_df['pdb_id'].unique())
        val_pdbs = len(val_df['pdb_id'].unique())
        # 应用热力学可逆性增强
        # train_df = apply_thermodynamic_reversibility(train_df)
        
        # 保存数据
        train_path = os.path.join(output_dir, f'fold_{fold_idx}_train.csv')
        val_path = os.path.join(output_dir, f'fold_{fold_idx}_val.csv')
        train_df.to_csv(train_path, index=False)
        val_df.to_csv(val_path, index=False)
        
        all_fold_data.append({
            'fold': fold_idx,
            'train_size': len(train_df),
            'train_pdbs': train_pdbs,
            'val_size': len(val_df),
            'val_pdbs': val_pdbs
        })
    return all_fold_data

In [3]:
# 1. 加载数据集
dataset_path = './dataset/cdna/mutations/cdna_processed.csv'
df = load_dataset(dataset_path)
df

Unnamed: 0.1,Unnamed: 0,pdb_id,ddg,mut_info,mut_seq,fr1,pos1,to1,mut_info1,mut_info2,fr2,pos2,to2,wt_seq,is_stable,seq
0,120420,1f0m,-0.002386,F26W:M35L,SFNTVDEWLEAIKMGQYKESFANAGWTSFDVVSQLMMEDILRVGVT...,F,26,W,F26W,M35L,M,35.0,L,SFNTVDEWLEAIKMGQYKESFANAGFTSFDVVSQMMMEDILRVGVT...,True,SFNTVDEWLEAIKMGQYKESFANAGFTSFDVVSQMMMEDILRVGVT...
1,120084,1f0m,2.231979,Y17F:H50F,SFNTVDEWLEAIKMGQFKESFANAGFTSFDVVSQMMMEDILRVGVT...,Y,17,F,Y17F,H50F,H,50.0,F,SFNTVDEWLEAIKMGQYKESFANAGFTSFDVVSQMMMEDILRVGVT...,False,SFNTVDEWLEAIKMGQYKESFANAGFTSFDVVSQMMMEDILRVGVT...
2,120923,1k1v,-0.012141,S10T:E13Q,LTDEELVTMTVRQLNQHLRGLSKEEIIQLKQRRRTLKNRGY,S,10,T,S10T,E13Q,E,13.0,Q,LTDEELVTMSVRELNQHLRGLSKEEIIQLKQRRRTLKNRGY,True,LTDEELVTMSVRELNQHLRGLSKEEIIQLKQRRRTLKNRGY
3,121126,1k1v,0.412506,S10P:E13L,LTDEELVTMPVRLLNQHLRGLSKEEIIQLKQRRRTLKNRGY,S,10,P,S10P,E13L,E,13.0,L,LTDEELVTMSVRELNQHLRGLSKEEIIQLKQRRRTLKNRGY,False,LTDEELVTMSVRELNQHLRGLSKEEIIQLKQRRRTLKNRGY
4,121714,1lp1,-0.220713,Q23C:Q52C,KFNKELSVAGREIVTLPNLNDPCKKAFIFSLWDDPSQSANLLAEAK...,Q,23,C,Q23C,Q52C,Q,52.0,C,KFNKELSVAGREIVTLPNLNDPQKKAFIFSLWDDPSQSANLLAEAK...,True,KFNKELSVAGREIVTLPNLNDPQKKAFIFSLWDDPSQSANLLAEAK...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6017,217749,5z2s,-0.106517,Q6L:T32C,AVTGSLTALLLRAFEKDRFPGIAAREELARECGLPESRIQIWFQNRR,Q,6,L,Q6L,T32C,T,32.0,C,AVTGSQTALLLRAFEKDRFPGIAAREELARETGLPESRIQIWFQNRR,True,AVTGSQTALLLRAFEKDRFPGIAAREELARETGLPESRIQIWFQNRR
6018,217779,5z2s,1.814489,Q6Y:T32S,AVTGSYTALLLRAFEKDRFPGIAAREELARESGLPESRIQIWFQNRR,Q,6,Y,Q6Y,T32S,T,32.0,S,AVTGSQTALLLRAFEKDRFPGIAAREELARETGLPESRIQIWFQNRR,False,AVTGSQTALLLRAFEKDRFPGIAAREELARETGLPESRIQIWFQNRR
6019,217776,5z2s,2.074259,Q6Y:T32H,AVTGSYTALLLRAFEKDRFPGIAAREELAREHGLPESRIQIWFQNRR,Q,6,Y,Q6Y,T32H,T,32.0,H,AVTGSQTALLLRAFEKDRFPGIAAREELARETGLPESRIQIWFQNRR,False,AVTGSQTALLLRAFEKDRFPGIAAREELARETGLPESRIQIWFQNRR
6020,217989,6ews,-0.757058,K1C:Y44C,CNAAQIVDEALNQGITLFVADNRLQYETSRDNIPEELLNEWKYCRQ...,K,1,C,K1C,Y44C,Y,44.0,C,KNAAQIVDEALNQGITLFVADNRLQYETSRDNIPEELLNEWKYYRQ...,True,KNAAQIVDEALNQGITLFVADNRLQYETSRDNIPEELLNEWKYYRQ...


In [4]:
# 2. 提取唯一蛋白质及其序列
unique_proteins = df[['pdb_id', 'wt_seq']].drop_duplicates()
proteins_list = unique_proteins.to_dict('records')
proteins_list

[{'pdb_id': '1f0m',
  'wt_seq': 'SFNTVDEWLEAIKMGQYKESFANAGFTSFDVVSQMMMEDILRVGVTLAGHQKKILNSIQVMRAQMN'},
 {'pdb_id': '1k1v', 'wt_seq': 'LTDEELVTMSVRELNQHLRGLSKEEIIQLKQRRRTLKNRGY'},
 {'pdb_id': '1lp1',
  'wt_seq': 'KFNKELSVAGREIVTLPNLNDPQKKAFIFSLWDDPSQSANLLAEAKKLNDAQAPK'},
 {'pdb_id': '1qp2',
  'wt_seq': 'MVQRGSKVRILRPESYWFQDVGTVASVDQSGIKYPVIVRFEKVNYSGINTNNFAEDELVEVEA'},
 {'pdb_id': '1r69',
  'wt_seq': 'SISSRVKSKRIQLGLNQAELAQKVGTTQQSIEQLENGKTKRPRFLPELASALGVSVDWLLN'},
 {'pdb_id': '1tg0',
  'wt_seq': 'EVPFKVVAQFPYKSDYEDDLNFEKDQEIIVTSVEDAEWYFGEYQDSNGDVIEGIFPKSFVAVQG'},
 {'pdb_id': '1v1c',
  'wt_seq': 'FDIYVVTADYLPLGAEQDAITLREGQYVEVLDAAHPLRWLVRTKPTKSSPSRQGWVSPAYLDRRL'},
 {'pdb_id': '1wcl',
  'wt_seq': 'EAHAAIDTFTKYLDIDEDFATVLVEEGFSTLEELAYVPMKELLEIEGLDEPTVEALRERAKNALATIAQ'},
 {'pdb_id': '1yu5',
  'wt_seq': 'KLETFPLDVLVNTAAEDLPRGVDPSRKENHLSDEDFKAVFGMTRSAFANLPLWKQQNLKKEKGLF'},
 {'pdb_id': '2btt',
  'wt_seq': 'KDPKFEAAYDFPGSGSSSELPLKKGDIVFISRDEPSGWSLAKLLDGSKEGWVPTAYMTPYK'},
 {'pdb_id': '2d1u',
  

In [5]:
# 3. 按序列相似度分组
protein_groups = group_proteins_by_similarity(proteins_list, threshold=30)
len(protein_groups)

7

In [6]:
# 4. 划分五折
fold_assignments = split_proteins_into_folds(protein_groups, n_splits=5)
len(fold_assignments)

5

In [7]:
# 5. 准备各折数据并保存
output_dir = './dataset/cdna/mutations/fold_data'
fold_stats = prepare_fold_data(fold_assignments, df, output_dir)

# 打印各折统计信息
for stat in fold_stats:
    print(f"Fold {stat['fold']}: Train size={stat['train_size']}, Train PDBs={stat['train_pdbs']}, Val size={stat['val_size']}, Val PDBs={stat['val_pdbs']}")

Fold 0: Train size=5390, Train PDBs=63, Val size=544, Val PDBs=11
Fold 1: Train size=862, Train PDBs=15, Val size=5072, Val PDBs=59
Fold 2: Train size=5910, Train PDBs=72, Val size=24, Val PDBs=2
Fold 3: Train size=5850, Train PDBs=73, Val size=84, Val PDBs=1
Fold 4: Train size=5724, Train PDBs=73, Val size=210, Val PDBs=1
