# Precomputation

# Cell 1: Imports

In [1]:
import os
import random
import glob
import datetime
import time
from PIL import Image
import numpy as np
import math
# import subprocess # 通常在 notebook 中不直接运行外部脚本
# import shlex # 同上

import torch
import torch.nn as nn
# import torch.nn.functional as F # 如果GigaPath模型内部需要，否则可能不需要直接用
# import torch.optim as optim # 进行预计算时不需要优化器
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast

import torchvision.transforms as transforms
# from torchvision.utils import make_grid, save_image # 如果不进行可视化，则不需要

import timm # GigaPath 需要的库

try:
    from tqdm.notebook import tqdm as notebook_tqdm
    IS_NOTEBOOK = True
    print("侦测到 Notebook 环境。")
except ImportError:
    from tqdm.auto import tqdm as notebook_tqdm # Fallback for non-notebook
    IS_NOTEBOOK = False
    print("未完全侦测到 Notebook 环境。使用标准 tqdm。")

侦测到 Notebook 环境。


## Validate the graph data

In [3]:
import torch
import os
import glob
import platform

# --- 前置条件 ---
# 请确保你的 TrainingConfig 类已定义，并且名为 `config` 的实例已创建。
# 如果还没有，请将 TrainingConfig 类的定义粘贴到此单元格的顶部，并实例化它，例如：
#
class TrainingConfig:
    # --- Data Paths ---
    LATENT_DIR = "/cwStorage/nodecw_group/jijh/hest_output_latents_bf16"
    GRAPH_DATA_DIR = "/cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6"
    # ... (你的其他配置项) ...
#
config = TrainingConfig()  # 确保 config 实例存在
# # 确保 config 实例存在
if 'config' not in locals() or not isinstance(config, TrainingConfig):
    print("警告: 'config' 未定义或类型不正确。请先定义并实例化 TrainingConfig 类。")
    # 为防止脚本出错，这里可以创建一个虚拟的 config，但最好是用户提供真实的
    class DummyConfig:
        GRAPH_DATA_DIR = "YOUR_GRAPH_DATA_DIR_HERE" # 用户需要替换
        LATENT_DIR = "YOUR_LATENT_DIR_HERE"       # 用户需要替换
    if 'config' not in locals(): config = DummyConfig() # 仅为示例，实际应由用户提供 TrainingConfig
#
# --- 用户配置检查部分 ---
# 1. 指定要检查的图文件的索引 (从列出的文件中选择)
#    如果你想覆盖自动选择，请直接设置 SPECIFIC_GRAPH_FILE_PATH_OVERRIDE
FILE_INDEX_TO_CHECK = 0  # 默认检查第0个 (第一个) 文件
SPECIFIC_GRAPH_FILE_PATH_OVERRIDE = None # 若要手动指定完整路径，请取消注释并修改，例如: "/path/to/your/specific_graph.pt"
# 2. 是否简略显示每个路径的检查结果 (只显示统计和错误)
CONCISE_OUTPUT_PER_PATH = False # 设置为 True 以获得更简洁的逐路径输出
# --- 用户配置结束 ---

# --- 主要检查逻辑 ---
print(f"--- 开始检查图文件中的 'latent_paths' (单单元格版本) ---")
print(f"操作系统: {platform.system()} {platform.release()}, 当前工作目录: {os.getcwd()}")

selected_graph_file_to_check = None
graph_files_found = []

