# 1.定义函数

In [11]:
import pandas as pd
import os
import numpy as np
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# 基础配置
# ============================================================================
MIN_SEQUENCE_LENGTH = 5
MAX_SEQUENCE_LENGTH = 50

# ============================================================================
# 通用工具函数
# ============================================================================
def filter_peptides_detailed(peptides_data, sequence_col='Sequence', min_length=MIN_SEQUENCE_LENGTH, max_length=MAX_SEQUENCE_LENGTH):
    """详细的过滤步骤：筛选5-50的天然氨基酸长度包括大小写的不重复的氨基酸序列，返回每步详细统计"""
    natural_aa = set('ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy')
    
    # 初始化统计
    stats = {
        'step0_original': len(peptides_data),
        'step1_after_natural_aa': 0,
        'step2_after_length': 0,
        'step3_final': 0,
        'removed_non_natural_aa': 0,
        'removed_too_short': 0,
        'removed_too_long': 0,
        'removed_duplicates': 0
    }
    
    print(f"         {'┄' * 60}")
    print(f"         🔍 Detailed filtering process:")
    print(f"         {'┄' * 60}")
    print(f"            Step 0 - Original sequences: {stats['step0_original']:,}")
    
    # 步骤1: 检查天然氨基酸
    filtered_data = peptides_data.copy()
    mask = filtered_data[sequence_col].apply(lambda seq: all(aa in natural_aa for aa in seq))
    stats['removed_non_natural_aa'] = len(filtered_data) - mask.sum()
    filtered_data = filtered_data[mask]
    stats['step1_after_natural_aa'] = len(filtered_data)
    print(f"            Step 1 - After natural AA filter: {stats['step1_after_natural_aa']:,} (removed: {stats['removed_non_natural_aa']:,})")
    
    # 步骤2: 过滤长度
    filtered_data['temp_length'] = filtered_data[sequence_col].apply(len)
    
    # 先过滤太短的
    too_short_mask = filtered_data['temp_length'] >= min_length
    stats['removed_too_short'] = len(filtered_data) - too_short_mask.sum()
    filtered_data = filtered_data[too_short_mask]
    
    # 再过滤太长的
    too_long_mask = filtered_data['temp_length'] <= max_length
    stats['removed_too_long'] = len(filtered_data) - too_long_mask.sum()
    filtered_data = filtered_data[too_long_mask]
    
    stats['step2_after_length'] = len(filtered_data)
    print(f"            Step 2 - After length filter ({min_length}-{max_length}): {stats['step2_after_length']:,}")
    print(f"                   ├─ Too short (<{min_length}): {stats['removed_too_short']:,}")
    print(f"                   └─ Too long (>{max_length}): {stats['removed_too_long']:,}")
    
    # 步骤3: 去除重复（不区分大小写）
    filtered_data['temp_upper'] = filtered_data[sequence_col].str.upper()
    before_dedup = len(filtered_data)
    filtered_data = filtered_data.drop_duplicates(subset=['temp_upper'])
    stats['removed_duplicates'] = before_dedup - len(filtered_data)
    stats['step3_final'] = len(filtered_data)
    print(f"            Step 3 - After deduplication: {stats['step3_final']:,} (removed duplicates: {stats['removed_duplicates']:,})")
    
    # 转换为大写并清理临时列
    filtered_data[sequence_col] = filtered_data[sequence_col].str.upper()
    filtered_data = filtered_data.drop(['temp_length', 'temp_upper'], axis=1, errors='ignore')
    
    # 验证总数
    total_removed = stats['removed_non_natural_aa'] + stats['removed_too_short'] + stats['removed_too_long'] + stats['removed_duplicates']
    expected_final = stats['step0_original'] - total_removed
    print(f"            ✅ Verification: Expected {expected_final:,}, Got {stats['step3_final']:,}")
    print(f"         {'┄' * 60}")
    
    return filtered_data, stats

def parse_fasta_to_df(fasta_file, verbose=True):
    """将FASTA文件解析为pandas DataFrame"""
    if verbose:
        print(f"         📄 Parsing FASTA file: {os.path.basename(fasta_file)}")
    
    records = []
    encodings_to_try = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1', 'ascii']
    
    for encoding in encodings_to_try:
        try:
            with open(fasta_file, 'r', encoding=encoding, errors='ignore') as file:
                content = file.read()
            
            current_id = None
            current_seq = []
            
            for line in content.split('\n'):
                line = line.strip()
                if not line:
                    continue
                    
                if line.startswith('>'):
                    if current_id and current_seq:
                        records.append({'Id': current_id, 'Sequence': ''.join(current_seq)})
                    current_id = line[1:]
                    current_seq = []
                else:
                    current_seq.append(line)
            
            if current_id and current_seq:
                records.append({'Id': current_id, 'Sequence': ''.join(current_seq)})
            
            if verbose and records:
                lengths = [len(rec['Sequence']) for rec in records]
                print(f"            ✓ {len(records)} records, length range: {min(lengths)}-{max(lengths)}")
            break
            
        except (UnicodeDecodeError, UnicodeError):
            continue
    
    return pd.DataFrame(records) if records else pd.DataFrame(columns=['Id', 'Sequence'])

def process_dataset_detailed(file_path, sequence_col, id_col=None, source_name="", file_type="csv"):
    """详细的数据集处理函数"""
    if not os.path.exists(file_path):
        print(f"      ❌ File {file_path} does not exist")
        return None, None
    
    try:
        print(f"      {'─' * 70}")
        print(f"      📊 Processing: {source_name}")
        print(f"      {'─' * 70}")
        
        # 读取数据
        if file_type == "csv":
            data = pd.read_csv(file_path)
        elif file_type == "excel":
            data = pd.read_excel(file_path)
        elif file_type == "fasta":
            data = parse_fasta_to_df(file_path, verbose=True)
        else:
            raise ValueError("Unsupported file type")
        
        if sequence_col not in data.columns:
            print(f"      ❌ Sequence column '{sequence_col}' not found")
            return None, None
        
        # 详细过滤数据
        filtered_data, stats = filter_peptides_detailed(data, sequence_col=sequence_col)
        
        # 标准化列
        if id_col and id_col in filtered_data.columns:
            filtered_data = filtered_data[[id_col, sequence_col]].copy()
            filtered_data = filtered_data.rename(columns={id_col: 'Id', sequence_col: 'Sequence'})
        else:
            filtered_data = filtered_data[[sequence_col]].copy()
            filtered_data = filtered_data.rename(columns={sequence_col: 'Sequence'})
            filtered_data['Id'] = [f"{source_name}_{i+1}" for i in range(len(filtered_data))]
            filtered_data = filtered_data[['Id', 'Sequence']]
        
        filtered_data['Source'] = source_name
        filtered_data['Length'] = filtered_data['Sequence'].apply(len)
        
        print(f"      ✅ {source_name} completed: {stats['step0_original']:,} → {stats['step3_final']:,} sequences")
        print(f"      {'─' * 70}")
        
        return filtered_data, stats
        
    except Exception as e:
        print(f"      ❌ Error processing {source_name}: {str(e)}")
        print(f"      {'─' * 70}")
        return None, None

def merge_datasets_with_priority_detailed(datasets_list, priority_order, dataset_type=""):
    """按优先级合并数据集并去重，显示详细统计"""
    valid_datasets = [ds for ds in datasets_list if ds is not None]
    
    if not valid_datasets:
        return pd.DataFrame(columns=['Id', 'Sequence', 'Source', 'Length'])
    
    print(f"\n   {'─' * 70}")
    print(f"   🔗 Merging {dataset_type} datasets")
    print(f"   {'─' * 70}")
    
    # 计算合并前的统计
    total_before = sum(len(ds) for ds in valid_datasets)
    print(f"      📈 Total sequences before merge: {total_before:,}")
    
    # 显示各数据集贡献
    for i, ds in enumerate(valid_datasets):
        source = ds['Source'].iloc[0] if len(ds) > 0 else f"Dataset_{i+1}"
        print(f"         └─ {source}: {len(ds):,} sequences")
    
    combined_df = pd.concat(valid_datasets, ignore_index=True)
    combined_df['Length'] = combined_df['Sequence'].apply(len)
    
    # 按优先级排序
    source_priority = {source: i+1 for i, source in enumerate(priority_order)}
    combined_df['Priority'] = combined_df['Source'].map(source_priority)
    combined_df = combined_df.sort_values('Priority')
    
    # 去重
    before_final_dedup = len(combined_df)
    final_dataset = combined_df.drop_duplicates(subset=['Sequence'], keep='first')
    final_dataset = final_dataset.drop('Priority', axis=1).reset_index(drop=True)
    
    final_dedup_removed = before_final_dedup - len(final_dataset)
    print(f"      📉 After final deduplication: {len(final_dataset):,} sequences")
    print(f"      🗑️ Final duplicates removed: {final_dedup_removed:,}")
    print(f"   {'─' * 70}")
    
    return final_dataset

