# XeniumPR1_broad: Data Merge, Cross-Dataset Splits, and Embedding Reuse

This notebook builds a **combined benchmark dataset** named `XeniumPR1_broad` by merging two prepared sources:

- **XeniumPR1**
- **broad**

It consolidates their `adata/` and `patches/`, creates **two complementary train/test splits**, prepares **gene lists**, **and** reuses previously extracted **vision embeddings** (so we don’t recompute them).

---

## Part A — Merge data & create cross-dataset splits

**Inputs**
- `/project/simmons_hts/kxu/hest/eval/data/XeniumPR1/`
- `/project/simmons_hts/kxu/hest/eval/data/broad/`

**Outputs**
- `/project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad/`

**What happens**
1. **Merge `adata/` and `patches/`**  
   Copies (or symlinks) samples from both datasets. If a sample ID collides,
   the **broad** sample is renamed to `broad_<id>` to avoid overwriting.

2. **Create two splits (cross-dataset test)**  
   - `split 0`: `train_0.csv` = XeniumPR1 IDs, `test_0.csv` = broad IDs  
   - `split 1`: `train_1.csv` = broad IDs,    `test_1.csv` = XeniumPR1 IDs  

3. **Genes**
   - `common_genes.json` = set intersection of `XeniumPR1/all_genes.json` and `broad/all_genes.json`.
   - `var_50genes.json` = top-50 variable genes via `HEST.utils.get_k_genes` **on the intersection genes** (uses `min_cells_pct=0.10` internally).

---

## Output Folder

/project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad/
    
---

## Assumptions & prerequisites

- Both source datasets already follow the standard layout:
  - `adata/<sample>.h5ad`, `patches/<sample>.h5`, optional `patches/vis/<sample>.png`, and `all_genes.json`.
- `scanpy` is available; `HEST.utils.get_k_genes` is importable.
- The script logs any missing files but proceeds with those available.

---

## Quickstart

