In [1]:
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from rdkit import Chem

In [2]:
# for file in os.listdir('ademt_data'):
#     if file.endswith('.csv'):
#         df = pd.read_csv(f'ademt_data/{file}')
#         print(df.head())
        
#         if len(set(list(df['Label']))) == 2:
#             os.system(f'cp ademt_data/{file} ./ademt_data/cla/')
#         else:
#             os.system(f'cp ademt_data/{file} ./ademt_data/reg/')


In [3]:
import re
from rdkit import Chem

def correct_smiles_errors(error_log: str) -> dict:
    """
    从 RDKit 的错误日志中提取并尝试修复无法解析的 SMILES 字符串。

    Args:
        error_log: 包含了 SMILES 解析错误的日志文本。

    Returns:
        一个字典，其中键是原始的错误 SMILES 字符串，值是修复后的 SMILES 字符串。
        如果一个 SMILES 字符串经过所有尝试后仍然无法被修复，那么对应的值将会是 None。
    """
    # 从日志中提取所有解析失败的 SMILES 字符串
    erroneous_smiles_list = re.findall(r"Failed parsing SMILES '(.*?)'", error_log)
    
    corrected_smiles = {}

    for smiles in erroneous_smiles_list:
        corrected_smiles[smiles] = None  # 默认为未成功修复
        
        # 尝试一系列的修复方法
        # 1. 替换常见的可能引起问题的离子形式
        temp_smiles = smiles.replace('[N+H](O)[O-]', '[N+](=O)[O-]')
        temp_smiles = temp_smiles.replace('[N+H2]', '[NH2+]')
        temp_smiles = temp_smiles.replace('[N+H3]', '[NH3+]')
        
        # 2. 移除在方括号内明确指定的氢原子，让 RDKit 自动处理
        temp_smiles = re.sub(r'\[([A-Za-z]+)H[0-9]?\]', r'[\1]', temp_smiles)
        
        # 3. 移除特殊官能团周围的括号，以简化结构
        temp_smiles = temp_smiles.replace('([N+H](O)[O-])', '[N+](=O)[O-]')
        
        # 验证修复后的 SMILES 是否可读
        mol = Chem.MolFromSmiles(temp_smiles, sanitize=False) # 先不进行化学合理性检查，以尽可能多地解析
        if mol is not None:
            try:
                # 尝试进行化学合理性检查，这是更严格的验证
                Chem.SanitizeMol(mol)
                corrected_smiles[smiles] = temp_smiles
                continue  # 如果成功，则处理下一个 SMILES
            except Exception:
                # 如果化学合理性检查失败，仍然认为它是一个有效的修复（因为它至少可读）
                corrected_smiles[smiles] = temp_smiles
                continue

    return corrected_smiles

In [4]:
import pandas as pd
import os
import re
from rdkit import Chem
from rdkit import rdBase
from functools import reduce
import logging

# 禁用RDKit的详细错误日志，以便我们自己捕获和处理
rdBase.DisableLog('rdApp.error')

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- 核心辅助函数 (带有诊断功能) ---

def smiles_to_inchikey(smiles_list):
    """将SMILES列表转换为InChIKey列表。"""
    inchikeys = []
    for smiles in smiles_list:
        if pd.isna(smiles) or smiles is None:
            inchikeys.append(None)
            continue
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                inchikey = Chem.MolToInchiKey(mol)
                inchikeys.append(inchikey)
            else:
                inchikeys.append(None)
        except Exception:
            inchikeys.append(None)
    return inchikeys