try:
    if 'config' not in locals():
        raise NameError("错误: 'config' 对象未定义。请确保 TrainingConfig 类已定义并已创建其实例 'config'。")

    graph_data_directory = config.GRAPH_DATA_DIR
    expected_latent_base_dir = config.LATENT_DIR
    print(f"图数据目录 (来自config): {graph_data_directory}")
    print(f"Latent基准目录 (来自config): {expected_latent_base_dir}")

    if SPECIFIC_GRAPH_FILE_PATH_OVERRIDE:
        selected_graph_file_to_check = SPECIFIC_GRAPH_FILE_PATH_OVERRIDE
        if not os.path.exists(selected_graph_file_to_check):
             print(f"⚠️ 警告: 手动指定的图文件 '{selected_graph_file_to_check}' 不存在。")
    elif not os.path.isdir(graph_data_directory):
        print(f"❌ 错误: 图数据目录 '{graph_data_directory}' 不存在或不是一个目录。")
    else:
        search_pattern = os.path.join(graph_data_directory, "*.pt") # 你可以修改为更精确的模式如 "*_graph.pt"
        graph_files_found = sorted(glob.glob(search_pattern))
        if graph_files_found:
            print(f"\n在 '{graph_data_directory}' 中找到以下图文件 (模式: '{search_pattern}'):")
            for i, f_path in enumerate(graph_files_found):
                print(f"  [{i}] {os.path.basename(f_path)}")
            
            if 0 <= FILE_INDEX_TO_CHECK < len(graph_files_found):
                selected_graph_file_to_check = graph_files_found[FILE_INDEX_TO_CHECK]
                print(f"\n✅ 自动选择索引为 {FILE_INDEX_TO_CHECK} 的文件进行检查: '{os.path.basename(selected_graph_file_to_check)}'")
                print(f"   完整路径: {selected_graph_file_to_check}")
                print(f"   (提示: 可修改本单元格顶部的 FILE_INDEX_TO_CHECK 或 SPECIFIC_GRAPH_FILE_PATH_OVERRIDE 来选择其他文件)")
            else:
                print(f"❌ 错误: 指定的 FILE_INDEX_TO_CHECK ({FILE_INDEX_TO_CHECK}) 超出范围 (0-{len(graph_files_found)-1})。")
        else:
            print(f"  ⚠️ 在 '{graph_data_directory}' 中未找到匹配的图文件 (模式: '{search_pattern}')。")

    if not selected_graph_file_to_check or not os.path.exists(selected_graph_file_to_check):
        print(f"\n❌ 未能确定有效的图文件进行检查。脚本终止。")
    else:
        print("-" * 50)
        print(f"🚀 开始详细检查文件: {selected_graph_file_to_check}")
        graph_data = torch.load(selected_graph_file_to_check, map_location='cpu')
        print(f"✅ 成功加载图数据。")

        if not hasattr(graph_data, 'latent_paths'):
            print("\n❌ 错误: 图数据中未找到 'latent_paths' 属性。")
            keys_available = list(graph_data.keys) if hasattr(graph_data, 'keys') and callable(graph_data.keys) else [d for d in dir(graph_data) if not d.startswith('_')]
            print(f"  图对象中可用的keys/属性 (部分): {keys_available[:15]}")
        else:
            latent_paths_list = graph_data.latent_paths
            print(f"\n➡️ 属性 'latent_paths' 存在。")

            if latent_paths_list is None:
                print("  ⚠️ 值为 None。")
            elif isinstance(latent_paths_list, (list, tuple)) and len(latent_paths_list) > 0:
                num_paths = len(latent_paths_list)
                num_nodes = getattr(graph_data, 'num_nodes', 'N/A (未找到属性)')
                print(f"  包含 {num_paths} 个路径条目。图中节点数 (num_nodes): {num_nodes}.")
                if isinstance(num_nodes, int) and num_paths != num_nodes:
                    print(f"  ⚠️ 警告: 路径数 ({num_paths}) 与节点数 ({num_nodes}) 不匹配!")

                found_count, not_found_count, none_path_count = 0, 0, 0
                
                # 为了简洁，只展示前5条和后5条的详细信息（如果路径总数较多）
                indices_to_show_detail = set(list(range(min(num_paths, 5))) + list(range(max(0, num_paths - 5), num_paths)))

                for i, path_in_graph in enumerate(latent_paths_list):
                    is_detailed_view = (i in indices_to_show_detail) or (not CONCISE_OUTPUT_PER_PATH)

                    if path_in_graph is None:
                        none_path_count += 1
                        if is_detailed_view: print(f"  路径 {i+1}/{num_paths}: None")
                        continue # None 路径直接跳过后续文件检查，计入not_found

                    path_str = str(path_in_graph)
                    exists_directly = os.path.exists(path_str)
                    exists_with_base = False
                    
                    if not exists_directly and expected_latent_base_dir and not os.path.isabs(path_str):
                        resolved_path_with_base = os.path.join(os.path.normpath(expected_latent_base_dir), os.path.normpath(path_str))
                        exists_with_base = os.path.exists(resolved_path_with_base)
                    
                    if exists_directly or exists_with_base:
                        found_count += 1
                        if is_detailed_view and not CONCISE_OUTPUT_PER_PATH:
                             status_detail = f"(绝对路径: '{os.path.abspath(path_str)}')" if exists_directly else f"(基于基准目录 '{expected_latent_base_dir}' 解析为 '{os.path.abspath(resolved_path_with_base)}')"
                             print(f"  路径 {i+1}/{num_paths}: \"{path_str}\" ✅ 找到 {status_detail}")
                    else:
                        not_found_count += 1
                        if is_detailed_view:
                             print(f"  路径 {i+1}/{num_paths}: \"{path_str}\" ❌ 未找到")
                             if not CONCISE_OUTPUT_PER_PATH :
                                 if os.path.isabs(path_str):
                                     print(f"    提示: 这是一个绝对路径但不存在。")
                                 else:
                                     print(f"    提示: 这是一个相对路径。当前工作目录 '{os.getcwd()}'。")
                                     if expected_latent_base_dir: print(f"    尝试的基准目录拼接路径 '{resolved_path_with_base if 'resolved_path_with_base' in locals() else 'N/A'}' (也不存在)。")
                                     else: print(f"    未提供基准目录，若为相对路径请检查。")

                # 实际未找到文件的是 not_found_count (不含None的初始计数) + none_path_count (None也算未找到)
                actual_not_found_files = not_found_count + none_path_count

                print("-" * 30)
                print(f"  📊 路径检查统计:")
                print(f"    总路径条目数: {num_paths}")
                print(f"    值为 None 的路径数: {none_path_count}")
                print(f"    有效并找到文件的路径数: {found_count}")
                print(f"    未找到对应文件的路径数 (包含值为None的路径): {actual_not_found_files}")
                if (found_count + actual_not_found_files) != num_paths:
                     print(f"    ⚠️ 统计总和 ({found_count + actual_not_found_files}) 与总路径数 ({num_paths}) 不符，请检查逻辑。")


            elif isinstance(latent_paths_list, (list, tuple)) and len(latent_paths_list) == 0:
                print("  ⚠️ 'latent_paths' 是一个空列表/元组。")
            else:
                print(f"  ❓ 'latent_paths' 类型为 {type(latent_paths_list)}，非预期格式。内容预览: {str(latent_paths_list)[:200]}")

except NameError as ne: # 特别捕捉 config 未定义的错误
    print(ne)
    print("请确保在此单元格顶部或之前的单元格中正确定义 TrainingConfig 类并创建其实例 'config'。")
except FileNotFoundError as fnfe:
    print(f"❌ 文件未找到错误: {fnfe}")
except Exception as e:
    print(f"❌ 处理过程中发生错误: {e}")
    import traceback
    traceback.print_exc()

print(f"\n--- 检查脚本结束 ---")

--- 开始检查图文件中的 'latent_paths' (单单元格版本) ---
操作系统: Linux 3.10.0-1160.el7.x86_64, 当前工作目录: /public/home/jijh
图数据目录 (来自config): /cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6
Latent基准目录 (来自config): /cwStorage/nodecw_group/jijh/hest_output_latents_bf16

