# Data Preprocessing

In [12]:
# -*- coding: utf-8 -*-

# ==============================================================================
# SOTA数据预处理工作流 for SpaGLaM (已修复多进程问题)
#
# 核心策略:
# 1. 加载预先合并好的AnnData缓存文件。
# 2. **拆分-计算-合并**: 遍历每个样本(sample_id)，独立计算其内部的空间
#    邻接图，然后将所有图合并成一个全局的稀疏矩阵。
# 3. 将所有独立的图像(.png)和基因语句(.txt)文件打包成高效的.tar分片。
# ==============================================================================

import os
import sys
import time
import scanpy as sc
import pandas as pd
import numpy as np
import webdataset as wds
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor
from scipy.sparse import lil_matrix
from itertools import repeat # 导入 repeat 函数

# --- 1. 全局设置 ---

# --- 路径配置 (请根据您的环境修改) ---
BASE_DIR = "/cwStorage/nodecw_group/jijh" 
EXISTING_CACHE_PATH = os.path.join(BASE_DIR, "hest_sentences_human_all/cache/adata_preprocessed_canonical_v2.h5ad")
IMAGE_BASE_DIR = os.path.join(BASE_DIR, "hest_output")
SENTENCE_BASE_DIR = os.path.join(BASE_DIR, "hest_sentences_human_all")

OUTPUT_DIR = os.path.join(BASE_DIR, "spaglam_sota_data")
FINAL_ADATA_PATH = os.path.join(OUTPUT_DIR, "master_adata_with_graph.h5ad")
SHARDS_OUTPUT_PATH = os.path.join(OUTPUT_DIR, "webdataset_shards")

# --- 参数配置 ---
N_NEIGHBORS = 6
SAMPLES_PER_SHARD = 10000
NUM_WORKERS_PACKAGING = 16 

# --- 辅助函数 ---
def format_time(seconds: float) -> str:
    """将秒数格式化为易读的“分-秒”字符串"""
    mins, secs = divmod(seconds, 60)
    return f"{int(mins)}分 {secs:.2f}秒"

# --- 核心功能函数 (已修改) ---

def add_spatial_graph_to_adata_no_squidpy(adata: sc.AnnData, n_neighbors: int) -> sc.AnnData:
    # ... 此函数保持不变 ...
    print("\n--- 步骤 1: 计算并添加空间邻接图 (无Squidpy替代方案) ---")
    start_time = time.time()
    if 'spatial_connectivities' in adata.obsp:
        print("✅ 空间邻接图 'spatial_connectivities' 已存在，跳过计算。")
        return adata
    unique_samples = adata.obs['sample_id'].unique()
    print(f"检测到 {len(unique_samples)} 个独立样本。将逐个计算空间图...")
    global_conn_matrix = lil_matrix((adata.n_obs, adata.n_obs), dtype=np.float32)
    for sample_id in tqdm(unique_samples, desc="处理每个样本"):
        adata_sample = adata[adata.obs['sample_id'] == sample_id]
        sc.pp.neighbors(adata_sample, n_neighbors=n_neighbors, use_rep='spatial', key_added='spatial')
        local_conn = adata_sample.obsp['spatial_connectivities']
        global_indices = np.where(adata.obs['sample_id'] == sample_id)[0]
        rows, cols = local_conn.nonzero()
        global_rows = global_indices[rows]
        global_cols = global_indices[cols]
        global_conn_matrix[global_rows, global_cols] = local_conn.data.reshape(-1, 1)
    adata.obsp['spatial_connectivities'] = global_conn_matrix.tocsr()
    print("\n✅ 所有样本的空间邻接图已计算并合并！")
    end_time = time.time()
    print(f"🕒 耗时: {format_time(end_time - start_time)}")
    return adata

# ==============================================================================
# 【【【关键修改点】】】
# 将 process_chunk 函数定义在全局作用域中，使其成为一个顶层函数
# ==============================================================================
def _process_chunk_worker(args):
    """
    这是一个独立的工作函数，设计用于被多进程池调用。
    它接收一个元组作为参数，包含了所有需要的信息。
    """
    # 从元组中解包参数
    spot_ids_chunk, chunk_index, all_spot_info, output_pattern, samples_per_shard = args
    
    shard_path = output_pattern % chunk_index
    
    with wds.TarWriter(shard_path) as sink:
        for spot_id in spot_ids_chunk:
            # 使用 all_spot_info 来获取路径，避免访问全局变量
            paths = all_spot_info.get(spot_id)
            if not paths:
                continue

            img_path = paths['image_path']
            sentence_path = paths['sentence_path']

            if not (os.path.exists(img_path) and os.path.exists(sentence_path)):
                continue

            with open(img_path, "rb") as f_img:
                image_data = f_img.read()
            with open(sentence_path, "rb") as f_txt:
                sentence_data = f_txt.read()
            
            sample = {"__key__": spot_id, "png": image_data, "txt": sentence_data}
            sink.write(sample)
    return chunk_index # 返回一个结果，方便tqdm跟踪进度

