# Tutorial 5: gene function prediction

In [None]:
import pandas as pd
import scanpy as sc
import numpy as np
import torch.multiprocessing as mp
import pickle
import os
from argparse import Namespace
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict, OrderedDict
from tqdm import tqdm
import gc
import csv

from omics.constants import *
import os
from omics.constants import *
import random
import h5py
from step2_unified_get_gene_emb import process_single_gpu_gene_embeddings
from step3_gene_mapping import convert_gene_names_to_ids
from step4_merge_emb import merge_embedding_csvs, merge_gene_lists

## Get gene list

### Peek gene list from benchmarks

In [None]:
import pickle
import pandas as pd
import numpy as np
import os

def extract_genes_from_benchmark_results(result_files):
    """从基准测试结果中提取所有使用过的基因"""
    all_genes = set()
    
    for file_path in result_files:
        print(f"处理文件: {file_path}")
        
        try:
            with open(file_path, 'rb') as f:
                results = pickle.load(f)
            
            # 遍历每个疾病/生物过程的结果
            for disease_term, result_df in results.items():
                print(f"  - {disease_term}: {len(result_df)} 个嵌入方法")
                
                # 这里的result_df应该包含不同嵌入方法的结果
                # 但基因信息可能在更深的结构中
                
        except Exception as e:
            print(f"读取 {file_path} 时出错: {e}")
    
    return all_genes

# 基准测试结果文件列表
result_files = [
    "../gene-embedding-benchmarks/results/gene_level/go_all_holdout_results.pkl",
    "../gene-embedding-benchmarks/results/gene_level/go_all_holdout_results_after_2020.pkl", 
    "../gene-embedding-benchmarks/results/gene_level/omim_holdout_results.pkl"
]

# 先看看文件结构
def inspect_pickle_structure(file_path):
    """检查pickle文件的结构"""
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    
    print(f"\n=== {file_path} 结构分析 ===")
    print(f"数据类型: {type(data)}")
    
    if isinstance(data, dict):
        print(f"字典键的数量: {len(data)}")
        print("前几个键:", list(data.keys())[:3])
        
        # 检查第一个值的结构
        first_key = list(data.keys())[0]
        first_value = data[first_key]
        print(f"\n第一个值 ({first_key}) 的类型: {type(first_value)}")
        
        if isinstance(first_value, pd.DataFrame):
            print(f"DataFrame形状: {first_value.shape}")
            print(f"列名: {first_value.columns.tolist()}")
            print(f"索引前几个: {first_value.index[:3].tolist()}")
            print(f"示例数据:\n{first_value.head()}")
            
    return data

# 检查每个文件的结构
for file_path in result_files:
    if os.path.exists(file_path):
        try:
            data = inspect_pickle_structure(file_path)
        except Exception as e:
            print(f"无法读取 {file_path}: {e}")
    else:
        print(f"文件不存在: {file_path}")

### Get gene list from benchmarks

In [None]:
import pickle
import pandas as pd
import os
import warnings

def read_pickle_with_compatibility(file_path):
    """兼容性pickle读取函数"""
    
    # 方法1: 标准pickle读取
    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        return data
    except Exception as e1:
        print(f"    标准读取失败: {str(e1)[:100]}...")
        
        # 方法2: 忽略pandas版本警告
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                with open(file_path, 'rb') as f:
                    data = pickle.load(f)
                return data
        except Exception as e2:
            print(f"    忽略警告读取失败: {str(e2)[:100]}...")
            
            # 方法3: 使用pandas直接读取
            try:
                data = pd.read_pickle(file_path)
                return data
            except Exception as e3:
                print(f"    pandas读取失败: {str(e3)[:100]}...")
                
                # 方法4: 降级处理 - 尝试重构DataFrame
                try:
                    with open(file_path, 'rb') as f:
                        # 使用pickle底层协议
                        import pickle
                        data = pickle.load(f)
                        
                        # 如果是字典，尝试修复其中的DataFrame
                        if isinstance(data, dict):
                            fixed_data = {}
                            for key, value in data.items():
                                if hasattr(value, 'values') and hasattr(value, 'columns'):
                                    # 尝试重建DataFrame
                                    try:
                                        if hasattr(value, 'index'):
                                            new_df = pd.DataFrame(
                                                data=value.values,
                                                columns=value.columns,
                                                index=value.index
                                            )
                                        else:
                                            new_df = pd.DataFrame(
                                                data=value.values,
                                                columns=value.columns
                                            )
                                        fixed_data[key] = new_df
                                    except:
                                        # 如果重建失败，跳过这个条目
                                        continue
                                else:
                                    fixed_data[key] = value
                            return fixed_data
                        else:
                            return data
                            
                except Exception as e4:
                    print(f"    重构读取失败: {str(e4)[:100]}...")
                    return None