def save_datasets_to_csv(datasets_dict, output_dir):
    """保存数据集为CSV文件"""
    print(f"\n💾 SAVING PROCESSED DATASETS:")
    print("=" * 80)
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # ✅ 添加：创建日志目录并保存处理信息
    log_dir = "2_Log/2.1_Training set_and_test_set_processing"
    os.makedirs(log_dir, exist_ok=True)
    
    # 定义文件名映射
    filename_mapping = {
        'external_avp': 'Initial_TR_AVP.csv',
        'external_non_avp': 'Initial_TR_non_AVP.csv',
        'internal_avp': 'Initial_TS_AVP.csv',
        'internal_non_avp': 'Initial_TS_non_AVP.csv'
    }
    
    # 保存每个数据集
    saved_files = {}
    processing_log = []  # ✅ 添加：记录处理日志
    
    for key, filename in filename_mapping.items():
        if key in datasets_dict and datasets_dict[key] is not None:
            dataset = datasets_dict[key]
            if len(dataset) > 0:
                # 确保列顺序：Id, Sequence, Source, Length, Label, Type
                columns_order = ['Id', 'Sequence', 'Source', 'Length', 'Label', 'Type']
                
                # 检查并重新排列列
                available_columns = [col for col in columns_order if col in dataset.columns]
                dataset_to_save = dataset[available_columns].copy()
                
                file_path = os.path.join(output_dir, filename)
                dataset_to_save.to_csv(file_path, index=False)
                saved_files[key] = file_path
                
                # ✅ 添加：记录处理信息
                log_entry = {
                    'dataset': key.replace('_', ' ').title(),
                    'sequences': len(dataset),
                    'file_path': file_path,
                    'columns': list(dataset_to_save.columns)
                }
                processing_log.append(log_entry)
                
                print(f"   ✅ {key.replace('_', ' ').title()}: {len(dataset):,} sequences")
                print(f"      └─ Saved to: {file_path}")
                print(f"      └─ Columns: {list(dataset_to_save.columns)}")
            else:
                print(f"   ⚠️  {key.replace('_', ' ').title()}: No data to save")
        else:
            print(f"   ❌ {key.replace('_', ' ').title()}: Dataset not found")
    
    # ✅ 添加：保存处理日志到指定目录
    if processing_log:
        import json
        
        log_data = {
            'step': 'Step 1 - Dataset Processing',
            'total_files_saved': len(saved_files),
            'datasets': processing_log
        }
        
        log_file = os.path.join(log_dir, "step1_dataset_processing_log.json")
        with open(log_file, 'w') as f:
            json.dump(log_data, f, indent=2)
        print(f"\n   📄 Processing log saved to: {log_file}")
    
    print(f"\n   📁 All files saved to: {output_dir}")
    print(f"   📊 Total files saved: {len(saved_files)}")
    
    return saved_files

# ============================================================================
# 第一步：处理内部和外部数据集（带详细统计）
# ============================================================================