def create_webdataset_shards(adata: sc.AnnData, output_pattern: str):
    """
    将所有独立的 .png 和 .txt 文件打包成 WebDataset 的 .tar 分片。
    【已修改】: 现在调用顶层的 _process_chunk_worker 函数。
    """
    print("\n--- 步骤 2: 将原始数据打包成 WebDataset 分片 (.tar) ---")
    start_time = time.time()

    if 'image_path' not in adata.obs.columns or 'sentence_path' not in adata.obs.columns:
        print("警告: AnnData中缺少 'image_path' 或 'sentence_path'。正在尝试重新构建...")
        # (这部分逻辑与之前相同，用于构建路径)
        paths_df = pd.DataFrame(index=adata.obs_names)
        paths_df[['image_path', 'sentence_path']] = [
            (
                os.path.join(IMAGE_BASE_DIR, f"{sid.split('_')[0]}_tiles", f"{sid}.png"),
                os.path.join(SENTENCE_BASE_DIR, f"{sid.split('_')[0]}_sentences_hvg", f"{sid}.txt")
            ) for sid in adata.obs_names
        ]
        adata.obs = adata.obs.join(paths_df)

    all_spot_info = adata.obs[['image_path', 'sentence_path']].dropna().to_dict('index')
    all_spot_ids = list(all_spot_info.keys())
    
    total_files = len(all_spot_ids)
    print(f"准备打包 {total_files} 个 spots 的数据...")
    
    os.makedirs(os.path.dirname(output_pattern), exist_ok=True)

    # 将所有spot_ids分块
    spot_id_chunks = [all_spot_ids[i:i + SAMPLES_PER_SHARD] for i in range(0, len(all_spot_ids), SAMPLES_PER_SHARD)]
    
    # 【【【关键修改点】】】
    # 为每个任务准备一个包含所有必要参数的元组
    tasks = [
        (chunk, i, all_spot_info, output_pattern, SAMPLES_PER_SHARD) 
        for i, chunk in enumerate(spot_id_chunks)
    ]
    
    # 使用多进程并行处理
    print(f"启动 {NUM_WORKERS_PACKAGING} 个工作进程进行打包...")
    with ProcessPoolExecutor(max_workers=NUM_WORKERS_PACKAGING) as executor:
        # 使用tqdm来显示处理进度
        results = list(tqdm(executor.map(_process_chunk_worker, tasks), total=len(tasks), desc="打包 .tar 分片"))

    end_time = time.time()
    print(f"✅ 数据打包完成！成功处理了 {len(results)} 个分片。")
    print(f"🕒 耗时: {format_time(end_time - start_time)}")

In [8]:


master_adata = sc.read_h5ad(EXISTING_CACHE_PATH)
print(f"✅ 加载成功！AnnData 维度: {master_adata.shape}")



✅ 加载成功！AnnData 维度: (997054, 30148)


In [9]:
# 使用新的、无squidpy依赖的函数来添加空间图
master_adata_with_graph = add_spatial_graph_to_adata_no_squidpy(master_adata, n_neighbors=N_NEIGHBORS)



--- 步骤 1: 计算并添加空间邻接图 (无Squidpy替代方案) ---
检测到 505 个独立样本。将逐个计算空间图...


处理每个样本:   0%|          | 0/505 [00:00<?, ?it/s]


✅ 所有样本的空间邻接图已计算并合并！
🕒 耗时: 2分 41.06秒


In [None]:

print(f"\n正在保存带有图结构的最终主 AnnData 文件至: {FINAL_ADATA_PATH}")
os.makedirs(os.path.dirname(FINAL_ADATA_PATH), exist_ok=True)
master_adata_with_graph.write_h5ad(FINAL_ADATA_PATH)
print("✅ 保存成功！")




正在保存带有图结构的最终主 AnnData 文件至: /cwStorage/nodecw_group/jijh/spaglam_sota_data/master_adata_with_graph.h5ad
✅ 保存成功！

--- 步骤 2: 将原始数据打包成 WebDataset 分片 (.tar) ---
正在重建文件路径...
准备打包 997054 个 spots 的数据...


打包 .tar 分片:   0%|          | 0/100 [00:00<?, ?it/s]

AttributeError: Can't get local object 'create_webdataset_shards.<locals>.process_chunk'

In [13]:
shard_pattern = os.path.join(SHARDS_OUTPUT_PATH, "dataset-%06d.tar")
create_webdataset_shards(master_adata_with_graph, shard_pattern)

print("\n\n🎉🎉🎉 恭喜！SOTA 数据预处理全部完成！🎉🎉🎉")
print("您现在拥有：")
print(f"1. 一个包含所有元数据和空间图的中央索引文件: {FINAL_ADATA_PATH}")
print(f"2. 一组用于高效训练的数据仓库分片文件位于: {SHARDS_OUTPUT_PATH}")


--- 步骤 2: 将原始数据打包成 WebDataset 分片 (.tar) ---
准备打包 997054 个 spots 的数据...
启动 16 个工作进程进行打包...


打包 .tar 分片:   0%|          | 0/100 [00:00<?, ?it/s]

✅ 数据打包完成！成功处理了 100 个分片。
🕒 耗时: 2分 33.69秒


🎉🎉🎉 恭喜！SOTA 数据预处理全部完成！🎉🎉🎉
您现在拥有：
1. 一个包含所有元数据和空间图的中央索引文件: /cwStorage/nodecw_group/jijh/spaglam_sota_data/master_adata_with_graph.h5ad
2. 一组用于高效训练的数据仓库分片文件位于: /cwStorage/nodecw_group/jijh/spaglam_sota_data/webdataset_shards


# Check

In [14]:
# -*- coding: utf-8 -*-

# ==============================================================================
# SpaGLaM SOTA 数据集验证脚本
# 
# 用途:
#   - 验证预处理流程是否成功。
#   - 检查主 AnnData 文件和 WebDataset 分片的完整性与一致性。
#   - 在开始昂贵的模型训练前，进行一次快速的健全性检查 (Sanity Check)。
# ==============================================================================

import os
import sys
import random
import scanpy as sc
import webdataset as wds
import numpy as np
from PIL import Image
import io

# --- 1. 配置区域 (请根据您的环境修改) ---

# --- 路径配置 ---
BASE_DIR = "/cwStorage/nodecw_group/jijh"
OUTPUT_DIR = os.path.join(BASE_DIR, "spaglam_sota_data")
FINAL_ADATA_PATH = os.path.join(OUTPUT_DIR, "master_adata_with_graph.h5ad")
SHARDS_OUTPUT_PATH = os.path.join(OUTPUT_DIR, "webdataset_shards")

# --- 验证参数 ---
# 随机抽样检查的spot数量
NUM_SAMPLES_TO_CHECK = 5

# --- 辅助函数 ---
def print_header(title):
    print("\n" + "="*80)
    print(f"  {title.upper()}")
    print("="*80)

def print_status(message, success=True):
    prefix = "✅ [成功]" if success else "❌ [失败]"
    print(f"{prefix} {message}")

