### Prepare XeniumPR1_segger data for training in /project/simmons_hts/kxu/hest/eval/data/XeniumPR1_segger/

In [19]:
# --- Standard library ---
import os
import json
import re
import shutil
from pathlib import Path
from typing import List, Dict, Tuple, Optional

# --- Third-party ---
import numpy as np
import pandas as pd
import scanpy as sc

# --- HEST ---
from hest import iter_hest
from hest.utils import get_k_genes
from hest.HESTData import create_splits


# ---------- helpers ----------
def _sanitize_tag(s: str, maxlen: int = 8) -> str:
    s2 = re.sub(r'[^A-Za-z0-9]', '', s)
    return s2.upper()[:maxlen] or "R"

def _extract_pr_number(path: Path) -> Optional[int]:
    """
    Look for 'VisiumR<digit>' pattern in the path (case-insensitive).
    Returns int digit (1..9) or None.
    """
    m = re.search(r'VisiumR(\d)', str(path), flags=re.IGNORECASE)
    return int(m.group(1)) if m else None

def _extract_slide_number(root: Path) -> Optional[str]:
    """
    Look for 'slideN' pattern in the root folder name and return the digit as string.
    If not found, try to infer from name like 'S1' or 's1' inside the folder name.
    """
    n = root.name.lower()
    m = re.search(r'slide[_\-]?(\d+)', n)
    if m:
        return m.group(1)
    m2 = re.search(r'\bS(\d+)\b', root.name, flags=re.IGNORECASE)
    if m2:
        return m2.group(1)
    return None

def _discover_samples_from_roots(
    roots: List[Path],
    ids: Optional[List[str]] = None,
) -> Dict[str, Dict[str, Path]]:
    """
    Discover samples under multiple roots and merge into a single map.
    """
    roots = [Path(r) for r in roots]
    roots = [r for r in roots if r.exists() and r.is_dir()]
    collected = []

    if ids is None:
        for r in sorted(roots, key=lambda p: str(p)):
            for p in sorted([d for d in r.iterdir() if d.is_dir()], key=lambda d: d.name):
                collected.append((r, p.name))
    else:
        for sid in sorted(ids):
            for r in sorted(roots, key=lambda p: str(p)):
                if (r / sid).is_dir():
                    collected.append((r, sid))

    samples: Dict[str, Dict[str, Path]] = {}
    for root, sid in collected:
        sdir = root / sid
        adata = sdir / "aligned_adata.h5ad"
        if not adata.exists():
            continue

        # pick patch .h5
        patch_h5 = None
        patches_dir = sdir / "patches"
        if patches_dir.exists():
            cands = sorted(patches_dir.glob("*.h5"))
            if cands:
                exact = [c for c in cands if c.name == f"{sid}.h5"]
                patch_h5 = exact[0] if exact else cands[0]

        # pick vis .png
        vis_png = None
        vis_dir = sdir / "patches_vis"
        if vis_dir.exists():
            cands = sorted(vis_dir.glob("*.png"))
            if cands:
                exact = [c for c in cands if c.name == f"{sid}_patch_vis.png"]
                vis_png = exact[0] if exact else cands[0]

        # --- Naming rule ---
        pr_num = _extract_pr_number(root)
        slide_num = _extract_slide_number(root) or _sanitize_tag(root.name, 3)

        if pr_num is not None:
            prefix = f"VisiumR{pr_num}S{slide_num}"
        else:
            # fallback for unknown roots
            prefix = f"{_sanitize_tag(root.name)}S{slide_num}"

        new_id = f"{prefix}{sid}"
        if new_id in samples:
            raise ValueError(
                f"Duplicate renamed sample id '{new_id}' (collision between roots for sid='{sid}')."
            )

        samples[new_id] = {"adata": adata, "patch": patch_h5, "vis": vis_png}

    return samples



def _transfer(src: Optional[Path], dst: Path, label: str, symlink: bool, missing_list: list):
    if src is None or not Path(src).exists():
        missing_list.append((dst.stem, label, str(src) if src is not None else "<none>"))
        return
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        dst.unlink()
    if symlink:
        try:
            os.symlink(src, dst)
        except FileExistsError:
            pass
    else:
        shutil.copy(src, dst)


def write_var_k_genes_from_paths(
    adata_paths,
    k,
    criteria,
    var_out_path,
    all_genes_out_path=None,
    exclude_keywords=None,
    filtered_common_out_path=None,
    min_cells_pct: float = 0.10,
):
    """
    Load all adatas, call HEST's get_k_genes() for top-k genes,
    and also save:
      - all common genes (keyword-filtered, no expression threshold)
      - filtered common genes using min_cells_pct across each sample

    Returns:
        (var_k_genes, all_common_genes, filtered_common_genes)

    Notes:
        - 'all_common_genes' uses only keyword filtering (like before).
        - 'filtered_common_genes' reproduces the min_cells_pct filtering
          logic used by get_k_genes: for each AnnData, genes not expressed
          in at least ceil(min_cells_pct * n_obs) spots are removed,
          then we intersect across samples, and finally drop BLANK/Control.
    """
    import json, warnings
    import numpy as np
    import scanpy as sc
    from hest.utils import get_k_genes

    if exclude_keywords is None:
        exclude_keywords = ["NegControl", "Codeword", "Intergenic_Region", "Control", "BLANK"]

    warnings.filterwarnings("ignore", category=FutureWarning, module="anndata")

    # ---- Load all adatas
    adata_list = []
    for p in adata_paths:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=FutureWarning)
            ad = sc.read_h5ad(str(p))
        adata_list.append(ad)

    # ---- Top-k variable/mean genes (delegates JSON writing to get_k_genes if var_out_path is a file path)
    var_k_genes = get_k_genes(adata_list, k, criteria, save_dir=str(var_out_path), min_cells_pct=min_cells_pct)

    # ---- ALL common genes (keyword-filtered only; preserves your original behavior)
    common_genes = set(adata_list[0].var_names)
    for ad in adata_list[1:]:
        common_genes &= set(ad.var_names)

    def _keep_keyword(gene: str) -> bool:
        for kw in exclude_keywords:
            if kw in gene:
                return False
        return True

    all_common_genes = sorted([g for g in common_genes if _keep_keyword(g)])

    # ---- Filtered common genes (expression threshold per sample, then intersect)
    filtered_sets = []
    for ad in adata_list:
        # work on a shallow copy to avoid mutating caller's object
        ad_tmp = ad[:, :].copy()
        min_cells = int(np.ceil(min_cells_pct * ad_tmp.n_obs)) if min_cells_pct else 0
        if min_cells > 0:
            sc.pp.filter_genes(ad_tmp, min_cells=min_cells)
        filtered_sets.append(set(ad_tmp.var_names))

    filtered_common = set.intersection(*filtered_sets) if filtered_sets else set()
    # remove BLANK/Control like in get_k_genes
    filtered_common_genes = sorted(
        [g for g in filtered_common if ("BLANK" not in g and "Control" not in g)]
    )

    # ---- Write JSONs
    if all_genes_out_path is None:
        all_genes_out_path = Path(var_out_path).parent / "all_genes.json"
    with open(all_genes_out_path, "w") as f:
        json.dump({"genes": all_common_genes}, f)

    if filtered_common_out_path is None:
        filtered_common_out_path = Path(var_out_path).parent / "common_genes_0.1.json"
    with open(filtered_common_out_path, "w") as f:
        json.dump({"genes": filtered_common_genes, "min_cells_pct": min_cells_pct}, f)

    print(
        f"[INFO] Wrote {var_out_path} (top-{k}, criteria={criteria}); "
        f"{all_genes_out_path} (all_common={len(all_common_genes)}); "
        f"{filtered_common_out_path} (filtered_common={len(filtered_common_genes)}, min_cells_pct={min_cells_pct})"
    )

    return var_k_genes, all_common_genes, filtered_common_genes


# ---------- main entry ----------