def step1_process_datasets_detailed():
    """第一步：详细处理并统计内部和外部数据集"""
    print("🚀 STEP 1: DETAILED PROCESSING OF INTERNAL AND EXTERNAL DATASETS")
    print("=" * 80)
    
    # ✅ 添加：创建日志目录
    log_dir = "2_Log/2.1_Training set_and_test_set_processing"
    os.makedirs(log_dir, exist_ok=True)
    
    # ========== 处理内部AVP数据集 ==========
    print("\n📊 PROCESSING INTERNAL AVP DATASETS")
    print("=" * 80)
    
    # 配置所有数据源
    avp_configs = [
        {'file_path': '1_Data/Raw_data/AVPdataset/dravp_antiviral_peptides.xlsx', 'sequence_col': 'Sequence', 'id_col': 'DRAVP_ID', 'source_name': 'DRAVP', 'file_type': 'excel'},
        {'file_path': '1_Data/Raw_data/AVPdataset/AVPdb_data.xls', 'sequence_col': 'Sequence', 'id_col': 'Id', 'source_name': 'AVPdb', 'file_type': 'excel'},
        {'file_path': '1_Data/Raw_data/AVPdataset/ACovPepDB_Data_Entirety.csv', 'sequence_col': 'Sequence', 'id_col': 'ACovPid', 'source_name': 'ACovPepDB', 'file_type': 'csv'},
        {'file_path': '1_Data/Raw_data/AVPdataset/HIPdb_data.xls', 'sequence_col': 'SEQUENCE', 'id_col': 'ID', 'source_name': 'HIPdb', 'file_type': 'excel'}
    ]
    
    amp_configs = [
        {'file_path': '1_Data/Raw_data/AMPdataset/CAMP.xlsx', 'sequence_col': 'Seqence', 'id_col': 'Camp_ID', 'source_name': 'CAMP', 'file_type': 'excel'},
        {'file_path': '1_Data/Raw_data/AMPdataset/dbaasp_peptides.xlsx', 'sequence_col': 'SEQUENCE', 'id_col': 'ID', 'source_name': 'DBAASP', 'file_type': 'excel'},
        {'file_path': '1_Data/Raw_data/AMPdataset/dramp_general_avps.xlsx', 'sequence_col': 'Sequence', 'id_col': 'DRAMP_ID', 'source_name': 'DRAMP', 'file_type': 'excel'},
        {'file_path': '1_Data/Raw_data/AMPdataset/dbAMP_AntiHIV_2024.fasta', 'sequence_col': 'Sequence', 'id_col': 'Id', 'source_name': 'dbAMP_AntiHIV', 'file_type': 'fasta'},
        {'file_path': '1_Data/Raw_data/AMPdataset/dbAMP_Antiviral_2024.fasta', 'sequence_col': 'Sequence', 'id_col': 'Id', 'source_name': 'dbAMP_Antiviral', 'file_type': 'fasta'}
    ]
    
    other_configs = [
        {'file_path': '1_Data/Raw_data/Peptipedia_Antiviral.fasta', 'sequence_col': 'Sequence', 'id_col': 'Id', 'source_name': 'Peptipedia', 'file_type': 'fasta'}
    ]
    
    # 处理所有内部AVP数据源
    all_internal_avp_datasets = []
    internal_avp_detailed_stats = {}
    
    print("   🔹 Processing individual internal AVP datasets:")
    for config in avp_configs + amp_configs + other_configs:
        result = process_dataset_detailed(**config)
        if result and result[0] is not None:
            dataset, stats = result
            all_internal_avp_datasets.append(dataset)
            internal_avp_detailed_stats[config['source_name']] = stats
    
    # 合并内部AVP数据集
    priority_order = ['DRAVP', 'AVPdb', 'ACovPepDB', 'HIPdb', 'CAMP', 'DBAASP', 'DRAMP', 'dbAMP_AntiHIV', 'dbAMP_Antiviral', 'Peptipedia']
    internal_avp_dataset = merge_datasets_with_priority_detailed(all_internal_avp_datasets, priority_order, "internal AVP")
    internal_avp_dataset['Label'] = 1
    internal_avp_dataset['Type'] = 'AVP'
    
    # ========== 处理内部non_AVP数据集 ==========
    print(f"\n📊 PROCESSING INTERNAL non_AVP DATASETS")
    print("=" * 80)

    file_path = '1_Data/Raw_data/non_AVP/uniprotkb_13124.fasta'
    print("   🔹 Processing UniProt non_AVP dataset:")

    result = process_dataset_detailed(file_path, 'Sequence', 'Id', 'UniProt', 'fasta')
    if result[0] is not None:
        internal_non_avp_dataset, non_avp_detailed_stats = result
        
        # 处理UniProt ID格式
        def process_uniprot_id(uniprot_id):
            """
            处理UniProt ID，只保留前两个竖杠的部分
            例如: sp|A5A616|YXXX_HUMAN -> sp|A5A616|
            """
            if pd.isna(uniprot_id):
                return uniprot_id
            
            id_str = str(uniprot_id).strip()
            
            # 查找所有竖杠的位置
            pipe_positions = [i for i, char in enumerate(id_str) if char == '|']
            
            if len(pipe_positions) >= 2:
                # 保留到第二个竖杠之后（包含第二个竖杠）
                return id_str[:pipe_positions[1] + 1]
            elif len(pipe_positions) == 1:
                # 如果只有一个竖杠，保留到第一个竖杠之后
                return id_str[:pipe_positions[0] + 1]
            else:
                # 如果没有竖杠，返回原ID
                return id_str
        
        # 应用ID处理
        original_ids = internal_non_avp_dataset['Id'].copy()
        internal_non_avp_dataset['Id'] = internal_non_avp_dataset['Id'].apply(process_uniprot_id)
        
        # 显示ID处理示例
        print(f"      📝 ID format examples:")
        for i in range(min(5, len(original_ids))):
            old_id = original_ids.iloc[i]
            new_id = internal_non_avp_dataset['Id'].iloc[i]
            print(f"         {old_id} → {new_id}")
        
        print(f"      ✅ Processed {len(internal_non_avp_dataset)} UniProt IDs to short format")
        
        internal_non_avp_dataset['Label'] = 0
        internal_non_avp_dataset['Type'] = 'non_AVP'
    else:
        print("      ❌ Failed to process non_AVP dataset")
        return None
    
    # ========== 处理外部数据集 ==========
    print(f"\n📊 PROCESSING EXTERNAL DATASETS")
    print("=" * 80)
    
    external_files = {
        'AVP': {
            'Stack-AVP-TR_pos': '1_Data/Raw_data/External_dataset/Stack-AVP-TR_pos.fasta'
        },
        'non_AVP': {
            'Stack-AVP-TR_neg': '1_Data/Raw_data/External_dataset/Stack-AVP-TR_neg.fasta'
        },
        'mixed': {
            'AVP-HNCL_non-AMP': '1_Data/Raw_data/External_dataset/AVP-HNCL_non-AMP TR dataset.txt',
            'AVP-HNCL_non-AVP': '1_Data/Raw_data/External_dataset/AVP-HNCL_non-AVP TR dataset.txt'
        }
    }
    
    external_individual_datasets = {'AVP': [], 'non_AVP': []}
    external_detailed_stats = {}
    
    # 处理纯AVP文件
    print("   🔹 Processing external AVP datasets:")
    for name, file_path in external_files['AVP'].items():
        if not os.path.exists(file_path):
            print(f"      ⚠️ File not found: {name}")
            continue
        
        try:
            result = process_dataset_detailed(file_path, 'Sequence', 'Id', name, 'fasta')
            if result and result[0] is not None:
                filtered_df, stats = result
                filtered_df['Label'] = 1  # AVP
                external_individual_datasets['AVP'].append(filtered_df)
                external_detailed_stats[name] = {**stats, 'type': 'AVP'}
            
        except Exception as e:
            print(f"      ❌ Error processing {name}: {str(e)}")
            continue
    
    # 处理纯non_AVP文件
    print("   🔹 Processing external non_AVP datasets:")
    for name, file_path in external_files['non_AVP'].items():
        if not os.path.exists(file_path):
            print(f"      ⚠️ File not found: {name}")
            continue
        
        try:
            result = process_dataset_detailed(file_path, 'Sequence', 'Id', name, 'fasta')
            if result and result[0] is not None:
                filtered_df, stats = result
                filtered_df['Label'] = 0  # non_AVP
                external_individual_datasets['non_AVP'].append(filtered_df)
                external_detailed_stats[name] = {**stats, 'type': 'non_AVP'}
            
        except Exception as e:
            print(f"      ❌ Error processing {name}: {str(e)}")
            continue
    
    # 处理混合标签的txt文件
    print("   🔹 Processing mixed-label datasets:")
    for name, file_path in external_files['mixed'].items():
        if not os.path.exists(file_path):
            print(f"      ⚠️ File not found: {name}")
            continue
        
        try:
            sequences_data = []
            encodings = ['utf-8', 'latin-1', 'cp1252', 'gbk']
            
            print(f"      {'─' * 70}")
            print(f"      📊 Processing mixed-label dataset: {name}")
            print(f"      {'─' * 70}")
            
            for encoding in encodings:
                try:
                    with open(file_path, 'r', encoding=encoding) as f:
                        content = f.read()
                    
                    # 首先检查是否是FASTA格式
                    if '>' in content:
                        # FASTA格式处理
                        current_id = None
                        current_seq = []
                        
                        for line in content.split('\n'):
                            line = line.strip()
                            if not line:
                                continue
                                
                            if line.startswith('>'):
                                # 保存前一个序列
                                if current_id and current_seq:
                                    sequence = ''.join(current_seq)
                                    # 从ID中推断标签
                                    header_lower = current_id.lower()
                                    if any(pos_indicator in header_lower for pos_indicator in ['pos']):
                                        label = 1  # AVP
                                    elif any(neg_indicator in header_lower for neg_indicator in ['neg']):
                                        label = 0  # non_AVP
                                    else:
                                        # 根据文件名推断
                                        if 'non-AMP' in name or 'non-AVP' in name:
                                            label = 0  # non_AVP
                                        else:
                                            label = 1  # AVP (默认)
                                    
                                    sequences_data.append({
                                        'Id': current_id,
                                        'Sequence': sequence,
                                        'Label': label
                                    })
                                
                                current_id = line[1:]  # 去掉 '>'
                                current_seq = []
                            else:
                                current_seq.append(line)
                        
                        # 保存最后一个序列
                        if current_id and current_seq:
                            sequence = ''.join(current_seq)
                            # 从ID中推断标签
                            header_lower = current_id.lower()
                            if any(pos_indicator in header_lower for pos_indicator in ['pos']):
                                label = 1  # AVP
                            elif any(neg_indicator in header_lower for neg_indicator in ['neg']):
                                label = 0  # non_AVP
                            else:
                                # 根据文件名推断
                                if 'non-AMP' in name or 'non-AVP' in name:
                                    label = 0  # non_AVP
                                else:
                                    label = 1  # AVP (默认)
                            
                            sequences_data.append({
                                'Id': current_id,
                                'Sequence': sequence,
                                'Label': label
                            })
                    
                    else:
                        # 非FASTA格式，按行处理
                        for line_num, line in enumerate(content.split('\n'), 1):
                            line = line.strip()
                            if not line or line.startswith('#'):
                                continue
                            
                            # 尝试不同的分隔符
                            parts = None
                            for sep in ['\t', ' ', ',', ';']:
                                if sep in line:
                                    parts = [p.strip() for p in line.split(sep) if p.strip()]
                                    break
                            
                            if parts is None:
                                parts = [line.strip()]
                            
                            if len(parts) == 0:
                                continue
                            
                            # 解析序列和标签
                            sequence = None
                            label = None
                            
                            for part in parts:
                                # 检查是否为标签
                                if part.lower() in ['pos', 'positive', '1']:
                                    label = 1  # AVP
                                elif part.lower() in ['neg', 'negative', '0']:
                                    label = 0  # non_AVP
                                else:
                                    # 检查是否为序列（包含氨基酸字符）
                                    if len(part) > 3 and all(c.upper() in 'ACDEFGHIKLMNPQRSTVWY' for c in part if c.isalpha()):
                                        sequence = part
                            
                            # 如果没有找到明确的序列，使用第一个非标签部分
                            if sequence is None:
                                for part in parts:
                                    if part.lower() not in ['pos', 'negative', 'positive', 'neg', '1', '0']:
                                        if len(part) > 3:  # 序列长度至少为4
                                            sequence = part
                                            break
                            
                            # 如果没有找到标签，根据文件名推断
                            if label is None:
                                if 'non-AMP' in name or 'non-AVP' in name:
                                    label = 0  # non_AVP
                                else:
                                    label = 1  # AVP (默认)
                            
                            if sequence:
                                sequences_data.append({
                                    'Id': f"{name}_{line_num}",
                                    'Sequence': sequence,
                                    'Label': label
                                })
                    
                    break
                except UnicodeDecodeError:
                    continue
            
            if sequences_data:
                df = pd.DataFrame(sequences_data)
                original_count = len(df)
                
                # 统计原始标签分布
                original_avp_count = len(df[df['Label'] == 1])
                original_non_avp_count = len(df[df['Label'] == 0])
                
                print(f"         📊 Original data: {original_count} total sequences")
                print(f"            ├─ AVP (positive): {original_avp_count}")
                print(f"            └─ non_AVP (negative): {original_non_avp_count}")
                print(f"         {'┄' * 50}")
                
                # 应用详细过滤
                filtered_df, detailed_stats = filter_peptides_detailed(df, sequence_col='Sequence')
                
                if len(filtered_df) > 0:
                    filtered_df['Source'] = name
                    filtered_df['Length'] = filtered_df['Sequence'].apply(len)
                    
                    # 按标签分组
                    avp_df = filtered_df[filtered_df['Label'] == 1].copy()
                    non_avp_df = filtered_df[filtered_df['Label'] == 0].copy()
                    
                    print(f"         {'┄' * 50}")
                    print(f"         📊 After all filtering: {len(filtered_df)} total sequences")
                    print(f"            ├─ AVP: {len(avp_df)} (lost: {original_avp_count - len(avp_df)})")
                    print(f"            └─ non_AVP: {len(non_avp_df)} (lost: {original_non_avp_count - len(non_avp_df)})")
                    
                    if len(avp_df) > 0:
                        external_individual_datasets['AVP'].append(avp_df)
                    if len(non_avp_df) > 0:
                        external_individual_datasets['non_AVP'].append(non_avp_df)
                    
                    external_detailed_stats[name] = {
                        **detailed_stats,
                        'original_avp_count': original_avp_count,
                        'original_non_avp_count': original_non_avp_count,
                        'filtered_avp_count': len(avp_df),
                        'filtered_non_avp_count': len(non_avp_df)
                    }
                else:
                    print(f"         ⚠️ No sequences passed filtering")
            else:
                print(f"         ⚠️ No sequences found in file")
            
            print(f"      {'─' * 70}")
            
        except Exception as e:
            print(f"      ❌ Error processing {name}: {str(e)}")
            print(f"      {'─' * 70}")
            continue
    
    # 合并外部数据集
    external_avp_dataset = merge_datasets_with_priority_detailed(
        external_individual_datasets['AVP'], 
        ['Stack-AVP-TR_pos', 'AVP-HNCL_non-AMP', 'AVP-HNCL_non-AVP'], 
        "external AVP"
    )
    if len(external_avp_dataset) > 0:
        external_avp_dataset['Type'] = 'AVP'
    
    external_non_avp_dataset = merge_datasets_with_priority_detailed(
        external_individual_datasets['non_AVP'],
        ['Stack-AVP-TR_neg', 'AVP-HNCL_non-AMP', 'AVP-HNCL_non-AVP'],
        "external non_AVP"
    )
    if len(external_non_avp_dataset) > 0:
        external_non_avp_dataset['Type'] = 'non_AVP'
    
    # ========== 保存处理后的数据集 ==========
    datasets_to_save = {
        'internal_avp': internal_avp_dataset,
        'internal_non_avp': internal_non_avp_dataset,
        'external_avp': external_avp_dataset,
        'external_non_avp': external_non_avp_dataset
    }
    
    output_dir = "1_Data/Processed_data_set/Initial_merged_data_set"
    saved_files = save_datasets_to_csv(datasets_to_save, output_dir)
    
    # ========== 打印详细统计汇总 ==========
    print(f"\n{'═' * 80}")
    print(f"📋 COMPREHENSIVE STATISTICS SUMMARY")
    print(f"{'═' * 80}")
    
    print(f"\n🔹 INTERNAL DATASETS DETAILED BREAKDOWN:")
    print(f"{'─' * 80}")
    print(f"   📊 Individual AVP dataset filtering details:")
    for source, stats in internal_avp_detailed_stats.items():
        print(f"   {'┌' + '─' * 65 + '┐'}")
        print(f"   │  {source:<61s}  │")
        print(f"   {'├' + '─' * 65 + '┤'}")
        print(f"   │  Original: {stats['step0_original']:>8,} sequences{' ' * (65 - len(f'Original: {stats['step0_original']:,} sequences') - 2)}│")
        print(f"   │  Non-natural AA removed: {stats['removed_non_natural_aa']:>8,}{' ' * (65 - len(f'Non-natural AA removed: {stats['removed_non_natural_aa']:,}') - 2)}│")
        print(f"   │  Too short removed: {stats['removed_too_short']:>8,}{' ' * (65 - len(f'Too short removed: {stats['removed_too_short']:,}') - 2)}│")
        print(f"   │  Too long removed: {stats['removed_too_long']:>8,}{' ' * (65 - len(f'Too long removed: {stats['removed_too_long']:,}') - 2)}│")
        print(f"   │  Duplicates removed: {stats['removed_duplicates']:>8,}{' ' * (65 - len(f'Duplicates removed: {stats['removed_duplicates']:,}') - 2)}│")
        print(f"   │  Final: {stats['step3_final']:>8,} sequences{' ' * (65 - len(f'Final: {stats['step3_final']:,} sequences') - 2)}│")
        print(f"   {'└' + '─' * 65 + '┘'}")
        print()
    
    print(f"{'─' * 80}")
    print(f"   📊 Internal non_AVP (UniProt) filtering details:")
    stats = non_avp_detailed_stats
    print(f"   {'┌' + '─' * 65 + '┐'}")
    print(f"   │  {'UniProt non_AVP':<61s}  │")
    print(f"   {'├' + '─' * 65 + '┤'}")
    print(f"   │  Original: {stats['step0_original']:>8,} sequences{' ' * (65 - len(f'Original: {stats['step0_original']:,} sequences') - 2)}│")
    print(f"   │  Non-natural AA removed: {stats['removed_non_natural_aa']:>8,}{' ' * (65 - len(f'Non-natural AA removed: {stats['removed_non_natural_aa']:,}') - 2)}│")
    print(f"   │  Too short removed: {stats['removed_too_short']:>8,}{' ' * (65 - len(f'Too short removed: {stats['removed_too_short']:,}') - 2)}│")
    print(f"   │  Too long removed: {stats['removed_too_long']:>8,}{' ' * (65 - len(f'Too long removed: {stats['removed_too_long']:,}') - 2)}│")
    print(f"   │  Duplicates removed: {stats['removed_duplicates']:>8,}{' ' * (65 - len(f'Duplicates removed: {stats['removed_duplicates']:,}') - 2)}│")
    print(f"   │  Final: {stats['step3_final']:>8,} sequences{' ' * (65 - len(f'Final: {stats['step3_final']:,} sequences') - 2)}│")
    print(f"   {'└' + '─' * 65 + '┘'}")
    
    print(f"\n🔹 EXTERNAL DATASETS DETAILED BREAKDOWN:")
    print(f"{'─' * 80}")
    for source, stats in external_detailed_stats.items():
        print(f"   {'┌' + '─' * 75 + '┐'}")
        print(f"   │  📊 {source:<69s}  │")
        print(f"   {'├' + '─' * 75 + '┤'}")
        
        if 'original_avp_count' in stats:
            # 混合标签数据集
            orig_text = f"Original: {stats['step0_original']:,} (AVP: {stats['original_avp_count']}, non_AVP: {stats['original_non_avp_count']})"
            print(f"   │  {orig_text:<73s}  │")
            print(f"   │  Non-natural AA removed: {stats['removed_non_natural_aa']:>8,}{' ' * (75 - len(f'Non-natural AA removed: {stats['removed_non_natural_aa']:,}') - 2)}│")
            print(f"   │  Too short removed: {stats['removed_too_short']:>8,}{' ' * (75 - len(f'Too short removed: {stats['removed_too_short']:,}') - 2)}│")
            print(f"   │  Too long removed: {stats['removed_too_long']:>8,}{' ' * (75 - len(f'Too long removed: {stats['removed_too_long']:,}') - 2)}│")
            print(f"   │  Duplicates removed: {stats['removed_duplicates']:>8,}{' ' * (75 - len(f'Duplicates removed: {stats['removed_duplicates']:,}') - 2)}│")
            final_text = f"Final: {stats['step3_final']:,} (AVP: {stats['filtered_avp_count']}, non_AVP: {stats['filtered_non_avp_count']})"
            print(f"   │  {final_text:<73s}  │")
        else:
            # 单一标签数据集
            orig_text = f"Original: {stats['step0_original']:,} sequences ({stats.get('type', 'Unknown')})"
            print(f"   │  {orig_text:<73s}  │")
            print(f"   │  Non-natural AA removed: {stats['removed_non_natural_aa']:>8,}{' ' * (75 - len(f'Non-natural AA removed: {stats['removed_non_natural_aa']:,}') - 2)}│")
            print(f"   │  Too short removed: {stats['removed_too_short']:>8,}{' ' * (75 - len(f'Too short removed: {stats['removed_too_short']:,}') - 2)}│")
            print(f"   │  Too long removed: {stats['removed_too_long']:>8,}{' ' * (75 - len(f'Too long removed: {stats['removed_too_long']:,}') - 2)}│")
            print(f"   │  Duplicates removed: {stats['removed_duplicates']:>8,}{' ' * (75 - len(f'Duplicates removed: {stats['removed_duplicates']:,}') - 2)}│")
            print(f"   │  Final: {stats['step3_final']:>8,} sequences{' ' * (75 - len(f'Final: {stats['step3_final']:,} sequences') - 2)}│")
        
        print(f"   {'└' + '─' * 75 + '┘'}")
        print()
    
    print(f"{'═' * 80}")
    print(f"🔹 FINAL DATASET SUMMARY:")
    print(f"{'═' * 80}")
    print(f"   📈 External AVP: {len(external_avp_dataset):,} sequences (Initial_TR_AVP.csv)")
    print(f"   📈 External non_AVP: {len(external_non_avp_dataset):,} sequences (Initial_TR_non_AVP.csv)")
    print(f"   📈 Internal AVP: {len(internal_avp_dataset):,} sequences (Initial_TS_AVP.csv)")
    print(f"   📈 Internal non_AVP: {len(internal_non_avp_dataset):,} sequences (Initial_TS_non_AVP.csv)") 

    print(f"{'─' * 80}")
    grand_total = len(internal_avp_dataset) + len(internal_non_avp_dataset) + len(external_avp_dataset) + len(external_non_avp_dataset)
    print(f"   📊 Grand Total: {grand_total:,} sequences")
    print(f"{'═' * 80}")
    
    # ✅ 添加：保存Step 1完整摘要日志
    import json
    
    overall_log = {
        'step': 'Complete Step 1 Processing',
        'internal_avp_stats': internal_avp_detailed_stats,
        'internal_non_avp_stats': non_avp_detailed_stats,
        'external_stats': external_detailed_stats,
        'final_summary': {
            'internal_avp_count': len(internal_avp_dataset),
            'internal_non_avp_count': len(internal_non_avp_dataset),
            'external_avp_count': len(external_avp_dataset),
            'external_non_avp_count': len(external_non_avp_dataset),
            'grand_total': grand_total
        },
        'saved_files': saved_files
    }
    
    overall_log_file = os.path.join(log_dir, "step1_complete_processing_summary.json")
    with open(overall_log_file, 'w') as f:
        json.dump(overall_log, f, indent=2, default=str)
    print(f"\n📄 Complete Step 1 summary saved to: {overall_log_file}")
    
    return {
        'internal_avp': internal_avp_dataset,
        'internal_non_avp': internal_non_avp_dataset,
        'external_avp': external_avp_dataset,
        'external_non_avp': external_non_avp_dataset,
        'internal_avp_stats': internal_avp_detailed_stats,
        'internal_non_avp_stats': non_avp_detailed_stats,
        'external_stats': external_detailed_stats,
        'saved_files': saved_files
    }