def extract_genes_from_dataframe(df):
    """从DataFrame中安全提取基因"""
    try:
        if isinstance(df, pd.DataFrame) and 'gene' in df.columns:
            # 确保基因列是字符串类型
            genes = df['gene'].astype(str).tolist()
            # 过滤掉空值和NaN
            genes = [g for g in genes if g and g != 'nan' and g != 'None']
            return set(genes)
        else:
            return set()
    except Exception as e:
        print(f"      提取基因时出错: {e}")
        return set()

def extract_genes_from_fold_data():
    """从fold数据中提取基因列表"""
    
    # 所有可能的文件位置
    possible_files = [
        "../gene-embedding-benchmarks/bin/gene_level/GO/go_cv_fold1_dict.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/GO/go_cv_fold2_dict.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/GO/go_cv_fold3_dict.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/GO/go_holdout_dict.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/GO_after_2020/go_cv_fold1_dict_after_2020.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/GO_after_2020/go_cv_fold2_dict_after_2020.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/GO_after_2020/go_cv_fold3_dict_after_2020.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/GO_after_2020/go_holdout_dict_after_2020.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/OMIM/omim_cv_fold1_dict.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/OMIM/omim_cv_fold2_dict.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/OMIM/omim_cv_fold3_dict.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/OMIM/omim_holdout_dict.pkl"
    ]
    
    all_genes = set()
    successful_files = []
    failed_files = []
    
    for file_path in possible_files:
        if os.path.exists(file_path):
            print(f"\n=== {file_path} ===")
            
            # 使用兼容性读取
            data = read_pickle_with_compatibility(file_path)
            
            if data is not None:
                print(f"数据类型: {type(data)}")
                
                if isinstance(data, dict):
                    print(f"疾病/过程数量: {len(data)}")
                    
                    # 检查第一个条目的结构
                    if len(data) > 0:
                        first_key = list(data.keys())[0]
                        first_value = data[first_key]
                        print(f"第一个条目 ({first_key}) 类型: {type(first_value)}")
                        
                        if isinstance(first_value, pd.DataFrame):
                            print(f"DataFrame列: {first_value.columns.tolist()}")
                            print(f"DataFrame形状: {first_value.shape}")
                            if 'gene' in first_value.columns:
                                print(f"包含基因列！前几个基因: {first_value['gene'].head().tolist()}")
                    
                    # 提取所有基因
                    file_genes = set()
                    successful_terms = 0
                    
                    for term, df in data.items():
                        genes_in_term = extract_genes_from_dataframe(df)
                        if genes_in_term:
                            file_genes.update(genes_in_term)
                            successful_terms += 1
                    
                    if file_genes:
                        print(f"从该文件提取到 {len(file_genes)} 个基因 (来自 {successful_terms} 个条目)")
                        all_genes.update(file_genes)
                        successful_files.append(file_path)
                    else:
                        print("该文件中未找到基因信息")
                        failed_files.append((file_path, "无基因信息"))
                else:
                    print(f"数据不是字典格式: {type(data)}")
                    failed_files.append((file_path, "数据格式错误"))
            else:
                print("所有读取方法都失败")
                failed_files.append((file_path, "读取失败"))
                
        else:
            print(f"文件不存在: {file_path}")
            failed_files.append((file_path, "文件不存在"))
    
    return all_genes, successful_files, failed_files