在 '/cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6' 中找到以下图文件 (模式: '/cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6/*.pt'):
  [0] MEND123_graph.pt
  [1] MEND124_graph.pt
  [2] MEND129_graph.pt
  [3] MEND130_graph.pt
  [4] MEND131_graph.pt
  [5] MEND42_graph.pt
  [6] MEND43_graph.pt
  [7] MEND44_graph.pt
  [8] MEND46_graph.pt
  [9] MEND50_graph.pt
  [10] MEND53_graph.pt
  [11] MEND55_graph.pt
  [12] MEND63_graph.pt
  [13] MEND64_graph.pt
  [14] MEND65_graph.pt
  [15] MEND66_graph.pt
  [16] MEND67_graph.pt
  [17] MEND68_graph.pt
  [18] MEND71_graph.pt
  [19] MEND72_graph.pt
  [20] MEND73_graph.pt
  [21] MEND74_graph.pt
  [22] MEND75_graph.pt
  [23] MEND76_graph.pt
  [24] MEND77_graph.pt
  [25] MEND78_graph.pt
  [26] MISC1

  graph_data = torch.load(selected_graph_file_to_check, map_location='cpu')


✅ 成功加载图数据。

➡️ 属性 'latent_paths' 存在。
  包含 1490 个路径条目。图中节点数 (num_nodes): 1490.
  路径 1/1490: "/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123_tiles/MEND123_20022_24778.pt" ✅ 找到 (绝对路径: '/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123_tiles/MEND123_20022_24778.pt')
  路径 2/1490: "/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123_tiles/MEND123_23562_30299.pt" ✅ 找到 (绝对路径: '/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123_tiles/MEND123_23562_30299.pt')
  路径 3/1490: "/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123_tiles/MEND123_19517_18095.pt" ✅ 找到 (绝对路径: '/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123_tiles/MEND123_19517_18095.pt')
  路径 4/1490: "/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123_tiles/MEND123_19517_20420.pt" ✅ 找到 (绝对路径: '/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND123_tiles/MEND123_19517_20420.pt')
  路径 5/1490: "/cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MEND12

# Cell 2: Configuration

In [7]:
class TrainingConfig:
    # --- Hardware & Precision ---
    GPU_IDS = [1] # 示例：根据你的GPU调整
    PRIMARY_GPU_ID = GPU_IDS[0] if GPU_IDS else 0
    NUM_GPUS = len(GPU_IDS) if torch.cuda.is_available() and GPU_IDS else 0
    PRIMARY_DEVICE_NAME = f"cuda:{PRIMARY_GPU_ID}" if NUM_GPUS > 0 else "cpu"
    MIXED_PRECISION_TYPE = "bf16" # "bf16", "fp16", or "no" for fp32 (影响 autocast 的 dtype)

    # --- Data ---
    ROOT_DIR = "/cwStorage/nodecw_group/jijh/hest_output" # 你的图像根目录
    IMAGE_EXTS = ['.png', '.jpg', '.jpeg']
    PRECOMPUTE_MAX_SAMPLES = None # None 表示使用所有找到的图像

    # --- GigaPath Specific ---
    GIGAPATH_MODEL_NAME = 'hf_hub:prov-gigapath/prov-gigapath'
    GIGAPATH_CACHE_DIR = '/home1/jijh/diffusion_project/huggingface_repo/prov-gigapath' # GigaPath模型缓存目录
    GIGAPATH_INPUT_SIZE_RESIZE = 256
    GIGAPATH_INPUT_SIZE_CROP = 224
    GIGAPATH_NORMALIZATION_MEAN = (0.485, 0.456, 0.406)
    GIGAPATH_NORMALIZATION_STD = (0.229, 0.224, 0.225)

    # --- Feature Pre-computation (GigaPath) ---
    FEATURE_DIR = "/cwStorage/nodecw_group/jijh/hest_output_gigapath_features_bf16_notebook" # 特征保存目录 (修改了名称以区分)
    PRECOMPUTE_BATCH_SIZE = 128 # 预计算时的批处理大小
    FEATURE_SAVE_PRECISION = torch.bfloat16 # 目标保存精度

    NUM_WORKERS = 16 # DataLoader 的 num_workers
    # ... 其他你可能需要的配置 ...


# CELL 3: GigaPath 模型加载函数 和 Dataset 定义

In [5]:
def get_gigapath_model(config: TrainingConfig, device):
    """加载预训练的 GigaPath 模型"""
    print(f"正在从 {config.GIGAPATH_MODEL_NAME} 加载 GigaPath 模型...")
    model = timm.create_model(
        model_name=config.GIGAPATH_MODEL_NAME,
        pretrained=True,
        cache_dir=config.GIGAPATH_CACHE_DIR
    )
    model = model.to(device)
    model.eval() # 设置为评估模式
    print("GigaPath 模型加载完毕并已移至设备。")
    return model