# 运行详细的第一步处理
print("🚀 Starting Step 1: Detailed Dataset Processing and Statistics")
print("=" * 80)

try:
    step1_detailed_results = step1_process_datasets_detailed()
    
    if step1_detailed_results:
        print("\n✅ Step 1 completed successfully with detailed statistics!")
        print(f"   📁 Results stored in step1_detailed_results variable")
        print(f"   💾 CSV files saved in: 1_Data/Processed_data_set/Initial_merged_data_set/")
        print(f"   📊 Ready for Step 2: Overlap analysis")
    else:
        print("❌ Step 1 failed")
        
except Exception as e:
    print(f"❌ Error during Step 1: {str(e)}")
    import traceback
    traceback.print_exc()

print("\n🎯 Step 1 Complete - Detailed filtering statistics generated for all datasets!")

🚀 Starting Step 1: Detailed Dataset Processing and Statistics
🚀 STEP 1: DETAILED PROCESSING OF INTERNAL AND EXTERNAL DATASETS

📊 PROCESSING INTERNAL AVP DATASETS
   🔹 Processing individual internal AVP datasets:
      ──────────────────────────────────────────────────────────────────────
      📊 Processing: DRAVP
      ──────────────────────────────────────────────────────────────────────
         ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄
         🔍 Detailed filtering process:
         ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄
            Step 0 - Original sequences: 1,986
            Step 1 - After natural AA filter: 1,756 (removed: 230)
            Step 2 - After length filter (5-50): 1,714
                   ├─ Too short (<5): 17
                   └─ Too long (>50): 25
            Step 3 - After deduplication: 1,608 (removed duplicates: 106)
            ✅ Verification: Expected 1,608, Got 1,608
         ┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄

# 2.数据集独立化

In [12]:
# ============================================================================
# 第二步：重叠分析和数据集独立化
# ============================================================================

def step2_overlap_analysis_enhanced(step1_results):
    """第二步：增强版重叠分析和数据集独立化"""
    print("\n🚀 STEP 2: ENHANCED OVERLAP ANALYSIS AND DATASET INDEPENDENCE")
    print("=" * 80)
    
    # 获取四个数据集
    initial_custom_avp = step1_results['internal_avp'].copy()
    initial_custom_non_avp = step1_results['internal_non_avp'].copy()
    initial_tr_avp = step1_results['external_avp'].copy()
    initial_tr_non_avp = step1_results['external_non_avp'].copy()
    
    print(f"\n📊 INITIAL DATASET OVERVIEW:")
    print(f"   📈 Initial_TR_AVP: {len(initial_tr_avp):,} sequences")
    print(f"   📈 Initial_TR_non_AVP: {len(initial_tr_non_avp):,} sequences")
    print(f"   📈 Initial_TS_AVP: {len(initial_custom_avp):,} sequences")
    print(f"   📈 Initial_TS_non_AVP: {len(initial_custom_non_avp):,} sequences")

    print(f"   📊 Total: {len(initial_custom_avp) + len(initial_custom_non_avp) + len(initial_tr_avp) + len(initial_tr_non_avp):,} sequences")
    
    # 创建输出目录
    output_dir = "1_Data/Processed_data_set/Final_merged_data_set"
    os.makedirs(output_dir, exist_ok=True)
    
    # ✅ 添加：创建日志目录
    log_dir = "2_Log/2.1_Training set_and_test_set_processing"
    os.makedirs(log_dir, exist_ok=True)
    
    # ========== 第一阶段：标签冲突分析 ==========
    print(f"\n🔍 PHASE 1: LABEL CONFLICT ANALYSIS")
    print("=" * 80)
    
    # 创建数据集字典，便于分析
    datasets = {
        'Initial_TS_AVP': {'data': initial_custom_avp, 'expected_label': 1},        # ✅ 修改
        'Initial_TS_non_AVP': {'data': initial_custom_non_avp, 'expected_label': 0}, # ✅ 修改
        'Initial_TR_AVP': {'data': initial_tr_avp, 'expected_label': 1},
        'Initial_TR_non_AVP': {'data': initial_tr_non_avp, 'expected_label': 0}
    }
    
    # 收集所有序列及其来源和标签
    sequence_sources = {}
    
    for dataset_name, dataset_info in datasets.items():
        data = dataset_info['data']
        expected_label = dataset_info['expected_label']
        
        for _, row in data.iterrows():
            sequence = row['Sequence'].upper()
            actual_label = row['Label']
            
            if sequence not in sequence_sources:
                sequence_sources[sequence] = []
            
            sequence_sources[sequence].append({
                'dataset': dataset_name,
                'expected_label': expected_label,
                'actual_label': actual_label,
                'source': row['Source'] if 'Source' in row else 'Unknown',
                'id': row['Id'] if 'Id' in row else 'Unknown'
            })
    
    # 识别标签冲突
    label_conflicts = []
    conflict_sequences = set()
    
    print(f"\n   🔎 Analyzing label conflicts across datasets...")
    
    for sequence, sources in sequence_sources.items():
        if len(sources) > 1:  # 序列出现在多个数据集中
            labels = [source['expected_label'] for source in sources]
            if len(set(labels)) > 1:  # 标签不一致
                conflict_sequences.add(sequence)
                
                # 创建冲突记录
                conflict_record = {
                    'Sequence': sequence,
                    'Length': len(sequence),
                    'Conflict_Type': 'Label_Mismatch',
                    'Dataset_Count': len(sources)
                }
                
                # 添加每个数据集的信息
                for i, source in enumerate(sources):
                    conflict_record[f'Dataset_{i+1}'] = source['dataset']
                    conflict_record[f'Label_{i+1}'] = source['expected_label']
                    conflict_record[f'Source_{i+1}'] = source['source']
                    conflict_record[f'ID_{i+1}'] = source['id']
                
                # 填充空列（如果某些序列出现次数少于最大出现次数）
                max_occurrences = max(len(sources) for sources in sequence_sources.values() if len(sources) > 1)
                for i in range(len(sources), max_occurrences):
                    conflict_record[f'Dataset_{i+1}'] = ''
                    conflict_record[f'Label_{i+1}'] = ''
                    conflict_record[f'Source_{i+1}'] = ''
                    conflict_record[f'ID_{i+1}'] = ''
                
                label_conflicts.append(conflict_record)
    
    print(f"   ⚠️  Found {len(conflict_sequences):,} sequences with label conflicts")
    print(f"   📊 Total conflict records: {len(label_conflicts):,}")
    
    # 保存标签冲突数据
    if label_conflicts:
        conflicts_df = pd.DataFrame(label_conflicts)
        conflicts_file = os.path.join(output_dir, "Overlap_data.csv")
        conflicts_df.to_csv(conflicts_file, index=False)
        print(f"   ✅ Label conflicts saved to: {conflicts_file}")
    else:
        print(f"   🎉 No label conflicts found!")
    
    # ========== 第二阶段：从数据集中移除标签冲突序列 ==========
    print(f"\n🧹 PHASE 2: REMOVING CONFLICT SEQUENCES FROM DATASETS")
    print("=" * 80)
    
    # 移除冲突序列
    cleaned_datasets = {}
    removal_stats = {}
    
    for dataset_name, dataset_info in datasets.items():
        original_data = dataset_info['data']
        original_count = len(original_data)
        
        # 移除冲突序列
        cleaned_data = original_data[~original_data['Sequence'].str.upper().isin(conflict_sequences)].copy()
        cleaned_count = len(cleaned_data)
        removed_count = original_count - cleaned_count
        
        cleaned_datasets[dataset_name] = cleaned_data
        removal_stats[dataset_name] = {
            'original': original_count,
            'cleaned': cleaned_count,
            'removed': removed_count,
            'removal_rate': (removed_count / original_count * 100) if original_count > 0 else 0
        }
        
        print(f"   📊 {dataset_name}:")
        print(f"      Original: {original_count:,} sequences")
        print(f"      Cleaned: {cleaned_count:,} sequences")
        print(f"      Removed: {removed_count:,} sequences ({removal_stats[dataset_name]['removal_rate']:.2f}%)")
    
    # ========== 第三阶段：去除重复序列 ==========
    print(f"\n🔄 PHASE 3: REMOVING DUPLICATE SEQUENCES")
    print("=" * 80)
    
    # 提取清理后的数据集
    cleaned_tr_avp = cleaned_datasets['Initial_TR_AVP']
    cleaned_tr_non_avp = cleaned_datasets['Initial_TR_non_AVP']
    cleaned_custom_avp = cleaned_datasets['Initial_TS_AVP']
    cleaned_custom_non_avp = cleaned_datasets['Initial_TS_non_AVP']
   
    # TR数据集的序列集合
    tr_avp_sequences = set(cleaned_tr_avp['Sequence'].str.upper())
    tr_non_avp_sequences = set(cleaned_tr_non_avp['Sequence'].str.upper())
    
    print(f"   🔎 Comparing datasets for duplicates...")
    
    # TS_AVP: Initial_TS_AVP中与Initial_TR_AVP不重复的序列
    custom_avp_sequences = set(cleaned_custom_avp['Sequence'].str.upper())
    ts_avp_sequences = custom_avp_sequences - tr_avp_sequences
    ts_avp_data = cleaned_custom_avp[cleaned_custom_avp['Sequence'].str.upper().isin(ts_avp_sequences)].copy()
    
    # TS_non_AVP: Initial_TS_non_AVP中与Initial_TR_non_AVP不重复的序列
    custom_non_avp_sequences = set(cleaned_custom_non_avp['Sequence'].str.upper())
    ts_non_avp_sequences = custom_non_avp_sequences - tr_non_avp_sequences
    ts_non_avp_data = cleaned_custom_non_avp[cleaned_custom_non_avp['Sequence'].str.upper().isin(ts_non_avp_sequences)].copy()
    
    # TR数据集保持不变（已清理冲突）
    tr_avp_data = cleaned_tr_avp.copy()
    tr_non_avp_data = cleaned_tr_non_avp.copy()
    
    # 统计去重结果
    avp_overlap = len(custom_avp_sequences.intersection(tr_avp_sequences))
    non_avp_overlap = len(custom_non_avp_sequences.intersection(tr_non_avp_sequences))
    
    print(f"\n   📈 Duplicate removal results:")
    print(f"      AVP datasets overlap: {avp_overlap:,} sequences")
    print(f"      non_AVP datasets overlap: {non_avp_overlap:,} sequences")
    print(f"      TS_AVP (unique custom): {len(ts_avp_data):,} sequences")
    print(f"      TS_non_AVP (unique custom): {len(ts_non_avp_data):,} sequences")
    print(f"      TR_AVP (external): {len(tr_avp_data):,} sequences")
    print(f"      TR_non_AVP (external): {len(tr_non_avp_data):,} sequences")
    
    # ========== 第四阶段：保存最终数据集 ==========
    print(f"\n💾 PHASE 4: SAVING FINAL DATASETS")
    print("=" * 80)
    
    final_datasets = {
        'TS_AVP': ts_avp_data,
        'TS_non_AVP': ts_non_avp_data,
        'TR_AVP': tr_avp_data,
        'TR_non_AVP': tr_non_avp_data
    }
    
    saved_files = {}
    columns_order = ['Id', 'Sequence', 'Source', 'Length', 'Label', 'Type']
    
    for dataset_name, data in final_datasets.items():
        if len(data) > 0:
            # 确保列顺序
            available_columns = [col for col in columns_order if col in data.columns]
            data_to_save = data[available_columns].copy()
            
            filename = f"{dataset_name}.csv"
            file_path = os.path.join(output_dir, filename)
            data_to_save.to_csv(file_path, index=False)
            saved_files[dataset_name] = file_path
            
            print(f"   ✅ {dataset_name}: {len(data):,} sequences → {filename}")
        else:
            print(f"   ⚠️  {dataset_name}: No data to save")
    
    # ========== 第五阶段：独立性验证 ==========
    print(f"\n✅ PHASE 5: INDEPENDENCE VERIFICATION")
    print("=" * 80)
    
    # 创建最终数据集的序列集合
    final_sequences = {}
    for name, data in final_datasets.items():
        if len(data) > 0:
            final_sequences[name] = set(data['Sequence'].str.upper())
        else:
            final_sequences[name] = set()
    
    # 验证独立性
    verification_passed = True
    overlaps_found = []
    
    dataset_names = list(final_sequences.keys())
    for i, name1 in enumerate(dataset_names):
        for j, name2 in enumerate(dataset_names):
            if i < j:  # 避免重复检查
                overlap = final_sequences[name1].intersection(final_sequences[name2])
                if len(overlap) > 0:
                    verification_passed = False
                    overlaps_found.append((name1, name2, len(overlap)))
                    print(f"   ❌ OVERLAP FOUND: {name1} ↔ {name2}: {len(overlap):,} sequences")
                else:
                    print(f"   ✅ {name1} ↔ {name2}: Independent")
    
    if verification_passed:
        print(f"\n   🎉 ALL DATASETS ARE COMPLETELY INDEPENDENT!")
    else:
        print(f"\n   ⚠️  WARNING: {len(overlaps_found)} overlaps found between final datasets")
    
    # ========== 最终统计报告 ==========
    print(f"\n📋 COMPREHENSIVE FINAL REPORT")
    print("=" * 80)
    
    print(f"\n🔹 LABEL CONFLICT ANALYSIS:")
    print(f"   ⚠️  Conflicting sequences: {len(conflict_sequences):,}")
    print(f"   📄 Overlap data file: Overlap_data.csv")
    
    print(f"\n🔹 DATASET TRANSFORMATION:")
    for name, stats in removal_stats.items():
        print(f"   {name:25s}: {stats['original']:>6,} → {stats['cleaned']:>6,} (-{stats['removed']:,})")
    
    print(f"\n🔹 FINAL INDEPENDENT DATASETS:")
    total_final = 0
    for name, data in final_datasets.items():
        count = len(data)
        total_final += count
        print(f"   {name:15s}: {count:>6,} sequences")
    
    print(f"   {'─' * 40}")
    print(f"   {'Total Final':15s}: {total_final:>6,} sequences")
    
    print(f"\n🔹 DATA REDUCTION SUMMARY:")
    total_initial = sum(len(dataset['data']) for dataset in datasets.values())
    reduction = total_initial - total_final
    reduction_rate = (reduction / total_initial * 100) if total_initial > 0 else 0
    print(f"   Initial total: {total_initial:,} sequences")
    print(f"   Final total: {total_final:,} sequences")
    print(f"   Reduction: {reduction:,} sequences ({reduction_rate:.2f}%)")
    
    final_stats = {
        'label_conflicts': {
            'conflict_sequences_count': len(conflict_sequences),
            'conflict_records_count': len(label_conflicts)
        },
        'removal_stats': removal_stats,
        'final_datasets': {name: len(data) for name, data in final_datasets.items()},
        'independence_verification': {
            'passed': verification_passed,
            'overlaps_found': overlaps_found
        },
        'summary': {
            'initial_total': total_initial,
            'final_total': total_final,
            'reduction': reduction,
            'reduction_rate': reduction_rate
        }
    }
    
    import json
    
    # ✅ 修改：保存到日志目录
    stats_file = os.path.join(log_dir, "step2_overlap_analysis_statistics.json")
    
    # ✅ 添加：更详细的日志信息
    detailed_log = {
        'step': 'Step 2 - Overlap Analysis and Dataset Independence',
        'statistics': final_stats,
        'saved_files': saved_files,
        'verification_passed': verification_passed
    }
    
    with open(stats_file, 'w') as f:
        json.dump(detailed_log, f, indent=2, default=str)
    print(f"\n   📄 Statistics saved to: {stats_file}")
    
    
    return {
        'final_datasets': final_datasets,
        'label_conflicts': label_conflicts,
        'conflict_sequences': conflict_sequences,
        'removal_stats': removal_stats,
        'saved_files': saved_files,
        'verification_passed': verification_passed,
        'final_stats': final_stats
    }