def print_info(message):
    print(f"ℹ️ [信息] {message}")


# --- 验证函数 ---

def check_file_existence():
    """检查所有关键文件和目录是否存在"""
    print_header("1. 文件存在性检查")
    all_exist = True
    
    # 检查主 AnnData 文件
    if os.path.exists(FINAL_ADATA_PATH):
        print_status(f"主 AnnData 文件存在: {FINAL_ADATA_PATH}")
    else:
        print_status(f"主 AnnData 文件缺失: {FINAL_ADATA_PATH}", success=False)
        all_exist = False

    # 检查 WebDataset 分片目录
    if os.path.isdir(SHARDS_OUTPUT_PATH):
        shards = [f for f in os.listdir(SHARDS_OUTPUT_PATH) if f.endswith('.tar')]
        if shards:
            print_status(f"WebDataset 分片目录存在，并找到 {len(shards)} 个 .tar 文件。")
            # 随机选择一个分片路径用于后续检查
            random_shard_path = os.path.join(SHARDS_OUTPUT_PATH, random.choice(shards))
        else:
            print_status("WebDataset 分片目录存在，但其中没有 .tar 文件。", success=False)
            all_exist = False
            random_shard_path = None
    else:
        print_status(f"WebDataset 分片目录缺失: {SHARDS_OUTPUT_PATH}", success=False)
        all_exist = False
        random_shard_path = None
        
    return all_exist, random_shard_path


def inspect_master_adata():
    """检查主 AnnData 文件的内部结构"""
    print_header("2. 主 AnnData 文件完整性检查")
    try:
        adata = sc.read_h5ad(FINAL_ADATA_PATH)
        print_status(f"成功加载 AnnData 文件。")
        print_info(f"AnnData 维度 (spots, genes): {adata.n_obs} x {adata.n_vars}")
        
        # 检查关键列
        required_obs = ['sample_id', 'image_path', 'sentence_path']
        for col in required_obs:
            if col not in adata.obs.columns:
                print_status(f"'.obs' 中缺少关键列: '{col}'", success=False)
                return None
        print_status("'.obs' 中的关键列均存在。")
        
        # 检查空间坐标
        if 'spatial' not in adata.obsm:
            print_status("'.obsm' 中缺少空间坐标 'spatial'", success=False)
            return None
        print_status("空间坐标 '.obsm['spatial']' 存在。")

        # 检查空间邻接图 (最重要)
        if 'spatial_connectivities' not in adata.obsp:
            print_status("'.obsp' 中缺少空间邻接图 'spatial_connectivities'", success=False)
            return None
        print_status("空间邻接图 '.obsp['spatial_connectivities']' 存在。")
        
        return adata
    except Exception as e:
        print_status(f"加载或检查 AnnData 文件时出错: {e}", success=False)
        return None


def inspect_tar_shard(shard_path):
    """检查单个 .tar 分片的内部结构"""
    print_header("3. WebDataset 分片结构检查")
    if not shard_path:
        print_status("没有可供检查的分片文件。", success=False)
        return False
        
    try:
        print_info(f"正在抽样检查分片: {os.path.basename(shard_path)}")
        dataset = wds.WebDataset(shard_path)
        
        sample_count = 0
        for i, sample in enumerate(dataset):
            if i >= 3: break  # 只检查前3个样本
            
            # 检查关键键
            expected_keys = {'__key__', 'png', 'txt'}
            if not expected_keys.issubset(sample.keys()):
                print_status(f"样本 {sample['__key__']} 缺少关键键。期望: {expected_keys}, 实际: {sample.keys()}", success=False)
                return False
            
            # 尝试解码
            Image.open(io.BytesIO(sample['png'])).convert("RGB")
            sample['txt'].decode('utf-8')
            sample_count += 1

        if sample_count > 0:
            print_status(f"成功检查了 {sample_count} 个样本，结构正确。")
            return True
        else:
            print_status("分片文件为空或无法读取样本。", success=False)
            return False

    except Exception as e:
        print_status(f"检查 .tar 分片时出错: {e}", success=False)
        return False


def perform_end_to_end_check(adata, tar_urls):
    """对随机样本进行端到端一致性检查"""
    print_header("4. 端到端一致性检查")
    
    try:
        spot_ids_to_check = random.sample(adata.obs_names.tolist(), k=NUM_SAMPLES_TO_CHECK)
        print_info(f"将随机抽样检查以下 {NUM_SAMPLES_TO_CHECK} 个 spots: {spot_ids_to_check}")

        # 使用WebDataset查找并加载这些样本
        dataset = wds.WebDataset(tar_urls).select(lambda x: x['__key__'] in spot_ids_to_check)
        
        found_count = 0
        for sample in tqdm(dataset, total=len(spot_ids_to_check), desc="端到端检查"):
            spot_id = sample['__key__']
            print(f"\n--- 正在检查 Spot: {spot_id} ---")
            
            # 1. 验证 AnnData 中的邻居信息
            try:
                idx = adata.obs_names.get_loc(spot_id)
                neighbors_indices = adata.obsp['spatial_connectivities'][idx].indices
                neighbor_ids = adata.obs_names[neighbors_indices].tolist()
                print_status(f"AnnData: 找到 {len(neighbor_ids)} 个邻居。示例: {neighbor_ids[:3]}")
            except Exception as e:
                print_status(f"AnnData: 查找邻居失败: {e}", success=False)
                return False

            # 2. 验证 WebDataset 中的数据
            try:
                img = Image.open(io.BytesIO(sample['png']))
                sentence = sample['txt'].decode('utf-8')
                print_status("WebDataset: 图像和文本数据成功解码。")
                print_info(f"图像尺寸: {img.size}, 句子预览: '{sentence[:50]}...'")
            except Exception as e:
                print_status(f"WebDataset: 解码数据失败: {e}", success=False)
                return False
            
            found_count += 1
        
        if found_count == len(spot_ids_to_check):
            print_status(f"\n所有 {found_count} 个抽样检查的 spot 均通过了端到端一致性验证！")
            return True
        else:
            print_status(f"只找到了 {found_count}/{len(spot_ids_to_check)} 个抽样样本。数据可能不完整或key不匹配。", success=False)
            return False

    except Exception as e:
        print_status(f"端到端检查过程中发生严重错误: {e}", success=False)
        return False


