# ST-MIL 验证 Notebook

本 Notebook 用于验证重写后的 sCellST 思路流程（细胞->spot 映射 + NB 损失），
并使用 LazySlide 生成细胞嵌入、CellFM 做基因表达预测。

说明：
- 请先修改下面的路径。
- LazySlide 细胞分割是可选项；你也可以直接提供细胞 CSV。
- 每一步都可以单独测试，建议从上往下依次运行。


In [None]:
from pathlib import Path
import sys

ROOT = Path('..').resolve()
sys.path.append(str(ROOT / 'src'))

# ====== 请修改以下路径 ======
wsi_path = Path('/path/to/slide.tif')
h5ad_path = Path('/path/to/data.h5ad')
cell_csv = Path('/path/to/cells.csv')
cell_patch_h5 = Path('data/cell_images/sample_cell_patches.h5')
cell_emb_h5 = Path('data/cell_embeddings/sample_cell_emb.h5')

# 如果 h5ad 里没有 spot 直径信息，在这里填（像素）
spot_radius_px = 0

# LazySlide 模型
lazyslide_model = 'resnet50'

# CellFM
gene_vocab_path = ROOT / 'assets' / 'cellfm' / 'gene_info.csv'
cellfm_checkpoint = '/path/to/CellFM_80M_weight.pt'
use_mock_cellfm = True  # 服务器上可改为 False


## 1) 可选：LazySlide 细胞分割导出 CSV
将 RUN_LAZYSLIDE_SEG 设为 True 才会执行。


In [None]:
from st_pipeline.data.lazyslide_cells_to_csv import segment_cells_to_csv

RUN_LAZYSLIDE_SEG = False
if RUN_LAZYSLIDE_SEG:
    segment_cells_to_csv(
        wsi_path=wsi_path,
        output_csv=cell_csv,
        model='instanseg',
        tile_px=512,
        stride_px=512,
        mpp=0.5,
        device='cuda',
        batch_size=4,
        num_workers=0,
    )


## 2) 导出细胞 patch（如果还没有）


In [None]:
from st_pipeline.data.cell_patch_export import load_cell_coords_csv, export_cell_patches

coords = load_cell_coords_csv(cell_csv)
export_cell_patches(
    wsi_path=wsi_path,
    cell_coords=coords,
    output_h5=cell_patch_h5,
    patch_size=72,
)


## 3) 使用 LazySlide 提取细胞嵌入


In [None]:
from st_pipeline.data.cell_embed_lazyslide import EmbedConfig, embed_cells

cfg = EmbedConfig(
    model_name=lazyslide_model,
    model_path=None,
    device='cuda',
    batch_size=64,
    num_workers=4,
)
embed_cells(cell_patch_h5=cell_patch_h5, output_h5=cell_emb_h5, config=cfg)


## 4) 构建 MIL 数据集并检查形状


In [None]:
from st_pipeline.data.h5ad_loader import load_h5ad
from st_pipeline.data.mil_dataset import MilSpotDataset
from st_pipeline.constants import KEYS

data = load_h5ad(
    h5ad_path=h5ad_path,
    genes='HVG:1000',
    spot_radius_px=spot_radius_px,
    gene_vocab_path=gene_vocab_path,
)
dataset = MilSpotDataset(
    adata=data.adata,
    embedding_h5=cell_emb_h5,
    spot_radius_px=data.spot_radius_px,
    gene_ids=data.gene_ids,
)
print('spot 数量:', len(dataset))
sample = dataset[0]
print('X 形状:', sample[KEYS.X].shape)
print('Y_bag 形状:', sample[KEYS.Y_BAG].shape)


## 5) 模型前向验证（不训练）


In [None]:
import torch
from torch.utils.data import DataLoader
from st_pipeline.data.collate import mil_collate
from st_pipeline.data.gene_vocab import load_gene_vocab
from st_pipeline.model.morpho_cellfm_mil import MorphoCellfmMIL

loader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=mil_collate)
batch = next(iter(loader))

vocab_size = len(load_gene_vocab(gene_vocab_path))

model = MorphoCellfmMIL(
    input_dim=batch[KEYS.X].shape[1],
    n_genes=len(data.genes),
    cellfm_dim=1536,
    cellfm_layers=2,
    cellfm_heads=48,
    cellfm_checkpoint=cellfm_checkpoint,
    freeze_cellfm=True,
    use_mock=use_mock_cellfm,
    use_retention=True,
    vocab_size=vocab_size,
    dropout=0.1,
    aggregation='mean',
    dispersion='gene',
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
for k in batch:
    batch[k] = batch[k].to(device)
mu_bag, mu_inst = model(batch)
print('mu_bag 形状:', mu_bag.shape)
print('mu_inst 形状:', mu_inst.shape)


## 6) （可选）完整训练 CLI
当你准备好数据后再执行。


In [None]:
# import subprocess
# subprocess.run([
#     'python', 'src/st_pipeline/train/train_cli.py',
#     '--config', 'configs/st_mil.yaml',
# ], check=True)