# ============================================================================
# 运行第二步分析
# ============================================================================

print("\n🚀 Starting Step 2: Enhanced Overlap Analysis and Dataset Independence")
print("=" * 80)

try:
    # 确保第一步已完成
    if 'step1_detailed_results' not in locals():
        print("❌ Error: Step 1 results not found. Please run Step 1 first.")
    else:
        step2_results = step2_overlap_analysis_enhanced(step1_detailed_results)
        
        if step2_results:
            print("\n✅ Step 2 completed successfully!")
            print(f"   📁 Results stored in step2_results variable")
            
            if step2_results['verification_passed']:
                print(f"   🎉 All final datasets are completely independent!")
            else:
                print(f"   ⚠️  Warning: Some overlaps detected in final datasets")
            
            print(f"   💾 Final datasets saved to: 1_Data/Processed_data_set/Final_merged_data_set/")
            print(f"   📊 Files created:")
            for dataset_name, file_path in step2_results['saved_files'].items():
                print(f"      - {dataset_name}.csv")
            print(f"      - Overlap_data.csv (label conflicts)")
            print(f"      - processing_statistics.json (detailed stats)")
        else:
            print("❌ Step 2 failed")
            
except Exception as e:
    print(f"❌ Error during Step 2: {str(e)}")
    import traceback
    traceback.print_exc()