def diagnose_and_fix_smiles(smiles: str) -> dict:
    """
    诊断并尝试修复单个SMILES字符串。

    Returns:
        一个包含诊断结果的字典:
        {'status': 'valid'/'fixed'/'unfixable', 'smiles': str/None, 'error': str/None}
    """
    if not isinstance(smiles, str) or not smiles.strip():
        return {'status': 'unfixable', 'smiles': None, 'error': 'Input is not a valid string or is empty.'}

    smiles = smiles.strip()

    # 1. 尝试原始SMILES
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol is not None:
            return {'status': 'valid', 'smiles': smiles, 'error': None}
    except Exception as e:
        original_error = str(e)

    # 2. 应用一系列修复策略
    smiles_fixed = smiles
    
    # 策略 A: 处理常见的元素符号错误
    element_typo_replacements = {
        'IN': '[In]',
        'SN': '[Sn]',
        'PB': '[Pb]',
        'AS': '[As]',
        'SB': '[Sb]',
        'BI': '[Bi]'
    }
    
    for typo, correction in element_typo_replacements.items():
        # 使用更精确的正则表达式匹配
        pattern = r'(?<![A-Za-z\[\]])' + re.escape(typo) + r'(?![A-Za-z\]\+\-])'
        smiles_fixed = re.sub(pattern, correction, smiles_fixed)

    # 策略 B: 转换 [N+H] 风格为 [NH+] 风格
    charge_hydrogen_format_replacements = {
        r'\[N\+H\]': '[NH+]', 
        r'\[N\+H2\]': '[NH2+]', 
        r'\[N\+H3\]': '[NH3+]', 
        r'\[n\+H\]': '[nH+]', 
        r'\[O\+H\]': '[OH+]',
        r'\[S\+H\]': '[SH+]'
    }
    
    for pattern, replacement in charge_hydrogen_format_replacements.items():
        smiles_fixed = re.sub(pattern, replacement, smiles_fixed)
    
    # 策略 C: 修复不平衡的电荷（移除不必要的显式氢原子）
    # 匹配 [元素+电荷H数字] 格式并简化
    smiles_fixed = re.sub(r'\[([A-Za-z@\*]+[+\-]\d*)H\d*\]', r'[\1]', smiles_fixed)

    # 策略 D: 处理可能的括号不匹配问题
    # 统计括号数量
    open_brackets = smiles_fixed.count('(')
    close_brackets = smiles_fixed.count(')')
    if open_brackets != close_brackets:
        # 尝试平衡括号
        if open_brackets > close_brackets:
            smiles_fixed += ')' * (open_brackets - close_brackets)
        elif close_brackets > open_brackets:
            smiles_fixed = '(' * (close_brackets - open_brackets) + smiles_fixed

    # 策略 E: 处理方括号不匹配
    open_square = smiles_fixed.count('[')
    close_square = smiles_fixed.count(']')
    if open_square != close_square:
        if open_square > close_square:
            smiles_fixed += ']' * (open_square - close_square)
        elif close_square > open_square:
            smiles_fixed = '[' * (close_square - open_square) + smiles_fixed

    # 策略 F: 处理特殊的带电离子表示法问题
    # 修复 [C-] 这样可能有问题的离子表示
    smiles_fixed = re.sub(r'\[C-\]', '[C-]', smiles_fixed)  # 这个例子可能不会改变，但保留模式
    
    # 策略 G: 处理连续的电荷符号
    smiles_fixed = re.sub(r'([+\-])\1+', r'\1', smiles_fixed)

    # 3. 尝试解析修复后的SMILES
    try:
        mol = Chem.MolFromSmiles(smiles_fixed, sanitize=True)
        if mol is not None:
            # 验证修复的分子是否合理
            try:
                # 尝试标准化SMILES来验证
                canonical_smiles = Chem.MolToSmiles(mol)
                if canonical_smiles:
                    return {'status': 'fixed', 'smiles': smiles_fixed, 'error': None}
            except Exception:
                pass
            
            # 如果标准化失败但分子对象有效，仍然返回修复状态
            if smiles_fixed != smiles:
                return {'status': 'fixed', 'smiles': smiles_fixed, 'error': None}
            else:
                return {'status': 'valid', 'smiles': smiles, 'error': None}
    except Exception as e:
        final_error = str(e)
        return {'status': 'unfixable', 'smiles': None, 'error': f"Original: {original_error if 'original_error' in locals() else 'Unknown'}, After fix: {final_error}"}

    # 如果所有修复尝试都失败
    return {'status': 'unfixable', 'smiles': None, 'error': 'All repair strategies failed.'}


def validate_dataframe(df, filename):
    """验证数据框的基本结构"""
    if df.empty:
        logger.warning(f"文件 {filename} 是空的")
        return False
    
    if 'smiles' not in df.columns:
        logger.warning(f"文件 {filename} 缺少 'smiles' 列")
        return False
    
    if 'Label' not in df.columns:
        logger.warning(f"文件 {filename} 缺少 'Label' 列")
        return False
    
    return True