# 提取基因
print("正在从fold数据中提取基因...")
benchmark_genes, successful_files, failed_files = extract_genes_from_fold_data()

if benchmark_genes:
    print(f"\n=== 提取结果总结 ===")
    print(f"总共找到 {len(benchmark_genes)} 个独特基因")
    print(f"成功处理的文件: {len(successful_files)}")
    print(f"失败的文件: {len(failed_files)}")
    
    # 显示一些示例基因
    sample_genes = sorted(list(benchmark_genes))[:10]
    print(f"示例基因: {sample_genes}")
    
    # 保存基因列表
    with open('benchmark_gene_list_complete.txt', 'w') as f:
        for gene in sorted(benchmark_genes):
            f.write(f"{gene}\n")
    
    print("\n基因列表已保存到 benchmark_gene_list_complete.txt")
    
    # 按文件类型统计
    print(f"\n=== 成功提取来源 ===")
    go_files = 0
    omim_files = 0
    for file_path in successful_files:
        filename = os.path.basename(file_path)
        print(f"  ✓ {filename}")
        if 'omim' in filename.lower():
            omim_files += 1
        else:
            go_files += 1
    
    print(f"\nGO文件: {go_files}, OMIM文件: {omim_files}")
    
    # 显示失败的文件
    if failed_files:
        print(f"\n=== 失败的文件 ===")
        for file_path, reason in failed_files:
            filename = os.path.basename(file_path)
            print(f"  ✗ {filename}: {reason}")
        
else:
    print("未能提取到基因信息，可能需要检查其他文件位置")
    
    # 显示所有失败原因
    print("\n失败详情:")
    for file_path, reason in failed_files:
        print(f"  {os.path.basename(file_path)}: {reason}")

### Convert gene ID into gene name

In [None]:
import mygene
import pandas as pd

def convert_entrez_to_symbols(entrez_file):
    """将Entrez ID批量转换为基因符号"""
    
    # 读取基因列表
    with open(entrez_file, 'r') as f:
        entrez_ids = [line.strip() for line in f if line.strip()]
    
    print(f"准备转换 {len(entrez_ids)} 个Entrez ID")
    
    # 初始化mygene
    mg = mygene.MyGeneInfo()
    
    # 批量转换（避免API限制）
    batch_size = 1000
    gene_mapping = {}
    failed_ids = []
    
    for i in range(0, len(entrez_ids), batch_size):
        batch = entrez_ids[i:i+batch_size]
        batch_num = i//batch_size + 1
        total_batches = (len(entrez_ids)-1)//batch_size + 1
        
        print(f"处理批次 {batch_num}/{total_batches} ({len(batch)} 个基因)")
        
        try:
            results = mg.querymany(
                batch, 
                scopes='entrezgene', 
                fields='symbol,entrezgene,name', 
                species='human',
                returnall=True
            )
            
            # 处理成功的结果
            for result in results['out']:
                if 'symbol' in result and 'entrezgene' in result:
                    entrez_id = str(result['entrezgene'])
                    symbol = result['symbol']
                    gene_name = result.get('name', '')
                    gene_mapping[entrez_id] = {
                        'symbol': symbol,
                        'name': gene_name
                    }
                elif 'query' in result:
                    # 记录失败的ID
                    failed_ids.append(result['query'])
                    
        except Exception as e:
            print(f"批次 {batch_num} 转换失败: {e}")
            failed_ids.extend(batch)
    
    print(f"\n转换完成！")
    print(f"成功转换: {len(gene_mapping)} 个基因")
    print(f"转换失败: {len(failed_ids)} 个基因")
    print(f"转换成功率: {len(gene_mapping)/len(entrez_ids)*100:.1f}%")
    
    return gene_mapping, failed_ids

# 执行转换
print("开始基因ID转换...")
gene_mapping, failed_ids = convert_entrez_to_symbols('benchmark_gene_list_complete.txt')

