In [None]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt

adata = sc.read('/home/zwye/simulator/merfish/hypothalamic_preoptic.h5ad')

我们进行三维空间转录组数据模拟所依赖的数据需要以下几个组成部分：1.x, y, z空间坐标 2.cell type 3.region

In [None]:
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier

def region_simulation(coords, regions, nx, ny, nz):

    scaler = StandardScaler()
    coords_scaled = scaler.fit_transform(coords)

    X_train, y_train = coords_scaled, regions
    X_test = coords_grid
    clf = MLPClassifier(
    hidden_layer_sizes=(128, 64, 32),
    activation='relu',
    solver='adam',
    max_iter=400,
    random_state=42
    )

    clf.fit(X_train, y_train)


    xmin, ymin, zmin = coords.min(axis=0)
    xmax, ymax, zmax = coords.max(axis=0)

    xs = np.linspace(xmin, xmax, nx)
    ys = np.linspace(ymin, ymax, ny)
    zs = np.linspace(zmin, zmax, nz)

    Xg, Yg, Zg = np.meshgrid(xs, ys, zs, indexing="ij")  # shape (nx, ny, nz)

    grid_points = np.c_[Xg.ravel(), Yg.ravel(), Zg.ravel()]

    grid_scaled = scaler.transform(grid_points)
    X_test = grid_scaled
    pred = clf.predict(X_test)
    region_volume = pred.reshape(nx, ny, nz)
    return region_volume

In [None]:
import numpy as np

def cell_type_simulation(
    region_volume,
    adata,
    region_key="STAIR_int",
    celltype_key="cell_type",
    all_cell_types=None,
    seed=42,
    dtype=np.int16
):
    """
    根据 region → cell type 的经验分布，
    将 region_volume 采样为 celltype_volume。

    Parameters
    ----------
    region_volume : ndarray (nx, ny, nz)
        每个 voxel 的 region label
    adata : AnnData
        原始细胞数据，用于统计 region 内的 cell type 分布
    region_key : str
        adata.obs 中表示 region 的列名
    celltype_key : str
        adata.obs 中表示 cell type 的列名
    all_cell_types : list[str]
        所有可能的 cell type（用于对齐概率）
    seed : int
        随机种子（保证可复现）
    dtype : numpy dtype
        celltype_volume 的存储类型（默认 int16）

    Returns
    -------
    celltype_volume : ndarray
        与 region_volume 同形状，存 cell type index
    celltype_to_idx : dict
        cell type → int index
    idx_to_celltype : dict
        int index → cell type
    region_to_celltype_prob : dict
        region → cell type 概率分布
    """

    # ---- 准备 cell type 列表 ----
    if all_cell_types is None:
        all_cell_types = sorted(adata.obs[celltype_key].unique())

    celltype_to_idx = {ct: i for i, ct in enumerate(all_cell_types)}
    idx_to_celltype = {i: ct for ct, i in celltype_to_idx.items()}

    # ---- 统计 region → cell type 概率 ----
    region_to_celltype_prob = {}

    for region in adata.obs[region_key].unique():
        cells_in_region = adata.obs[adata.obs[region_key] == region]

        counts = (
            cells_in_region[celltype_key]
            .value_counts()
            .reindex(all_cell_types, fill_value=0)
        )

        if counts.sum() == 0:
            continue

        probs = counts.values / counts.sum()
        region_to_celltype_prob[region] = probs

    # ---- 初始化 volume ----
    celltype_volume = np.full(
        region_volume.shape,
        fill_value=-1,
        dtype=dtype
    )

    rng = np.random.default_rng(seed)

    # ---- 按 region 逐块采样 ----
    for region, probs in region_to_celltype_prob.items():
        mask = region_volume == region
        n_voxels = mask.sum()

        if n_voxels == 0:
            continue

        sampled = rng.choice(
            len(all_cell_types),
            size=n_voxels,
            p=probs
        )

        celltype_volume[mask] = sampled

    return (
        celltype_volume,
        celltype_to_idx,
        idx_to_celltype,
        region_to_celltype_prob
    )


In [None]:
def gene_expression_simulation(
    region_volume,
    celltype_volume,
    adata,
    idx_to_celltype,
    region_key="STAIR_int",
    celltype_key="cell_type",
    seed=42,
    return_flat=False
):
    """
    根据 (region, cell type) 条件，
    从真实细胞中直接抽样基因表达，作为模拟表达。

    Parameters
    ----------
    region_volume : ndarray (nx, ny, nz)
        voxel 的 region label
    celltype_volume : ndarray (nx, ny, nz)
        voxel 的 cell type index（int）
    adata : AnnData
        原始真实数据
    idx_to_celltype : dict
        cell type index → cell type name
    region_key : str
        adata.obs 中 region 的列名
    celltype_key : str
        adata.obs 中 cell type 的列名
    seed : int
        随机种子
    return_flat : bool
        是否返回 (n_voxels, n_genes)

    Returns
    -------
    expr_volume : ndarray
        (nx, ny, nz, n_genes) 或 (n_voxels, n_genes)
    """

    rng = np.random.default_rng(seed)

    nx, ny, nz = region_volume.shape
    n_voxels = nx * ny * nz
    n_genes = adata.X.shape[1]

    # ---- 扁平化 voxel ----
    regions_flat = region_volume.ravel()
    celltypes_flat = celltype_volume.ravel()

    # ---- 输出矩阵 ----
    expr_flat = np.zeros((n_voxels, n_genes), dtype=adata.X.dtype)

    # ---- 预先建立 (region, celltype) → cell index 映射 ----
    group_to_cells = {}

    for ct_idx, ct_name in idx_to_celltype.items():
        for region in np.unique(regions_flat):
            mask = (
                (adata.obs[region_key] == region) &
                (adata.obs[celltype_key] == ct_name)
            )
            cell_indices = np.where(mask)[0]
            if len(cell_indices) > 0:
                group_to_cells[(region, ct_idx)] = cell_indices

    # ---- 逐 voxel 抽样 ----
    for i in range(n_voxels):
        region = regions_flat[i]
        ct_idx = celltypes_flat[i]

        key = (region, ct_idx)

        if key not in group_to_cells:
            # 极端情况：真实数据里没有这个组合
            expr_flat[i, :] = 0
            continue

        sampled_cell = rng.choice(group_to_cells[key])
        expr_flat[i, :] = adata.X[sampled_cell].toarray().ravel() \
            if hasattr(adata.X, "toarray") else adata.X[sampled_cell]

    if return_flat:
        return expr_flat
    else:
        return expr_flat.reshape(nx, ny, nz, n_genes)


In [None]:
def visualization