# --- 主执行流程 ---
def main():
    """主验证函数"""
    files_ok, random_shard = check_file_existence()
    if not files_ok:
        print("\n❌ 基础文件缺失，验证终止。请先成功运行预处理脚本。")
        return

    adata = inspect_master_adata()
    if adata is None:
        print("\n❌ 主 AnnData 文件存在问题，验证终止。")
        return
        
    tar_ok = inspect_tar_shard(random_shard)
    if not tar_ok:
        print("\n❌ WebDataset 分片文件存在问题，验证终止。")
        return
    
    # 获取所有分片的URL列表
    all_shards = [os.path.join(SHARDS_OUTPUT_PATH, f) for f in os.listdir(SHARDS_OUTPUT_PATH) if f.endswith('.tar')]
    
    e2e_ok = perform_end_to_end_check(adata, all_shards)
    if not e2e_ok:
        print("\n❌ 端到端一致性检查失败，请检查数据生成逻辑。")
        return

    print("\n\n" + "*"*25 + "  数据集验证通过  " + "*"*25)
    print("✅ 您的数据集已准备就绪，可以用于SpaGLaM模型训练！")
    print("*"*80)


if __name__ == "__main__":
    main()


  1. 文件存在性检查
✅ [成功] 主 AnnData 文件存在: /cwStorage/nodecw_group/jijh/spaglam_sota_data/master_adata_with_graph.h5ad
✅ [成功] WebDataset 分片目录存在，并找到 100 个 .tar 文件。

  2. 主 ANNDATA 文件完整性检查
✅ [成功] 成功加载 AnnData 文件。
ℹ️ [信息] AnnData 维度 (spots, genes): 997054 x 30148
❌ [失败] '.obs' 中缺少关键列: 'image_path'

❌ 主 AnnData 文件存在问题，验证终止。


In [16]:
master_adata.obs

Unnamed: 0,in_tissue,pxl_col_in_fullres,pxl_row_in_fullres,array_col,array_row,n_counts,n_genes_by_counts,log1p_n_genes_by_counts,total_counts,log1p_total_counts,...,log1p_total_counts_mito,pct_counts_mito,sample_id,pxl_row_in_fullres_old,pxl_col_in_fullres_old,total_counts_mt,pct_counts_mt,n_genes,image_path,sentence_path
TENX158_000x017,True,7727.328901,1043.862888,17.0,0.0,9334.0,980,6.908755,9056.0,9.141526,...,0.000000,0.000000,TENX158,,,0.0,0.000000,980,/cwStorage/nodecw_group/jijh/hest_output/TENX1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...
TENX158_000x018,True,8092.620792,1043.862888,18.0,0.0,2166.0,447,6.120297,2079.0,7.681099,...,0.000000,0.000000,TENX158,,,0.0,0.000000,447,/cwStorage/nodecw_group/jijh/hest_output/TENX1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...
TENX158_000x020,True,8823.204575,1043.862888,20.0,0.0,5275.0,866,6.778785,5081.0,8.570924,...,0.000000,0.000000,TENX158,,,0.0,0.000000,866,/cwStorage/nodecw_group/jijh/hest_output/TENX1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...
TENX158_000x021,True,9188.496466,1043.862888,21.0,0.0,5838.0,1268,7.163947,5674.0,8.672315,...,0.000000,0.000000,TENX158,,,0.0,0.000000,1268,/cwStorage/nodecw_group/jijh/hest_output/TENX1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...
TENX158_001x016,True,7362.037009,1409.154779,16.0,1.0,24396.0,1840,7.554335,23763.0,10.102215,...,0.000000,0.000000,TENX158,,,0.0,0.000000,1840,/cwStorage/nodecw_group/jijh/hest_output/TENX1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
MISC1_TTGTGTTTCCCGAAAG-1,True,6663.460818,8760.178788,59.0,51.0,,1456,7.312553,2394.0,7.803027,...,5.755742,12.872906,MISC1,,,315.0,13.157895,1456,/cwStorage/nodecw_group/jijh/hest_output/MISC1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...
MISC1_TTGTTGTGTGTCAAGA-1,True,7922.034775,6375.541625,77.0,31.0,,1587,7.401842,2813.0,7.965893,...,6.165418,16.493055,MISC1,,,475.0,16.885887,1587,/cwStorage/nodecw_group/jijh/hest_output/MISC1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...
MISC1_TTGTTTCACATCCAGG-1,True,5486.172801,9589.796448,42.0,58.0,,1476,7.331060,2577.0,7.878913,...,5.771441,12.121212,MISC1,,,320.0,12.417540,1476,/cwStorage/nodecw_group/jijh/hest_output/MISC1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...
MISC1_TTGTTTCATTAGTCTA-1,True,4657.957578,9823.488712,30.0,60.0,,1241,7.154615,2021.0,7.632401,...,5.572154,12.699951,MISC1,,,262.0,12.963879,1241,/cwStorage/nodecw_group/jijh/hest_output/MISC1...,/cwStorage/nodecw_group/jijh/hest_sentences_hu...


# 改正之后

In [17]:
# -*- coding: utf-8 -*-

# ==============================================================================
# SpaGLaM SOTA 数据预处理工作流 (V2 - 已修复路径和多进程问题)
#
# 核心策略:
# 1. 加载预先合并好的AnnData缓存文件。
# 2. **准备主AnnData**:
#    a. 根据坐标和样本ID，重建并添加 image_path 和 sentence_path 列。
#    b. 通过“拆分-计算-合并”策略计算并添加空间邻接图。
# 3. 将所有独立的图像和基因语句文件打包成高效的.tar分片。
# ==============================================================================

import os
import sys
import time
import scanpy as sc
import pandas as pd
import numpy as np
import webdataset as wds
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor
from scipy.sparse import lil_matrix