if gene_mapping:
    # 保存详细映射表
    mapping_df = pd.DataFrame([
        {
            'entrez_id': entrez_id,
            'gene_symbol': info['symbol'], 
            'gene_name': info['name']
        }
        for entrez_id, info in gene_mapping.items()
    ])
    
    mapping_df.to_csv('benchmark_gene_mapping.csv', index=False)
    print(f"详细映射表已保存到 benchmark_gene_mapping.csv")
    
    # 保存基因符号列表
    symbols = sorted([info['symbol'] for info in gene_mapping.values()])
    with open('benchmark_gene_symbols.txt', 'w') as f:
        for symbol in symbols:
            f.write(f"{symbol}\n")
    
    print(f"基因符号列表已保存到 benchmark_gene_symbols.txt")
    
    # 显示一些示例
    print(f"\n示例转换结果:")
    for i, (entrez_id, info) in enumerate(list(gene_mapping.items())[:10]):
        print(f"  {entrez_id} -> {info['symbol']} ({info['name'][:50]}...)")
    
    # 保存失败的ID（供后续处理）
    if failed_ids:
        with open('failed_entrez_ids.txt', 'w') as f:
            for failed_id in failed_ids:
                f.write(f"{failed_id}\n")
        print(f"\n{len(failed_ids)} 个失败的ID已保存到 failed_entrez_ids.txt")

# 生成总结报告
print(f"\n=== 基准测试基因库总结 ===")
print(f"原始Entrez ID数量: 6544")
print(f"成功转换的基因符号: {len(gene_mapping)}")
print(f"最终可用基因数量: {len(gene_mapping)}")
print(f"基因来源: GO (8个文件) + OMIM (4个文件)")

### map into our dataset's gene list

In [None]:
merfish_adata = sc.read_h5ad('/media/dang/Omics/data/spot_level/bento/merfish_processed.h5ad')
merfish_adata

In [None]:
cosmx_data = pd.read_pickle('/media/dang/Omics/data/spot_level/CosMx/spot_dataframe.pkl')
CosMx_gene = np.unique(cosmx_data['gene'])

In [None]:
def extract_and_compare_genes():
    """从实际数据中提取基因并与基准比较"""
    
    # 读取基准测试基因
    with open('benchmark_gene_symbols.txt', 'r') as f:
        benchmark_genes = set(line.strip() for line in f if line.strip())
    
    print(f"基准测试基因总数: {len(benchmark_genes)}")
    
    # === 方法1: 如果你有merfish_adata对象 ===
    try:
        # 假设你的merfish_adata已经加载
        merfish_genes = list(merfish_adata.var_names)
        # 过滤掉non-target基因
        merfish_genes = [g for g in merfish_genes if not g.startswith('notarget')]
        print(f"merFISH基因数 (过滤后): {len(merfish_genes)}")
    except:
        print("请先加载merfish_adata")
        merfish_genes = []
    
    # === 方法2: 如果你有CosMx数据文件 ===
    try:
        # 从pickle文件中读取
        cosmx_data = pd.read_pickle('/media/dang/Omics/data/spot_level/CosMx/spot_dataframe.pkl')
        cosmx_genes = list(np.unique(cosmx_data['gene']))
        print(f"CosMx基因数: {len(cosmx_genes)}")
    except:
        print("无法读取CosMx数据文件")
        cosmx_genes = []
    
    # 进行匹配分析
    if merfish_genes or cosmx_genes:
        # 转为大写进行匹配
        benchmark_upper = set(g.upper() for g in benchmark_genes)
        
        results = {}
        
        if merfish_genes:
            merfish_upper = set(g.upper() for g in merfish_genes)
            merfish_overlap = merfish_upper.intersection(benchmark_upper)
            
            print(f"\n=== merFISH 匹配结果 ===")
            print(f"merFISH基因: {len(merfish_genes)}")
            print(f"与基准重叠: {len(merfish_overlap)}")
            print(f"重叠率: {len(merfish_overlap)/len(benchmark_upper)*100:.1f}%")
            
            # 显示一些重叠的基因
            sample_overlap = sorted(list(merfish_overlap))[:10]
            print(f"重叠基因示例: {sample_overlap}")
            
            results['merfish'] = {
                'total': len(merfish_genes),
                'overlap': len(merfish_overlap),
                'overlap_genes': sorted(merfish_overlap)
            }
        
        if cosmx_genes:
            cosmx_upper = set(g.upper() for g in cosmx_genes)
            cosmx_overlap = cosmx_upper.intersection(benchmark_upper)
            
            print(f"\n=== CosMx 匹配结果 ===")
            print(f"CosMx基因: {len(cosmx_genes)}")
            print(f"与基准重叠: {len(cosmx_overlap)}")
            print(f"重叠率: {len(cosmx_overlap)/len(benchmark_upper)*100:.1f}%")
            
            # 显示一些重叠的基因
            sample_overlap = sorted(list(cosmx_overlap))[:10]
            print(f"重叠基因示例: {sample_overlap}")
            
            results['cosmx'] = {
                'total': len(cosmx_genes),
                'overlap': len(cosmx_overlap),
                'overlap_genes': sorted(cosmx_overlap)
            }
        
        # 如果两个数据集都有，比较它们
        if merfish_genes and cosmx_genes:
            merfish_upper = set(g.upper() for g in merfish_genes)
            cosmx_upper = set(g.upper() for g in cosmx_genes)
            
            both_overlap = merfish_upper.intersection(cosmx_upper)
            three_way_overlap = merfish_overlap.intersection(cosmx_overlap)
            
            print(f"\n=== 数据集间比较 ===")
            print(f"merFISH与CosMx重叠: {len(both_overlap)}")
            print(f"三方重叠(都与基准匹配): {len(three_way_overlap)}")
            
            results['comparison'] = {
                'merfish_cosmx_overlap': len(both_overlap),
                'three_way_overlap': len(three_way_overlap),
                'three_way_genes': sorted(three_way_overlap)
            }
        
        return results
    
    else:
        print("没有可用的基因数据进行比较")
        return None

