### Prepare broad data for training in /project/simmons_hts/kxu/hest/eval/data/broad/



In [3]:
# --- Standard library ---
import os
import json
import shutil
from pathlib import Path
from typing import List, Dict, Tuple, Optional

# --- Third-party ---
import numpy as np
import pandas as pd
import scanpy as sc

# --- HEST ---
from hest import iter_hest
from hest.utils import get_k_genes
from hest.HESTData import create_splits

In [4]:


# ---------- helpers ----------

def _discover_samples_broad(broad_root: Path, ids: Optional[List[str]] = None) -> Dict[str, Dict[str, Path]]:
    """
    Return: {sample_id: {"adata": Path, "patch": Path|None, "vis": Path|None}}
    """
    samples = {}
    if ids is None:
        ids = sorted([p.name for p in Path(broad_root).iterdir() if p.is_dir()])

    for sid in ids:
        sdir = Path(broad_root) / sid
        if not sdir.is_dir():
            continue

        adata = sdir / "aligned_adata.h5ad"
        patches_dir = sdir / "patches"
        vis_dir = sdir / "patches_vis"

        patch_h5 = None
        if patches_dir.exists():
            cand = sorted(patches_dir.glob("*.h5"))
            if cand:
                exact = [c for c in cand if c.name == f"{sid}.h5"]
                patch_h5 = exact[0] if exact else cand[0]

        vis_png = None
        if vis_dir.exists():
            cand = sorted(vis_dir.glob("*.png"))
            if cand:
                exact = [c for c in cand if c.name == f"{sid}_patch_vis.png"]
                vis_png = exact[0] if exact else cand[0]

        if adata.exists():
            samples[sid] = {"adata": adata, "patch": patch_h5, "vis": vis_png}

    return samples


def _transfer(src: Optional[Path], dst: Path, label: str, symlink: bool, missing_list: list):
    if src is None or not Path(src).exists():
        missing_list.append((dst.stem, label, str(src) if src is not None else "<none>"))
        return
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        dst.unlink()
    if symlink:
        try:
            os.symlink(src, dst)
        except FileExistsError:
            pass
    else:
        shutil.copy(src, dst)


def write_var_k_genes_from_paths(
    adata_paths,
    k,
    criteria,
    var_out_path,
    all_genes_out_path=None,
    exclude_keywords=None,
):
    """
    Load all adatas, call HEST's get_k_genes() for top-k variable genes,
    and also save all common genes (filtered). Suppresses anndata FutureWarnings.

    Args:
        adata_paths (list[Path]): paths to aligned_adata.h5ad
        k (int): number of top variable genes
        criteria (str): selection criteria
        var_out_path (Path): output path for var_k_genes.json
        all_genes_out_path (Path|None): output path for all_genes.json
        exclude_keywords (list[str]|None): list of substrings to filter out
    """
    import scanpy as sc
    from hest.utils import get_k_genes
    import json, warnings

    if exclude_keywords is None:
        exclude_keywords = ["NegControl", "Codeword", "Intergenic_Region", "Control", "BLANK"]

    # suppress FutureWarnings (e.g. is_categorical_dtype)
    warnings.filterwarnings("ignore", category=FutureWarning, module="anndata")

    # Load all adatas
    adata_list = []
    for p in adata_paths:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=FutureWarning)
            ad = sc.read_h5ad(str(p))
        adata_list.append(ad)

    # Get top-k variable genes (writes to var_out_path)
    _ = get_k_genes(adata_list, k, criteria, save_dir=str(var_out_path))

    # Collect *all* common genes
    common_genes = set(adata_list[0].var_names)
    for ad in adata_list[1:]:
        common_genes &= set(ad.var_names)

    # Filter unwanted genes
    def _keep(gene: str) -> bool:
        for kw in exclude_keywords:
            if kw in gene:
                return False
        return True

    filtered_genes = [g for g in common_genes if _keep(g)]

    # Write all_genes.json
    if all_genes_out_path is None:
        all_genes_out_path = var_out_path.parent / "all_genes.json"

    with open(all_genes_out_path, "w") as f:
        json.dump({"genes": sorted(filtered_genes)}, f)

    print(f"[INFO] Wrote {var_out_path} (top-{k}) and {all_genes_out_path} ({len(filtered_genes)} filtered common genes)")



# ---------- main entry ----------