# --- 1. 全局设置 ---

# --- 路径配置 (请根据您的环境修改) ---
BASE_DIR = "/cwStorage/nodecw_group/jijh" 
EXISTING_CACHE_PATH = os.path.join(BASE_DIR, "hest_cache/cache/adata_preprocessed_canonical_v2.h5ad")
IMAGE_BASE_DIR = os.path.join(BASE_DIR, "hest_output")
SENTENCE_BASE_DIR = os.path.join(BASE_DIR, "hest_sentences_human_all")

OUTPUT_DIR = os.path.join(BASE_DIR, "spaglam_sota_data")
FINAL_ADATA_PATH = os.path.join(OUTPUT_DIR, "master_adata_with_graph_and_paths.h5ad") # 新文件名
SHARDS_OUTPUT_PATH = os.path.join(OUTPUT_DIR, "webdataset_shards")

# --- 参数配置 ---
N_NEIGHBORS = 6
SAMPLES_PER_SHARD = 10000
NUM_WORKERS_PACKAGING = 16

# --- 辅助函数 ---
def format_time(seconds: float) -> str:
    """将秒数格式化为易读的“分-秒”字符串"""
    mins, secs = divmod(seconds, 60)
    return f"{int(mins)}分 {secs:.2f}秒"

def print_header(title):
    print("\n" + "="*80)
    print(f"  {title.upper()}")
    print("="*80)

# --- 核心功能函数 (已修改和重构) ---

def prepare_master_adata(adata: sc.AnnData, n_neighbors: int) -> sc.AnnData:
    """
    一个集总函数，用于对加载的 AnnData 对象进行最终准备。
    包括添加文件路径和计算空间图。
    """
    print_header("准备主 AnnData 文件")
    adata = add_file_paths(adata)
    adata = add_spatial_graph(adata, n_neighbors)
    return adata

def add_file_paths(adata: sc.AnnData) -> sc.AnnData:
    """
    【新增函数】根据坐标和样本ID，重建并添加 image_path 和 sentence_path 列。
    """
    print("--> 步骤 1.1: 重建并添加文件路径...")
    start_time = time.time()
    
    if 'image_path' in adata.obs.columns and 'sentence_path' in adata.obs.columns:
        # 检查路径是否有效，如果无效则重新生成
        first_path = adata.obs['image_path'].iloc[0]
        if os.path.exists(first_path):
            print("✅ 文件路径列已存在且有效，跳过重建。")
            return adata
        else:
            print("⚠️ 检测到无效的文件路径，将重新构建...")

    # 从 spot ID 和坐标重建路径
    def _get_path(row):
        spot_id = row.name
        sample_id = row['sample_id']
        # 根据您的提示，文件名由坐标取整得到
        x_coord = int(round(row['spatial_x']))
        y_coord = int(round(row['spatial_y']))
        
        # 构建与您原始文件名格式一致的字符串
        filename_base = f"{sample_id}_{y_coord}_{x_coord}"
        
        img_path = os.path.join(IMAGE_BASE_DIR, f"{sample_id}_tiles", f"{filename_base}.png")
        sentence_path = os.path.join(SENTENCE_BASE_DIR, f"{sample_id}_sentences_hvg", f"{filename_base}.txt")
        return img_path, sentence_path

    # 为了高效处理，先将坐标提取出来
    adata.obs['spatial_x'] = adata.obsm['spatial'][:, 0]
    adata.obs['spatial_y'] = adata.obsm['spatial'][:, 1]
    
    # 使用.apply()方法高效地为每一行生成路径
    paths_df = adata.obs.apply(_get_path, axis=1, result_type='expand')
    paths_df.columns = ['image_path', 'sentence_path']
    
    # 将新生成的列合并回 .obs
    adata.obs['image_path'] = paths_df['image_path']
    adata.obs['sentence_path'] = paths_df['sentence_path']

    # 清理临时列
    adata.obs.drop(columns=['spatial_x', 'spatial_y'], inplace=True)

    print(f"✅ 文件路径已成功添加到 '.obs' 中。")
    end_time = time.time()
    print(f"🕒 耗时: {format_time(end_time - start_time)}")
    return adata


def add_spatial_graph(adata: sc.AnnData, n_neighbors: int) -> sc.AnnData:
    """【无Squidpy替代方案】通过“拆分-计算-合并”策略，计算并添加空间邻接图。"""
    print("--> 步骤 1.2: 计算并添加空间邻接图...")
    start_time = time.time()
    if 'spatial_connectivities' in adata.obsp:
        print("✅ 空间邻接图 'spatial_connectivities' 已存在，跳过计算。")
        return adata
        
    unique_samples = adata.obs['sample_id'].unique()
    print(f"检测到 {len(unique_samples)} 个独立样本，将逐个计算空间图...")
    
    global_conn_matrix = lil_matrix((adata.n_obs, adata.n_obs), dtype=np.float32)

    for sample_id in tqdm(unique_samples, desc="处理每个样本的图"):
        sample_mask = adata.obs['sample_id'] == sample_id
        adata_sample = adata[sample_mask]
        
        sc.pp.neighbors(adata_sample, n_neighbors=n_neighbors, use_rep='spatial', key_added='spatial')
        
        local_conn = adata_sample.obsp['spatial_connectivities']
        global_indices = np.where(sample_mask)[0]
        
        rows, cols = local_conn.nonzero()
        global_rows = global_indices[rows]
        global_cols = global_indices[cols]
        
        global_conn_matrix[global_rows, global_cols] = local_conn.data.reshape(-1, 1)

    adata.obsp['spatial_connectivities'] = global_conn_matrix.tocsr()
    
    print("✅ 所有样本的空间邻接图已计算并合并！")
    end_time = time.time()
    print(f"🕒 耗时: {format_time(end_time - start_time)}")
    return adata