# 运行分析
results = extract_and_compare_genes()

### Save matched genes

In [None]:
def save_matched_genes_for_embedding():
    """保存用于生成embedding的基因列表"""
    
    # 读取基准测试基因
    with open('benchmark_gene_symbols.txt', 'r') as f:
        benchmark_genes = set(line.strip() for line in f if line.strip())
    
    # 从你的数据中提取基因（假设你已经运行了之前的代码）
    try:
        # merFISH基因
        merfish_genes = list(merfish_adata.var_names)
        merfish_genes_filtered = [g for g in merfish_genes if not g.startswith('notarget')]
        
        # CosMx基因  
        cosmx_data = pd.read_pickle('/media/dang/Omics/data/spot_level/CosMx/spot_dataframe.pkl')
        cosmx_genes = list(np.unique(cosmx_data['gene']))
        
        # 计算重叠
        benchmark_upper = set(g.upper() for g in benchmark_genes)
        merfish_upper = set(g.upper() for g in merfish_genes_filtered)
        cosmx_upper = set(g.upper() for g in cosmx_genes)
        
        merfish_overlap = merfish_upper.intersection(benchmark_upper)
        cosmx_overlap = cosmx_upper.intersection(benchmark_upper)
        
        # 保存各自的重叠基因
        with open('merfish_benchmark_genes.txt', 'w') as f:
            for gene in sorted(merfish_overlap):
                f.write(f"{gene}\n")
        
        with open('cosmx_benchmark_genes.txt', 'w') as f:
            for gene in sorted(cosmx_overlap):
                f.write(f"{gene}\n")
        
        # 保存合并的基因列表（用于完整基准测试）
        combined_genes = merfish_overlap.union(cosmx_overlap)
        with open('combined_benchmark_genes.txt', 'w') as f:
            for gene in sorted(combined_genes):
                f.write(f"{gene}\n")
        
        # 找出三方重叠的基因（可以进行直接比较）
        three_way_overlap = merfish_overlap.intersection(cosmx_overlap)
        with open('three_way_overlap_genes.txt', 'w') as f:
            for gene in sorted(three_way_overlap):
                f.write(f"{gene}\n")
        
        print("=== 基因列表已保存 ===")
        print(f"✅ merfish_benchmark_genes.txt: {len(merfish_overlap)} 个基因")
        print(f"✅ cosmx_benchmark_genes.txt: {len(cosmx_overlap)} 个基因") 
        print(f"✅ combined_benchmark_genes.txt: {len(combined_genes)} 个基因")
        print(f"✅ three_way_overlap_genes.txt: {len(three_way_overlap)} 个基因")
        
        return {
            'merfish_genes': sorted(merfish_overlap),
            'cosmx_genes': sorted(cosmx_overlap),
            'combined_genes': sorted(combined_genes),
            'three_way_genes': sorted(three_way_overlap)
        }
        
    except Exception as e:
        print(f"处理基因数据时出错: {e}")
        return None