print("\n🎯 Step 2 Complete - Final independent datasets ready for model training!")


🚀 Starting Step 2: Enhanced Overlap Analysis and Dataset Independence

🚀 STEP 2: ENHANCED OVERLAP ANALYSIS AND DATASET INDEPENDENCE

📊 INITIAL DATASET OVERVIEW:
   📈 Initial_TR_AVP: 3,412 sequences
   📈 Initial_TR_non_AVP: 4,049 sequences
   📈 Initial_TS_AVP: 4,993 sequences
   📈 Initial_TS_non_AVP: 9,402 sequences
   📊 Total: 21,856 sequences

🔍 PHASE 1: LABEL CONFLICT ANALYSIS

   🔎 Analyzing label conflicts across datasets...
   ⚠️  Found 521 sequences with label conflicts
   📊 Total conflict records: 521
   ✅ Label conflicts saved to: 1_Data/Processed_data_set/Final_merged_data_set/Overlap_data.csv

🧹 PHASE 2: REMOVING CONFLICT SEQUENCES FROM DATASETS
   📊 Initial_TS_AVP:
      Original: 4,993 sequences
      Cleaned: 4,474 sequences
      Removed: 519 sequences (10.39%)
   📊 Initial_TS_non_AVP:
      Original: 9,402 sequences
      Cleaned: 9,172 sequences
      Removed: 230 sequences (2.45%)
   📊 Initial_TR_AVP:
      Original: 3,412 sequences
      Cleaned: 3,091 sequences
    

# 3.csv to fasta

In [13]:
import pandas as pd
import numpy as np
import os
from collections import defaultdict

def check_and_fix_duplicate_ids(df):
    """
    检查并修复重复的ID，为重复ID添加数字后缀
    
    Parameters:
    df: DataFrame with columns including 'Id', 'Source', 'Dataset'
    
    Returns:
    df: DataFrame with fixed IDs
    duplicate_info: dict with duplicate statistics
    """
    print(f"🔍 Checking for duplicate IDs...")
    
    # 记录原始ID和修复后的ID
    original_ids = df['Id'].copy()
    id_counts = defaultdict(int)
    fixed_ids = []
    duplicate_info = {
        'total_duplicates': 0,
        'duplicate_groups': 0,
        'fixes_applied': 0
    }
    
    # 第一次遍历：统计每个ID的出现次数
    for idx, row in df.iterrows():
        original_id = str(row['Id']) if pd.notna(row['Id']) else f"Unknown_{idx}"
        id_counts[original_id] += 1
    
    # 找出重复的ID
    duplicate_ids = {id_val: count for id_val, count in id_counts.items() if count > 1}
    
    if duplicate_ids:
        duplicate_info['duplicate_groups'] = len(duplicate_ids)
        duplicate_info['total_duplicates'] = sum(duplicate_ids.values())
        
        print(f"   ⚠️  Found {len(duplicate_ids)} duplicate ID groups affecting {sum(duplicate_ids.values())} sequences")
        
        # 显示重复ID示例
        print(f"   📋 Duplicate ID examples:")
        for i, (dup_id, count) in enumerate(list(duplicate_ids.items())[:5]):
            print(f"      '{dup_id}': {count} occurrences")
        if len(duplicate_ids) > 5:
            print(f"      ... and {len(duplicate_ids) - 5} more")
    else:
        print(f"   ✅ No duplicate IDs found")
        return df, duplicate_info
    
    # 第二次遍历：修复重复ID
    occurrence_counter = defaultdict(int)
    
    for idx, row in df.iterrows():
        original_id = str(row['Id']) if pd.notna(row['Id']) else f"Unknown_{idx}"
        
        if original_id in duplicate_ids:
            occurrence_counter[original_id] += 1
            
            if occurrence_counter[original_id] == 1:
                # 第一次出现，保持原ID
                fixed_id = original_id
            else:
                # 后续出现，添加数字后缀
                suffix_number = occurrence_counter[original_id] - 1  # 从_1开始
                
                # 如果原ID以|结尾，在|前添加后缀
                if original_id.endswith('|'):
                    fixed_id = f"{original_id[:-1]}_{suffix_number}|"
                else:
                    fixed_id = f"{original_id}_{suffix_number}"
                
                duplicate_info['fixes_applied'] += 1
        else:
            # 非重复ID，保持原样
            fixed_id = original_id
        
        fixed_ids.append(fixed_id)
    
    # 更新DataFrame
    df = df.copy()
    df['Id'] = fixed_ids
    
    print(f"   ✅ Applied {duplicate_info['fixes_applied']} ID fixes")
    
    # 显示修复示例
    if duplicate_info['fixes_applied'] > 0:
        print(f"   📝 ID fix examples:")
        fix_count = 0
        for i, (orig, fixed) in enumerate(zip(original_ids, fixed_ids)):
            if orig != fixed and fix_count < 5:
                print(f"      {orig} → {fixed}")
                fix_count += 1
        if duplicate_info['fixes_applied'] > 5:
            print(f"      ... and {duplicate_info['fixes_applied'] - 5} more fixes")
    
    return df, duplicate_info