# ... 多进程工作函数 _process_chunk_worker 保持不变 ...
def _process_chunk_worker(args):
    spot_ids_chunk, chunk_index, all_spot_info, output_pattern = args
    shard_path = output_pattern % chunk_index
    with wds.TarWriter(shard_path) as sink:
        for spot_id in spot_ids_chunk:
            paths = all_spot_info.get(spot_id)
            if not paths: continue
            img_path, sentence_path = paths['image_path'], paths['sentence_path']
            if not (os.path.exists(img_path) and os.path.exists(sentence_path)): continue
            with open(img_path, "rb") as f_img: image_data = f_img.read()
            with open(sentence_path, "rb") as f_txt: sentence_data = f_txt.read()
            sample = {"__key__": spot_id, "png": image_data, "txt": sentence_data}
            sink.write(sample)
    return chunk_index

def create_webdataset_shards(adata: sc.AnnData, output_pattern: str):
    """【逻辑简化】现在假设adata中已包含所需路径。"""
    print_header("2. 将原始数据打包成 WebDataset 分片")
    start_time = time.time()
    
    # 直接使用 AnnData 中的路径信息
    all_spot_info = adata.obs[['image_path', 'sentence_path']].dropna().to_dict('index')
    all_spot_ids = list(all_spot_info.keys())
    
    print(f"准备打包 {len(all_spot_ids)} 个 spots 的数据...")
    os.makedirs(os.path.dirname(output_pattern), exist_ok=True)
    
    spot_id_chunks = [all_spot_ids[i:i + SAMPLES_PER_SHARD] for i in range(0, len(all_spot_ids), SAMPLES_PER_SHARD)]
    
    # 准备任务参数
    tasks = [(chunk, i, all_spot_info, output_pattern) for i, chunk in enumerate(spot_id_chunks)]
    
    print(f"启动 {NUM_WORKERS_PACKAGING} 个工作进程进行打包...")
    with ProcessPoolExecutor(max_workers=NUM_WORKERS_PACKAGING) as executor:
        results = list(tqdm(executor.map(_process_chunk_worker, tasks), total=len(tasks), desc="打包 .tar 分片"))

    end_time = time.time()
    print(f"✅ 数据打包完成！成功处理了 {len(results)} 个分片。")
    print(f"🕒 耗时: {format_time(end_time - start_time)}")

In [18]:
print("🚀 开始SOTA数据预处理工作流 (V2 - 已修复路径和多进程问题)...")

print(f"\n正在从缓存加载主 AnnData 文件: {EXISTING_CACHE_PATH}")
if not os.path.exists(EXISTING_CACHE_PATH):
    print(f"❌ 错误: 缓存文件未找到！请确认路径 '{EXISTING_CACHE_PATH}' 是否正确。")
    sys.exit(1)

master_adata = sc.read_h5ad(EXISTING_CACHE_PATH)
print(f"✅ 加载成功！AnnData 维度: {master_adata.shape}")

# 【核心步骤】一步准备好主 AnnData
final_master_adata = prepare_master_adata(master_adata, n_neighbors=N_NEIGHBORS)

print(f"\n正在保存最终的主 AnnData 文件至: {FINAL_ADATA_PATH}")
os.makedirs(os.path.dirname(FINAL_ADATA_PATH), exist_ok=True)
final_master_adata.write_h5ad(FINAL_ADATA_PATH)
print("✅ 保存成功！")


🚀 开始SOTA数据预处理工作流 (V2 - 已修复路径和多进程问题)...

正在从缓存加载主 AnnData 文件: /cwStorage/nodecw_group/jijh/hest_cache/cache/adata_preprocessed_canonical_v2.h5ad
✅ 加载成功！AnnData 维度: (997054, 30148)

  准备主 ANNDATA 文件
--> 步骤 1.1: 重建并添加文件路径...
✅ 文件路径已成功添加到 '.obs' 中。
🕒 耗时: 0分 20.62秒
--> 步骤 1.2: 计算并添加空间邻接图...
检测到 505 个独立样本，将逐个计算空间图...


处理每个样本的图:   0%|          | 0/505 [00:00<?, ?it/s]

✅ 所有样本的空间邻接图已计算并合并！
🕒 耗时: 2分 16.59秒

正在保存最终的主 AnnData 文件至: /cwStorage/nodecw_group/jijh/spaglam_sota_data/master_adata_with_graph_and_paths.h5ad
✅ 保存成功！


In [19]:

shard_pattern = os.path.join(SHARDS_OUTPUT_PATH, "dataset-%06d.tar")
create_webdataset_shards(final_master_adata, shard_pattern)

print("\n\n🎉🎉🎉 恭喜！SOTA 数据预处理全部完成！🎉🎉🎉")
print("您现在拥有：")
print(f"1. 一个包含所有元数据、文件路径和空间图的中央索引文件: {FINAL_ADATA_PATH}")
print(f"2. 一组用于高效训练的数据仓库分片文件位于: {SHARDS_OUTPUT_PATH}")


  2. 将原始数据打包成 WEBDATASET 分片
准备打包 997054 个 spots 的数据...
启动 16 个工作进程进行打包...


打包 .tar 分片:   0%|          | 0/100 [00:00<?, ?it/s]

✅ 数据打包完成！成功处理了 100 个分片。
🕒 耗时: 12分 43.16秒


🎉🎉🎉 恭喜！SOTA 数据预处理全部完成！🎉🎉🎉
您现在拥有：
1. 一个包含所有元数据、文件路径和空间图的中央索引文件: /cwStorage/nodecw_group/jijh/spaglam_sota_data/master_adata_with_graph_and_paths.h5ad
2. 一组用于高效训练的数据仓库分片文件位于: /cwStorage/nodecw_group/jijh/spaglam_sota_data/webdataset_shards


In [None]:
# -*- coding: utf-8 -*-

# ==============================================================================
# SpaGLaM SOTA 数据集验证脚本
# 
# 用途:
#   - 验证预处理流程是否成功。
#   - 检查主 AnnData 文件和 WebDataset 分片的完整性与一致性。
#   - 在开始昂贵的模型训练前，进行一次快速的健全性检查 (Sanity Check)。
# ==============================================================================