# 保存基因列表
gene_lists = save_matched_genes_for_embedding()

### filter new gene list

In [None]:
def prepare_test_subsets_with_compatibility():
    """处理pandas版本兼容性问题的测试子集创建"""
    
    import pickle
    import os
    import pandas as pd
    
    # 读取你的基因列表
    with open('combined_benchmark_genes.txt', 'r') as f:
        target_genes = set(line.strip().upper() for line in f)
    
    print(f"目标基因数量: {len(target_genes)}")
    
    # 读取基因映射
    gene_mapping = pd.read_csv('benchmark_gene_mapping.csv')
    entrez_to_symbol = dict(zip(gene_mapping['entrez_id'].astype(str), 
                               gene_mapping['gene_symbol'].str.upper()))
    
    # 原始基准测试文件
    benchmark_files = [
        "../gene-embedding-benchmarks/bin/gene_level/GO/go_cv_fold1_dict.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/GO/go_cv_fold2_dict.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/GO/go_cv_fold3_dict.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/GO/go_holdout_dict.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/OMIM/omim_cv_fold1_dict.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/OMIM/omim_cv_fold2_dict.pkl",
        "../gene-embedding-benchmarks/bin/gene_level/OMIM/omim_cv_fold3_dict.pkl", 
        "../gene-embedding-benchmarks/bin/gene_level/OMIM/omim_holdout_dict.pkl"
    ]
    
    os.makedirs('test_subsets', exist_ok=True)
    
    subset_stats = []
    
    for file_path in benchmark_files:
        if os.path.exists(file_path):
            filename = os.path.basename(file_path)
            print(f"\n处理 {filename}...")
            
            try:
                # 方法1: 直接读取
                with open(file_path, 'rb') as f:
                    original_data = pickle.load(f)
                
            except Exception as e1:
                print(f"直接读取失败: {e1}")
                
                try:
                    # 方法2: 使用protocol=2
                    with open(file_path, 'rb') as f:
                        original_data = pickle.load(f)
                        
                except Exception as e2:
                    print(f"Protocol=2读取失败: {e2}")
                    
                    try:
                        # 方法3: 使用pd.read_pickle
                        original_data = pd.read_pickle(file_path)
                        
                    except Exception as e3:
                        print(f"pd.read_pickle读取失败: {e3}")
                        
                        # 方法4: 跳过这个文件，记录问题
                        print(f"❌ 跳过 {filename} - 无法读取")
                        continue
            
            # 如果成功读取，进行处理
            try:
                subset_data = {}
                total_samples = 0
                kept_samples = 0
                
                for term, data in original_data.items():
                    # 处理不同的数据格式
                    if isinstance(data, pd.DataFrame):
                        df = data
                    elif hasattr(data, 'to_frame'):
                        df = data.to_frame()
                    elif isinstance(data, dict):
                        df = pd.DataFrame(data)
                    else:
                        # 尝试转换为DataFrame
                        try:
                            df = pd.DataFrame(data)
                        except:
                            print(f"无法处理term {term}的数据格式")
                            continue
                    
                    if 'gene' in df.columns:
                        total_samples += len(df)
                        
                        # 创建新的DataFrame避免版本问题
                        new_df = pd.DataFrame({
                            'gene': df['gene'].values,
                            'result': df['result'].values if 'result' in df.columns else [1] * len(df)
                        })
                        
                        # 转换基因ID
                        new_df['gene_symbol'] = new_df['gene'].astype(str).map(entrez_to_symbol)
                        
                        # 筛选目标基因
                        mask = new_df['gene_symbol'].isin(target_genes)
                        subset_df = new_df[mask][['gene', 'result']].copy()
                        
                        if len(subset_df) > 0:
                            subset_data[term] = subset_df
                            kept_samples += len(subset_df)
                
                # 保存子集文件
                if subset_data:
                    output_filename = filename.replace('.pkl', '_subset.pkl')
                    output_path = os.path.join('test_subsets', output_filename)
                    
                    with open(output_path, 'wb') as f:
                        pickle.dump(subset_data, f, protocol=2)  # 使用兼容性更好的protocol
                    
                    subset_stats.append({
                        'file': output_filename,
                        'terms': len(subset_data),
                        'original_samples': total_samples,
                        'kept_samples': kept_samples,
                        'retention_rate': f"{kept_samples/total_samples*100:.1f}%" if total_samples > 0 else "0%"
                    })
                    
                    print(f"✅ {output_filename}: {len(subset_data)} terms, {kept_samples}/{total_samples} samples")
                else:
                    print(f"⚠️ {filename}: 没有匹配的数据")
                    
            except Exception as e:
                print(f"❌ 处理 {filename} 的数据时出错: {e}")
    
    # 打印统计摘要
    print(f"\n=== 子集创建摘要 ===")
    print(f"成功创建 {len(subset_stats)} 个测试子集文件")
    
    for stat in subset_stats:
        print(f"  {stat['file']}: {stat['terms']} terms, {stat['retention_rate']} retention")
    
    return subset_stats