def merge_datasets_to_both_formats(csv_files, output_base_path, random_seed=42):
    """
    将多个CSV文件合并为FASTA和CSV两种格式，保留所有信息并打乱顺序
    处理重复ID问题
    
    Parameters:
    csv_files: list of dict, 包含文件路径和标签信息
    output_base_path: str, 输出文件的基础路径（不含扩展名）
    random_seed: int, 随机种子
    """
    # ✅ 添加：创建日志目录
    log_dir = "2_Log/2.1_Training set_and_test_set_processing"
    os.makedirs(log_dir, exist_ok=True)
    
    all_data = []
    
    print(f"📊 Reading CSV files:")
    print("-" * 50)
    
    # 读取每个CSV文件
    for file_info in csv_files:
        csv_path = file_info['path']
        dataset_type = file_info['type']
        
        if os.path.exists(csv_path):
            df = pd.read_csv(csv_path)
            
            print(f"   ✅ {os.path.basename(csv_path)}: {len(df):,} sequences")
            
            # 添加数据集类型信息
            df = df.copy()
            df['Dataset'] = dataset_type
            
            # 将数据添加到总列表
            all_data.append(df)
        else:
            print(f"   ❌ File not found: {csv_path}")
    
    if not all_data:
        print("❌ No data found!")
        return False, False
    
    # 合并所有数据
    combined_df = pd.concat(all_data, ignore_index=True)
    
    print(f"\n📊 Combined dataset: {len(combined_df):,} sequences")
    
    # ========== 检查和修复重复ID ==========
    print(f"\n🔧 ID Duplicate Check and Fix:")
    print("-" * 50)
    
    fixed_df, duplicate_info = check_and_fix_duplicate_ids(combined_df)
    
    # 设置随机种子并打乱顺序
    print(f"\n🔀 Shuffling data with random seed {random_seed}...")
    np.random.seed(random_seed)
    shuffled_df = fixed_df.sample(frac=1, random_state=random_seed).reset_index(drop=True)
    
    # 生成输出路径
    csv_output_path = f"{output_base_path}.csv"
    fasta_output_path = f"{output_base_path}.fasta"
    
    # ========== 保存CSV文件 ==========
    print(f"\n💾 Saving CSV file: {csv_output_path}")
    
    # 确保列顺序
    desired_columns = ['Id', 'Sequence', 'Source', 'Length', 'Label', 'Type', 'Dataset']
    available_columns = [col for col in desired_columns if col in shuffled_df.columns]
    
    # 保存CSV
    shuffled_df[available_columns].to_csv(csv_output_path, index=False)
    csv_success = True
    
    # ========== 保存FASTA文件 ==========
    print(f"💾 Saving FASTA file: {fasta_output_path}")
    
    fasta_success = True
    try:
        with open(fasta_output_path, 'w') as fasta_file:
            for _, row in shuffled_df.iterrows():
                # 构建FASTA header，包含所有信息
                header_parts = []
                
                # 添加所有可用信息
                if 'Id' in row and pd.notna(row['Id']):
                    header_parts.append(f"ID={row['Id']}")
                
                if 'Label' in row and pd.notna(row['Label']):
                    header_parts.append(f"Label={row['Label']}")
                
                if 'Type' in row and pd.notna(row['Type']):
                    header_parts.append(f"Type={row['Type']}")
                
                if 'Source' in row and pd.notna(row['Source']):
                    header_parts.append(f"Source={row['Source']}")
                
                if 'Length' in row and pd.notna(row['Length']):
                    header_parts.append(f"Length={row['Length']}")
                
                if 'Dataset' in row and pd.notna(row['Dataset']):
                    header_parts.append(f"Dataset={row['Dataset']}")
                
                # 构建完整的FASTA header
                fasta_header = '|'.join(header_parts)
                
                # 获取序列
                sequence = str(row['Sequence']) if pd.notna(row['Sequence']) else ''
                
                # 写入FASTA格式
                fasta_file.write(f">{fasta_header}\n")
                fasta_file.write(f"{sequence}\n")
    
    except Exception as e:
        print(f"❌ Error writing FASTA file: {e}")
        fasta_success = False
    
    # ========== 统计信息 ==========
    print(f"\n📈 Final Statistics:")
    
    # ID重复处理统计
    if duplicate_info['total_duplicates'] > 0:
        print(f"   🔧 ID Duplicate Processing:")
        print(f"      Original duplicates: {duplicate_info['total_duplicates']} sequences in {duplicate_info['duplicate_groups']} groups")
        print(f"      Fixes applied: {duplicate_info['fixes_applied']}")
        print(f"      Final unique IDs: {shuffled_df['Id'].nunique():,}")
    
    # 按数据集统计
    dataset_counts = shuffled_df['Dataset'].value_counts()
    print(f"   📊 By dataset:")
    for dataset, count in dataset_counts.items():
        print(f"      {dataset}: {count:,} sequences")
    
    # 按标签统计
    if 'Label' in shuffled_df.columns:
        label_counts = shuffled_df['Label'].value_counts()
        print(f"   📊 By label:")
        for label, count in label_counts.items():
            print(f"      {label}: {count:,} sequences")
    
    # 按类型统计
    if 'Type' in shuffled_df.columns:
        type_counts = shuffled_df['Type'].value_counts()
        print(f"   📊 By type:")
        for type_name, count in type_counts.items():
            print(f"      {type_name}: {count:,} sequences")
    
    print(f"   📊 Total: {len(shuffled_df):,} sequences")
    
    if csv_success:
        print(f"   ✅ CSV saved: {csv_output_path}")
    if fasta_success:
        print(f"   ✅ FASTA saved: {fasta_output_path}")
    
    # ✅ 添加：保存转换日志
    import json
    
    conversion_log = {
        'step': 'Step 3 - CSV to FASTA Conversion',
        'output_base_path': output_base_path,
        'random_seed': random_seed,
        'csv_success': csv_success,
        'fasta_success': fasta_success,
        'duplicate_info': duplicate_info,
        'final_statistics': {
            'total_sequences': len(shuffled_df),
            'unique_ids': shuffled_df['Id'].nunique(),
            'dataset_counts': dataset_counts.to_dict(),
            'label_counts': label_counts.to_dict() if 'Label' in shuffled_df.columns else {},
            'type_counts': type_counts.to_dict() if 'Type' in shuffled_df.columns else {}
        },
        'files_created': {
            'csv_file': csv_output_path if csv_success else None,
            'fasta_file': fasta_output_path if fasta_success else None
        }
    }
    
    # 确定日志文件名
    dataset_name = os.path.basename(output_base_path)
    log_file = os.path.join(log_dir, f"step3_{dataset_name}_conversion_log.json")
    
    with open(log_file, 'w') as f:
        json.dump(conversion_log, f, indent=2, default=str)
    
    print(f"   📄 Conversion log saved to: {log_file}")
    
    return csv_success, fasta_success

# ============================================================================
# 合并TR数据集 (TR_AVP.csv + TR_non_AVP.csv → TR.csv + TR.fasta)
# ============================================================================

print("🚀 Converting TR datasets to CSV and FASTA formats")
print("=" * 70)

base_dir = "1_Data/Processed_data_set/Final_merged_data_set"

# TR数据集配置
tr_csv_files = [
    {
        'path': os.path.join(base_dir, "TR_AVP.csv"),
        'type': 'TR_AVP'
    },
    {
        'path': os.path.join(base_dir, "TR_non_AVP.csv"),
        'type': 'TR_non_AVP'
    }
]

tr_output_base = os.path.join(base_dir, "TR")

# 转换TR数据集
tr_csv_success, tr_fasta_success = merge_datasets_to_both_formats(
    tr_csv_files, tr_output_base, random_seed=42
)

# ============================================================================
# 合并TS数据集 (TS_AVP.csv + TS_non_AVP.csv → TS.csv + TS.fasta)
# ============================================================================

print(f"\n🚀 Converting TS datasets to CSV and FASTA formats")
print("=" * 70)

# TS数据集配置
ts_csv_files = [
    {
        'path': os.path.join(base_dir, "TS_AVP.csv"),
        'type': 'TS_AVP'
    },
    {
        'path': os.path.join(base_dir, "TS_non_AVP.csv"),
        'type': 'TS_non_AVP'
    }
]

ts_output_base = os.path.join(base_dir, "TS")

# 转换TS数据集
ts_csv_success, ts_fasta_success = merge_datasets_to_both_formats(
    ts_csv_files, ts_output_base, random_seed=42
)

# ============================================================================
# 总结报告
# ============================================================================

print(f"\n📋 CONVERSION SUMMARY")
print("=" * 70)

# TR数据集结果
print(f"📊 TR Dataset:")
if tr_csv_success:
    tr_csv_path = os.path.join(base_dir, "TR.csv")
    if os.path.exists(tr_csv_path):
        tr_df = pd.read_csv(tr_csv_path)
        unique_ids = tr_df['Id'].nunique()
        total_rows = len(tr_df)
        print(f"   ✅ TR.csv: {total_rows:,} sequences ({unique_ids:,} unique IDs)")
    else:
        print(f"   ❌ TR.csv: File not found")
else:
    print(f"   ❌ TR.csv: Failed to create")

if tr_fasta_success:
    tr_fasta_path = os.path.join(base_dir, "TR.fasta")
    if os.path.exists(tr_fasta_path):
        with open(tr_fasta_path, 'r') as f:
            lines = f.readlines()
        seq_count = len([line for line in lines if line.startswith('>')])
        print(f"   ✅ TR.fasta: {seq_count:,} sequences")
    else:
        print(f"   ❌ TR.fasta: File not found")
else:
    print(f"   ❌ TR.fasta: Failed to create")

# TS数据集结果
print(f"\n📊 TS Dataset:")
if ts_csv_success:
    ts_csv_path = os.path.join(base_dir, "TS.csv")
    if os.path.exists(ts_csv_path):
        ts_df = pd.read_csv(ts_csv_path)
        unique_ids = ts_df['Id'].nunique()
        total_rows = len(ts_df)
        print(f"   ✅ TS.csv: {total_rows:,} sequences ({unique_ids:,} unique IDs)")
    else:
        print(f"   ❌ TS.csv: File not found")
else:
    print(f"   ❌ TS.csv: Failed to create")

if ts_fasta_success:
    ts_fasta_path = os.path.join(base_dir, "TS.fasta")
    if os.path.exists(ts_fasta_path):
        with open(ts_fasta_path, 'r') as f:
            lines = f.readlines()
        seq_count = len([line for line in lines if line.startswith('>')])
        print(f"   ✅ TS.fasta: {seq_count:,} sequences")
    else:
        print(f"   ❌ TS.fasta: File not found")
else:
    print(f"   ❌ TS.fasta: Failed to create")

print(f"\n📁 All files saved in: {base_dir}")
print(f"🎲 Random seed used: 42")

print(f"\n🎯 Conversion complete! All files generated with random seed 42.")
print(f"💡 Duplicate IDs have been automatically resolved with numeric suffixes (_1, _2, etc.).")

🚀 Converting TR datasets to CSV and FASTA formats
📊 Reading CSV files:
--------------------------------------------------
   ✅ TR_AVP.csv: 3,091 sequences
   ✅ TR_non_AVP.csv: 3,719 sequences

📊 Combined dataset: 6,810 sequences

🔧 ID Duplicate Check and Fix:
--------------------------------------------------
🔍 Checking for duplicate IDs...
   ⚠️  Found 76 duplicate ID groups affecting 152 sequences
   📋 Duplicate ID examples:
      'neg1786': 2 occurrences
      'neg1762': 2 occurrences
      'neg1761': 2 occurrences
      'neg1740': 2 occurrences
      'neg1703': 2 occurrences
      ... and 71 more
   ✅ Applied 76 ID fixes
   📝 ID fix examples:
      neg2124 → neg2124_1
      neg105 → neg105_1
      neg171 → neg171_1
      neg299 → neg299_1
      neg560 → neg560_1
      ... and 71 more fixes

🔀 Shuffling data with random seed 42...

💾 Saving CSV file: 1_Data/Processed_data_set/Final_merged_data_set/TR.csv
💾 Saving FASTA file: 1_Data/Processed_data_set/Final_merged_data_set/TR.fasta