import os
import sys
import random
import scanpy as sc
import webdataset as wds
import numpy as np
from PIL import Image
import io

# --- 1. 配置区域 (请根据您的环境修改) ---

# --- 路径配置 ---
BASE_DIR = "/cwStorage/nodecw_group/jijh"
OUTPUT_DIR = os.path.join(BASE_DIR, "spaglam_sota_data")
FINAL_ADATA_PATH = os.path.join(OUTPUT_DIR, "master_adata_with_graph_and_paths.h5ad")
SHARDS_OUTPUT_PATH = os.path.join(OUTPUT_DIR, "webdataset_shards")

# --- 验证参数 ---
# 随机抽样检查的spot数量
NUM_SAMPLES_TO_CHECK = 5

# --- 辅助函数 ---
def print_header(title):
    print("\n" + "="*80)
    print(f"  {title.upper()}")
    print("="*80)

def print_status(message, success=True):
    prefix = "✅ [成功]" if success else "❌ [失败]"
    print(f"{prefix} {message}")

def print_info(message):
    print(f"ℹ️ [信息] {message}")


# --- 验证函数 ---

def check_file_existence():
    """检查所有关键文件和目录是否存在"""
    print_header("1. 文件存在性检查")
    all_exist = True
    
    # 检查主 AnnData 文件
    if os.path.exists(FINAL_ADATA_PATH):
        print_status(f"主 AnnData 文件存在: {FINAL_ADATA_PATH}")
    else:
        print_status(f"主 AnnData 文件缺失: {FINAL_ADATA_PATH}", success=False)
        all_exist = False

    # 检查 WebDataset 分片目录
    if os.path.isdir(SHARDS_OUTPUT_PATH):
        shards = [f for f in os.listdir(SHARDS_OUTPUT_PATH) if f.endswith('.tar')]
        if shards:
            print_status(f"WebDataset 分片目录存在，并找到 {len(shards)} 个 .tar 文件。")
            # 随机选择一个分片路径用于后续检查
            random_shard_path = os.path.join(SHARDS_OUTPUT_PATH, random.choice(shards))
        else:
            print_status("WebDataset 分片目录存在，但其中没有 .tar 文件。", success=False)
            all_exist = False
            random_shard_path = None
    else:
        print_status(f"WebDataset 分片目录缺失: {SHARDS_OUTPUT_PATH}", success=False)
        all_exist = False
        random_shard_path = None
        
    return all_exist, random_shard_path


def inspect_master_adata():
    """检查主 AnnData 文件的内部结构"""
    print_header("2. 主 AnnData 文件完整性检查")
    try:
        adata = sc.read_h5ad(FINAL_ADATA_PATH)
        print_status(f"成功加载 AnnData 文件。")
        print_info(f"AnnData 维度 (spots, genes): {adata.n_obs} x {adata.n_vars}")
        
        # 检查关键列
        required_obs = ['sample_id', 'image_path', 'sentence_path']
        for col in required_obs:
            if col not in adata.obs.columns:
                print_status(f"'.obs' 中缺少关键列: '{col}'", success=False)
                return None
        print_status("'.obs' 中的关键列均存在。")
        
        # 检查空间坐标
        if 'spatial' not in adata.obsm:
            print_status("'.obsm' 中缺少空间坐标 'spatial'", success=False)
            return None
        print_status("空间坐标 '.obsm['spatial']' 存在。")

        # 检查空间邻接图 (最重要)
        if 'spatial_connectivities' not in adata.obsp:
            print_status("'.obsp' 中缺少空间邻接图 'spatial_connectivities'", success=False)
            return None
        print_status("空间邻接图 '.obsp['spatial_connectivities']' 存在。")
        
        return adata
    except Exception as e:
        print_status(f"加载或检查 AnnData 文件时出错: {e}", success=False)
        return None


def inspect_tar_shard(shard_path):
    """检查单个 .tar 分片的内部结构"""
    print_header("3. WebDataset 分片结构检查")
    if not shard_path:
        print_status("没有可供检查的分片文件。", success=False)
        return False
        
    try:
        print_info(f"正在抽样检查分片: {os.path.basename(shard_path)}")
        dataset = wds.WebDataset(shard_path)
        
        sample_count = 0
        for i, sample in enumerate(dataset):
            if i >= 3: break  # 只检查前3个样本
            
            # 检查关键键
            expected_keys = {'__key__', 'png', 'txt'}
            if not expected_keys.issubset(sample.keys()):
                print_status(f"样本 {sample['__key__']} 缺少关键键。期望: {expected_keys}, 实际: {sample.keys()}", success=False)
                return False
            
            # 尝试解码
            Image.open(io.BytesIO(sample['png'])).convert("RGB")
            sample['txt'].decode('utf-8')
            sample_count += 1

        if sample_count > 0:
            print_status(f"成功检查了 {sample_count} 个样本，结构正确。")
            return True
        else:
            print_status("分片文件为空或无法读取样本。", success=False)
            return False

    except Exception as e:
        print_status(f"检查 .tar 分片时出错: {e}", success=False)
        return False