# 运行修复版本
subset_stats = prepare_test_subsets_with_compatibility()

## Generate gene embeddings for each dataset

In [None]:
# 直接按照命令行脚本写死推理配置，避免 Notebook 再次解析 CLI
args = Namespace(
    ckpt_path="/path/to/your/checkpoint.ckpt",  # TODO: 替换成实际的checkpoint路径
    dataset_name="seqfish",  # 根据实际情况修改
    target_dataset="seqfish",
    gene_pct=100,
    linear_hidden_dim='256',
    radius=20,
    max_points=20,
    config='/media/dang/Omics/omics/configs/bert_config_5-12.json',
    seed=42,
    f=None,
    split_slice=None,
    target_width=6724,
    target_height=5885,
    x_min=None,
    x_max=None,
    y_min=None,
    y_max=None,
    sample_ratio=1.0,
    merge_threshold=0.5,
    percentile=70,
    gene_list_path="/media/dang/Omics/omics/gene_level/function/preprocess/cosmx_benchmark_genes.txt"
)

print('ckpt_path:', args.ckpt_path)
print(f"Dataset: {args.dataset_name}, radius: {args.radius}, max_points: {args.max_points}")

# 构建cache路径
if args.dataset_name == 'cosmx_lung5_rep1':
    cache_path = f'/media/dang/Omics/omics/baseline/cached_spot_data/spot_input_spot_sampled_{args.dataset_name}_42_{args.gene_pct}_{args.radius}_{args.max_points}_0.1.h5'
else:
    cache_path = f'/media/dang/Omics/omics/baseline/cached_spot_data/spot_input_spot_all_{args.dataset_name}_42_{args.gene_pct}_{args.radius}_{args.max_points}.h5'

print("Processing gene-level embeddings from gene list -> CSV format")
emb_file, genelist_file = process_single_gpu_gene_embeddings(
    args, args.gene_list_path, args.target_dataset, cache_path
)

# 设置数据集特定参数
args.linear_hidden_dim = '256'
args.radius = 20
args.max_points = 20
args.config = '/media/dang/Omics/omics/configs/bert_config_5-12.json'
print(f"Dataset: {args.dataset_name}, radius: {args.radius}, max_points: {args.max_points}")

# 构建cache路径
if args.dataset_name == 'cosmx_lung5_rep1':
    cache_path = f'/media/dang/Omics/omics/baseline/cached_spot_data/spot_input_spot_sampled_{args.dataset_name}_42_{args.gene_pct}_{args.radius}_{args.max_points}_0.1.h5'