def process_and_merge_datasets_with_diagnostics(input_dir, output_path, files_to_skip=None, report_unfixable=True):
    """
    加载、修复、诊断、清洗并合并CSV文件。
    """
    if files_to_skip is None:
        files_to_skip = []
    
    skip_list_lower = [item.lower().strip() for item in files_to_skip]
    all_processed_dfs = []
    unfixable_report = []
    processing_stats = {
        'total_files': 0,
        'processed_files': 0,
        'skipped_files': 0,
        'error_files': 0,
        'total_molecules': 0,
        'valid_molecules': 0,
        'fixed_molecules': 0,
        'unfixable_molecules': 0
    }

    logger.info(f"开始处理目录 '{input_dir}' ...")
    
    if not os.path.exists(input_dir):
        logger.error(f"输入目录不存在: {input_dir}")
        return
    
    csv_files = [f for f in os.listdir(input_dir) if f.endswith('.csv')]
    processing_stats['total_files'] = len(csv_files)
    
    for filename in csv_files:
        task_name_original = filename.split('.')[0]
        task_name_normalized = task_name_original.strip().lower()

        if task_name_normalized in skip_list_lower:
            logger.info(f"跳过文件: {filename}")
            processing_stats['skipped_files'] += 1
            continue
            
        logger.info(f"正在处理: {filename}")
        file_path = os.path.join(input_dir, filename)

        try:
            df = pd.read_csv(file_path)
            
            # 标准化列名
            if 'SMILES' in df.columns:
                df.rename(columns={'SMILES': 'smiles'}, inplace=True)
            
            # 验证数据框
            if not validate_dataframe(df, filename):
                processing_stats['error_files'] += 1
                continue

            original_count = len(df)
            processing_stats['total_molecules'] += original_count

            # 应用诊断和修复函数
            logger.info(f"  诊断和修复 {original_count} 个分子...")
            diagnostics = df['smiles'].apply(diagnose_and_fix_smiles)
            
            df['smiles_fixed'] = diagnostics.apply(lambda x: x['smiles'])
            df['status'] = diagnostics.apply(lambda x: x['status'])
            df['error_msg'] = diagnostics.apply(lambda x: x['error'])

            # 统计结果
            status_counts = df['status'].value_counts()
            processing_stats['valid_molecules'] += status_counts.get('valid', 0)
            processing_stats['fixed_molecules'] += status_counts.get('fixed', 0)
            processing_stats['unfixable_molecules'] += status_counts.get('unfixable', 0)

            logger.info(f"  结果: {status_counts.get('valid', 0)} 有效, {status_counts.get('fixed', 0)} 已修复, {status_counts.get('unfixable', 0)} 无法修复")

            # 收集无法修复的SMILES报告
            if report_unfixable:
                unfixable_df = df[df['status'] == 'unfixable']
                for _, row in unfixable_df.iterrows():
                    unfixable_report.append({
                        'file': filename,
                        'original_smiles': row['smiles'],
                        'error': row['error_msg']
                    })

            # 清洗数据：只保留有效和已修复的分子
            cleaned_df = df[df['status'].isin(['valid', 'fixed'])].copy()
            
            if cleaned_df.empty:
                logger.warning(f"  文件 {filename} 清洗后没有有效数据")
                continue

            # 生成InChIKey
            logger.info(f"  生成 {len(cleaned_df)} 个分子的InChIKey...")
            cleaned_df['Inchikey'] = smiles_to_inchikey(cleaned_df['smiles_fixed'])
            
            # 移除无法生成InChIKey的行
            cleaned_df = cleaned_df.dropna(subset=['Inchikey'])
            
            if cleaned_df.empty:
                logger.warning(f"  文件 {filename} 生成InChIKey后没有有效数据")
                continue

            # 准备最终数据框
            cleaned_df.rename(columns={'Label': task_name_original}, inplace=True)
            final_df = cleaned_df[['smiles_fixed', 'Inchikey', task_name_original]].copy()
            final_df.rename(columns={'smiles_fixed': 'smiles'}, inplace=True)
            
            all_processed_dfs.append(final_df)
            processing_stats['processed_files'] += 1
            
            logger.info(f"  成功处理 {len(final_df)} 个分子")

        except Exception as e:
            logger.error(f"处理文件 {filename} 时发生严重错误: {e}")
            processing_stats['error_files'] += 1

    # 合并数据
    if all_processed_dfs:
        logger.info("开始合并所有已处理的数据集...")
        try:
            merged_df = reduce(
                lambda left, right: pd.merge(left, right, on=['smiles', 'Inchikey'], how='outer'), 
                all_processed_dfs
            )
            
            # 创建输出目录（如果不存在）
            output_dir = os.path.dirname(output_path)
            if output_dir and not os.path.exists(output_dir):
                os.makedirs(output_dir)
            
            merged_df.to_csv(output_path, index=False)
            logger.info(f"最终合并的数据已保存至: {output_path}")
            logger.info(f"合并后的数据集包含 {len(merged_df)} 个唯一分子")
            
        except Exception as e:
            logger.error(f"合并数据时发生错误: {e}")
    else:
        logger.warning("没有可合并的数据")

    # 打印处理统计
    logger.info("处理统计:")
    logger.info(f"  总文件数: {processing_stats['total_files']}")
    logger.info(f"  成功处理: {processing_stats['processed_files']}")
    logger.info(f"  跳过文件: {processing_stats['skipped_files']}")
    logger.info(f"  错误文件: {processing_stats['error_files']}")
    logger.info(f"  总分子数: {processing_stats['total_molecules']}")
    logger.info(f"  有效分子: {processing_stats['valid_molecules']}")
    logger.info(f"  修复分子: {processing_stats['fixed_molecules']}")
    logger.info(f"  无法修复: {processing_stats['unfixable_molecules']}")

    # 保存无法修复的SMILES报告
    if unfixable_report:
        logger.info(f"发现 {len(unfixable_report)} 个无法修复的SMILES")
        report_df = pd.DataFrame(unfixable_report)
        
        # 保存报告到文件
        report_path = output_path.replace('.csv', '_unfixable_report.csv')
        report_df.to_csv(report_path, index=False)
        logger.info(f"无法修复的SMILES报告已保存至: {report_path}")
        
        if len(unfixable_report) <= 20:  # 只在控制台显示少量记录
            logger.info("无法修复的SMILES样例:")
            for i, record in enumerate(unfixable_report[:10]):
                logger.info(f"  {i+1}. 文件: {record['file']}, SMILES: {record['original_smiles'][:50]}...")


# --- 测试和执行 ---
if __name__ == '__main__':
    # 测试新的诊断函数
    logger.info("测试诊断功能...")
    test_cases = [
        'ClN12([C-]3CC(OC(=O)C(O)(c4ccccc4)c4ccccc4)C[C-]1CC3)[C-2]CC[C-2]2',
        'CCO',  # 简单有效的SMILES
        'C[N+H3]Cl-',  # 需要修复的格式
        'Invalid_SMILES',  # 无效的SMILES
        '',  # 空字符串
    ]
    
    for test_smiles in test_cases:
        result = diagnose_and_fix_smiles(test_smiles)
        logger.info(f"原始: '{test_smiles}' -> {result['status']} -> '{result['smiles']}'")
        if result['error']:
            logger.info(f"  错误: {result['error']}")
    
    logger.info("-" * 50)
    
    # 运行主流程
    INPUT_REG_DIR = 'ademt_data/wash_reg'
    OUTPUT_REG_FILE = 'ademt_data/reg1.csv'
    SKIP_FILES = ['logD', 'logP', 'logS', 'pka_acidic', 'pka_basic']

    # 取消注释以运行主处理流程
    process_and_merge_datasets_with_diagnostics(INPUT_REG_DIR, OUTPUT_REG_FILE, SKIP_FILES, report_unfixable=True)

2025-08-15 22:09:21,818 - INFO - 测试诊断功能...
2025-08-15 22:09:21,823 - INFO - 原始: 'ClN12([C-]3CC(OC(=O)C(O)(c4ccccc4)c4ccccc4)C[C-]1CC3)[C-2]CC[C-2]2' -> unfixable -> 'None'
2025-08-15 22:09:21,823 - INFO -   错误: All repair strategies failed.
2025-08-15 22:09:21,823 - INFO - 原始: 'CCO' -> valid -> 'CCO'
2025-08-15 22:09:21,824 - INFO - 原始: 'C[N+H3]Cl-' -> unfixable -> 'None'
2025-08-15 22:09:21,824 - INFO -   错误: All repair strategies failed.
2025-08-15 22:09:21,824 - INFO - 原始: 'Invalid_SMILES' -> unfixable -> 'None'
2025-08-15 22:09:21,825 - INFO -   错误: All repair strategies failed.
2025-08-15 22:09:21,825 - INFO - 原始: '' -> unfixable -> 'None'
2025-08-15 22:09:21,825 - INFO -   错误: Input is not a valid string or is empty.
2025-08-15 22:09:21,826 - INFO - --------------------------------------------------
2025-08-15 22:09:21,826 - INFO - 开始处理目录 'ademt_data/wash_reg' ...
2025-08-15 22:09:21,826 - INFO - 跳过文件: logp.csv
2025-08-15 22:09:21,826 - INFO - 跳过文件: pka_acidic.csv
2025-08-15 22:0