class GigapathImageDataset(Dataset):
    """
    从子目录加载图像，并应用 GigaPath 特定的转换。
    返回图像张量和原始图像路径。
    """
    def __init__(self, root_dir, config: TrainingConfig, max_samples=None):
        self.root_dir = root_dir
        self.image_paths = []
        if not os.path.isdir(root_dir):
             raise FileNotFoundError(f"根目录未找到: {root_dir}")

        subfolders = [os.path.join(root_dir, d) for d in os.listdir(root_dir)
                      if os.path.isdir(os.path.join(root_dir, d))]
        if not subfolders:
            print(f"警告: 在 {root_dir} 中未找到子文件夹。将直接在 root_dir 中搜索。")
            subfolders = [root_dir]

        print(f"正在以下路径搜索图像: {subfolders}")
        all_paths = []
        # 在 Notebook 中，如果图像数量巨大，这里的 tqdm 可能会比较慢，是正常的
        for folder in notebook_tqdm(subfolders, desc="扫描图像文件夹"):
            for ext in config.IMAGE_EXTS:
                all_paths.extend(glob.glob(os.path.join(folder, f'*{ext}')))

        if not all_paths:
             raise FileNotFoundError(f"在 {root_dir} 及其子目录中未找到扩展名为 {config.IMAGE_EXTS} 的图像。")

        self.image_paths = all_paths
        print(f"找到 {len(self.image_paths)} 张图像。")

        if max_samples is not None and len(self.image_paths) > max_samples:
            print(f"从找到的 {len(self.image_paths)} 张图像中采样 {max_samples} 张。")
            self.image_paths = random.sample(self.image_paths, max_samples)
        print(f"GigapathImageDataset 初始化完毕，包含 {len(self.image_paths)} 张图像。")

        self.transform = transforms.Compose([
            transforms.Resize(config.GIGAPATH_INPUT_SIZE_RESIZE, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(config.GIGAPATH_INPUT_SIZE_CROP),
            transforms.ToTensor(),
            transforms.Normalize(mean=config.GIGAPATH_NORMALIZATION_MEAN, std=config.GIGAPATH_NORMALIZATION_STD)
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            image = self.transform(image)
            return image, img_path
        except Exception as e:
            print(f"加载图像 {img_path} 时出错: {e}")
            return None, None


# CELL 4: GigaPath 特征预计算和保存函数定义

In [6]:
def precompute_and_save_gigapath_features(gigapath_model, config: TrainingConfig, device):
    """
    使用 GigaPath 模型编码图像，并将特征保存到磁盘。
    """
    gigapath_model.eval()

    save_dtype = config.FEATURE_SAVE_PRECISION
    if save_dtype == torch.bfloat16 and not (device.type == 'cuda' and torch.cuda.is_bf16_supported()):
        print(f"警告: 配置要求 FEATURE_SAVE_PRECISION=bf16, 但在 {device} 上不受支持。回退到 float16。")
        save_dtype = torch.float16
    elif save_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
         print(f"警告: 不支持的 FEATURE_SAVE_PRECISION {save_dtype}。回退到 float16。")
         save_dtype = torch.float16
    if device.type == 'cpu' and save_dtype != torch.float32:
        print(f"警告: CPU模式下，特征将以 float32 保存，而非配置的 {save_dtype}。")
        save_dtype = torch.float32

    print(f"--- 开始 GigaPath 特征预计算 (保存为 {save_dtype}，镜像结构) ---")
    
    precompute_dataset = GigapathImageDataset(
        root_dir=config.ROOT_DIR,
        config=config,
        max_samples=config.PRECOMPUTE_MAX_SAMPLES
    )

    def precompute_collate_fn(batch):
        filtered_batch = [item for item in batch if item[0] is not None and item[1] is not None]
        if not filtered_batch: return None, None
        images = torch.stack([item[0] for item in filtered_batch])
        paths = [item[1] for item in filtered_batch]
        return images, paths

    precompute_dataloader = DataLoader(
        precompute_dataset, batch_size=config.PRECOMPUTE_BATCH_SIZE, shuffle=False,
        num_workers=config.NUM_WORKERS, pin_memory=(device.type == 'cuda'), collate_fn=precompute_collate_fn
    )

    feature_dir = config.FEATURE_DIR
    os.makedirs(feature_dir, exist_ok=True)
    print(f"特征将镜像结构保存在: {feature_dir}")

    count = 0
    amp_enabled = (device.type == 'cuda' and config.MIXED_PRECISION_TYPE != "no")
    if config.MIXED_PRECISION_TYPE == "bf16" and torch.cuda.is_bf16_supported():
        amp_dtype = torch.bfloat16
    elif config.MIXED_PRECISION_TYPE == "fp16":
        amp_dtype = torch.float16
    else:
        amp_dtype = torch.float32
        if amp_enabled and amp_dtype == torch.float32:
            amp_enabled = False

    start_time = time.time()
    progress_bar = notebook_tqdm(precompute_dataloader, desc="预计算 GigaPath 特征")

    with torch.no_grad():
        for images, img_paths in progress_bar:
            if images is None or not img_paths:
                continue
            
            images = images.to(device)

            with autocast(enabled=amp_enabled, dtype=amp_dtype):
                features = gigapath_model(images)
            
            features_cpu = features.detach().to('cpu', dtype=save_dtype)

            for i, img_path in enumerate(img_paths):
                try:
                    relative_img_path = os.path.relpath(img_path, config.ROOT_DIR)
                    relative_dir, img_filename = os.path.split(relative_img_path)
                    img_basename = os.path.splitext(img_filename)[0]
                    target_feature_subdir = os.path.join(config.FEATURE_DIR, relative_dir)
                    feature_filename = img_basename + ".pt"
                    feature_save_path = os.path.join(target_feature_subdir, feature_filename)
                    os.makedirs(target_feature_subdir, exist_ok=True)
                    
                    tensor_to_save = features_cpu[i].clone()
                    torch.save(tensor_to_save, feature_save_path)
                    count += 1
                except Exception as e:
                    print(f"\n处理/保存图像 {img_path} 的特征时出错: {e}")
            progress_bar.set_postfix({"已保存": count})

    total_time = time.time() - start_time
    print(f"--- GigaPath 特征预计算完成 ---")
    print(f"已保存 {count} 个特征 (格式: {save_dtype})，镜像结构于 {feature_dir}")
    print(f"总耗时: {str(datetime.timedelta(seconds=int(total_time)))}")


# %% CELL 5: 配置实例化、设备设置 和 GigaPath 模型加载

In [7]:
config = TrainingConfig()
print("TrainingConfig 已加载。")

# --- 设备设置 ---
if config.NUM_GPUS > 0 and torch.cuda.is_available():
    precompute_device = torch.device(config.PRIMARY_DEVICE_NAME)
    try:
        torch.cuda.set_device(precompute_device) # 尝试设置当前CUDA设备
        print(f"将使用主 GPU 进行预计算: {torch.cuda.get_device_name(precompute_device)}")
        if config.FEATURE_SAVE_PRECISION == torch.bfloat16:
            if torch.cuda.is_bf16_supported():
                print("此 GPU 支持 bf16 格式。")
            else:
                print("警告: 此 GPU 不支持 bf16。FEATURE_SAVE_PRECISION 将在函数内部回退。")
    except Exception as e:
        print(f"设置 GPU {config.PRIMARY_DEVICE_NAME} 失败: {e}. 回退到 CPU。")
        precompute_device = torch.device("cpu")
        config.NUM_GPUS = 0 # 更新配置以反映实际情况
        print("将使用 CPU 进行预计算。")
        if config.FEATURE_SAVE_PRECISION != torch.float32:
            print(f"警告: CPU 模式下，特征将以 float32 保存，而非配置的 {config.FEATURE_SAVE_PRECISION}。")
            config.FEATURE_SAVE_PRECISION = torch.float32
else:
    precompute_device = torch.device("cpu")
    print("将使用 CPU 进行预计算。")
    if config.FEATURE_SAVE_PRECISION != torch.float32:
        print(f"警告: CPU 模式下，特征将以 float32 保存，而非配置的 {config.FEATURE_SAVE_PRECISION}。")
        config.FEATURE_SAVE_PRECISION = torch.float32

print(f"预计算设备设置为: {precompute_device}")

# --- 加载 GigaPath 模型 ---
# 将模型加载放到一个 try-except 块中，以便更好地处理可能的错误
try:
    gigapath_model = get_gigapath_model(config, precompute_device)
except Exception as e:
    print(f"加载 GigaPath 模型时发生错误: {e}")
    gigapath_model = None # 明确设为 None，以便后续检查

TrainingConfig 已加载。
将使用主 GPU 进行预计算: NVIDIA H100 PCIe
此 GPU 支持 bf16 格式。
预计算设备设置为: cuda:0
正在从 hf_hub:prov-gigapath/prov-gigapath 加载 GigaPath 模型...
GigaPath 模型加载完毕并已移至设备。


# Cell 6: Main Execution

In [8]:
if gigapath_model: # 仅当模型成功加载时才执行
    # 可选：用户确认，防止意外覆盖
    # user_input = input(f"预计算会将特征保存到 {config.FEATURE_DIR}，可能覆盖现有文件。是否继续? (y/n): ").lower()
    # if user_input == 'y':
    #     precompute_and_save_gigapath_features(gigapath_model, config, precompute_device)
    # else:
    #     print("用户请求跳过特征预计算。")
    precompute_and_save_gigapath_features(gigapath_model, config, precompute_device) # 直接执行，如果需要确认，取消上面注释
else:
    print("GigaPath 模型未成功加载，跳过预计算步骤。")

--- 开始 GigaPath 特征预计算 (保存为 torch.bfloat16，镜像结构) ---
正在以下路径搜索图像: ['/cwStorage/nodecw_group/jijh/hest_output/TENX105_tiles', '/cwStorage/nodecw_group/jijh/hest_output/NCBI861_tiles', '/cwStorage/nodecw_group/jijh/hest_output/NCBI146_tiles', '/cwStorage/nodecw_group/jijh/hest_output/NCBI192_expr', '/cwStorage/nodecw_group/jijh/hest_output/NCBI563_tiles', '/cwStorage/nodecw_group/jijh/hest_output/SPA45_tiles', '/cwStorage/nodecw_group/jijh/hest_output/NCBI462_expr', '/cwStorage/nodecw_group/jijh/hest_output/MISC107_tiles', '/cwStorage/nodecw_group/jijh/hest_output/MISC58_tiles', '/cwStorage/nodecw_group/jijh/hest_output/NCBI434_tiles', '/cwStorage/nodecw_group/jijh/hest_output/TENX24_expr', '/cwStorage/nodecw_group/jijh/hest_output/MEND35_expr', '/cwStorage/nodecw_group/jijh/hest_output/NCBI296_expr', '/cwStorage/nodecw_group/jijh/hest_output/NCBI492_expr', '/cwStorage/nodecw_group/jijh/hest_output/SPA3_expr', '/cwStorage/nodecw_group/jijh/hest_output/TENX136_expr', '/cwStorage/nodecw_grou

扫描图像文件夹:   0%|          | 0/2458 [00:00<?, ?it/s]

找到 2104169 张图像。
GigapathImageDataset 初始化完毕，包含 2104169 张图像。
特征将镜像结构保存在: /cwStorage/nodecw_group/jijh/hest_output_gigapath_features_bf16_notebook


预计算 GigaPath 特征:   0%|          | 0/16439 [00:00<?, ?it/s]

  with autocast(enabled=amp_enabled, dtype=amp_dtype):


In [10]:
# Read the graph data from /cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6/MEND123_graph.pt
# 这里的路径和文件名需要根据实际情况调整
graph_file_path = "/cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6/MEND123_graph.pt"
if os.path.exists(graph_file_path):
    print(f"正在读取图数据文件: {graph_file_path}")
    graph_data = torch.load(graph_file_path, map_location='cpu')
    if hasattr(graph_data, 'latent_paths'):
        print(f"图数据中包含 'latent_paths' 属性，长度为 {len(graph_data.latent_paths)}。")
    else:
        print("图数据中未找到 'latent_paths' 属性。")
else:
    print(f"❌ 错误: 图数据文件 '{graph_file_path}' 不存在。请检查路径和文件名。")

正在读取图数据文件: /cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6/MEND123_graph.pt
图数据中包含 'latent_paths' 属性，长度为 1490。


  graph_data = torch.load(graph_file_path, map_location='cpu')


In [11]:
graph_data

Data(x=[1490, 50], edge_index=[2, 7920], coords=[1490, 2], spot_ids=[1490], latent_paths=[1490], sample_id='MEND123')

# Run the train scripts

In [2]:
# Cell 1: Imports for CLIP Training
import os
import sys
import random
import glob
import datetime
import time
import subprocess
import shlex
import warnings
from pathlib import Path
# os.chdir("/home1/jijh/diffusion_project/ADiffusion") # If running script directly
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
# from torch_geometric.loader import DataLoader as PyGDataLoader # Defined in script
from torch_geometric.nn import GPSConv, GATConv # For GraphConditioner

# Filter warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

# %matplotlib inline # If you plan to add inline plotting for something

## Training configuration

In [3]:
# Cell 2: Configuration Class for CLIP Training
class TrainingConfigCLIP:
    # --- Hardware & Precision ---
    GPU_IDS = [0, 1] # <-- SPECIFY GPU(s) FOR CLIP TRAINING
    PRIMARY_GPU_ID = GPU_IDS[0] if GPU_IDS else 0
    NUM_GPUS = len(GPU_IDS) if torch.cuda.is_available() and GPU_IDS else 0
    PRIMARY_DEVICE_NAME = f"cuda:{PRIMARY_GPU_ID}" if NUM_GPUS > 0 else "cpu"
    MIXED_PRECISION_TYPE = "bf16" # "bf16", "fp16", or "no"
    DDP_MASTER_PORT = 29501 # Choose a different unused port for CLIP DDP

    # --- Data Paths ---
    # Input GigaPath Features (from your GigaPath precomputation)
    GIGAPATH_FEATURE_DIR = "/cwStorage/nodecw_group/jijh/hest_output_gigapath_features_bf16_notebook"
    # Input Processed Graphs (from your graph preprocessing notebook)
    GRAPH_DATA_DIR = "/cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6"
    # Base directory of original VAE latents (used to reconstruct GigaPath feature paths if necessary)
    # This is needed if graph_data.latent_paths store absolute paths to VAE latents
    ORIGINAL_LATENT_DIR_BASE_FOR_PATH_RECONSTRUCTION = "/cwStorage/nodecw_group/jijh/hest_output_latents_bf16"


    # --- GigaPath Feature Info ---
    GIGAPATH_FEATURE_DIM = 1536 # Or 768, depending on the GigaPath model used (e.g., prov-gigapath uses 1024 for CLS token)
                                # Check the output dimension of your GigaPath features

    # --- Graph Conditioner (Text/Context Encoder) ---
    # These parameters define the GraphConditioner architecture
    PCA_N_COMPS = 50 # Should match the PCA used for graph node features
    CONDITIONER_INPUT_DIM = PCA_N_COMPS
    CONDITIONER_HIDDEN_DIM = 256
    CONDITIONER_N_LAYERS = 4
    CONDITIONER_N_HEADS = 4
    CONDITIONER_ATTN_DROPOUT = 0.1
    # CONDITIONER_OUTPUT_DIM will be CLIP_EMBED_DIM (see below)

    # --- CLIP Model Specific ---
    CLIP_EMBED_DIM = 512 # Dimension of the common embedding space (e.g., 512, 768)
    IMAGE_ENCODER_MLP_LAYERS = [GIGAPATH_FEATURE_DIM, 1024, CLIP_EMBED_DIM] # Example: GigaPath_dim -> 1024 -> CLIP_dim
    # Logit scale for contrastive loss (learnable)
    CLIP_LOGIT_SCALE_INIT = np.log(1 / 0.07) # Standard CLIP initialization

    # --- Training Script Config ---
    CLIP_BATCH_SIZE_PER_GPU = 64 # Adjust based on GPU memory (more than diffusion usually)
    CLIP_NUM_EPOCHS = 50
    CLIP_LEARNING_RATE = 1e-4 # CLIP often uses smaller LRs
    CLIP_WEIGHT_DECAY = 0.01
    ACCUMULATION_STEPS = 2
    NUM_WORKERS = 16

    # --- 【新增】调试和复现性设置 ---
    # 通过设置此值 > 0 来激活调试模式，仅在指定数量的图上进行训练。
    # 设置为 0 或负数则使用全部图进行训练。
    DEBUG_NUM_GRAPHS = 10  # <-- 在这里设置你想用于调试的图的数量，例如 10
    
    # 用于确保随机选择的图和训练中的其他随机过程是可复现的。
    SEED = 42 # <-- 在这里设置你的随机种子
    # --- 结束新增 ---

    # --- Logging & Saving (Script) ---
    CHECKPOINT_DIR = "/cwStorage/nodecw_group/jijh/model_path/clip_graph_gigapath_v1"
    LOG_DIR = "/cwStorage/nodecw_group/jijh/training_log/clip_graph_gigapath_v1"
    CHECKPOINT_FILENAME_PREFIX = "clip_graph_gigapath"
    TRAIN_SCRIPT_PATH = "/home1/jijh/diffusion_project/ADiffusion/src/pipeline/train_clip_graph_gigapath_ddp.py" # Path to NEW DDP script
    SAVE_INTERVAL_EPOCHS = 5
    LOG_INTERVAL_STEPS = 50 # Log basic stats every N optimizer steps

    @classmethod
    def get_script_args(cls):
        args = [
            f"--gigapath_feature_dir={cls.GIGAPATH_FEATURE_DIR}",
            f"--graph_data_dir={cls.GRAPH_DATA_DIR}",
            f"--original_latent_dir_base={cls.ORIGINAL_LATENT_DIR_BASE_FOR_PATH_RECONSTRUCTION}",
            f"--checkpoint_dir={cls.CHECKPOINT_DIR}",
            f"--log_dir={cls.LOG_DIR}",

            f"--epochs={cls.CLIP_NUM_EPOCHS}",
            f"--batch_size_per_gpu={cls.CLIP_BATCH_SIZE_PER_GPU}",
            f"--lr={cls.CLIP_LEARNING_RATE}",
            f"--weight_decay={cls.CLIP_WEIGHT_DECAY}",
            f"--accumulation_steps={cls.ACCUMULATION_STEPS}",
            f"--mixed_precision={cls.MIXED_PRECISION_TYPE}",
            f"--num_workers={cls.NUM_WORKERS}",

            f"--gigapath_feature_dim={cls.GIGAPATH_FEATURE_DIM}",
            f"--pca_n_comps={cls.PCA_N_COMPS}",
            f"--conditioner_input_dim={cls.CONDITIONER_INPUT_DIM}",
            f"--conditioner_hidden_dim={cls.CONDITIONER_HIDDEN_DIM}",
            f"--conditioner_output_dim={cls.CLIP_EMBED_DIM}", # Conditioner outputs to CLIP embedding dim
            f"--conditioner_n_layers={cls.CONDITIONER_N_LAYERS}",
            f"--conditioner_n_heads={cls.CONDITIONER_N_HEADS}",
            f"--conditioner_attn_dropout={cls.CONDITIONER_ATTN_DROPOUT}",

            f"--clip_embed_dim={cls.CLIP_EMBED_DIM}",
            f"--image_encoder_mlp_layers={','.join(map(str, cls.IMAGE_ENCODER_MLP_LAYERS))}",
            f"--clip_logit_scale_init={cls.CLIP_LOGIT_SCALE_INIT}",

            f"--save_interval={cls.SAVE_INTERVAL_EPOCHS}",
            f"--log_interval={cls.LOG_INTERVAL_STEPS}",
            f"--checkpoint_filename_prefix={cls.CHECKPOINT_FILENAME_PREFIX}",
        ]

        # --- 【修改】添加新的调试参数到命令行 ---
        # 总是添加seed以保证复现性
        args.append(f"--seed={cls.SEED}")
        
        # 只有在DEBUG_NUM_GRAPHS > 0时才添加这个参数，这样更干净
        if cls.DEBUG_NUM_GRAPHS > 0:
            args.append(f"--debug_num_graphs={cls.DEBUG_NUM_GRAPHS}")
        # --- 结束修改 ---

        return args

# --- Instantiate config ---
config_clip = TrainingConfigCLIP()
os.makedirs(config_clip.GRAPH_DATA_DIR, exist_ok=True) # Should already exist
os.makedirs(config_clip.GIGAPATH_FEATURE_DIR, exist_ok=True) # Should already exist
os.makedirs(config_clip.CHECKPOINT_DIR, exist_ok=True)
os.makedirs(config_clip.LOG_DIR, exist_ok=True)

# print("CLIP TrainingConfig loaded.")
# if not os.path.exists(config_clip.GIGAPATH_FEATURE_DIR) or not glob.glob(os.path.join(config_clip.GIGAPATH_FEATURE_DIR, "**/*.pt"), recursive=True):
#     print(f"WARNING: GigaPath feature directory {config_clip.GIGAPATH_FEATURE_DIR} is empty or does not exist!")
# if not os.path.exists(config_clip.GRAPH_DATA_DIR) or not glob.glob(os.path.join(config_clip.GRAPH_DATA_DIR, "*.pt")):
#     print(f"WARNING: Graph data directory {config_clip.GRAPH_DATA_DIR} is empty or does not exist!")

## Clip model definition

In [4]:
# Cell 3: CLIP Model Definitions (Conceptual - for the script)

# --- Image Encoder (from GigaPath features) ---
class GigapathFeatureEncoder(nn.Module):
    def __init__(self, mlp_layers): # e.g., [gigapath_dim, 1024, clip_embed_dim]
        super().__init__()
        layers = []
        for i in range(len(mlp_layers) - 1):
            layers.append(nn.Linear(mlp_layers[i], mlp_layers[i+1]))
            if i < len(mlp_layers) - 2: # No ReLU after the last linear layer
                layers.append(nn.ReLU())
        self.projection = nn.Sequential(*layers)

    def forward(self, x):
        return self.projection(x)

# --- Graph/Text/Context Encoder (Your GraphConditioner) ---
# Re-use the GraphConditioner class from your diffusion script.
# Ensure its output_dim is config_clip.CLIP_EMBED_DIM.
# (Definition would be in the training script, not repeated here for brevity)
# class GraphConditioner(nn.Module): ... (as defined before)


# --- Full CLIP Model ---
class GraphGigapathCLIP(nn.Module):
    def __init__(self, gigapath_encoder, graph_conditioner, logit_scale_init):
        super().__init__()
        self.gigapath_encoder = gigapath_encoder
        self.graph_conditioner = graph_conditioner
        self.logit_scale = nn.Parameter(torch.ones([]) * logit_scale_init)

    def forward(self, gigapath_features, graph_batch, graph_node_indices):
        # gigapath_features: (batch_size, gigapath_feature_dim)
        # graph_batch: PyG Batch object
        # graph_node_indices: (batch_size,) global indices of target nodes in graph_batch

        image_embeddings = self.gigapath_encoder(gigapath_features)

        all_graph_node_embeddings = self.graph_conditioner(graph_batch)
        text_embeddings = all_graph_node_embeddings[graph_node_indices]

        # Normalized features
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        return image_embeddings, text_embeddings, self.logit_scale.exp()

## Test data loading for CLIP

In [5]:
# Cell 4: (Optional) Test data loading for CLIP

try:
    print("\n--- Testing Graph and GigaPath Feature Loading ---")
    example_graph_files = glob.glob(os.path.join(config_clip.GRAPH_DATA_DIR, "*_graph.pt"))
    if not example_graph_files:
        print("No graph files found in", config_clip.GRAPH_DATA_DIR)
    else:
        example_graph_path = random.choice(example_graph_files)
        print(f"Loading example graph: {example_graph_path}")
        example_graph = torch.load(example_graph_path, map_location='cpu')

        # Find a valid node with a latent_path
        valid_node_idx = -1
        original_vae_latent_path_for_node = None
        for idx, pth in enumerate(example_graph.latent_paths):
            if pth is not None: # Assuming pth is the full path to VAE latent
                 valid_node_idx = idx
                 original_vae_latent_path_for_node = pth
                 break

        if valid_node_idx == -1 or original_vae_latent_path_for_node is None:
            print("No valid latent_paths found in the example graph.")
        else:
            print(f"Found valid node {valid_node_idx} in graph.")
            print(f"  Original VAE latent path: {original_vae_latent_path_for_node}")

            # Construct the GigaPath feature path
            # This logic needs to be robust and match how features were saved
            relative_path_to_latent = os.path.relpath(
                original_vae_latent_path_for_node,
                config_clip.ORIGINAL_LATENT_DIR_BASE_FOR_PATH_RECONSTRUCTION
            )
            gigapath_feature_path = os.path.join(
                config_clip.GIGAPATH_FEATURE_DIR,
                relative_path_to_latent # This assumes the subfolder structure is identical
            )
            print(f"  Constructed GigaPath feature path: {gigapath_feature_path}")

            if os.path.exists(gigapath_feature_path):
                print("  GigaPath feature file EXISTS.")
                gp_feature = torch.load(gigapath_feature_path, map_location='cpu')
                print(f"  Loaded GigaPath feature: shape={gp_feature.shape}, dtype={gp_feature.dtype}")
                assert gp_feature.shape[0] == config_clip.GIGAPATH_FEATURE_DIM, \
                    f"Feature dim mismatch! Expected {config_clip.GIGAPATH_FEATURE_DIM}, got {gp_feature.shape[0]}"
            else:
                print(f"  GigaPath feature file NOT FOUND at: {gigapath_feature_path}")
                print(f"  Please check: ")
                print(f"    1. GigaPath features were precomputed to {config_clip.GIGAPATH_FEATURE_DIR}")
                print(f"    2. The relative path structure within GIGAPATH_FEATURE_DIR mirrors that of ORIGINAL_LATENT_DIR_BASE")
                print(f"    3. ORIGINAL_LATENT_DIR_BASE_FOR_PATH_RECONSTRUCTION ('{config_clip.ORIGINAL_LATENT_DIR_BASE_FOR_PATH_RECONSTRUCTION}') is correct.")


            # Test model instantiation (conceptual)
            print("\nInstantiating conceptual CLIP model parts...")
            dummy_gp_encoder = GigapathFeatureEncoder(config_clip.IMAGE_ENCODER_MLP_LAYERS)
            dummy_graph_conditioner = nn.Linear(config_clip.CONDITIONER_INPUT_DIM, config_clip.CLIP_EMBED_DIM) # Simplified
            print("Conceptual models instantiated.")

except Exception as e:
    print(f"Error during testing: {e}")
    import traceback
    traceback.print_exc()


--- Testing Graph and GigaPath Feature Loading ---
Loading example graph: /cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6/MISC3_graph.pt
Found valid node 0 in graph.
  Original VAE latent path: /cwStorage/nodecw_group/jijh/hest_output_latents_bf16/MISC3_tiles/MISC3_8521_9716.pt
  Constructed GigaPath feature path: /cwStorage/nodecw_group/jijh/hest_output_gigapath_features_bf16_notebook/MISC3_tiles/MISC3_8521_9716.pt
  GigaPath feature file EXISTS.
  Loaded GigaPath feature: shape=torch.Size([1536]), dtype=torch.bfloat16

Instantiating conceptual CLIP model parts...
Conceptual models instantiated.


## Run the script

In [7]:
# Cell 5: Launch DDP Training Script for CLIP
print("\n--- Preparing to Launch DDP Training Script for CLIP ---")
config_clip = TrainingConfigCLIP() # Reload config

if not os.path.exists(config_clip.TRAIN_SCRIPT_PATH):
    print(f"Error: CLIP Training script '{config_clip.TRAIN_SCRIPT_PATH}' not found.")
    print("Please create this script (train_clip_graph_gigapath_ddp.py).")
else:
    if config_clip.NUM_GPUS == 0 or not config_clip.GPU_IDS:
        print("Error: No GPUs specified in config_clip.GPU_IDS or CUDA not available.")
    else:
        python_executable = sys.executable
        print(f"Using Python executable: {python_executable}")
        print(f"Attempting to use {config_clip.NUM_GPUS} GPUs with IDs: {config_clip.GPU_IDS}")

        modified_env = os.environ.copy()
        cuda_visible_devices = ",".join(map(str, config_clip.GPU_IDS))
        modified_env["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
        print(f"Setting CUDA_VISIBLE_DEVICES={cuda_visible_devices} for the subprocess.")

        cmd = [
            python_executable,
            "-m", "torch.distributed.run",
            f"--nproc_per_node={config_clip.NUM_GPUS}",
            f"--master_port={config_clip.DDP_MASTER_PORT}",
            config_clip.TRAIN_SCRIPT_PATH,
        ]
        cmd.extend(config_clip.get_script_args())

        print("\nLaunching command for CLIP training:")
        print(shlex.join(cmd))
        print("-" * 30 + "\nScript Output:\n" + "-" * 30)

        process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                                   text=True, bufsize=1, encoding='utf-8', errors='replace',
                                   env=modified_env)
        while True:
            output = process.stdout.readline()
            if output == '' and process.poll() is not None:
                break
            if output:
                try:
                    print(output.strip())
                except Exception as e_print:
                    print(f"[Encoding error in output: {e_print}]")
        rc = process.poll()
        print("-" * 30 + "\n--- CLIP Script Finished ---")
        if rc == 0:
            print("CLIP Training script finished successfully.")
        else:
            print(f"CLIP Training script exited with error code {rc}.")


--- Preparing to Launch DDP Training Script for CLIP ---
Using Python executable: /public/home/jijh/micromamba/envs/gpu_env/bin/python
Attempting to use 2 GPUs with IDs: [0, 1]
Setting CUDA_VISIBLE_DEVICES=0,1 for the subprocess.

Launching command for CLIP training:
/public/home/jijh/micromamba/envs/gpu_env/bin/python -m torch.distributed.run --nproc_per_node=2 --master_port=29501 /home1/jijh/diffusion_project/ADiffusion/src/pipeline/train_clip_graph_gigapath_ddp.py --gigapath_feature_dir=/cwStorage/nodecw_group/jijh/hest_output_gigapath_features_bf16_notebook --graph_data_dir=/cwStorage/nodecw_group/jijh/hest_graph_data_pca50_knn6 --original_latent_dir_base=/cwStorage/nodecw_group/jijh/hest_output_latents_bf16 --checkpoint_dir=/cwStorage/nodecw_group/jijh/model_path/clip_graph_gigapath_v1 --log_dir=/cwStorage/nodecw_group/jijh/training_log/clip_graph_gigapath_v1 --epochs=50 --batch_size_per_gpu=64 --lr=0.0001 --weight_decay=0.01 --accumulation_steps=2 --mixed_precision=bf16 --num_wo