else:
    cache_path = f'/media/dang/Omics/omics/baseline/cached_spot_data/spot_input_spot_all_{args.dataset_name}_42_{args.gene_pct}_{args.radius}_{args.max_points}.h5'

print("Processing gene-level embeddings from gene list -> CSV format")
emb_file, genelist_file = process_single_gpu_gene_embeddings(
    args, args.gene_list_path, args.target_dataset, cache_path
)

## Convert gene name list to gene ID list using benchmark gene mapping

In [None]:
# 直接按照配置写死参数，避免 Notebook 再次解析 CLI
args = Namespace(
    genelist_file='/media/dang/Omics/omics/gene_level/function/preprocess/gene_embeddings_csv/merfish_gene100pct_Spotformer_genelist.txt',
    mapping_file="/media/dang/Omics/omics/gene_level/function/preprocess/benchmark_gene_mapping.csv",
    output_file='/media/dang/Omics/omics/gene_level/function/preprocess/gene_embeddings_csv/Spotformer_genelist2.txt'
)

# 如果没有指定输出文件，自动生成
if args.output_file is None:
    # 从genelist文件名生成对应的gene ID文件名
    base_name = os.path.splitext(args.genelist_file)[0]
    if base_name.endswith('_genelist'):
        base_name = base_name[:-9]  # 移除'_genelist'
    args.output_file = f"{base_name}_geneid.txt"

print(f"Converting gene names to IDs...")
print(f"  Input: {args.genelist_file}")
print(f"  Mapping: {args.mapping_file}")
print(f"  Output: {args.output_file}")

gene_ids, matched_count, unmatched_genes = convert_gene_names_to_ids(
    args.genelist_file, args.mapping_file, args.output_file
)

print(f"\nDone! Generated {len(gene_ids)} gene IDs with {matched_count} matches.")


## Merge gene embeddings from different datasets

In [None]:
# 直接按照配置写死参数，避免 Notebook 再次解析 CLI
args = Namespace(
    emb_file1="/media/dang/Omics/omics/gene_level/function/preprocess/gene_embeddings_csv/cosmx_lung5_rep1_gene20pct_Spotformer_emb.csv",
    emb_file2="/media/dang/Omics/omics/gene_level/function/preprocess/gene_embeddings_csv/merfish_gene100pct_Spotformer_emb.csv",
    gene_file1="/media/dang/Omics/omics/gene_level/function/preprocess/gene_embeddings_csv/Spotformer_genelist.txt",
    gene_file2="/media/dang/Omics/omics/gene_level/function/preprocess/gene_embeddings_csv/Spotformer_genelist2.txt",
    output_dir="/media/dang/Omics/omics/gene_level/function/preprocess/gene_embeddings_csv/",
    output_prefix="merged"
)

# 创建输出文件路径
merged_emb_file = os.path.join(args.output_dir, f"{args.output_prefix}_Spotformer_emb.csv")
merged_gene_file = os.path.join(args.output_dir, f"{args.output_prefix}_Spotformer_genelist.txt")

print("=" * 60)
print("MERGING EMBEDDING FILES")
print("=" * 60)

# 合并embedding文件
merged_embeddings = merge_embedding_csvs(
    args.emb_file1, args.emb_file2, merged_emb_file
)

print("\n" + "=" * 60)
print("MERGING GENE LIST FILES") 
print("=" * 60)

# 合并gene list文件
merged_genes = merge_gene_lists(
    args.gene_file1, args.gene_file2, merged_gene_file
)

print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)

# 验证一致性
if len(merged_genes) == len(merged_embeddings):
    print(f"✓ Consistency check passed: {len(merged_genes)} genes match {len(merged_embeddings)} embeddings")
else:
    print(f"✗ Consistency check failed: {len(merged_genes)} genes vs {len(merged_embeddings)} embeddings")

print(f"\nOutput files:")
print(f"  Merged embeddings: {merged_emb_file}")
print(f"  Merged gene list: {merged_gene_file}")