In [5]:
# import pandas as pd
# import os
# import re
# from rdkit import Chem
# from rdkit import rdBase
# from functools import reduce

# # 禁用RDKit的详细错误日志，以便我们自己捕获和处理
# rdBase.DisableLog('rdApp.error')

# # --- 核心辅助函数 (带有诊断功能) ---

# def smiles_to_inchikey(smiles_list):
#     """将SMILES列表转换为InChIKey列表。"""
#     # ... (代码与之前相同) ...
#     inchikeys = []
#     for smiles in smiles_list:
#         if pd.isna(smiles):
#             inchikeys.append(None)
#             continue
#         mol = Chem.MolFromSmiles(smiles)
#         if mol is not None:
#             inchikey = Chem.MolToInchiKey(mol)
#             inchikeys.append(inchikey)
#         else:
#             inchikeys.append(None)
#     return inchikeys


# def diagnose_and_fix_smiles(smiles: str) -> dict:
#     """
#     诊断并尝试修复单个SMILES字符串。

#     Returns:
#         一个包含诊断结果的字典:
#         {'status': 'valid'/'fixed'/'unfixable', 'smiles': str/None, 'error': str/None}
#     """
#     if not isinstance(smiles, str) or not smiles:
#         return {'status': 'unfixable', 'smiles': None, 'error': 'Input is not a valid string.'}

#     # 1. 尝试原始SMILES
#     try:
#         if Chem.MolFromSmiles(smiles, sanitize=True) is None:
#             return {'status': 'valid', 'smiles': smiles, 'error': None}
#     except Exception as e:
#         pass # 继续尝试修复

#     # 2. 应用一系列修复策略
#     smiles_fixed = smiles
    
#     # 策略 A: 处理常见的元素符号错误
#     element_typo_replacements = {'IN': '[In]'}
#     for typo, correction in element_typo_replacements.items():
#         smiles_fixed = re.sub(r'(^|[^a-zA-Z])' + re.escape(typo) + r'($|[^a-zA-Z])', r'\1' + correction + r'\2', smiles_fixed)

#     # 策略 B: 转换 [N+H] 风格为 [NH+] 风格
#     charge_hydrogen_format_replacements = {
#         '[N+H]': '[NH+]', '[N+H2]': '[NH2+]', '[N+H3]': '[NH3+]', '[n+H]': '[nH+]', '[O+H]': '[OH+]'
#     }
#     for old, new in charge_hydrogen_format_replacements.items():
#         smiles_fixed = smiles_fixed.replace(old, new)
    
#     # 策略 C: 移除其他带电原子中不必要的显式氢原子
#     smiles_fixed = re.sub(r'\[([A-Za-z@\*]+[+\-]\d+)H\d*\]', r'[\1]', smiles_fixed)

#     # 3. 尝试解析修复后的SMILES并捕获错误
#     try:
#         mol = Chem.MolFromSmiles(smiles_fixed, sanitize=True)
#         if mol is not None:
#             # 检查修复是否真的改变了字符串
#             if smiles_fixed != smiles:
#                 return {'status': 'fixed', 'smiles': smiles_fixed, 'error': None}
#             else:
#                 # 这种情况理论上不应该发生，因为原始的已经检查过了
#                 return {'status': 'valid', 'smiles': smiles, 'error': None}
#     except Exception as e:
#         # 如果修复后仍然失败，捕获错误信息
#         return {'status': 'unfixable', 'smiles': None, 'error': str(e)}

#     # 如果修复后没有变化且原始SMILES无效
#     return {'status': 'unfixable', 'smiles': None, 'error': 'Initial parsing failed and no fixes were applicable.'}


# # --- 主处理流程 ---

# def process_and_merge_datasets_with_diagnostics(input_dir, output_path, files_to_skip, report_unfixable=True):
#     """
#     加载、修复、诊断、清洗并合并CSV文件。
#     """
#     skip_list_lower = [item.lower().strip() for item in files_to_skip]
#     all_processed_dfs = []
#     unfixable_report = []

#     print(f"开始处理目录 '{input_dir}' ...")
    
#     for filename in os.listdir(input_dir):
#         if not filename.endswith('.csv'): continue
        
#         task_name_original = filename.split('.')[0]
#         task_name_normalized = task_name_original.strip().lower()

#         if task_name_normalized in skip_list_lower: continue
            
#         print(f"\n--- 正在处理: {filename} ---")
#         file_path = os.path.join(input_dir, filename)

#         try:
#             df = pd.read_csv(file_path)
#             if 'SMILES' in df.columns: df.rename(columns={'SMILES': 'smiles'}, inplace=True)
#             if 'smiles' not in df.columns or 'Label' not in df.columns: continue

#             # 应用诊断和修复函数
#             diagnostics = df['smiles'].apply(diagnose_and_fix_smiles)
#             df['smiles_fixed'] = diagnostics.apply(lambda x: x['smiles'])
#             df['status'] = diagnostics.apply(lambda x: x['status'])
#             df['error_msg'] = diagnostics.apply(lambda x: x['error'])