def create_benchmark_data_multislide(
    save_dir: str | Path,
    K: int,
    base_root: str | Path = "sftp://login1.molbiol.ox.ac.uk/ceph/project/simmons_hts/kxu/hest/xenium_data/XeniumPR1_segger",
    slide_subdirs: List[str] | tuple = ("slide1", "slide2"),
    ids: Optional[List[str]] = None,
    gene_k: int = 50,
    gene_criteria: str = "var",
    symlink: bool = False,
    seed: int = 0,
):
    """
    Build a HEST benchmark package from both slide1 and slide2 under the XeniumPR1_segger tree
    (or any set of slide subfolders you pass), without relying on a prebuilt metadata DF.

    Expected layout:
        <base_root>/slide1/<sample_id>/...
        <base_root>/slide2/<sample_id>/...

    Output tree:
      <save_dir>/
        var_50genes.json
        splits/...
        patches/<id>.h5
        patches/vis/<id>.png
        adata/<id>.h5ad

    Args:
        save_dir: destination directory for the assembled benchmark package
        K: number of folds for HEST's create_splits
        base_root: base directory containing slide subfolders
        slide_subdirs: which slide folders to include (defaults to ["slide1", "slide2"])
        ids: optional list of sample IDs to include (if None, auto-discovers)
        gene_k: number of variable genes to select
        gene_criteria: criteria for get_k_genes (e.g., "var")
        symlink: if True, symlink files instead of copying
        seed: RNG seed used to deterministically shuffle within groups before splitting
    """
    
    from hest.HESTData import create_splits

    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # 1) Build slide roots list and discover samples across them
    base_root = Path(base_root)
    roots = [base_root / sd for sd in slide_subdirs]
    print(f"[INFO] Using slide roots: {roots}")

    samples = _discover_samples_from_roots(roots, ids=ids)
    if not samples:
        raise ValueError(
            f"No valid samples (with aligned_adata.h5ad) found under any of: {roots}."
        )
    discovered_ids = sorted(samples.keys())
    print(f"[INFO] Discovered {len(discovered_ids)} samples: {discovered_ids}")

    # 2) Minimal metadata DF for splitting (patient from prefix; dataset_title from base folder name)
    def _infer_patient(sid: str) -> str:
        return sid.split("_")[0] if "_" in sid else sid

    dataset_title = base_root.name or "xenium"
    meta = pd.DataFrame(
        {
            "id": discovered_ids,
            "patient": [_infer_patient(s) for s in discovered_ids],
            "dataset_title": [dataset_title] * len(discovered_ids),
        }
    )

    # 3) Compute var_k genes → var_50genes.json
    adata_paths = [samples[sid]["adata"] for sid in discovered_ids]
    var_json = save_dir / f"var_{gene_k}genes.json"
    write_var_k_genes_from_paths(adata_paths, gene_k, gene_criteria, var_json)
    print(f"[INFO] Wrote {var_json}")

    # 4) K-fold splits using HEST's create_splits
    #    Group by (dataset_title, patient)
    group = meta.groupby(["dataset_title", "patient"])["id"].agg(list).to_dict()

    # Deterministic shuffle within each group
    rng = np.random.RandomState(seed)
    for key, id_list in group.items():
        rng.shuffle(id_list)

    splits_dir = save_dir / "splits"
    splits_dir.mkdir(parents=True, exist_ok=True)
    create_splits(str(splits_dir), group, K=K)
    print(f"[INFO] Wrote {K}-fold splits to {splits_dir}")

    # 5) Copy/symlink assets
    (save_dir / "patches").mkdir(exist_ok=True, parents=True)
    (save_dir / "patches" / "vis").mkdir(exist_ok=True, parents=True)
    (save_dir / "adata").mkdir(exist_ok=True, parents=True)

    missing: List[tuple] = []
    for sid in discovered_ids:
        info = samples[sid]
        _transfer(info.get("patch"), save_dir / "patches" / f"{sid}.h5", "patch", symlink, missing, overwrite=False)
        _transfer(info.get("vis"), save_dir / "patches" / "vis" / f"{sid}.png", "vis", symlink, missing, overwrite=False)
        _transfer(info.get("adata"), save_dir / "adata" / f"{sid}.h5ad", "adata", symlink, missing, overwrite=False)

    if missing:
        print("[WARN] Missing files:")
        for sid, lbl, path in missing:
            print(f"  - {sid} [{lbl}] → {path}")

    print(f"✅ Benchmark dataset created at {save_dir}")



In [6]:
create_benchmark_data_multislide(
    save_dir="/project/simmons_hts/kxu/hest/eval/data/VisiumR1",
    K=15, 
    base_root="/project/simmons_hts/kxu/hest/visium_data/VisiumR1",
    gene_k=50,
    gene_criteria="var",
    symlink=False,            # set True to save disk space
    seed=0                    # controls fold assignment deterministically
)

