# Tutorial 3: Generate context-aware cell embedding

In [None]:
import torch.multiprocessing as mp
import pickle
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict, OrderedDict
from tqdm import tqdm
from argparse import Namespace
import gc

from omics.constants import *
import os
import torch.utils.data as data
from copy import deepcopy
import random
from sklearn.preprocessing import LabelEncoder
from omics.constants import *
import h5py
from pytorch_lightning.callbacks import (EarlyStopping, LearningRateMonitor,
                                         ModelCheckpoint, Callback)
from math import exp
import random
import h5py
import hashlib
import anndata as ad
from omics.Pretrain.omics.pretrain_fold2 import Omics
# from omics.Finetune.finetune_fold import Omics
from pytorch_lightning import LightningModule, Trainer, seed_everything
from argparse import ArgumentParser
from collections import defaultdict
from scipy.spatial import cKDTree
from step0_spot_input_gen import load_spot_data_with_full_info
from step2_uncond_cell_emb_gen import run_multi_gpu_processing_with_filtered_data

## Initialize args and fix seed

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # 根据 num_gpus=2 设置

# 直接按照命令行脚本写死推理配置，避免 Notebook 再次解析 CLI
args = Namespace(
    ckpt_path="/media/dang/Omics/data/ckpts/Omics/merfish_1_merfish_2025_09_10_11_58_10/epoch=2-step=5055.ckpt",
    dataset_name="merfish",
    target_dataset="merfish",  # 根据你的实际需求设置
    gene_pct=100,
    use_multi_gpu=True,
    num_gpus=2,
    radius=20,  # 如果需要的话取消注释
    max_points=20,  # 如果需要的话取消注释
    linear_hidden_dim='256',  # 根据你的模型配置设置
    config="/media/dang/Omics/omics/configs/bert_config_5-12.json", 
    seed=42,
    f=None,
    split_slice=None,
    target_width=6724,
    target_height=5885,
    x_min=None,
    x_max=None,
    y_min=None,
    y_max=None,
    sample_ratio=1.0,
    merge_threshold=0.5,
    percentile=70
)

print('ckpt_path:', args.ckpt_path)
print(f"Dataset: {args.dataset_name}, gene_pct: {args.gene_pct}")
print(f"Use multi-GPU: {args.use_multi_gpu}, num_gpus: {args.num_gpus}")

seed_everything(args.seed)

## Generate unconditional cell embeddings

In [None]:
global_h5_path = f'/media/dang/Omics/omics/baseline/cached_spot_data/spot_input_spot_all_{args.dataset_name}_42_{args.gene_pct}_{args.radius}_{args.max_points}.h5'

# 检查是否使用多GPU
if args.use_multi_gpu:
    print("Using Multi-GPU processing with H5 split")
    cell_embeddings = run_multi_gpu_processing_with_filtered_data(args, global_h5_path, num_gpus=args.num_gpus)
else:
    print("Single-GPU processing not implemented for filtered data yet")
    # 如果需要单GPU版本，可以参考多GPU版本简化实现

## Generate conditional cell embeddings

### match spot embeddings into cell

In [None]:
from step3_cell_matching import match_cells_by_coordinates

In [None]:
# 设置命令行参数解析
args = Namespace(
    dataset_id="9494",
    output_dir="AD_cell_embeddings",
    data_name=None
)

# 读取cell embeddings AnnData文件
adata_path = f'/media/dang/Omics/data/spot_level/AD/2022-09-16-Hu-AD-stardist-scaled.h5ad'

print("=== 读取Cell Embeddings AnnData ===")

# 加载文件
adata_cells = ad.read_h5ad(adata_path)

adata_9494 = adata_cells[adata_cells.obs['sample'] == f'ADmouse_{args.dataset_id}']

# 读取cell embeddings AnnData文件
adata_emb_path = f'/media/dang/Omics/omics/baseline/imputation/cell_embeddings_filtered/spot_input_spot_all_AD_2766g_m{args.dataset_id}_42_100_20_20{args.data_name}_cells.h5ad'

# 加载文件
adata_emb = ad.read_h5ad(adata_emb_path)

# 使用示例
# adata1是第一个adata (包含orig_index, x, y等)
# adata2是cell embedding的adata (包含center_x, center_y等)
matched_adata = match_cells_by_coordinates(adata_9494, adata_emb, tolerance=20)

# 在保存前确保所有字符串列的数据类型正确
matched_adata.obs['cell_orig_id'] = matched_adata.obs['cell_orig_id'].astype(str)
matched_adata.obs['cell_total_id'] = matched_adata.obs['cell_total_id'].astype(str)

# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)

# 保存匹配结果
matched_adata.write_h5ad(f'{args.output_dir}/AD_cell_matching_m{args.dataset_id}{args.data_name}.h5ad')

# 查看匹配结果
matched_count = (matched_adata.obs['cell_orig_id'] != 'unmatched').sum()
print(f"成功匹配的细胞数量: {matched_count}")
print(f"总细胞数量: {len(matched_adata.obs)}")
print(f"匹配率: {matched_count / len(matched_adata.obs):.2%}")
print("\n匹配结果示例:")
print(matched_adata.obs[['center_x', 'center_y', 'cell_orig_id', 'cell_total_id']].head())

# 查看匹配成功的细胞
matched_cells = matched_adata.obs[matched_adata.obs['cell_orig_id'] != 'unmatched']
if len(matched_cells) > 0:
    print("\n成功匹配的细胞示例:")
    print(matched_cells[['center_x', 'center_y', 'cell_orig_id', 'cell_total_id']].head())

### Train conditional aggregation and cell embedding

In [None]:
from step4_cond_cell_emb_gen import run_train_and_cell_emb_gen

run_train_and_cell_emb_gen()