#             # 报告无法修复的SMILES
#             if report_unfixable:
#                 unfixable_df = df[df['status'] == 'unfixable']
#                 for _, row in unfixable_df.iterrows():
#                     unfixable_report.append({
#                         'file': filename,
#                         'original_smiles': row['smiles'],
#                         'error': row['error_msg']
#                     })

#             # 清洗数据
#             cleaned_df = df.dropna(subset=['smiles_fixed']).copy()
#             if cleaned_df.empty: continue

#             # 后续处理...
#             cleaned_df['Inchikey'] = smiles_to_inchikey(cleaned_df['smiles_fixed'])
#             cleaned_df.rename(columns={'Label': task_name_original}, inplace=True)
#             final_df = cleaned_df[['smiles_fixed', 'Inchikey', task_name_original]]
#             final_df.rename(columns={'smiles_fixed': 'smiles'}, inplace=True)
#             all_processed_dfs.append(final_df)

#         except Exception as e:
#             print(f"  -> 处理文件 {filename} 时发生严重错误: {e}")

#     # 合并数据...
#     if all_processed_dfs:
#         print("\n--- 开始合并所有已处理的数据集 ---")
#         merged_df = reduce(lambda left, right: pd.merge(left, right, on=['smiles', 'Inchikey'], how='outer'), all_processed_dfs)
#         merged_df.to_csv(output_path, index=False)
#         print(f"最终合并的数据已保存至: {output_path}")
#     else:
#         print("\n没有可合并的数据。")

#     # 打印无法修复的SMILES报告
#     if unfixable_report:
#         print("\n--- 无法修复的SMILES报告 ---")
#         report_df = pd.DataFrame(unfixable_report)
#         print(report_df.to_string())


# # --- 执行 ---
# if __name__ == '__main__':
#     # --- 测试新的诊断函数 ---
#     print("--- 测试诊断功能 ---")
#     test_smiles = 'BrN([C-2]C)([C-2]C)([C-2]C)[C-2]C'
#     result = diagnose_and_fix_smiles(test_smiles)
#     print(f"原始SMILES: {test_smiles}")
#     print(f"诊断结果: {result}")
#     print("-" * 20)
    
#     # --- 运行您的主流程 ---
#     INPUT_REG_DIR = 'ademt_data/wash_reg'
#     OUTPUT_REG_FILE = 'ademt_data/reg.csv'
#     # SKIP_FILES = ['logD', 'logP', 'logS', 'pka_acidic', 'pka_basic']
#     SKIP_FILES = []

#     process_and_merge_datasets_with_diagnostics(INPUT_REG_DIR, OUTPUT_REG_FILE, SKIP_FILES, report_unfixable=True)

In [4]:
import pandas as pd
import os
import re
from rdkit import Chem
from rdkit import rdBase
from functools import reduce
from sklearn.model_selection import train_test_split
import numpy as np

# 禁用RDKit的详细错误日志，以便我们自己捕获和处理
rdBase.DisableLog('rdApp.error')

# --- 核心辅助函数 (带有诊断功能) ---

def smiles_to_inchikey(smiles_list):
    """将SMILES列表转换为InChIKey列表。"""
    inchikeys = []
    for smiles in smiles_list:
        if pd.isna(smiles):
            inchikeys.append(None)
            continue
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            inchikey = Chem.MolToInchiKey(mol)
            inchikeys.append(inchikey)
        else:
            inchikeys.append(None)
    return inchikeys


def diagnose_and_fix_smiles(smiles: str) -> dict:
    """
    诊断并尝试修复单个SMILES字符串。

    Returns:
        一个包含诊断结果的字典:
        {'status': 'valid'/'fixed'/'unfixable', 'smiles': str/None, 'error': str/None}
    """
    if not isinstance(smiles, str) or not smiles:
        return {'status': 'unfixable', 'smiles': None, 'error': 'Input is not a valid string.'}

    # 1. 尝试原始SMILES
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol is not None:
            return {'status': 'valid', 'smiles': smiles, 'error': None}
    except Exception as e:
        pass # 继续尝试修复

    # 2. 应用一系列修复策略
    smiles_fixed = smiles
    
    # 策略 A: 处理常见的元素符号错误
    element_typo_replacements = {'IN': '[In]'}
    for typo, correction in element_typo_replacements.items():
        smiles_fixed = re.sub(r'(^|[^a-zA-Z])' + re.escape(typo) + r'($|[^a-zA-Z])', r'\1' + correction + r'\2', smiles_fixed)

    # 策略 B: 转换 [N+H] 风格为 [NH+] 风格
    charge_hydrogen_format_replacements = {
        '[N+H]': '[NH+]', '[N+H2]': '[NH2+]', '[N+H3]': '[NH3+]', '[n+H]': '[nH+]', '[O+H]': '[OH+]'
    }
    for old, new in charge_hydrogen_format_replacements.items():
        smiles_fixed = smiles_fixed.replace(old, new)
    
    # 策略 C: 移除其他带电原子中不必要的显式氢原子
    smiles_fixed = re.sub(r'\[([A-Za-z@\*]+[+\-]\d+)H\d*\]', r'[\1]', smiles_fixed)

    # 3. 尝试解析修复后的SMILES并捕获错误
    try:
        mol = Chem.MolFromSmiles(smiles_fixed, sanitize=True)
        if mol is not None:
            # 检查修复是否真的改变了字符串
            if smiles_fixed != smiles:
                return {'status': 'fixed', 'smiles': smiles_fixed, 'error': None}
            else:
                # 这种情况理论上不应该发生，因为原始的已经检查过了
                return {'status': 'valid', 'smiles': smiles, 'error': None}
    except Exception as e:
        # 如果修复后仍然失败，捕获错误信息
        return {'status': 'unfixable', 'smiles': None, 'error': str(e)}

    # 如果修复后没有变化且原始SMILES无效
    return {'status': 'unfixable', 'smiles': None, 'error': 'Initial parsing failed and no fixes were applicable.'}