[INFO] Using slide roots: [PosixPath('/project/simmons_hts/kxu/hest/visium_data/VisiumR1/slide1'), PosixPath('/project/simmons_hts/kxu/hest/visium_data/VisiumR1/slide2')]
[INFO] Discovered 4 samples: ['VisiumR1S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI3', 'VisiumR1S1ROI4']
min_cells is  173.0
min_cells is  163.0
min_cells is  175.0
min_cells is  134.0
[32m12:20:13[0m | [1mINFO[0m | [1mFound 643 common genes[0m
[32m12:20:14[0m | [1mINFO[0m | [1mselected genes ['APLP2', 'ARF4', 'ATP1B1', 'ATP6V0D1', 'C1QC', 'CAST', 'CCNL1', 'CLU', 'COX6A1', 'CREB3L1', 'CTSZ', 'DEFA5', 'ECE1', 'EFHD2', 'ETFB', 'FAM102A', 'FN1', 'FTH1', 'GADD45B', 'GNB1', 'HIST1H4A', 'HLA-DRA', 'HSPB6', 'IER3', 'IER5', 'IGFBP5', 'ISG20', 'JPT1', 'LGALS3', 'LYZ', 'MAP1B', 'MBOAT7', 'OGN', 'PEX26', 'PLA2G2A', 'PRR13', 'RAB5A', 'REG1A', 'RHOG', 'SAT1', 'SERPINB6', 'SMTN', 'STAT6', 'TCF7L2', 'TIMP3', 'TNFRSF1A', 'TNFRSF21', 'TNIP1', 'VAMP8', 'VASP'][0m


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR1/var_50genes.json (top-50, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR1/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR1/common_genes_0.1.json (filtered_common=643, min_cells_pct=0.1)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR1/var_50genes.json
K=15 doesnt match the number of patients, try to distribute the patients instead
Split 0/4
train set is  ['VisiumR1S1ROI2', 'VisiumR1S1ROI3', 'VisiumR1S1ROI4']

test set is  ['VisiumR1S1ROI1']

Split 1/4
train set is  ['VisiumR1S1ROI1', 'VisiumR1S1ROI3', 'VisiumR1S1ROI4']

test set is  ['VisiumR1S1ROI2']

Split 2/4
train set is  ['VisiumR1S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI4']

test set is  ['VisiumR1S1ROI3']

Split 3/4
train set is  ['VisiumR1S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI3']

test set is  ['VisiumR1S1ROI4']

[INFO] Wrote 15-fold splits to /project/simmons_hts/kxu/hest/eval/data/VisiumR1/splits
✅ 

In [7]:
for i in range(1, 7):  # R1 to R6
    tag = f"VisiumR{i}"
    create_benchmark_data_multislide(
        save_dir=f"/project/simmons_hts/kxu/hest/eval/data/{tag}",
        K=15,
        base_root=f"/project/simmons_hts/kxu/hest/visium_data/{tag}",
        gene_k=50,
        gene_criteria="var",
        symlink=False,  # set True to save disk space
        seed=0          # controls fold assignment deterministically
    )

[INFO] Using slide roots: [PosixPath('/project/simmons_hts/kxu/hest/visium_data/VisiumR1/slide1'), PosixPath('/project/simmons_hts/kxu/hest/visium_data/VisiumR1/slide2')]
[INFO] Discovered 4 samples: ['VisiumR1S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI3', 'VisiumR1S1ROI4']
min_cells is  173.0
min_cells is  163.0
min_cells is  175.0
min_cells is  134.0
[32m12:21:35[0m | [1mINFO[0m | [1mFound 643 common genes[0m
[32m12:21:36[0m | [1mINFO[0m | [1mselected genes ['APLP2', 'ARF4', 'ATP1B1', 'ATP6V0D1', 'C1QC', 'CAST', 'CCNL1', 'CLU', 'COX6A1', 'CREB3L1', 'CTSZ', 'DEFA5', 'ECE1', 'EFHD2', 'ETFB', 'FAM102A', 'FN1', 'FTH1', 'GADD45B', 'GNB1', 'HIST1H4A', 'HLA-DRA', 'HSPB6', 'IER3', 'IER5', 'IGFBP5', 'ISG20', 'JPT1', 'LGALS3', 'LYZ', 'MAP1B', 'MBOAT7', 'OGN', 'PEX26', 'PLA2G2A', 'PRR13', 'RAB5A', 'REG1A', 'RHOG', 'SAT1', 'SERPINB6', 'SMTN', 'STAT6', 'TCF7L2', 'TIMP3', 'TNFRSF1A', 'TNFRSF21', 'TNIP1', 'VAMP8', 'VASP'][0m


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR1/var_50genes.json (top-50, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR1/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR1/common_genes_0.1.json (filtered_common=643, min_cells_pct=0.1)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR1/var_50genes.json
K=15 doesnt match the number of patients, try to distribute the patients instead
Split 0/4
train set is  ['VisiumR1S1ROI2', 'VisiumR1S1ROI3', 'VisiumR1S1ROI4']

test set is  ['VisiumR1S1ROI1']

Split 1/4
train set is  ['VisiumR1S1ROI1', 'VisiumR1S1ROI3', 'VisiumR1S1ROI4']

test set is  ['VisiumR1S1ROI2']

Split 2/4
train set is  ['VisiumR1S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI4']

test set is  ['VisiumR1S1ROI3']

Split 3/4
train set is  ['VisiumR1S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI3']

test set is  ['VisiumR1S1ROI4']

[INFO] Wrote 15-fold splits to /project/simmons_hts/kxu/hest/eval/data/VisiumR1/splits
✅ 

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR2/var_50genes.json (top-50, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR2/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR2/common_genes_0.1.json (filtered_common=130, min_cells_pct=0.1)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR2/var_50genes.json
K=15 doesnt match the number of patients, try to distribute the patients instead
Split 0/8
train set is  ['VisiumR2S1ROI2', 'VisiumR2S1ROI3', 'VisiumR2S1ROI4', 'VisiumR2S2ROI1', 'VisiumR2S2ROI2', 'VisiumR2S2ROI3', 'VisiumR2S2ROI4']

test set is  ['VisiumR2S1ROI1']

Split 1/8
train set is  ['VisiumR2S1ROI1', 'VisiumR2S1ROI3', 'VisiumR2S1ROI4', 'VisiumR2S2ROI1', 'VisiumR2S2ROI2', 'VisiumR2S2ROI3', 'VisiumR2S2ROI4']

test set is  ['VisiumR2S1ROI2']

Split 2/8
train set is  ['VisiumR2S1ROI1', 'VisiumR2S1ROI2', 'VisiumR2S1ROI4', 'VisiumR2S2ROI1', 'VisiumR2S2ROI2', 'VisiumR2S2ROI3', 'VisiumR2S2ROI4']

test set is  ['Visi

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[32m12:22:17[0m | [1mINFO[0m | [1mselected genes ['ADAMTS1', 'ADIRF', 'ALDH1B1', 'ATP8B1', 'BHLHE40', 'C1QB', 'C7', 'CAV1', 'CD68', 'CD9', 'CDC42EP5', 'CNN1', 'CRYAB', 'CTSC', 'CTSS', 'DSTN', 'FBP1', 'FXYD1', 'HAND2', 'HBA2', 'HIST1H2BN', 'HIST1H4A', 'HSPA1B', 'IER2', 'IGF1', 'IGLV3-1', 'ITGB2', 'LYZ', 'MAP1B', 'MFSD11', 'NUCKS1', 'OGN', 'PALLD', 'PGD', 'PGM5', 'PMP22', 'PRUNE2', 'PSMB10', 'RGS1', 'S100A8', 'S100A9', 'SDC4', 'SELENOP', 'SFRP4', 'SLMAP', 'SORBS1', 'ST6GALNAC6', 'SVIL', 'SYNM', 'TPSB2'][0m
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR3/var_50genes.json (top-50, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR3/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR3/common_genes_0.1.json (filtered_common=2388, min_cells_pct=0.1)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR3/var_50genes.json
K=15 doesnt match the number of patients, try to distribute the patients instead
Split 0/4
train set is  

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[32m12:22:38[0m | [1mINFO[0m | [1mselected genes ['ATP5MC3', 'BAG3', 'CCT2', 'CLTB', 'CRABP2', 'CSTB', 'CTNNBIP1', 'CYB5R1', 'DNAJA4', 'DNAJB1', 'DST', 'EBNA1BP2', 'EIF2S3', 'EIF4A3', 'ERO1A', 'FLNB', 'FSCN1', 'GJA1', 'GLTP', 'HEBP2', 'HSPA1B', 'HSPH1', 'ID1', 'IGFBP3', 'ITGA3', 'KPNA2', 'KRT10', 'KTN1', 'MYL12B', 'MYO1B', 'NAA20', 'NDRG1', 'ODC1', 'PLS3', 'PNP', 'PRNP', 'PRXL2A', 'PSMB6', 'PTBP3', 'RAB10', 'S100A16', 'SCD', 'SF3B6', 'SLC25A39', 'SLC38A2', 'TRAP1', 'TSTA3', 'TUBA4A', 'TXN', 'TXNDC17'][0m
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR4/var_50genes.json (top-50, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR4/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR4/common_genes_0.1.json (filtered_common=3907, min_cells_pct=0.1)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR4/var_50genes.json
K=15 doesnt match the number of patients, try to distribute the patients instead
Split 0/4
train set is  

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[32m12:23:05[0m | [1mINFO[0m | [1mselected genes ['ANPEP', 'APOL1', 'BAG3', 'BHLHE40', 'BLMH', 'BST2', 'C1QB', 'CD82', 'CD9', 'CFB', 'CLEC3B', 'COX7B', 'CSTB', 'CTNNBIP1', 'CYB5R1', 'DYNLL1', 'EIF3K', 'GM2A', 'HIST1H2BD', 'HIST1H4I', 'HIST2H2AB', 'HRAS', 'IFI27', 'IFI6', 'IFITM1', 'IGFBP3', 'IMPDH2', 'INSIG1', 'KPNA2', 'LMNB2', 'LY6E', 'MYL12B', 'NDRG1', 'NDUFB9', 'OAT', 'PML', 'PRXL2A', 'PSMA2', 'PSMB6', 'PSMB8', 'RAB3D', 'RUVBL2', 'S100A16', 'SELENOP', 'TAP1', 'TDP2', 'TRAP1', 'UBE2L6', 'WARS', 'ZC3H12A'][0m
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR5/var_50genes.json (top-50, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR5/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR5/common_genes_0.1.json (filtered_common=2951, min_cells_pct=0.1)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR5/var_50genes.json
K=15 doesnt match the number of patients, try to distribute the patients instead
Split 0/7
train se

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[32m12:23:36[0m | [1mINFO[0m | [1mselected genes ['AMOTL1', 'ANXA1', 'APOL1', 'BCR', 'BICD2', 'BLMH', 'CA12', 'CD44', 'CDC42EP1', 'CDC42EP5', 'CLTB', 'CTNNBIP1', 'DBI', 'DYNLL1', 'EGR1', 'EMP1', 'FASN', 'FCHSD1', 'G0S2', 'GIPC1', 'GPC1', 'H2AFJ', 'IFI27', 'MFSD11', 'MINK1', 'MTSS1', 'MYO15B', 'NDRG1', 'NDRG2', 'NFIB', 'NIBAN2', 'NUPR1', 'PHGDH', 'PLIN3', 'PLS3', 'PRXL2A', 'RARG', 'RMND5A', 'ROBO1', 'SDC1', 'SNX21', 'SPTLC2', 'TCF7L2', 'THBD', 'TMEM134', 'TUBA4A', 'TWIST2', 'TXN', 'UPP1', 'VWF'][0m
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR6/var_50genes.json (top-50, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR6/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR6/common_genes_0.1.json (filtered_common=1996, min_cells_pct=0.1)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR6/var_50genes.json
K=15 doesnt match the number of patients, try to distribute the patients instead
Split 0/8
train set is  ['Visiu

# create VisiumR folder containing all prime runs

In [8]:
metadata = pd.read_csv("/project/simmons_hts/kxu/hest/hest_directory.csv")
metadata

Unnamed: 0,sample_id,technology,panel,panel_name,patient_id,run_name,run_id,slide,slide_id,roi,...,alignment,directory_xenium_output,rds,alignment_note,crop_100_um,segmentation_target_pxl_size,num_patches_100um,num_patches_50um,num_patches_50um_0.25_um_px,num_patches_25um
0,XeniumPR1S1ROI1,10x Xenium,5k,,CAM006,RUNTrexBIO,PR1,1,43739.0,1.0,...,CAM006_Xenium5K_post_HnE_matrix.csv,/project/simmons_hts/shared/20_11_2024_xenium_...,/project/simmons_hts/jpark/1_project/1_objects...,,"{""type"": ""strip"", ""side"": ""right"", ""size"": 0.1...",,684.0,2727.0,2603.0,9950.0
1,XeniumPR1S1ROI2,10x Xenium,5k,,TIP877,RUNTrexBIO,PR1,1,43739.0,2.0,...,TIP877_Xenium5K_post_HnE_matrix.csv,/project/simmons_hts/shared/20_11_2024_xenium_...,/project/simmons_hts/jpark/1_project/1_objects...,,,,482.0,1886.0,1838.0,7130.0
2,XeniumPR1S1ROI3,10x Xenium,5k,,GI9389,RUNTrexBIO,PR1,1,43739.0,3.0,...,GI9389_Xenium5K_post_HnE_matrix.csv,/project/simmons_hts/shared/20_11_2024_xenium_...,/project/simmons_hts/jpark/1_project/1_objects...,,"{'type':'corner', 'corner':'top-left', 'width'...",,1168.0,4627.0,4502.0,17368.0
3,XeniumPR1S1ROI4,10x Xenium,5k,,GI9077,RUNTrexBIO,PR1,1,43739.0,4.0,...,GI9077_Xenium5K_post_HnE_matrix.csv,/project/simmons_hts/shared/20_11_2024_xenium_...,/project/simmons_hts/jpark/1_project/1_objects...,,"{'type':'corner', 'corner':'bottom-left', 'wid...",,1253.0,5010.0,4903.0,19360.0
4,XeniumPR1S1ROI5,10x Xenium,5k,,GI9612,RUNTrexBIO,PR1,1,43739.0,5.0,...,GI9612_Xenium5K_post_HnE_matrix.csv,/project/simmons_hts/shared/20_11_2024_xenium_...,/project/simmons_hts/jpark/1_project/1_objects...,,,,893.0,3520.0,3289.0,12449.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
115,XeniumR6S2ROI7,10x Xenium,480,,JR_50621_22,RUN6,R6,2,31265.0,7.0,...,XeniumR6S2ROI7_alignment_files/matrix.csv,/project/simmons_hts/shared/06_05_2024_xenium_...,,,,,4212.0,,,
116,XeniumR6S2ROI8,10x Xenium,480,,JR_18170_21,RUN6,R6,2,31265.0,8.0,...,XeniumR6S2ROI8_alignment_files/matrix.csv,/project/simmons_hts/shared/06_05_2024_xenium_...,,,,,2016.0,,,
117,XeniumR6S2ROI9,10x Xenium,480,,JR_8610_23,RUN6,R6,2,31265.0,9.0,...,XeniumR6S2ROI9_alignment_files/matrix.csv,/project/simmons_hts/shared/06_05_2024_xenium_...,,,,,3130.0,,,
118,XeniumR6S2ROI10,10x Xenium,480,,JR_20687_20,RUN6,R6,2,31265.0,10.0,...,XeniumR6S2ROI10_alignment_files/matrix.csv,/project/simmons_hts/shared/06_05_2024_xenium_...,,,,,665.0,,,


In [33]:
import re
import os
import json
import shutil
from pathlib import Path
from typing import List, Dict, Optional

import numpy as np
import pandas as pd
import scanpy as sc
from hest.utils import get_k_genes  # used by write_var_k_genes_from_paths; create_splits imported inside main


# ---------- helpers ----------

def _extract_slide_number_from_name(name: str) -> Optional[str]:
    """Extract slide number from a folder name like 'slide1', 'slide_2', 'S3', etc."""
    n = name.lower()
    m = re.search(r"slide[_\-]?(\d+)", n)
    if m:
        return m.group(1)
    m2 = re.search(r"\bS(\d+)\b", name, flags=re.IGNORECASE)
    if m2:
        return m2.group(1)
    return None

def _is_slide_like_folder(folder: Path) -> bool:
    """Return True if folder name looks like a slide folder and it contains subfolders."""
    if not folder.is_dir():
        return False
    if _extract_slide_number_from_name(folder.name) is not None:
        return True
    # also treat names starting with 'slide' case-insensitively
    if folder.name.lower().startswith("slide"):
        return True
    return False

def _expand_to_slide_paths(roots: List[Path]) -> List[Path]:
    """
    Given a list of roots, expand any PR-level root that contains slide-like subfolders
    into a list of slide paths. If a root already looks like a slide (contains aligned_adata.h5ad
    in its immediate subfolders), it's kept as-is.
    """
    expanded: List[Path] = []
    for r in roots:
        if not r.exists() or not r.is_dir():
            continue
        # find immediate subdirectories
        immediate_subdirs = sorted([d for d in r.iterdir() if d.is_dir()], key=lambda p: p.name)
        # if any immediate subdir looks like a slide and that slide has child sample folders, expand
        slide_candidates = [d for d in immediate_subdirs if _is_slide_like_folder(d)]
        if slide_candidates:
            # For each slide candidate, add it (but only if it contains sample subfolders)
            for s in slide_candidates:
                # if s contains at least one subdir with aligned_adata.h5ad, keep it
                has_sample_subdir = any((sd / "aligned_adata.h5ad").exists() for sd in sorted([d for d in s.iterdir() if d.is_dir()]))
                if has_sample_subdir:
                    expanded.append(s)
                else:
                    # if slide folder itself directly contains aligned_adata.h5ad files (uncommon), treat slide as sample root
                    if any((s / f).is_file() and f.endswith(".h5ad") for f in os.listdir(s)):
                        expanded.append(s)
        else:
            # No slide-like immediate subdirs. Check if this root itself directly contains sample subfolders (with aligned_adata.h5ad)
            has_direct_samples = any((d / "aligned_adata.h5ad").exists() for d in immediate_subdirs)
            if has_direct_samples:
                expanded.append(r)
            else:
                # fallback: if immediate_subdirs is non-empty, treat each immediate subdir as a slide candidate
                for d in immediate_subdirs:
                    if any((sd / "aligned_adata.h5ad").exists() for sd in sorted([sd for sd in d.iterdir() if sd.is_dir()])):
                        expanded.append(d)
    # dedupe while preserving order
    seen = set()
    uniq = []
    for p in expanded:
        if str(p) not in seen:
            uniq.append(p)
            seen.add(str(p))
    return uniq


def _discover_samples_from_slide_paths(
    slide_paths: List[Path],
    ids: Optional[List[str]] = None,
) -> Dict[str, Dict[str, Path]]:
    """
    Discover samples under *slide* paths (each slide path should contain sample subfolders).
    Returns mapping {new_id: {"adata": Path, "patch": Path|None, "vis": Path|None, "orig": orig_sid}}
    Naming: XeniumPR{n}S{slide}{ROI}
    """
    collected = []
    for sp in slide_paths:
        # sample subfolders are immediate children of the slide path
        for p in sorted([d for d in sp.iterdir() if d.is_dir()], key=lambda d: d.name):
            collected.append((sp, p.name))

    if ids is not None:
        # filter collected by ids list
        collected = [(sp, sid) for sp, sid in collected if sid in ids]

    samples: Dict[str, Dict[str, Path]] = {}
    for slide_path, sid in collected:
        sdir = slide_path / sid
        adata = sdir / "aligned_adata.h5ad"
        if not adata.exists():
            continue

        # find patch .h5 (optional)
        patch_h5 = None
        patches_dir = sdir / "patches"
        if patches_dir.exists():
            cands = sorted(patches_dir.glob("*.h5"))
            if cands:
                exact = [c for c in cands if c.name == f"{sid}.h5"]
                patch_h5 = exact[0] if exact else cands[0]

        # find vis png (optional)
        vis_png = None
        vis_dir = sdir / "patches_vis"
        if vis_dir.exists():
            cands = sorted(vis_dir.glob("*.png"))
            if cands:
                exact = [c for c in cands if c.name == f"{sid}_patch_vis.png"]
                vis_png = exact[0] if exact else cands[0]

        # determine PR number from slide_path or its ancestors
        pr_num = _extract_pr_number(slide_path)
        # if not present, try ancestors
        if pr_num is None:
            for ancestor in slide_path.parents:
                pr_num = _extract_pr_number(ancestor)
                if pr_num is not None:
                    break

        # slide number from slide path name (if can't find, use sanitized short tag)
        slide_num = _extract_slide_number_from_name(slide_path.name) or _sanitize_tag(slide_path.name, maxlen=3)

        if pr_num is not None:
            prefix = f"VisiumR{pr_num}S{slide_num}"
        else:
            # fallback if no PR number found anywhere upstream
            prefix = f"{_sanitize_tag(slide_path.name, maxlen=6)}S{slide_num}"

        new_id = f"{prefix}{sid}"
        if new_id in samples:
            raise ValueError(f"Duplicate renamed sample id '{new_id}' (collision for sid='{sid}').")

        samples[new_id] = {"adata": adata, "patch": patch_h5, "vis": vis_png, "orig": sid}

    return samples


def _transfer(src: Optional[Path], dst: Path, label: str, symlink: bool, missing_list: list, overwrite: bool = False):
    """
    Copy or symlink `src` -> `dst`. By default skip if dst exists (no-op).
    If overwrite=True, existing dst will be replaced.
    Records missing items into missing_list.
    """
    if src is None or not Path(src).exists():
        missing_list.append((dst.stem, label, str(src) if src is not None else "<none>"))
        return

    dst.parent.mkdir(parents=True, exist_ok=True)

    # If destination exists and we don't want to overwrite -> skip
    if dst.exists():
        # Optionally detect identical file (size + mtime) to avoid unnecessary replacement:
        try:
            src_stat = Path(src).stat()
            dst_stat = dst.stat()
            same_size = src_stat.st_size == dst_stat.st_size
            same_mtime = int(src_stat.st_mtime) == int(dst_stat.st_mtime)
        except Exception:
            same_size = False
            same_mtime = False

        if not overwrite and same_size and same_mtime:
            # identical (likely), skip quietly
            return
        if not overwrite:
            # dst exists but not identical (or we couldn't stat) — skip by default but warn
            print(f"[SKIP] {label} exists: {dst} (use overwrite=True to replace)")
            return

        # If overwrite requested: remove existing
        try:
            dst.unlink()
        except Exception as e:
            print(f"[WARN] could not remove existing {dst}: {e}")

    # perform transfer
    if symlink:
        try:
            os.symlink(src, dst)
        except FileExistsError:
            # already exists (race) — ignore
            pass
    else:
        shutil.copy2(src, dst)  # copy2 preserves mtime/metadata



def write_var_k_genes_from_paths(
    adata_paths,
    k,
    criteria,
    var_out_path,
    all_genes_out_path=None,
    exclude_keywords=None,
    filtered_common_out_path=None,
    min_cells_pct: float = 0.10,
):
    """
    Same behavior as before: load adatas, call get_k_genes, write JSONs.
    """
    import json, warnings
    import numpy as np
    import scanpy as sc
    from hest.utils import get_k_genes

    if exclude_keywords is None:
        exclude_keywords = ["NegControl", "Codeword", "Intergenic_Region", "Control", "BLANK"]

    warnings.filterwarnings("ignore", category=FutureWarning, module="anndata")

    adata_list = []
    for p in adata_paths:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=FutureWarning)
            ad = sc.read_h5ad(str(p))
        adata_list.append(ad)

    var_k_genes = get_k_genes(adata_list, k, criteria, save_dir=str(var_out_path), min_cells_pct=min_cells_pct)

    common_genes = set(adata_list[0].var_names) if adata_list else set()
    for ad in adata_list[1:]:
        common_genes &= set(ad.var_names)

    def _keep_keyword(gene: str) -> bool:
        for kw in exclude_keywords:
            if kw in gene:
                return False
        return True

    all_common_genes = sorted([g for g in common_genes if _keep_keyword(g)])

    filtered_sets = []
    for ad in adata_list:
        ad_tmp = ad[:, :].copy()
        min_cells = int(np.ceil(min_cells_pct * ad_tmp.n_obs)) if min_cells_pct else 0
        if min_cells > 0:
            sc.pp.filter_genes(ad_tmp, min_cells=min_cells)
        filtered_sets.append(set(ad_tmp.var_names))

    filtered_common = set.intersection(*filtered_sets) if filtered_sets else set()
    filtered_common_genes = sorted([g for g in filtered_common if ("BLANK" not in g and "Control" not in g)])

    if all_genes_out_path is None:
        all_genes_out_path = Path(var_out_path).parent / "all_genes.json"
    with open(all_genes_out_path, "w") as f:
        json.dump({"genes": all_common_genes}, f)

    if filtered_common_out_path is None:
        filtered_common_out_path = Path(var_out_path).parent / f"common_genes_{min_cells_pct}.json"
    with open(filtered_common_out_path, "w") as f:
        json.dump({"genes": filtered_common_genes, "min_cells_pct": min_cells_pct}, f)

    print(
        f"[INFO] Wrote {var_out_path} (top-{k}, criteria={criteria}); "
        f"{all_genes_out_path} (all_common={len(all_common_genes)}); "
        f"{filtered_common_out_path} (filtered_common={len(filtered_common_genes)}, min_cells_pct={min_cells_pct})"
    )

    return var_k_genes, all_common_genes, filtered_common_genes

# copy directly from other eval folder
def create_benchmark_from_eval_dirs(
    save_dir: str | Path,
    K: int,
    eval_dirs: List[str | Path],
    gene_k: int = 50,
    gene_criteria: str = "var",
    symlink: bool = False,
    seed: int = 0,
    metadata_csv: str = "/project/simmons_hts/kxu/hest/hest_directory.csv",
    dry_run: bool = False,
    exclude_ids: Optional[List[str]] = None
):
    """
    Build a merged benchmark package by copying (or symlinking) assets from one or more
    'eval' dataset folders that already contain:
        <eval_dir>/
            patches/
                *.h5
                vis/
                    *.png
            adata/
                *.h5ad

    Args:
        save_dir: destination directory to create merged dataset (will contain patches/, patches/vis/, adata/, splits/, var_*.json)
        K: number of folds (patient-level)
        eval_dirs: list of dataset root paths to copy from (e.g. XeniumPR2 eval folder)
        gene_k, gene_criteria: forwarded to get_k_genes
        symlink: if True, create symlinks instead of copying
        seed: RNG seed for deterministic fold assignment
        metadata_csv: CSV mapping sample_id -> patient_id
        dry_run: if True, only print planned actions without copying
    Returns:
        pd.DataFrame meta (columns: id, patient, dataset_title)
    """
    from hest.HESTData import create_splits

    save_dir = Path(save_dir)
    eval_dirs = [Path(x) for x in eval_dirs]
    # sanitise and check inputs
    existing = [d for d in eval_dirs if d.exists() and d.is_dir()]
    if not existing:
        raise ValueError(f"No valid eval_dirs found among: {eval_dirs}")
    print(f"[INFO] Using eval dirs: {existing}")

    # discover sample ids by scanning adata/ and patches/ for filenames
    discovered_ids = set()
    sample_sources = {}  # id -> dict(sources found)
    for d in existing:
        adata_dir = d / "adata"
        patches_dir = d / "patches"
        vis_dir = patches_dir / "vis"

        # adata
        if adata_dir.exists() and adata_dir.is_dir():
            for f in sorted(adata_dir.glob("*.h5ad")):
                sid = f.stem
                discovered_ids.add(sid)
                sample_sources.setdefault(sid, {}).setdefault("adata", []).append(f)

        # patches
        if patches_dir.exists() and patches_dir.is_dir():
            for f in sorted(patches_dir.glob("*.h5")):
                sid = f.stem
                discovered_ids.add(sid)
                sample_sources.setdefault(sid, {}).setdefault("patch", []).append(f)

            # vis images
            if vis_dir.exists() and vis_dir.is_dir():
                for f in sorted(vis_dir.glob("*.png")):
                    # allow vis file names like '<sid>_patch_vis.png' or '<sid>.png' or anything; map by stem heuristics
                    stem = f.stem
                    # normalize: if stem endswith '_patch_vis', strip it
                    stem_clean = re.sub(r"_?patch_vis$", "", stem, flags=re.IGNORECASE)
                    # sometimes vis is named '<sid>_patch_vis' or '<sid>'
                    sid = stem_clean
                    discovered_ids.add(sid)
                    sample_sources.setdefault(sid, {}).setdefault("vis", []).append(f)
                    
    # ---- Apply exclusion ----
    if exclude_ids:
        exclude_set = set(exclude_ids)
        before = len(discovered_ids)
        discovered_ids = [sid for sid in discovered_ids if sid not in exclude_set]

        missing_excludes = exclude_set - set(discovered_ids)
        if missing_excludes:
            print(f"[WARN] Some exclude_ids not found: {sorted(missing_excludes)}")

        removed = before - len(discovered_ids)
        print(f"[INFO] Excluded {removed} samples → remaining {len(discovered_ids)}")
        if removed > 0:
            for e in sorted(exclude_set & set(discovered_ids)):
                print(f"   - excluded: {e}")

    discovered_ids = sorted(discovered_ids)
    if not discovered_ids:
        raise ValueError("No samples discovered in provided eval_dirs (no *.h5ad or *.h5 files found).")
    print(f"[INFO] Discovered sample IDs ({len(discovered_ids)}): {discovered_ids}")

    # Prepare save_dir layout
    patches_out = save_dir / "patches"
    patches_vis_out = patches_out / "vis"
    adata_out = save_dir / "adata"
    for p in (patches_out, patches_vis_out, adata_out):
        if not dry_run:
            p.mkdir(parents=True, exist_ok=True)

    # Load metadata CSV mapping sample_id -> patient_id
    patient_map = {}
    meta_df_csv = None
    if Path(metadata_csv).exists():
        meta_df_csv = pd.read_csv(metadata_csv, dtype=str)
        if {"sample_id", "patient_id"}.issubset(meta_df_csv.columns):
            meta_df_csv["sample_id"] = meta_df_csv["sample_id"].astype(str).str.strip()
            meta_df_csv["patient_id"] = meta_df_csv["patient_id"].astype(str).str.strip()
            patient_map = dict(zip(meta_df_csv["sample_id"], meta_df_csv["patient_id"]))
            print(f"[INFO] Loaded {len(patient_map)} entries from {metadata_csv}")
        else:
            print(f"[WARN] metadata_csv missing columns 'sample_id'/'patient_id'; will fallback to automatic patient inference")
    else:
        print(f"[WARN] metadata_csv not found: {metadata_csv}; will fallback to automatic patient inference")

    # Copy / symlink files into save_dir using sample id as filename stem
    missing = []
    planned_actions = []
    for sid in discovered_ids:
        srcs = sample_sources.get(sid, {})
        # choose one adata: prefer first available
        adata_src = None
        if "adata" in srcs and srcs["adata"]:
            adata_src = srcs["adata"][0]
        # else fallback to none

        patch_src = None
        if "patch" in srcs and srcs["patch"]:
            patch_src = srcs["patch"][0]

        # vis: there may be multiple pngs per sample across eval_dirs — keep all but use a standardized name
        vis_srcs = srcs.get("vis", [])

        # plan copy/symlink
        if adata_src:
            dst = adata_out / f"{sid}.h5ad"
            planned_actions.append(("adata", adata_src, dst))
        else:
            # warn — adata missing for this sid
            missing.append((sid, "adata"))

        if patch_src:
            dst = patches_out / f"{sid}.h5"
            planned_actions.append(("patch", patch_src, dst))
        else:
            missing.append((sid, "patch"))

        # for vis, when multiple sources exist, copy each with a numeric suffix if needed
        for i, vs in enumerate(vis_srcs, start=1):
            # try base name '<sid>.png' then '<sid>_1.png', '<sid>_2.png'...
            if i == 1:
                dst = patches_vis_out / f"{sid}.png"
            else:
                dst = patches_vis_out / f"{sid}_{i}.png"
            planned_actions.append(("vis", vs, dst))

    # Show dry run summary
    print(f"[INFO] Planned actions: {len(planned_actions)} file operations; {len(missing)} missing types.")
    if dry_run:
        for act, src, dst in planned_actions[:200]:
            print(f"  - [{act}] {src} -> {dst}")
        if missing:
            print("[WARN] Missing items:")
            for sid, typ in missing[:50]:
                print(f"  - {sid}: missing {typ}")
        print("[INFO] dry_run=True → no files were copied.")
    else:
        # perform file ops
        for act, src, dst in planned_actions:
            try:
                _transfer(src, dst, act, symlink, [], overwrite=False)  # we pass temporary missing list per transfer
            except Exception as e:
                print(f"[ERROR] transferring {src} -> {dst}: {e}")

    # Build metadata DataFrame: use discovered sample IDs and patient mapping (full sample id)
    patient_ids = []
    unresolved = []
    for sid in discovered_ids:
        pid = patient_map.get(sid)
        if pid is None:
            # fallback: try a stem match where original source had 'orig' info: try to find sample with full stem in filenames
            # attempt to match any filename that contains sid as suffix: useful if CSV used 'XeniumPR1S1ROI1' but discovered was 'ROI1' etc.
            # we'll try simple heuristics:
            matched = None
            if meta_df_csv is not None:
                # try find any csv sample_id that endswith sid
                candidates = [s for s in meta_df_csv["sample_id"].values if str(s).endswith(str(sid))]
                if candidates:
                    matched = candidates[0]
                    pid = patient_map.get(matched)
            if pid is None:
                # fallback to using prefix before '_' or the sid itself as patient
                pid = sid.split("_")[0] if "_" in sid else sid
                unresolved.append(sid)
        patient_ids.append(pid)

    meta = pd.DataFrame({"id": discovered_ids, "patient": patient_ids, "dataset_title": ["XeniumPR"] * len(discovered_ids)})

    print(f"[INFO] Built metadata: {len(meta)} samples, {meta['patient'].nunique()} unique patients.")
    print(meta.head(20).to_string(index=False))

    # write var_k genes (requires adata files to be present in save_dir or accessible)
    adata_paths = [adata_out / f"{sid}.h5ad" for sid in discovered_ids]
    # If dry_run, don't run get_k_genes; just return meta

    var_json = save_dir / f"var_{gene_k}genes.json"
    write_var_k_genes_from_paths(adata_paths, gene_k, gene_criteria, var_json)
    print(f"[INFO] Wrote {var_json}")

    # patient-level splits
    group = meta.groupby(["dataset_title", "patient"])["id"].agg(list).to_dict()
    rng = np.random.RandomState(seed)
    for key, id_list in group.items():
        rng.shuffle(id_list)

    splits_dir = save_dir / "splits"
    splits_dir.mkdir(parents=True, exist_ok=True)
    create_splits(str(splits_dir), group, K=K)
    print(f"[INFO] Wrote {K}-fold patient-level splits to {splits_dir}")

    # final warnings about missing files
    if missing:
        print("[WARN] Some samples were missing adata/patch files (listing up to 50):")
        for sid, typ in missing[:50]:
            print(f"  - {sid}: missing {typ}")

    print(f"✅ Merged benchmark created at {save_dir}")
    return meta


In [29]:
def create_benchmark_from_eval_dirs(
    save_dir: str | Path,
    K: int,
    eval_dirs: List[str | Path],
    gene_k: int = 50,
    gene_criteria: str = "var",
    symlink: bool = False,
    seed: int = 0,
    metadata_csv: str = "/project/simmons_hts/kxu/hest/hest_directory.csv",
    exclude_ids: Optional[List[str]] = None
):
    """
    Build a merged benchmark package by copying (or symlinking) assets from one or more
    'eval' dataset folders that already contain:
        <eval_dir>/
            patches/
                *.h5
                vis/
                    *.png
            adata/
                *.h5ad

    Args:
        save_dir: destination directory to create merged dataset (will contain patches/, patches/vis/, adata/, splits/, var_*.json)
        K: number of folds (patient-level)
        eval_dirs: list of dataset root paths to copy from (e.g. XeniumPR2 eval folder)
        gene_k, gene_criteria: forwarded to get_k_genes
        symlink: if True, create symlinks instead of copying
        seed: RNG seed for deterministic fold assignment
        metadata_csv: CSV mapping sample_id -> patient_id
    Returns:
        pd.DataFrame meta (columns: id, patient, dataset_title)
    """
    from hest.HESTData import create_splits

    save_dir = Path(save_dir)
    eval_dirs = [Path(x) for x in eval_dirs]
    # sanitise and check inputs
    existing = [d for d in eval_dirs if d.exists() and d.is_dir()]
    if not existing:
        raise ValueError(f"No valid eval_dirs found among: {eval_dirs}")
    print(f"[INFO] Using eval dirs: {existing}")

    # discover sample ids by scanning adata/ and patches/ for filenames
    discovered_ids = set()
    sample_sources = {}  # id -> dict(sources found)
    for d in existing:
        adata_dir = d / "adata"
        patches_dir = d / "patches"
        vis_dir = patches_dir / "vis"

        # adata
        if adata_dir.exists() and adata_dir.is_dir():
            for f in sorted(adata_dir.glob("*.h5ad")):
                sid = f.stem
                discovered_ids.add(sid)
                sample_sources.setdefault(sid, {}).setdefault("adata", []).append(f)

        # patches
        if patches_dir.exists() and patches_dir.is_dir():
            for f in sorted(patches_dir.glob("*.h5")):
                sid = f.stem
                discovered_ids.add(sid)
                sample_sources.setdefault(sid, {}).setdefault("patch", []).append(f)

            # vis images
            if vis_dir.exists() and vis_dir.is_dir():
                for f in sorted(vis_dir.glob("*.png")):
                    # allow vis file names like '<sid>_patch_vis.png' or '<sid>.png' or anything; map by stem heuristics
                    stem = f.stem
                    # normalize: if stem endswith '_patch_vis', strip it
                    stem_clean = re.sub(r"_?patch_vis$", "", stem, flags=re.IGNORECASE)
                    # sometimes vis is named '<sid>_patch_vis' or '<sid>'
                    sid = stem_clean
                    discovered_ids.add(sid)
                    sample_sources.setdefault(sid, {}).setdefault("vis", []).append(f)
                    
    # ---- Apply exclusion ----
    if exclude_ids:
        exclude_set = set(exclude_ids)
        before = len(discovered_ids)
        discovered_ids = [sid for sid in discovered_ids if sid not in exclude_set]

        missing_excludes = exclude_set - set(discovered_ids)
        if missing_excludes:
            print(f"[WARN] Some exclude_ids not found: {sorted(missing_excludes)}")

        removed = before - len(discovered_ids)
        print(f"[INFO] Excluded {removed} samples → remaining {len(discovered_ids)}")
        if removed > 0:
            for e in sorted(exclude_set & set(discovered_ids)):
                print(f"   - excluded: {e}")

    discovered_ids = sorted(discovered_ids)
    if not discovered_ids:
        raise ValueError("No samples discovered in provided eval_dirs (no *.h5ad or *.h5 files found).")
    print(f"[INFO] Discovered sample IDs ({len(discovered_ids)}): {discovered_ids}")

    # Prepare save_dir layout
    patches_out = save_dir / "patches"
    patches_vis_out = patches_out / "vis"
    adata_out = save_dir / "adata"
    for p in (patches_out, patches_vis_out, adata_out):

        # Load metadata CSV mapping sample_id -> patient_id
        patient_map = {}
        meta_df_csv = None
    if Path(metadata_csv).exists():
        meta_df_csv = pd.read_csv(metadata_csv, dtype=str)
        if {"sample_id", "patient_id"}.issubset(meta_df_csv.columns):
            meta_df_csv["sample_id"] = meta_df_csv["sample_id"].astype(str).str.strip()
            meta_df_csv["patient_id"] = meta_df_csv["patient_id"].astype(str).str.strip()
            patient_map = dict(zip(meta_df_csv["sample_id"], meta_df_csv["patient_id"]))
            print(f"[INFO] Loaded {len(patient_map)} entries from {metadata_csv}")
        else:
            print(f"[WARN] metadata_csv missing columns 'sample_id'/'patient_id'; will fallback to automatic patient inference")
    else:
        print(f"[WARN] metadata_csv not found: {metadata_csv}; will fallback to automatic patient inference")

    # Copy / symlink files into save_dir using sample id as filename stem
    missing = []
    planned_actions = []
    for sid in discovered_ids:
        srcs = sample_sources.get(sid, {})
        # choose one adata: prefer first available
        adata_src = None
        if "adata" in srcs and srcs["adata"]:
            adata_src = srcs["adata"][0]
        # else fallback to none

        patch_src = None
        if "patch" in srcs and srcs["patch"]:
            patch_src = srcs["patch"][0]

        # vis: there may be multiple pngs per sample across eval_dirs — keep all but use a standardized name
        vis_srcs = srcs.get("vis", [])

        # plan copy/symlink
        if adata_src:
            dst = adata_out / f"{sid}.h5ad"
            planned_actions.append(("adata", adata_src, dst))
        else:
            # warn — adata missing for this sid
            missing.append((sid, "adata"))

        if patch_src:
            dst = patches_out / f"{sid}.h5"
            planned_actions.append(("patch", patch_src, dst))
        else:
            missing.append((sid, "patch"))

        # for vis, when multiple sources exist, copy each with a numeric suffix if needed
        for i, vs in enumerate(vis_srcs, start=1):
            # try base name '<sid>.png' then '<sid>_1.png', '<sid>_2.png'...
            if i == 1:
                dst = patches_vis_out / f"{sid}.png"
            else:
                dst = patches_vis_out / f"{sid}_{i}.png"
            planned_actions.append(("vis", vs, dst))

    # Show dry run summary
    print(f"[INFO] Planned actions: {len(planned_actions)} file operations; {len(missing)} missing types.")
    # perform file ops
    for act, src, dst in planned_actions:
        try:
            _transfer(src, dst, act, symlink, [])  # we pass temporary missing list per transfer
        except Exception as e:
            print(f"[ERROR] transferring {src} -> {dst}: {e}")

    # Build metadata DataFrame: use discovered sample IDs and patient mapping (full sample id)
    patient_ids = []
    unresolved = []
    for sid in discovered_ids:
        pid = patient_map.get(sid)
        if pid is None:
            # fallback: try a stem match where original source had 'orig' info: try to find sample with full stem in filenames
            # attempt to match any filename that contains sid as suffix: useful if CSV used 'XeniumPR1S1ROI1' but discovered was 'ROI1' etc.
            # we'll try simple heuristics:
            matched = None
            if meta_df_csv is not None:
                # try find any csv sample_id that endswith sid
                candidates = [s for s in meta_df_csv["sample_id"].values if str(s).endswith(str(sid))]
                if candidates:
                    matched = candidates[0]
                    pid = patient_map.get(matched)
            if pid is None:
                # fallback to using prefix before '_' or the sid itself as patient
                pid = sid.split("_")[0] if "_" in sid else sid
                unresolved.append(sid)
        patient_ids.append(pid)

    meta = pd.DataFrame({"id": discovered_ids, "patient": patient_ids, "dataset_title": ["XeniumPR"] * len(discovered_ids)})

    print(f"[INFO] Built metadata: {len(meta)} samples, {meta['patient'].nunique()} unique patients.")
    print(meta.head(20).to_string(index=False))

    # write var_k genes (requires adata files to be present in save_dir or accessible)
    adata_paths = [adata_out / f"{sid}.h5ad" for sid in discovered_ids]

    var_json = save_dir / f"var_{gene_k}genes.json"
    write_var_k_genes_from_paths(adata_paths, gene_k, gene_criteria, var_json)
    print(f"[INFO] Wrote {var_json}")

    # patient-level splits
    group = meta.groupby(["dataset_title", "patient"])["id"].agg(list).to_dict()
    rng = np.random.RandomState(seed)
    for key, id_list in group.items():
        rng.shuffle(id_list)

    splits_dir = save_dir / "splits"
    splits_dir.mkdir(parents=True, exist_ok=True)
    create_splits(str(splits_dir), group, K=K)
    print(f"[INFO] Wrote {K}-fold patient-level splits to {splits_dir}")

    # final warnings about missing files
    if missing:
        print("[WARN] Some samples were missing adata/patch files (listing up to 50):")
        for sid, typ in missing[:50]:
            print(f"  - {sid}: missing {typ}")

    print(f"✅ Merged benchmark created at {save_dir}")
    return meta


In [30]:
meta = create_benchmark_from_eval_dirs(
    save_dir="/project/simmons_hts/kxu/hest/eval/data/VisiumR1-6",
    K=15,
    eval_dirs=["/project/simmons_hts/kxu/hest/eval/data/VisiumR1",
               "/project/simmons_hts/kxu/hest/eval/data/VisiumR2",
               "/project/simmons_hts/kxu/hest/eval/data/VisiumR3",
               "/project/simmons_hts/kxu/hest/eval/data/VisiumR4",
               "/project/simmons_hts/kxu/hest/eval/data/VisiumR5",
               "/project/simmons_hts/kxu/hest/eval/data/VisiumR6",
    ],
    gene_k=50,
    symlink=False,
    seed=0,
    metadata_csv="/project/simmons_hts/kxu/hest/visium_directory.csv",
)

[INFO] Using eval dirs: [PosixPath('/project/simmons_hts/kxu/hest/eval/data/VisiumR1'), PosixPath('/project/simmons_hts/kxu/hest/eval/data/VisiumR2'), PosixPath('/project/simmons_hts/kxu/hest/eval/data/VisiumR3'), PosixPath('/project/simmons_hts/kxu/hest/eval/data/VisiumR4'), PosixPath('/project/simmons_hts/kxu/hest/eval/data/VisiumR5'), PosixPath('/project/simmons_hts/kxu/hest/eval/data/VisiumR6')]
[INFO] Discovered sample IDs (35): ['VisiumR1S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI3', 'VisiumR1S1ROI4', 'VisiumR2S1ROI1', 'VisiumR2S1ROI2', 'VisiumR2S1ROI3', 'VisiumR2S1ROI4', 'VisiumR2S2ROI1', 'VisiumR2S2ROI2', 'VisiumR2S2ROI3', 'VisiumR2S2ROI4', 'VisiumR3S1ROI1', 'VisiumR3S1ROI2', 'VisiumR3S1ROI3', 'VisiumR3S1ROI4', 'VisiumR4S1ROI1', 'VisiumR4S1ROI2', 'VisiumR4S1ROI3', 'VisiumR4S1ROI4', 'VisiumR5S1ROI1', 'VisiumR5S1ROI2', 'VisiumR5S1ROI3', 'VisiumR5S1ROI4', 'VisiumR5S2ROI1', 'VisiumR5S2ROI2', 'VisiumR5S2ROI3', 'VisiumR6S1ROI1', 'VisiumR6S1ROI2', 'VisiumR6S1ROI3', 'VisiumR6S1ROI4', 'Vi

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[32m15:14:56[0m | [1mINFO[0m | [1mselected genes ['ATP5PD', 'B2M', 'CAPN1', 'CFD', 'CHCHD2', 'CNDP2', 'COL1A2', 'COX5B', 'COX8A', 'CTSD', 'CYC1', 'CYCS', 'EDF1', 'EEF1B2', 'EIF4A1', 'EIF4H', 'ENO1', 'FLNA', 'FTH1', 'GOLM1', 'GSN', 'GSTP1', 'HDGF', 'HLA-DMA', 'HNRNPA3', 'HSPB1', 'IGFBP7', 'IGKC', 'IL32', 'ITM2B', 'LASP1', 'LGALS3BP', 'LRP10', 'MT1X', 'MT2A', 'NCL', 'NDUFA13', 'NDUFB7', 'NDUFS6', 'P4HB', 'PABPC1', 'PNPLA2', 'POR', 'RAC1', 'TAGLN2', 'TMSB4X', 'TPM1', 'TPM2', 'UBC', 'VDAC1'][0m
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/var_50genes.json (top-50, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/common_genes_0.1.json (filtered_common=128, min_cells_pct=0.1)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/var_50genes.json
K=15 doesnt match the number of patients, try to distribute the patients instead
Split 0/18
train set is  ['Visi

Split 16/18
train set is  ['VisiumR4S1ROI1', 'VisiumR6S2ROI3', 'VisiumR6S2ROI2', 'VisiumR5S1ROI2', 'VisiumR4S1ROI2', 'VisiumR5S1ROI4', 'VisiumR4S1ROI3', 'VisiumR2S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI1', 'VisiumR1S1ROI3', 'VisiumR2S1ROI2', 'VisiumR4S1ROI4', 'VisiumR6S1ROI1', 'VisiumR3S1ROI4', 'VisiumR6S2ROI1', 'VisiumR5S2ROI2', 'VisiumR6S1ROI4', 'VisiumR3S1ROI3', 'VisiumR3S1ROI1', 'VisiumR6S2ROI4', 'VisiumR6S1ROI3', 'VisiumR5S2ROI1', 'VisiumR3S1ROI2', 'VisiumR5S1ROI3', 'VisiumR6S1ROI2', 'VisiumR5S1ROI1', 'VisiumR2S1ROI3', 'VisiumR2S2ROI1', 'VisiumR5S2ROI3', 'VisiumR2S2ROI4', 'VisiumR1S1ROI4', 'VisiumR2S2ROI3']

test set is  ['VisiumR2S2ROI2' 'VisiumR2S1ROI4']

Split 17/18
train set is  ['VisiumR4S1ROI1', 'VisiumR6S2ROI3', 'VisiumR6S2ROI2', 'VisiumR5S1ROI2', 'VisiumR4S1ROI2', 'VisiumR5S1ROI4', 'VisiumR4S1ROI3', 'VisiumR2S1ROI1', 'VisiumR1S1ROI2', 'VisiumR1S1ROI1', 'VisiumR1S1ROI3', 'VisiumR2S1ROI2', 'VisiumR4S1ROI4', 'VisiumR6S1ROI1', 'VisiumR3S1ROI4', 'VisiumR6S2ROI1', 'VisiumR5S2RO

### add extra sets of variable genes

In [34]:
save_dir = Path("/project/simmons_hts/kxu/hest/eval/data/VisiumR1-6")
adata_dir = save_dir / "adata"
adata_paths = sorted([p for p in adata_dir.glob("*.h5ad") if p.is_file()])

ks = [1000, 3000]
criteria = "var"

for k in ks:
    var_out = save_dir / f"var_{k}genes.json"
    print(f"\n--- computing top-{k} ({criteria}) -> {var_out} ---")
    write_var_k_genes_from_paths(adata_paths, k, criteria, var_out,min_cells_pct=0.01) # 3046 common genes


--- computing top-1000 (var) -> /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/var_1000genes.json ---
min_cells is  18.0
min_cells is  17.0
min_cells is  18.0
min_cells is  14.0
min_cells is  18.0
min_cells is  17.0
min_cells is  18.0
min_cells is  14.0
min_cells is  30.0
min_cells is  46.0
min_cells is  46.0
min_cells is  47.0
min_cells is  28.0
min_cells is  39.0
min_cells is  28.0
min_cells is  40.0
min_cells is  29.0
min_cells is  43.0
min_cells is  39.0
min_cells is  39.0
min_cells is  39.0
min_cells is  47.0
min_cells is  50.0
min_cells is  43.0
min_cells is  44.0
min_cells is  23.0
min_cells is  42.0
min_cells is  26.0
min_cells is  41.0
min_cells is  24.0
min_cells is  20.0
min_cells is  25.0
min_cells is  36.0
min_cells is  36.0
min_cells is  43.0
min_cells is  31.0
min_cells is  29.0
min_cells is  37.0
min_cells is  49.0
min_cells is  18.0
min_cells is  17.0
min_cells is  18.0
min_cells is  14.0
[32m15:18:15[0m | [1mINFO[0m | [1mFound 3046 common genes[0m


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[32m15:18:50[0m | [1mINFO[0m | [1mselected genes ['A2M', 'AAMP', 'ABCC3', 'ABHD12', 'ABI1', 'ABLIM1', 'ABR', 'ABTB2', 'ACAA1', 'ACAA2', 'ACADM', 'ACAP1', 'ACO2', 'ACOT7', 'ACOX1', 'ACOX3', 'ACSL5', 'ACTG2', 'ACTN1', 'ACTN4', 'ACTR1A', 'ADAM15', 'ADD1', 'ADIPOR1', 'ADIRF', 'ADM', 'ADNP', 'AEBP1', 'AFDN', 'AGAP3', 'AGPAT3', 'AHCYL1', 'AKAP1', 'AKAP6', 'AKNA', 'AKT1S1', 'ALDH18A1', 'ALDH1B1', 'ALDH2', 'ALDH3A2', 'ALKBH5', 'AMOTL1', 'ANKLE2', 'ANP32B', 'ANPEP', 'ANXA1', 'ANXA2', 'ANXA7', 'AP2S1', 'APEX1', 'API5', 'APOBR', 'APOE', 'APOL1', 'APRT', 'AQP1', 'AQP3', 'ARF1', 'ARF5', 'ARHGAP45', 'ARHGDIB', 'ARHGEF10L', 'ARL2', 'ARL4A', 'ARL6IP5', 'ARPC1A', 'ARRDC4', 'ASAP2', 'ATOX1', 'ATP2A3', 'ATP2B4', 'ATP5F1A', 'ATP5F1B', 'ATP5F1D', 'ATP5F1E', 'ATP5MC3', 'ATP5ME', 'ATP5MF', 'ATP5PB', 'ATP5PF', 'ATP6V0D1', 'ATP6V1B2', 'ATP6V1E1', 'ATP6V1F', 'ATP8B1', 'ATXN7L3B', 'BCAM', 'BCAP31', 'BCL2L11', 'BCL2L2', 'BCL7C', 'BCL9L', 'BCR', 'BECN1', 'BGN', 'BICD2', 'BLMH', 'BNIP3L', 'BOK', 'BOP1', 'BRI3B

[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/var_1000genes.json (top-1000, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/common_genes_0.01.json (filtered_common=3046, min_cells_pct=0.01)

--- computing top-3000 (var) -> /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/var_3000genes.json ---
min_cells is  18.0
min_cells is  17.0
min_cells is  18.0
min_cells is  14.0
min_cells is  18.0
min_cells is  17.0
min_cells is  18.0
min_cells is  14.0
min_cells is  30.0
min_cells is  46.0
min_cells is  46.0
min_cells is  47.0
min_cells is  28.0
min_cells is  39.0
min_cells is  28.0
min_cells is  40.0
min_cells is  29.0
min_cells is  43.0
min_cells is  39.0
min_cells is  39.0
min_cells is  39.0
min_cells is  47.0
min_cells is  50.0
min_cells is  43.0
min_cells is  44.0
min_cells is  23.0
min_cells is  42.0
min_cells is  26.0
min_cells is  41.0
min_cells is  24.0
min_cells i

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[32m15:19:22[0m | [1mINFO[0m | [1mselected genes ['A2M', 'AAMP', 'AARS', 'ABCA2', 'ABCC1', 'ABCC3', 'ABCD4', 'ABCF3', 'ABHD12', 'ABHD16A', 'ABHD17B', 'ABHD2', 'ABHD4', 'ABI1', 'ABL1', 'ABLIM1', 'ABLIM3', 'ABR', 'ABTB1', 'ABTB2', 'ACAA1', 'ACAA2', 'ACACA', 'ACACB', 'ACAD11', 'ACADM', 'ACADVL', 'ACAP1', 'ACAT1', 'ACBD6', 'ACIN1', 'ACLY', 'ACO2', 'ACOT7', 'ACOT9', 'ACOX1', 'ACOX3', 'ACP2', 'ACSL3', 'ACSL5', 'ACSS1', 'ACTA2', 'ACTB', 'ACTG2', 'ACTN1', 'ACTN4', 'ACTR1A', 'ACTR1B', 'ACTR3', 'ACVR1B', 'ADAM15', 'ADAM19', 'ADAM9', 'ADAP2', 'ADAR', 'ADARB1', 'ADCY6', 'ADD1', 'ADH5', 'ADI1', 'ADIPOR1', 'ADIRF', 'ADM', 'ADNP', 'ADSL', 'AEBP1', 'AFAP1', 'AFDN', 'AGAP1', 'AGAP3', 'AGAP9', 'AGO1', 'AGO2', 'AGPAT1', 'AGPAT3', 'AGPS', 'AGRN', 'AHCYL1', 'AHNAK', 'AHSA1', 'AIDA', 'AIG1', 'AIP', 'AK1', 'AKAP1', 'AKAP11', 'AKAP13', 'AKAP17A', 'AKAP6', 'AKNA', 'AKR1A1', 'AKT1', 'AKT1S1', 'AKT2', 'AKT3', 'ALAD', 'ALDH16A1', 'ALDH18A1', 'ALDH1B1', 'ALDH2', 'ALDH3A2', 'ALDH3B1', 'ALG11', 'ALG5', 'ALKBH5'

[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/var_3000genes.json (top-3000, criteria=var); /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/all_genes.json (all_common=17943); /project/simmons_hts/kxu/hest/eval/data/VisiumR1-6/common_genes_0.01.json (filtered_common=3046, min_cells_pct=0.01)


### create leave one patient out cross validation

The above split created 19 splits for K+15. Modify it

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd

def create_splits(dest_dir, splits, K=None):
    """
    Create K patient-level splits where no patient appears in both train and test
    for the same split.

    Args:
        dest_dir (str or Path): directory to write train_i.csv and test_i.csv
        splits (dict): mapping patient_id -> list of sample_ids (samples are strings)
        K (int or None): number of splits to create. If None, defaults to number of patients.
                         If K > number_of_patients, K is reduced to number_of_patients.

    Returns:
        dict: the patient-chunk splits used (mapping chunk_index -> list of patient_ids)
    """
    dest_dir = Path(dest_dir)
    os.makedirs(dest_dir, exist_ok=True)

    # canonical ordering for determinism
    patients = sorted(list(splits.keys()))
    n_patients = len(patients)

    if n_patients == 0:
        raise ValueError("splits dict is empty (no patients).")

    if K is None:
        K = n_patients

    if K < 1:
        raise ValueError("K must be >= 1")

    # If requested K is greater than number of patients, reduce it (can't split patients finer)
    if K > n_patients:
        print(f"Requested K={K} > n_patients={n_patients}; reducing K -> {n_patients}")
        K = n_patients

    # chunk patients into exactly K groups (patient-level)
    patient_chunks = np.array_split(np.array(patients, dtype=object), K)
    # convert to dict: chunk_index -> list(patient_ids)
    patient_splits = {i: list(chunk.tolist()) for i, chunk in enumerate(patient_chunks)}

    # For each chunk: that chunk's patients -> TEST, other chunks' patients -> TRAIN
    for i in range(len(patient_splits)):
        test_patients = patient_splits[i]
        # flatten sample lists for test
        test_ids = [s for p in test_patients for s in splits[p]]

        # train patients are all other patient groups
        train_patients = [p for j, group in patient_splits.items() if j != i for p in group]
        train_ids = [s for p in train_patients for s in splits[p]]

        print(f"Split {i+1}/{len(patient_splits)}: {len(train_ids)} train samples, {len(test_ids)} test samples")
        # optionally print patient-level composition:
        print(f"  test patients: {test_patients}")
        print("")

        # Build dataframes (keep same columns as your pipeline expects)
        train_df = pd.DataFrame({
            "sample_id": train_ids,
            "patches_path": [os.path.join("patches", sid + ".h5") for sid in train_ids],
            "expr_path":   [os.path.join("adata", sid + ".h5ad") for sid in train_ids],
        })

        test_df = pd.DataFrame({
            "sample_id": test_ids,
            "patches_path": [os.path.join("patches", sid + ".h5") for sid in test_ids],
            "expr_path":   [os.path.join("adata", sid + ".h5ad") for sid in test_ids],
        })

        train_df.to_csv(dest_dir / f"train_{i}.csv", index=False)
        test_df.to_csv(dest_dir / f"test_{i}.csv", index=False)

    return patient_splits