def perform_end_to_end_check(adata, tar_urls):
    """对随机样本进行端到端一致性检查"""
    print_header("4. 端到端一致性检查")
    
    try:
        spot_ids_to_check = random.sample(adata.obs_names.tolist(), k=NUM_SAMPLES_TO_CHECK)
        print_info(f"将随机抽样检查以下 {NUM_SAMPLES_TO_CHECK} 个 spots: {spot_ids_to_check}")

        # 使用WebDataset查找并加载这些样本
        dataset = wds.WebDataset(tar_urls).select(lambda x: x['__key__'] in spot_ids_to_check)
        
        found_count = 0
        for sample in tqdm(dataset, total=len(spot_ids_to_check), desc="端到端检查"):
            spot_id = sample['__key__']
            print(f"\n--- 正在检查 Spot: {spot_id} ---")
            
            # 1. 验证 AnnData 中的邻居信息
            try:
                idx = adata.obs_names.get_loc(spot_id)
                neighbors_indices = adata.obsp['spatial_connectivities'][idx].indices
                neighbor_ids = adata.obs_names[neighbors_indices].tolist()
                print_status(f"AnnData: 找到 {len(neighbor_ids)} 个邻居。示例: {neighbor_ids[:3]}")
            except Exception as e:
                print_status(f"AnnData: 查找邻居失败: {e}", success=False)
                return False

            # 2. 验证 WebDataset 中的数据
            try:
                img = Image.open(io.BytesIO(sample['png']))
                sentence = sample['txt'].decode('utf-8')
                print_status("WebDataset: 图像和文本数据成功解码。")
                print_info(f"图像尺寸: {img.size}, 句子预览: '{sentence[:50]}...'")
            except Exception as e:
                print_status(f"WebDataset: 解码数据失败: {e}", success=False)
                return False
            
            found_count += 1
        
        if found_count == len(spot_ids_to_check):
            print_status(f"\n所有 {found_count} 个抽样检查的 spot 均通过了端到端一致性验证！")
            return True
        else:
            print_status(f"只找到了 {found_count}/{len(spot_ids_to_check)} 个抽样样本。数据可能不完整或key不匹配。", success=False)
            return False

    except Exception as e:
        print_status(f"端到端检查过程中发生严重错误: {e}", success=False)
        return False


# --- 主执行流程 ---
def main():
    """主验证函数"""
    files_ok, random_shard = check_file_existence()
    if not files_ok:
        print("\n❌ 基础文件缺失，验证终止。请先成功运行预处理脚本。")
        return

    adata = inspect_master_adata()
    if adata is None:
        print("\n❌ 主 AnnData 文件存在问题，验证终止。")
        return
        
    tar_ok = inspect_tar_shard(random_shard)
    if not tar_ok:
        print("\n❌ WebDataset 分片文件存在问题，验证终止。")
        return
    
    # 获取所有分片的URL列表
    all_shards = [os.path.join(SHARDS_OUTPUT_PATH, f) for f in os.listdir(SHARDS_OUTPUT_PATH) if f.endswith('.tar')]
    
    e2e_ok = perform_end_to_end_check(adata, all_shards)
    if not e2e_ok:
        print("\n❌ 端到端一致性检查失败，请检查数据生成逻辑。")
        return

    print("\n\n" + "*"*25 + "  数据集验证通过  " + "*"*25)
    print("✅ 您的数据集已准备就绪，可以用于SpaGLaM模型训练！")
    print("*"*80)


if __name__ == "__main__":
    main()


  1. 文件存在性检查
✅ [成功] 主 AnnData 文件存在: /cwStorage/nodecw_group/jijh/spaglam_sota_data/master_adata_with_graph_and_paths.h5ad
✅ [成功] WebDataset 分片目录存在，并找到 100 个 .tar 文件。

  2. 主 ANNDATA 文件完整性检查
✅ [成功] 成功加载 AnnData 文件。
ℹ️ [信息] AnnData 维度 (spots, genes): 997054 x 30148
✅ [成功] '.obs' 中的关键列均存在。
✅ [成功] 空间坐标 '.obsm['spatial']' 存在。
✅ [成功] 空间邻接图 '.obsp['spatial_connectivities']' 存在。

  3. WEBDATASET 分片结构检查
ℹ️ [信息] 正在抽样检查分片: dataset-000099.tar
✅ [成功] 成功检查了 3 个样本，结构正确。

  4. 端到端一致性检查
ℹ️ [信息] 将随机抽样检查以下 5 个 spots: ['NCBI829_TCCCGGTCAGGAATTT-1', 'ZEN48_ACGTTAATGTCGAAGA-1', 'NCBI653_TGCAAACGTACTAGTT-1', 'ZEN46_ATATCGTTCCTCGAAC-1', 'MISC70_CAACTATATCGAATGC-1']


端到端检查:   0%|          | 0/5 [00:00<?, ?it/s]


--- 正在检查 Spot: NCBI653_TGCAAACGTACTAGTT-1 ---
✅ [成功] AnnData: 找到 4 个邻居。示例: ['NCBI653_CAAGTCGTTGAAATCT-1', 'NCBI653_GGCTGTCCTACTGCGG-1', 'NCBI653_TACTGCATGATTAAAT-1']
✅ [成功] WebDataset: 图像和文本数据成功解码。
ℹ️ [信息] 图像尺寸: (580, 580), 句子预览: 'CLU S100A6-1 SPARC-1 CMTM4 PLEKHB1-1 NUPR1-1 UQCRB...'

--- 正在检查 Spot: NCBI829_TCCCGGTCAGGAATTT-1 ---
✅ [成功] AnnData: 找到 5 个邻居。示例: ['NCBI829_ACGAGAACCCATCACG-1', 'NCBI829_CCAGCTACGCCTCATA-1', 'NCBI829_CGATATTAGCCGCAGG-1']
✅ [成功] WebDataset: 图像和文本数据成功解码。
ℹ️ [信息] 图像尺寸: (114, 114), 句子预览: 'ALB SAA1-1 SERPINA1-1 APOC1 MT2A-1 C3 SAA2-1 MT1G-...'

--- 正在检查 Spot: MISC70_CAACTATATCGAATGC-1 ---
✅ [成功] AnnData: 找到 6 个邻居。示例: ['MISC70_AGTTAAGCGGTCCCGG-1', 'MISC70_ATTTCATTATTTCGCG-1', 'MISC70_CATAGTACATTGAGAG-1']
✅ [成功] WebDataset: 图像和文本数据成功解码。
ℹ️ [信息] 图像尺寸: (362, 362), 句子预览: 'IGKC-1 SLC25A5-1 S100A6-1 LRATD1-1 ALDOA TPM2-1 IG...'

--- 正在检查 Spot: ZEN48_ACGTTAATGTCGAAGA-1 ---
✅ [成功] AnnData: 找到 5 个邻居。示例: ['ZEN48_AGGTTTCACACACCTT-1', 'ZEN48_CAAGGATCGCATGTTC-1', 'ZEN48_CCTAA

: 

: 

: 

: 

: 