def split_dataset(df, task_name, test_size=0.1, val_size=0.1, random_state=42):
    """
    将数据集按照8:1:1的比例分割为训练集、验证集和测试集
    
    Args:
        df: 包含数据的DataFrame
        task_name: 任务名称
        test_size: 测试集比例
        val_size: 验证集比例
        random_state: 随机种子
    
    Returns:
        train_df, val_df, test_df: 分割后的三个数据集
    """
    if len(df) < 10:  # 如果数据太少，不进行分割
        print(f"  警告: {task_name} 数据量太少 ({len(df)} 样本)，不进行分割")
        return df.copy(), pd.DataFrame(), pd.DataFrame()
    
    # 首先分出测试集
    train_val_df, test_df = train_test_split(
        df, 
        test_size=test_size, 
        random_state=random_state,
        stratify=None  # 对于回归任务不使用分层抽样
    )
    
    # 再从训练+验证集中分出验证集
    if len(train_val_df) < 5:  # 如果剩余数据太少
        return train_val_df.copy(), pd.DataFrame(), test_df
    
    # 计算验证集在剩余数据中的比例
    val_size_adjusted = val_size / (1 - test_size)
    
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=val_size_adjusted,
        random_state=random_state
    )
    
    return train_df, val_df, test_df


def process_and_split_datasets(input_dir, output_dir, files_to_skip=None, split_ratios=(0.8, 0.1, 0.1), random_state=42, report_unfixable=True):
    """
    加载、修复、诊断、清洗、分割并合并CSV文件。
    
    Args:
        input_dir: 输入目录路径
        output_dir: 输出目录路径
        files_to_skip: 要跳过的文件列表
        split_ratios: 分割比例 (train, val, test)
        random_state: 随机种子
        report_unfixable: 是否报告无法修复的SMILES
    """
    if files_to_skip is None:
        files_to_skip = []
    
    skip_list_lower = [item.lower().strip() for item in files_to_skip]
    
    # 存储所有任务的分割数据
    all_train_dfs = []
    all_val_dfs = []
    all_test_dfs = []
    
    unfixable_report = []
    split_summary = []

    print(f"开始处理目录 '{input_dir}' ...")
    print(f"分割比例: 训练集 {split_ratios[0]:.1%}, 验证集 {split_ratios[1]:.1%}, 测试集 {split_ratios[2]:.1%}")
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    for filename in os.listdir(input_dir):
        if not filename.endswith('.csv'): 
            continue
        
        task_name_original = filename.split('.')[0]
        task_name_normalized = task_name_original.strip().lower()

        if task_name_normalized in skip_list_lower: 
            print(f"跳过文件: {filename}")
            continue
            
        print(f"\n--- 正在处理: {filename} ---")
        file_path = os.path.join(input_dir, filename)

        try:
            df = pd.read_csv(file_path)
            if 'SMILES' in df.columns: 
                df.rename(columns={'SMILES': 'smiles'}, inplace=True)
            if 'smiles' not in df.columns or 'Label' not in df.columns: 
                print(f"  跳过: 缺少必要的列 (smiles 或 Label)")
                continue

            print(f"  原始数据: {len(df)} 样本")

            # 应用诊断和修复函数
            diagnostics = df['smiles'].apply(diagnose_and_fix_smiles)
            df['smiles_fixed'] = diagnostics.apply(lambda x: x['smiles'])
            df['status'] = diagnostics.apply(lambda x: x['status'])
            df['error_msg'] = diagnostics.apply(lambda x: x['error'])

            # 报告无法修复的SMILES
            if report_unfixable:
                unfixable_df = df[df['status'] == 'unfixable']
                for _, row in unfixable_df.iterrows():
                    unfixable_report.append({
                        'file': filename,
                        'original_smiles': row['smiles'],
                        'error': row['error_msg']
                    })

            # 清洗数据
            cleaned_df = df.dropna(subset=['smiles_fixed']).copy()
            if cleaned_df.empty: 
                print(f"  清洗后无有效数据")
                continue

            print(f"  清洗后数据: {len(cleaned_df)} 样本")

            # 生成InChIKey和重命名
            cleaned_df['Inchikey'] = smiles_to_inchikey(cleaned_df['smiles_fixed'])
            cleaned_df.rename(columns={'Label': task_name_original}, inplace=True)
            final_df = cleaned_df[['smiles_fixed', 'Inchikey', task_name_original]].copy()
            final_df.rename(columns={'smiles_fixed': 'smiles'}, inplace=True)
            
            # 移除InChIKey生成失败的行
            final_df = final_df.dropna(subset=['Inchikey'])
            
            if final_df.empty:
                print(f"  生成InChIKey后无有效数据")
                continue

            print(f"  最终数据: {len(final_df)} 样本")

            # 分割数据集
            train_df, val_df, test_df = split_dataset(
                final_df, 
                task_name_original, 
                test_size=split_ratios[2], 
                val_size=split_ratios[1], 
                random_state=random_state
            )

            print(f"  分割结果: 训练集 {len(train_df)}, 验证集 {len(val_df)}, 测试集 {len(test_df)}")

            # 记录分割统计
            split_summary.append({
                'task': task_name_original,
                'total': len(final_df),
                'train': len(train_df),
                'val': len(val_df),
                'test': len(test_df),
                'train_pct': len(train_df) / len(final_df) * 100 if len(final_df) > 0 else 0,
                'val_pct': len(val_df) / len(final_df) * 100 if len(final_df) > 0 else 0,
                'test_pct': len(test_df) / len(final_df) * 100 if len(final_df) > 0 else 0
            })

            # 保存单个任务的分割结果
            task_output_dir = os.path.join(output_dir, 'individual_tasks', task_name_original)
            os.makedirs(task_output_dir, exist_ok=True)
            
            if not train_df.empty:
                train_df.to_csv(os.path.join(task_output_dir, 'train.csv'), index=False)
            if not val_df.empty:
                val_df.to_csv(os.path.join(task_output_dir, 'val.csv'), index=False)
            if not test_df.empty:
                test_df.to_csv(os.path.join(task_output_dir, 'test.csv'), index=False)

            # 添加到总的数据集列表
            if not train_df.empty:
                all_train_dfs.append(train_df)
            if not val_df.empty:
                all_val_dfs.append(val_df)
            if not test_df.empty:
                all_test_dfs.append(test_df)

        except Exception as e:
            print(f"  -> 处理文件 {filename} 时发生严重错误: {e}")

    # 合并所有任务的数据集
    print(f"\n--- 合并所有任务的数据集 ---")
    
    def merge_datasets(df_list, set_name):
        """合并数据集列表"""
        if not df_list:
            print(f"  {set_name}: 没有数据可合并")
            return pd.DataFrame()
        
        print(f"  {set_name}: 合并 {len(df_list)} 个任务的数据")
        merged_df = reduce(
            lambda left, right: pd.merge(left, right, on=['smiles', 'Inchikey'], how='outer'), 
            df_list
        )
        print(f"    合并后: {len(merged_df)} 个唯一分子")
        return merged_df

    # 合并训练集
    merged_train = merge_datasets(all_train_dfs, "训练集")
    if not merged_train.empty:
        train_output_path = os.path.join(output_dir, 'train.csv')
        merged_train.drop(['Inchikey'], axis=1, inplace=True, errors='ignore')  # 移除InChIKey列
        merged_train.to_csv(train_output_path, index=False)
        print(f"    保存至: {train_output_path}")

    # 合并验证集
    merged_val = merge_datasets(all_val_dfs, "验证集")
    if not merged_val.empty:
        val_output_path = os.path.join(output_dir, 'val.csv')
        merged_val.drop(['Inchikey'], axis=1, inplace=True, errors='ignore')  # 移除InChIKey列
        merged_val.to_csv(val_output_path, index=False)
        print(f"    保存至: {val_output_path}")

    # 合并测试集
    merged_test = merge_datasets(all_test_dfs, "测试集")
    if not merged_test.empty:
        test_output_path = os.path.join(output_dir, 'test.csv')
        merged_test.drop(['Inchikey'], axis=1, inplace=True, errors='ignore')  # 移除InChIKey列
        merged_test.to_csv(test_output_path, index=False)
        print(f"    保存至: {test_output_path}")

    # 保存分割统计报告
    if split_summary:
        summary_df = pd.DataFrame(split_summary)
        summary_path = os.path.join(output_dir, 'split_summary.csv')
        summary_df.to_csv(summary_path, index=False)
        
        print(f"\n--- 分割统计报告 ---")
        print(summary_df.to_string(index=False))
        print(f"详细统计已保存至: {summary_path}")

    # 合并所有数据（不分割）用于对比
    print(f"\n--- 生成完整合并数据集 ---")
    if all_train_dfs or all_val_dfs or all_test_dfs:
        # 收集所有任务的完整数据
        all_complete_dfs = []
        
        for filename in os.listdir(input_dir):
            if not filename.endswith('.csv'): 
                continue
            
            task_name_original = filename.split('.')[0]
            task_name_normalized = task_name_original.strip().lower()

            if task_name_normalized in skip_list_lower: 
                continue
            
            # 重新读取并处理（这次不分割）
            file_path = os.path.join(input_dir, filename)
            try:
                df = pd.read_csv(file_path)
                if 'SMILES' in df.columns: 
                    df.rename(columns={'SMILES': 'smiles'}, inplace=True)
                if 'smiles' not in df.columns or 'Label' not in df.columns: 
                    continue

                # 诊断和修复
                diagnostics = df['smiles'].apply(diagnose_and_fix_smiles)
                df['smiles_fixed'] = diagnostics.apply(lambda x: x['smiles'])
                cleaned_df = df.dropna(subset=['smiles_fixed']).copy()
                
                if cleaned_df.empty:
                    continue

                # 处理最终数据
                cleaned_df['Inchikey'] = smiles_to_inchikey(cleaned_df['smiles_fixed'])
                cleaned_df.rename(columns={'Label': task_name_original}, inplace=True)
                final_df = cleaned_df[['smiles_fixed', 'Inchikey', task_name_original]].copy()
                final_df.rename(columns={'smiles_fixed': 'smiles'}, inplace=True)
                final_df = final_df.dropna(subset=['Inchikey'])
                
                if not final_df.empty:
                    all_complete_dfs.append(final_df)
                    
            except Exception as e:
                continue

        if all_complete_dfs:
            complete_merged = reduce(
                lambda left, right: pd.merge(left, right, on=['smiles', 'Inchikey'], how='outer'), 
                all_complete_dfs
            )
            complete_path = os.path.join(output_dir, 'complete_dataset.csv')
            complete_merged.to_csv(complete_path, index=False)
            print(f"完整数据集: {len(complete_merged)} 个分子，保存至: {complete_path}")

    # 打印无法修复的SMILES报告
    if unfixable_report:
        print("\n--- 无法修复的SMILES报告 ---")
        report_df = pd.DataFrame(unfixable_report)
        unfixable_path = os.path.join(output_dir, 'unfixable_smiles_report.csv')
        report_df.to_csv(unfixable_path, index=False)
        print(f"发现 {len(unfixable_report)} 个无法修复的SMILES")
        print(f"详细报告已保存至: {unfixable_path}")
        
        if len(unfixable_report) <= 10:
            print(report_df.to_string(index=False))

    print(f"\n=== 处理完成 ===")
    print(f"所有结果保存在目录: {output_dir}")
    print(f"  - train.csv: 合并的训练集")
    print(f"  - val.csv: 合并的验证集") 
    print(f"  - test.csv: 合并的测试集")
    print(f"  - complete_dataset.csv: 完整合并数据集")
    print(f"  - individual_tasks/: 各任务的单独分割结果")
    print(f"  - split_summary.csv: 分割统计报告")