```python
create_XeniumPR1_broad(
    xenium_dir="/project/simmons_hts/kxu/hest/eval/data/XeniumPR1",
    broad_dir="/project/simmons_hts/kxu/hest/eval/data/broad",
    out_dir="/project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad",
    k=50,
    criteria="var",      # or "mean"
    symlink=False        # set True to symlink instead of copy
)

In [9]:
from __future__ import annotations
import os, json, shutil
from pathlib import Path
from typing import Dict, List, Tuple
import scanpy as sc

def _discover_dataset(ds_dir: Path) -> Dict[str, Dict[str, Path]]:
    """
    Discover samples in a dataset folder with structure:
      ds_dir/
        adata/<sid>.h5ad
        patches/<sid>.h5
        patches/vis/<sid>.png (optional)
        all_genes.json
    Returns: {sid: {"adata": Path, "patch": Path|None, "vis": Path|None}}
    """
    adata_dir = ds_dir / "adata"
    patch_dir = ds_dir / "patches"
    vis_dir   = patch_dir / "vis"
    out: Dict[str, Dict[str, Path]] = {}
    for p in sorted(adata_dir.glob("*.h5ad")):
        sid = p.stem
        patch = patch_dir / f"{sid}.h5"
        vis = vis_dir / f"{sid}.png"
        out[sid] = {
            "adata": p,
            "patch": patch if patch.exists() else None,
            "vis":   vis if vis.exists()   else None
        }
    return out

def _copy_or_link(src: Path|None, dst: Path, symlink: bool, missing: List[Tuple[str,str,str]]):
    if src is None or not src.exists():
        missing.append((dst.stem, dst.parent.name, str(src) if src else "<none>"))
        return
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        dst.unlink()
    if symlink:
        try:
            os.symlink(src, dst)
        except FileExistsError:
            pass
    else:
        shutil.copy(src, dst)

def _read_all_genes(path: Path) -> List[str]:
    with open(path, "r") as f:
        obj = json.load(f)
    return obj["genes"]

def create_XeniumPR1_broad(
    xenium_dir: str|Path,
    broad_dir: str|Path,
    out_dir: str|Path,
    k: int = 50,
    criteria: str = "var",
    symlink: bool = False,
):
    """
    Build XeniumPR1_broad by merging two prepared datasets:
      - Merge adata/ and patches/ (and patches/vis if present)
      - Create splits/train_0.csv (XeniumPR1 IDs), test_0.csv (broad IDs)
        and splits/train_1.csv (broad IDs), test_1.csv (XeniumPR1 IDs)
      - Save common_genes.json = intersection(all_genes.json of both)
      - Compute var_{k}genes.json using HEST get_k_genes on *intersection genes*

    Notes:
      - If a sample ID collides between datasets, broad IDs will be
        prefixed with 'broad_' to avoid overwriting, and splits will use the
        remapped names.
    """
    from hest.utils import get_k_genes  # uses min_cells_pct=0.10 by default

    xenium_dir = Path(xenium_dir)
    broad_dir  = Path(broad_dir)
    out_dir    = Path(out_dir)

    (out_dir / "adata").mkdir(parents=True, exist_ok=True)
    (out_dir / "patches" / "vis").mkdir(parents=True, exist_ok=True)
    (out_dir / "splits").mkdir(parents=True, exist_ok=True)

    # --- Discover samples
    xen = _discover_dataset(xenium_dir)
    brd = _discover_dataset(broad_dir)

    xen_ids = list(xen.keys())
    brd_ids = list(brd.keys())

    # --- Handle ID collisions by prefixing broad_*
    final_xen_ids = list(xen_ids)
    final_brd_ids = []
    id_map_broad = {}
    for sid in brd_ids:
        new_id = sid
        if sid in xen or (out_dir / "adata" / f"{sid}.h5ad").exists():
            new_id = f"broad_{sid}"
        id_map_broad[sid] = new_id
        final_brd_ids.append(new_id)

    # --- Copy/symlink files to merged folder
    missing: List[Tuple[str,str,str]] = []

    # XeniumPR1 files
    for sid in xen_ids:
        _copy_or_link(xen[sid]["adata"], out_dir / "adata" / f"{sid}.h5ad", symlink, missing)
        _copy_or_link(xen[sid]["patch"], out_dir / "patches" / f"{sid}.h5", symlink, missing)
        _copy_or_link(xen[sid]["vis"],   out_dir / "patches" / "vis" / f"{sid}.png", symlink, missing)

    # broad files (with potential rename)
    for sid in brd_ids:
        new_id = id_map_broad[sid]
        _copy_or_link(brd[sid]["adata"], out_dir / "adata" / f"{new_id}.h5ad", symlink, missing)
        _copy_or_link(brd[sid]["patch"], out_dir / "patches" / f"{new_id}.h5", symlink, missing)
        _copy_or_link(brd[sid]["vis"],   out_dir / "patches" / "vis" / f"{new_id}.png", symlink, missing)

    if missing:
        print("[WARN] Some files were missing at source (copied what was available):")
        for sid, kind, path in missing:
            print(f"  - {sid:>20} [{kind}]  {path}")

    # --- Splits (two-way)
    # split 0: train = XeniumPR1 IDs, test = broad IDs
    # split 1: train = broad IDs,    test = XeniumPR1 IDs
    import pandas as pd
    def _rows_for(ids):
        # relative paths from bench_data_root (benchmark.py joins with bench_data_root)
        return [
            {
                "sample_id": sid,
                "patches_path": f"patches/{sid}.h5",
                "expr_path":    f"adata/{sid}.h5ad",
            }
            for sid in ids
        ]

    train0 = pd.DataFrame(_rows_for(final_xen_ids))
    test0  = pd.DataFrame(_rows_for(final_brd_ids))
    train1 = pd.DataFrame(_rows_for(final_brd_ids))
    test1  = pd.DataFrame(_rows_for(final_xen_ids))

    splits_dir = out_dir / "splits"
    splits_dir.mkdir(parents=True, exist_ok=True)
    train0.to_csv(splits_dir / "train_0.csv", index=False)
    test0.to_csv(splits_dir / "test_0.csv", index=False)
    train1.to_csv(splits_dir / "train_1.csv", index=False)
    test1.to_csv(splits_dir / "test_1.csv", index=False)

    print(f"[INFO] Wrote splits (with columns sample_id, patches_path, expr_path) → {splits_dir}")

    # --- Genes
    # 1) common_genes.json = intersection(all_genes.json from both datasets)
    all_genes_x = _read_all_genes(xenium_dir / "all_genes.json")
    all_genes_b = _read_all_genes(broad_dir  / "all_genes.json")
    common = sorted(set(all_genes_x).intersection(all_genes_b))
    with open(out_dir / "all_genes.json", "w") as f:
        json.dump({"genes": common}, f)
    print(f"[INFO] Wrote common genes (|G|={len(common)}) → {out_dir/'all_genes.json'}")

    # 2) var_{k}genes.json via HEST get_k_genes, restricted to the *intersection*
    #    We subset each AnnData to 'common' so selection is performed on that set.
    adatas: List[sc.AnnData] = []
    for sid in final_xen_ids:
        ad = sc.read_h5ad(out_dir / "adata" / f"{sid}.h5ad")
        adatas.append(ad[:, [g for g in common if g in ad.var_names]].copy())
    for sid in final_brd_ids:
        ad = sc.read_h5ad(out_dir / "adata" / f"{sid}.h5ad")
        adatas.append(ad[:, [g for g in common if g in ad.var_names]].copy())

    var_path = out_dir / f"var_{k}genes.json"
    # get_k_genes will also apply its min_cells_pct filter (default 0.10)
    _ = get_k_genes(adatas, k=k, criteria=criteria, save_dir=str(var_path))
    print(f"[INFO] Wrote top-{k} {criteria} genes → {var_path}")

    print(f"✅ Merged dataset created at: {out_dir}")

In [12]:
# --- Example call ---
create_XeniumPR1_broad(
    xenium_dir='/project/simmons_hts/kxu/hest/eval/data/XeniumPR1_segger',
    broad_dir='/project/simmons_hts/kxu/hest/eval/data/broad',
    out_dir='/project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad',
    k=50,
    criteria='var',
    symlink=False,
)

[INFO] Wrote splits (with columns sample_id, patches_path, expr_path) → /project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad/splits
[INFO] Wrote common genes (|G|=312) → /project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad/all_genes.json


  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is

min_cells is  110.0
min_cells is  88.0
min_cells is  162.0
min_cells is  187.0
min_cells is  208.0
min_cells is  453.0
min_cells is  472.0
min_cells is  469.0
min_cells is  480.0
min_cells is  424.0
min_cells is  282.0
min_cells is  465.0
min_cells is  182.0
min_cells is  226.0
min_cells is  296.0
min_cells is  757.0
min_cells is  826.0
min_cells is  695.0
min_cells is  936.0
min_cells is  678.0
min_cells is  438.0
min_cells is  1002.0
[32m14:09:46[0m | [1mINFO[0m | [1mFound 159 common genes[0m


  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  utils.warn_names_duplicates("obs")
  if not is_categorical_dtype(df_full[k]):
  utils.warn_names_duplicates("obs")


[32m14:09:47[0m | [1mINFO[0m | [1mselected genes ['BANK1', 'CCL24', 'CCR1', 'CD19', 'CD22', 'CD247', 'CD38', 'CD3D', 'CD3G', 'CD5', 'CD6', 'CD69', 'CD79A', 'CD83', 'CDKN2B', 'CTLA4', 'CXCL1', 'CXCR4', 'DERL3', 'EPCAM', 'ETS1', 'FCER1G', 'G0S2', 'HNF4A', 'IKZF1', 'IKZF3', 'IL12RB1', 'IL18R1', 'IL2RG', 'IL7R', 'IRF8', 'KLRB1', 'MS4A1', 'MYLK', 'OLFM4', 'PCNA', 'PLAUR', 'PLN', 'PPP1R1B', 'PRDM1', 'PTPN22', 'RRM2', 'SELENBP1', 'SLC12A2', 'SPP1', 'SYP', 'TCF7', 'TYMS', 'UCHL1', 'XBP1'][0m
[INFO] Wrote top-50 var genes → /project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad/var_50genes.json
✅ Merged dataset created at: /project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad


# Copy extracted Xenium & Broad embeddings 

In [13]:
from __future__ import annotations
import shutil
from pathlib import Path
from typing import List, Set

def _list_models(root: Path) -> Set[str]:
    if not root.exists():
        return set()
    return {p.name for p in root.iterdir() if p.is_dir()}

def _copy2(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy2(src, dst)

def merge_embeddings_to_xeniumpr1_broad(
    emb_xenium: str | Path = "/project/simmons_hts/kxu/hest/eval/ST_data_emb/XeniumPR1_segger",
    emb_broad: str | Path  = "/project/simmons_hts/kxu/hest/eval/ST_data_emb/broad",
    emb_out: str | Path    = "/project/simmons_hts/kxu/hest/eval/ST_data_emb/XeniumPR1_broad",
    merged_adata_dir: str | Path | None = "/project/simmons_hts/kxu/hest/eval/data/XeniumPR1_broad/adata",
    rename_prefix_for_broad: str = "broad_",   # only applied on filename collisions
) -> None:
    """
    Create /XeniumPR1_broad/ embeddings by unioning model folders from XeniumPR1 and broad.
    - Structure: <emb_out>/<model_name>/*.h5
    - On filename collision, broad files are renamed to 'broad_<sample>.h5'.
    - Optional consistency check: ensure each adata sample has an embedding per model.
    """
    emb_xenium = Path(emb_xenium)
    emb_broad  = Path(emb_broad)
    emb_out    = Path(emb_out)

    # 1) Create destination root
    emb_out.mkdir(parents=True, exist_ok=True)

    # 2) Union of model directories
    models = _list_models(emb_xenium) | _list_models(emb_broad)
    if not models:
        print("[WARN] No model folders found in sources.")
        return

    print(f"[INFO] Models to merge: {sorted(models)}")

    # 3) Copy embeddings model-by-model
    for model in sorted(models):
        src_x = emb_xenium / model
        src_b = emb_broad  / model
        dst_m = emb_out    / model
        dst_m.mkdir(parents=True, exist_ok=True)

        copied_x = copied_b = 0

        # Copy XeniumPR1 embeddings (verbatim)
        if src_x.exists():
            for f in sorted(src_x.glob("*.h5")):
                _copy2(f, dst_m / f.name)
                copied_x += 1

        # Copy broad embeddings (rename on collision)
        if src_b.exists():
            for f in sorted(src_b.glob("*.h5")):
                target = dst_m / f.name
                if target.exists():
                    target = dst_m / f"{rename_prefix_for_broad}{f.stem}.h5"
                _copy2(f, target)
                copied_b += 1

        print(f"[INFO] {model}: copied {copied_x} (XeniumPR1) + {copied_b} (broad) → {dst_m}")

    # 4) Optional: verify coverage vs merged adata IDs
    if merged_adata_dir:
        adir = Path(merged_adata_dir)
        if adir.exists():
            adata_ids = sorted(p.stem for p in adir.glob("*.h5ad"))
            print(f"[CHECK] adata samples in XeniumPR1_broad: {len(adata_ids)}")
            for model in sorted(models):
                dst_m = emb_out / model
                if not dst_m.exists():
                    print(f"[CHECK][{model}] no embeddings folder, skipping check.")
                    continue
                emb_ids = {p.stem for p in dst_m.glob("*.h5")}
                missing = [sid for sid in adata_ids if sid not in emb_ids]
                if missing:
                    print(f"[WARN][{model}] {len(missing)} samples missing embeddings. First few: {missing[:10]}")
                else:
                    print(f"[OK][{model}] embeddings present for all adata samples.")
        else:
            print(f"[INFO] Skipping coverage check; adata dir not found: {adir}")

    print(f"✅ Finished merging embeddings into: {emb_out}")

# ---- Run it ----
merge_embeddings_to_xeniumpr1_broad()


[INFO] Models to merge: ['conch_v1', 'ctranspath', 'gigapath', 'h0_mini', 'hibou_large', 'hoptimus0', 'kaiko_base_8', 'phikon', 'phikon_v2', 'resnet50', 'uni_v1', 'virchow', 'virchow2']
[INFO] conch_v1: copied 15 (XeniumPR1) + 7 (broad) → /project/simmons_hts/kxu/hest/eval/ST_data_emb/XeniumPR1_broad/conch_v1
[INFO] ctranspath: copied 15 (XeniumPR1) + 7 (broad) → /project/simmons_hts/kxu/hest/eval/ST_data_emb/XeniumPR1_broad/ctranspath
[INFO] gigapath: copied 15 (XeniumPR1) + 7 (broad) → /project/simmons_hts/kxu/hest/eval/ST_data_emb/XeniumPR1_broad/gigapath
[INFO] h0_mini: copied 15 (XeniumPR1) + 7 (broad) → /project/simmons_hts/kxu/hest/eval/ST_data_emb/XeniumPR1_broad/h0_mini
[INFO] hibou_large: copied 15 (XeniumPR1) + 7 (broad) → /project/simmons_hts/kxu/hest/eval/ST_data_emb/XeniumPR1_broad/hibou_large
[INFO] hoptimus0: copied 15 (XeniumPR1) + 7 (broad) → /project/simmons_hts/kxu/hest/eval/ST_data_emb/XeniumPR1_broad/hoptimus0
[INFO] kaiko_base_8: copied 15 (XeniumPR1) + 7 (broad)