def create_benchmark_data_broad_autodiscover(
    save_dir: str | Path,
    K: int,
    broad_root: str | Path = "/project/simmons_hts/kxu/hest/xenium_data/broad",
    ids: Optional[List[str]] = None,
    gene_k: int = 50,
    gene_criteria: str = "var",
    symlink: bool = False,
    seed: int = 0,   # for reproducible shuffling before create_splits
):
    """
    Build a HEST benchmark package from 'broad' without a metadata DF, using HEST's create_splits & get_k_genes.

    Output:
      <save_dir>/
        var_50genes.json
        splits/...
        patches/<id>.h5
        patches/vis/<id>.png
        adata/<id>.h5ad
    """
    save_dir = Path(save_dir)
    broad_root = Path(broad_root)
    save_dir.mkdir(parents=True, exist_ok=True)

    # 1) Discover samples
    samples = _discover_samples_broad(broad_root, ids=ids)
    if not samples:
        raise ValueError(f"No valid samples (with aligned_adata.h5ad) found under {broad_root}.")
    discovered_ids = sorted(samples.keys())
    print(f"[INFO] Discovered {len(discovered_ids)} samples: {discovered_ids}")

    # 2) Build a minimal "metadata" DF for splitting (patient from prefix; dataset_title='broad')
    #    e.g., 'UC6_I' -> patient='UC6'
    def _infer_patient(sid: str) -> str:
        return sid.split("_")[0] if "_" in sid else sid

    meta = pd.DataFrame({
        "id": discovered_ids,
        "patient": [ _infer_patient(s) for s in discovered_ids ],
        "dataset_title": ["broad"] * len(discovered_ids),
    })

    # 3) Compute var_k genes → var_50genes.json
    adata_paths = [samples[sid]["adata"] for sid in discovered_ids]
    var_json = save_dir / f"var_{gene_k}genes.json"
    write_var_k_genes_from_paths(adata_paths, gene_k, gene_criteria, var_json)
    print(f"[INFO] Wrote {var_json}")

    # 4) K-fold splits using HEST's create_splits
    #    HEST expects a dict mapping groups to list of sample ids.
    #    Match your old logic: group by (dataset_title, patient)
    group = meta.groupby(["dataset_title", "patient"])["id"].agg(list).to_dict()
    # shuffle deterministically within each group for stronger randomness
    rng = np.random.RandomState(seed)
    for key, id_list in group.items():
        rng.shuffle(id_list)
    splits_dir = save_dir / "splits"
    splits_dir.mkdir(parents=True, exist_ok=True)
    create_splits(str(splits_dir), group, K=K)
    print(f"[INFO] Wrote {K}-fold splits to {splits_dir}")

    # 5) Copy/symlink assets
    (save_dir / "patches").mkdir(exist_ok=True, parents=True)
    (save_dir / "patches" / "vis").mkdir(exist_ok=True, parents=True)
    (save_dir / "adata").mkdir(exist_ok=True, parents=True)

    missing = []
    for sid in discovered_ids:
        info = samples[sid]
        _transfer(info.get("patch"), save_dir / "patches" / f"{sid}.h5", "patch", symlink, missing)
        _transfer(info.get("vis"),   save_dir / "patches" / "vis" / f"{sid}.png", "vis",   symlink, missing)
        _transfer(info.get("adata"), save_dir / "adata" / f"{sid}.h5ad", "adata", symlink, missing)

    if missing:
        print("[WARN] Missing files:")
        for sid, lbl, path in missing:
            print(f"  - {sid} [{lbl}] → {path}")

    print(f"✅ Benchmark dataset created at {save_dir}")

In [11]:
create_benchmark_data_broad_autodiscover(
    save_dir="/project/simmons_hts/kxu/hest/eval/data/broad",
    K=5, # 5 patients with 7 samples
    broad_root="/project/simmons_hts/kxu/hest/xenium_data/broad",
    # ids=["UC1_I","UC2_I"],  # optionally limit to a subset
    gene_k=50,
    gene_criteria="var",
    symlink=False,            # set True to save disk space
    seed=0                    # controls fold assignment deterministically
)

[INFO] Discovered 7 samples: ['DC5', 'UC1_I', 'UC1_NI', 'UC6_I', 'UC6_NI', 'UC7_I', 'UC9_I']
min_cells is  757.0
min_cells is  826.0
min_cells is  695.0
min_cells is  936.0
min_cells is  678.0
min_cells is  438.0
min_cells is  1002.0
[32m00:07:41[0m | [1mINFO[0m | [1mFound 404 common genes[0m


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


[32m00:07:41[0m | [1mINFO[0m | [1mselected genes ['BANK1', 'CA1', 'CA2', 'CA4', 'CCR7', 'CD19', 'CD22', 'CD3D', 'CD40LG', 'CD5', 'CD6', 'CDHR5', 'CEACAM1', 'CLCA4', 'CXCL13', 'CXCL5', 'CXCR4', 'DMBT1', 'DUOX2', 'EPCAM', 'FCER2', 'FCRLA', 'HHLA2', 'HLA-DRA', 'IGHG2', 'IGHG3', 'IGHM', 'IL23A', 'IL7R', 'KRT19', 'LEF1', 'MS4A1', 'MS4A12', 'NOS2', 'NXPE1', 'OLFM4', 'OSM', 'PENK', 'PKIB', 'PROK2', 'RAB3B', 'SDCBP2', 'SNAP25', 'SPIB', 'TCF7', 'TIGIT', 'TRAC', 'TRAT1', 'UCHL1', 'UGT2B17'][0m
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/broad/var_50genes.json (top-50) and /project/simmons_hts/kxu/hest/eval/data/broad/all_genes.json (460 filtered common genes)
[INFO] Wrote /project/simmons_hts/kxu/hest/eval/data/broad/var_50genes.json
Split 0/5
train set is  ['UC1_NI', 'UC1_I', 'UC6_I', 'UC6_NI', 'UC7_I', 'UC9_I']

test set is  ['DC5']

Split 1/5
train set is  ['DC5', 'UC6_I', 'UC6_NI', 'UC7_I', 'UC9_I']

test set is  ['UC1_NI' 'UC1_I']

Split 2/5
train set is  ['DC5', 'UC1_NI', 'U