# --- 执行 ---
if __name__ == '__main__':
    # --- 测试新的诊断函数 ---
    print("--- 测试诊断功能 ---")
    test_smiles = 'BrN([C-2]C)([C-2]C)([C-2]C)[C-2]C'
    result = diagnose_and_fix_smiles(test_smiles)
    print(f"原始SMILES: {test_smiles}")
    print(f"诊断结果: {result}")
    print("-" * 50)
    
    # --- 运行数据分割和合并流程 ---
    INPUT_REG_DIR = 'ademt_data/wash_reg'
    OUTPUT_DIR = 'ademt_data/reg_3'
    SKIP_FILES = []  # 可以指定要跳过的文件，例如 ['logD', 'logP']
    
    # 设置分割参数
    SPLIT_RATIOS = (0.8, 0.1, 0.1)  # 训练集, 验证集, 测试集
    RANDOM_STATE = 42  # 随机种子，确保结果可重现

    process_and_split_datasets(
        input_dir=INPUT_REG_DIR, 
        output_dir=OUTPUT_DIR, 
        files_to_skip=SKIP_FILES,
        split_ratios=SPLIT_RATIOS,
        random_state=RANDOM_STATE,
        report_unfixable=True
    )

--- 测试诊断功能 ---
原始SMILES: BrN([C-2]C)([C-2]C)([C-2]C)[C-2]C
诊断结果: {'status': 'unfixable', 'smiles': None, 'error': 'Initial parsing failed and no fixes were applicable.'}
--------------------------------------------------
开始处理目录 'ademt_data/wash_reg' ...
分割比例: 训练集 80.0%, 验证集 10.0%, 测试集 10.0%

--- 正在处理: logp.csv ---
  原始数据: 12682 样本
  清洗后数据: 12635 样本
  最终数据: 12635 样本
  分割结果: 训练集 10107, 验证集 1264, 测试集 1264

--- 正在处理: pka_acidic.csv ---
  原始数据: 2750 样本
  清洗后数据: 2750 样本
  最终数据: 2750 样本
  分割结果: 训练集 2200, 验证集 275, 测试集 275

--- 正在处理: fu.csv ---
  -> 处理文件 fu.csv 时发生严重错误: 'utf-8' codec can't decode byte 0x83 in position 29: invalid start byte

--- 正在处理: cl-plasma.csv ---
  原始数据: 831 样本
  清洗后数据: 831 样本
  最终数据: 831 样本
  分割结果: 训练集 664, 验证集 83, 测试集 84

--- 正在处理: logs.csv ---
  原始数据: 4797 样本
  清洗后数据: 4797 样本
  最终数据: 4797 样本
  分割结果: 训练集 3837, 验证集 480, 测试集 480

--- 正在处理: logvdss.csv ---
  原始数据: 2440 样本
  清洗后数据: 2433 样本
  最终数据: 2433 样本
  分割结果: 训练集 1945, 验证集 244, 测试集 244

--- 正在处理: pka_basic.csv ---
  原始数

In [6]:
# 先分析reg1数据集的稀疏性
import pandas as pd
df = pd.read_csv('/mnt/newdisk/fuli/ADMET/ademt_data/reg1.csv')
tasks = ['cl-plasma', 'logvdss', 't12', 'mdck', 'caco2', 'ppb', 'cl-int']

print('=== REG1 数据集分析 ===')
print(f'总样本数: {len(df):,}')
print()

print('各任务数据统计:')
task_counts = {}
for task in tasks:
    valid_count = df[task].notna().sum()
    missing_rate = (len(df) - valid_count) / len(df) * 100
    task_counts[task] = valid_count
    print(f'  {task:12}: {valid_count:5,} 样本 ({missing_rate:5.1f}% 缺失)')

print()
print('建议权重分配 (基于样本稀疏程度):')
total_samples = sum(task_counts.values())
for task in tasks:
    # 稀疏度越高，权重越大
    sparsity = 1.0 - (task_counts[task] / len(df))
    suggested_weight = 1.0 + sparsity * 4  # 基础权重1.0，最大额外权重4.0
    print(f'  {task:12}: 权重 {suggested_weight:.1f}')

=== REG1 数据集分析 ===
总样本数: 23,198

各任务数据统计:
  cl-plasma   :   852 样本 ( 96.3% 缺失)
  logvdss     : 2,448 样本 ( 89.4% 缺失)
  t12         : 3,441 样本 ( 85.2% 缺失)
  mdck        : 1,144 样本 ( 95.1% 缺失)
  caco2       : 6,512 样本 ( 71.9% 缺失)
  ppb         : 4,737 样本 ( 79.6% 缺失)
  cl-int      : 7,893 样本 ( 66.0% 缺失)

建议权重分配 (基于样本稀疏程度):
  cl-plasma   : 权重 4.9
  logvdss     : 权重 4.6
  t12         : 权重 4.4
  mdck        : 权重 4.8
  caco2       : 权重 3.9
  ppb         : 权重 4.2
  cl-int      : 权重 3.6
