# Tutorial: BayesPrism Deconvolution Validation

Audience:
- Users validating `bayesprism-py` behavior before downstream analysis.

Prerequisites:
- Python environment with `bayesprism` installed.
- Optional: local tutorial dataset at `tutorial.dat/tutorial.gbm.rdata`.

Learning goals:
- Reproduce the deconvolution tutorial flow with Python APIs.
- Validate core invariants and deterministic synthetic regression checks.


## Outline

1. Environment and deterministic configuration
2. Data loading (`synthetic` default, optional `real_subset`)
3. Tutorial-style QC summaries and plot approximations
4. Preprocessing, optional marker selection, and `new_prism` construction
5. `run_prism`, result extraction, invariant checks, and JSON validation report


In [None]:
from __future__ import annotations

import json
import os
import random
import subprocess
import sys
import time
from pathlib import Path

import numpy as np
import pandas as pd

SEED = 20260220
random.seed(SEED)
np.random.seed(SEED)
NOTEBOOK_START = time.perf_counter()
TIMINGS: dict[str, float] = {}


def find_repo_root(start: Path) -> Path:
    cur = start.resolve()
    while True:
        if (cur / "pyproject.toml").exists() and (cur / "src").exists():
            return cur
        if cur.parent == cur:
            raise RuntimeError("Could not locate repository root from current working directory")
        cur = cur.parent


REPO_ROOT = find_repo_root(Path.cwd())
if str(REPO_ROOT / "src") not in sys.path:
    sys.path.insert(0, str(REPO_ROOT / "src"))

print(f"Repository root: {REPO_ROOT}")


In [None]:
def env_flag(name: str, default: str = "0") -> bool:
    return os.getenv(name, default).strip() == "1"


MODE = os.getenv("BAYESPRISM_NOTEBOOK_MODE", "synthetic").strip().lower()
if MODE not in {"synthetic", "real_subset"}:
    raise ValueError("BAYESPRISM_NOTEBOOK_MODE must be 'synthetic' or 'real_subset'")

RUN_PLOTS = env_flag("BAYESPRISM_NOTEBOOK_RUN_PLOTS", "1")
RUN_DE = env_flag("BAYESPRISM_NOTEBOOK_RUN_DE", "0")
RUN_HEAVY = env_flag("BAYESPRISM_NOTEBOOK_RUN_HEAVY", "0")
MAX_CELLS_PER_STATE = int(os.getenv("BAYESPRISM_NOTEBOOK_MAX_CELLS_PER_STATE", "30"))
N_GENE_SUBSET = int(os.getenv("BAYESPRISM_NOTEBOOK_N_GENE_SUBSET", "6000"))
TUTORIAL_RDATA = Path(
    os.getenv(
        "BAYESPRISM_TUTORIAL_RDATA",
        "/Users/andresvallejo/Documents/00-Bioinformatics/BayesPrism/tutorial.dat/tutorial.gbm.rdata",
    )
)

GIBBS_CONTROL_FAST = {
    "chain_length": 80,
    "burn_in": 20,
    "thinning": 4,
    "seed": 77,
    "alpha": 1.0,
}
OPT_CONTROL_FAST = {"optimizer": "MAP", "maxit": 200, "n_cores": 1}

GIBBS_CONTROL_HEAVY = {
    "chain_length": 1000,
    "burn_in": 500,
    "thinning": 2,
    "seed": 123,
    "alpha": 1.0,
}

config_preview = {
    "mode": MODE,
    "run_plots": RUN_PLOTS,
    "run_de": RUN_DE,
    "run_heavy": RUN_HEAVY,
    "max_cells_per_state": MAX_CELLS_PER_STATE,
    "n_gene_subset": N_GENE_SUBSET,
    "tutorial_rdata": str(TUTORIAL_RDATA),
}
print(json.dumps(config_preview, indent=2))


In [None]:
def ensure_optional_dependencies(required_modules: list[str]) -> None:
    missing: list[str] = []
    for module_name in required_modules:
        try:
            __import__(module_name)
        except ModuleNotFoundError:
            missing.append(module_name)

    if not missing:
        print("Optional dependencies already available.")
        return

    print(f"Installing missing optional dependencies: {missing}")
    subprocess.check_call([sys.executable, "-m", "pip", "install", *missing])


optional_modules: list[str] = []
if RUN_PLOTS:
    optional_modules.append("matplotlib")
if MODE == "real_subset":
    optional_modules.append("rdata")
ensure_optional_dependencies(optional_modules)

if RUN_PLOTS:
    import matplotlib.pyplot as plt


In [None]:
from bayesprism import (
    cleanup_genes,
    get_exp,
    get_exp_stat,
    get_fraction,
    new_prism,
    run_prism,
    select_gene_type,
    select_marker,
)


def load_synthetic_fixture(repo_root: Path) -> dict[str, object]:
    fixture_dir = repo_root / "tests" / "data" / "fixtures" / "small"
    npz_path = fixture_dir / "small_fixture.npz"
    metadata_path = fixture_dir / "small_fixture_metadata.json"

    npz = np.load(npz_path, allow_pickle=True)
    metadata = json.loads(metadata_path.read_text(encoding="utf-8"))

    reference = pd.DataFrame(
        npz["reference"],
        index=npz["reference_rows"].astype(str),
        columns=npz["genes"].astype(str),
    )
    mixture = pd.DataFrame(
        npz["mixture"],
        index=npz["mixture_rows"].astype(str),
        columns=npz["genes"].astype(str),
    )

    expected = {
        "phi_cell_state": npz["phi_cell_state"],
        "phi_cell_type": npz["phi_cell_type"],
        "theta_first_type": npz["theta_first_type"],
        "theta_final_type": npz["theta_final_type"],
    }

    return {
        "reference": reference,
        "mixture": mixture,
        "cell_type_labels": [str(x) for x in metadata["labels"]["cell_type_label"]],
        "cell_state_labels": [str(x) for x in metadata["labels"]["cell_state_label"]],
        "expected": expected,
        "objects": ["reference", "mixture", "cell_type_labels", "cell_state_labels"],
    }


def load_real_tutorial_subset(
    rdata_path: Path,
    max_cells_per_state: int,
    n_gene_subset: int,
    seed: int,
) -> dict[str, object]:
    if not rdata_path.exists():
        raise FileNotFoundError(f"Tutorial file not found: {rdata_path}")

    import rdata

    converted = rdata.conversion.convert(rdata.parser.parse_file(str(rdata_path)))
    required = ["sc.dat", "bk.dat", "cell.type.labels", "cell.state.labels"]
    missing = [name for name in required if name not in converted]
    if missing:
        raise KeyError(f"Missing required objects in tutorial file: {missing}")

    sc_da = converted["sc.dat"]
    bk_da = converted["bk.dat"]

    cell_type_labels = np.asarray(converted["cell.type.labels"], dtype=str)
    cell_state_labels = np.asarray(converted["cell.state.labels"], dtype=str)

    sc_cells = np.asarray(sc_da.coords[sc_da.dims[0]].values, dtype=str)
    sc_genes = np.asarray(sc_da.coords[sc_da.dims[1]].values, dtype=str)
    bk_samples = np.asarray(bk_da.coords[bk_da.dims[0]].values, dtype=str)
    bk_genes = np.asarray(bk_da.coords[bk_da.dims[1]].values, dtype=str)

    rng = np.random.default_rng(seed)
    state_series = pd.Series(cell_state_labels, dtype=str)
    selected_cell_idx: list[np.ndarray] = []
    for state in pd.unique(state_series):
        idx = np.flatnonzero(state_series.to_numpy() == state)
        take = min(max_cells_per_state, idx.size)
        if take == idx.size:
            picked = idx
        else:
            picked = np.sort(rng.choice(idx, size=take, replace=False))
        selected_cell_idx.append(picked)

    selected_idx = np.sort(np.concatenate(selected_cell_idx))

    bk_gene_set = set(bk_genes.tolist())
    shared_genes = np.asarray([g for g in sc_genes if g in bk_gene_set], dtype=str)
    if shared_genes.size == 0:
        raise ValueError("No shared genes between sc.dat and bk.dat")

    if shared_genes.size > n_gene_subset:
        sampled = np.sort(rng.choice(shared_genes.size, size=n_gene_subset, replace=False))
        shared_genes = shared_genes[sampled]

    sc_lookup = pd.Series(np.arange(sc_genes.size), index=sc_genes)
    bk_lookup = pd.Series(np.arange(bk_genes.size), index=bk_genes)
    sc_gene_idx = sc_lookup.loc[shared_genes].to_numpy(dtype=int)
    bk_gene_idx = bk_lookup.loc[shared_genes].to_numpy(dtype=int)

    reference = pd.DataFrame(
        sc_da.data[np.ix_(selected_idx, sc_gene_idx)],
        index=sc_cells[selected_idx],
        columns=shared_genes,
    )
    mixture = pd.DataFrame(
        bk_da.data[:, bk_gene_idx],
        index=bk_samples,
        columns=shared_genes,
    )

    keep_cells = reference.sum(axis=1) > 0
    if not bool(keep_cells.all()):
        kept_idx = np.flatnonzero(keep_cells.to_numpy())
        selected_idx = selected_idx[kept_idx]
        reference = reference.loc[keep_cells].copy()

    if reference.shape[0] == 0:
        raise ValueError("No non-zero reference rows remain after subsetting")

    return {
        "reference": reference,
        "mixture": mixture,
        "cell_type_labels": cell_type_labels[selected_idx].tolist(),
        "cell_state_labels": cell_state_labels[selected_idx].tolist(),
        "expected": None,
        "objects": required,
    }


## Step 1 - Load Data and Validate Orientation


In [None]:
_t0 = time.perf_counter()
if MODE == "synthetic":
    bundle = load_synthetic_fixture(REPO_ROOT)
else:
    bundle = load_real_tutorial_subset(
        rdata_path=TUTORIAL_RDATA,
        max_cells_per_state=MAX_CELLS_PER_STATE,
        n_gene_subset=N_GENE_SUBSET,
        seed=SEED,
    )
TIMINGS["load_data"] = round(time.perf_counter() - _t0, 3)

reference_raw = bundle["reference"]
mixture_raw = bundle["mixture"]
cell_type_labels = [str(x) for x in bundle["cell_type_labels"]]
cell_state_labels = [str(x) for x in bundle["cell_state_labels"]]
expected_synthetic = bundle["expected"]

assert isinstance(reference_raw, pd.DataFrame), "reference must be a DataFrame (cell-by-gene)"
assert isinstance(mixture_raw, pd.DataFrame), "mixture must be a DataFrame (sample-by-gene)"
assert reference_raw.shape[0] == len(cell_type_labels) == len(cell_state_labels)
assert reference_raw.shape[1] > 0 and mixture_raw.shape[1] > 0
shared_genes = [g for g in reference_raw.columns if g in set(mixture_raw.columns)]
assert len(shared_genes) > 0, "reference and mixture must have shared genes"
assert np.isfinite(reference_raw.to_numpy(dtype=float)).all()
assert np.isfinite(mixture_raw.to_numpy(dtype=float)).all()

print("Loaded objects:", bundle["objects"])
print("reference_raw shape (cells x genes):", reference_raw.shape)
print("mixture_raw shape (samples x genes):", mixture_raw.shape)
print(
    "unique cell types:",
    len(set(cell_type_labels)),
    "| unique cell states:",
    len(set(cell_state_labels)),
)
print("shared genes:", len(shared_genes))


## Step 2 - QC Label Checks


In [None]:
cell_type_counts = pd.Series(cell_type_labels, dtype=str).value_counts()
cell_state_counts = pd.Series(cell_state_labels, dtype=str).value_counts()
state_type_table = pd.crosstab(
    pd.Series(cell_state_labels, dtype=str),
    pd.Series(cell_type_labels, dtype=str),
)

print("Cell type counts (descending):")
display(cell_type_counts.to_frame("count"))

print("Cell state count summary:")
print(cell_state_counts.describe())

print("State x Type table (head):")
display(state_type_table.head(15))


## Step 3 - QC Plot Approximations (Optional)


In [None]:
if RUN_PLOTS:
    sc_mean = np.log1p(reference_raw.mean(axis=0).to_numpy(dtype=float))
    bulk_aligned = mixture_raw.reindex(columns=reference_raw.columns, fill_value=0)
    bulk_mean = np.log1p(bulk_aligned.mean(axis=0).to_numpy(dtype=float))

    ref_values = reference_raw.to_numpy(dtype=float)
    ref_row_norm = ref_values / np.clip(ref_values.sum(axis=1, keepdims=True), 1e-12, None)
    gene_specificity = (
        ref_row_norm / np.clip(ref_row_norm.sum(axis=0, keepdims=True), 1e-12, None)
    ).max(axis=0)

    fig, axes = plt.subplots(1, 3, figsize=(18, 4))

    cell_type_counts.sort_values(ascending=False).plot(kind="bar", ax=axes[0], color="#4C78A8")
    axes[0].set_title("Cell type counts")
    axes[0].set_ylabel("Cells")

    axes[1].scatter(sc_mean, gene_specificity, s=8, alpha=0.5, color="#F58518")
    axes[1].set_title("Mean expression vs specificity")
    axes[1].set_xlabel("log1p(mean sc expression)")
    axes[1].set_ylabel("Specificity")

    axes[2].scatter(sc_mean, bulk_mean, s=8, alpha=0.5, color="#54A24B")
    axes[2].set_title("Bulk vs sc mean expression")
    axes[2].set_xlabel("log1p(mean sc expression)")
    axes[2].set_ylabel("log1p(mean bulk expression)")

    fig.tight_layout()
else:
    print("RUN_PLOTS=0: skipping QC plots.")


## Step 4 - Filter Genes and Optional Marker Selection


In [None]:
_t0 = time.perf_counter()
if MODE == "synthetic":
    # Keep synthetic fixture reference unchanged for deterministic regression checks.
    reference_filtered = reference_raw.copy()
    reference_filtered_pc = reference_raw.copy()
    print("Synthetic mode: using fixture reference directly (cleanup/select_gene_type bypassed).")
else:
    reference_filtered = cleanup_genes(
        input=reference_raw,
        input_type="count.matrix",
        species="hs",
        gene_group=["Rb", "Mrp", "other_Rb", "chrM", "MALAT1", "chrX", "chrY"],
        exp_cells=5,
    )
    if reference_filtered.shape[1] == 0:
        raise ValueError("cleanup_genes removed all genes")

    reference_filtered_pc = select_gene_type(
        reference_filtered,
        gene_type=["protein_coding"],
    )
    if reference_filtered_pc.shape[1] == 0:
        print("No protein_coding genes matched; falling back to cleanup-filtered matrix.")
        reference_filtered_pc = reference_filtered.copy()

TIMINGS["filtering"] = round(time.perf_counter() - _t0, 3)
print("reference_filtered shape:", reference_filtered.shape)
print("reference_filtered_pc shape:", reference_filtered_pc.shape)


In [None]:
_t0 = time.perf_counter()
diff_exp_stat = None
reference_markers = None
reference_for_prism = reference_filtered_pc

if RUN_DE:
    de_input = reference_raw.loc[:, (reference_raw > 0).sum(axis=0) > 3]
    if de_input.shape[1] == 0:
        print("DE input has zero genes after prevalence filter; skipping marker selection.")
    else:
        diff_exp_stat = get_exp_stat(
            sc_dat=de_input,
            cell_type_labels=cell_type_labels,
            cell_state_labels=cell_state_labels,
            pseudo_count=0.1,
            cell_count_cutoff=50,
            n_cores=1,
        )

        reference_markers = select_marker(
            sc_dat=reference_filtered_pc,
            stat=diff_exp_stat,
            pval_max=0.01,
            lfc_min=0.1,
        )

        if reference_markers.shape[1] > 0:
            reference_for_prism = reference_markers
            print("Selected marker genes:", reference_markers.shape[1])
        else:
            print("No marker genes selected; using reference_filtered_pc.")
else:
    print("RUN_DE=0: skipping get_exp_stat/select_marker section.")

TIMINGS["marker_selection"] = round(time.perf_counter() - _t0, 3)
print("reference_for_prism shape:", reference_for_prism.shape)


## Step 5 - Construct Prism and Run Deconvolution


In [None]:
_t0 = time.perf_counter()
if MODE == "real_subset":
    prism_key = "tumor"
    if prism_key not in set(cell_type_labels):
        raise KeyError("'tumor' is required in real_subset mode cell_type_labels")
else:
    prism_key = None

outlier_cut = 1.0 if MODE == "synthetic" else 0.01
outlier_fraction = 1.0 if MODE == "synthetic" else 0.1

try:
    my_prism = new_prism(
        reference=reference_for_prism,
        input_type="count.matrix",
        cell_type_labels=cell_type_labels,
        cell_state_labels=cell_state_labels,
        key=prism_key,
        mixture=mixture_raw,
        outlier_cut=outlier_cut,
        outlier_fraction=outlier_fraction,
    )
except ValueError as exc:
    if MODE == "real_subset" and "reference rows must have positive sums" in str(exc):
        outlier_cut = 1.0
        outlier_fraction = 1.0
        print("Falling back to outlier_cut=1.0 and outlier_fraction=1.0 for stability.")
        my_prism = new_prism(
            reference=reference_for_prism,
            input_type="count.matrix",
            cell_type_labels=cell_type_labels,
            cell_state_labels=cell_state_labels,
            key=prism_key,
            mixture=mixture_raw,
            outlier_cut=outlier_cut,
            outlier_fraction=outlier_fraction,
        )
    else:
        raise
TIMINGS["new_prism"] = round(time.perf_counter() - _t0, 3)

print("prism_key:", prism_key)
print("outlier_cut:", outlier_cut, "| outlier_fraction:", outlier_fraction)
print("phi_cell_state shape:", my_prism.phi_cell_state.phi.shape)
print("phi_cell_type shape:", my_prism.phi_cell_type.phi.shape)
print("mixture shape in Prism:", my_prism.mixture.shape)


In [None]:
_t0 = time.perf_counter()
bp_res = run_prism(
    prism=my_prism,
    n_cores=1,
    update_gibbs=True,
    gibbs_control=GIBBS_CONTROL_FAST,
    opt_control=OPT_CONTROL_FAST,
)
TIMINGS["run_prism_fast"] = round(time.perf_counter() - _t0, 3)
print("Fast run completed.")

print("Heavy profile template (disabled by default):")
heavy_preview = {
    "gibbs_control": GIBBS_CONTROL_HEAVY,
    "opt_control": {"optimizer": "MAP", "maxit": 100000, "n_cores": 1},
}
print(json.dumps(heavy_preview, indent=2))

if RUN_HEAVY:
    _t_heavy = time.perf_counter()
    _ = run_prism(
        prism=my_prism,
        n_cores=1,
        update_gibbs=True,
        gibbs_control=GIBBS_CONTROL_HEAVY,
        opt_control={"optimizer": "MAP", "maxit": 100000, "n_cores": 1},
    )
    TIMINGS["run_prism_heavy"] = round(time.perf_counter() - _t_heavy, 3)
else:
    print("RUN_HEAVY=0: skipped heavy profile execution.")


## Step 6 - Extract Results, Check Invariants, Emit Validation Report


In [None]:
theta = get_fraction(bp=bp_res, which_theta="final", state_or_type="type")
target_cell_name = "tumor" if "tumor" in theta.columns else str(theta.columns[0])
z_target = get_exp(bp=bp_res, state_or_type="type", cell_name=target_cell_name)

print("theta head:")
display(theta.head())
print(f"z_target head for cell '{target_cell_name}':")
display(z_target.iloc[:5, :5])

invariants = {
    "theta_non_negative": bool((theta.to_numpy(dtype=float) >= -1e-12).all()),
    "theta_rowsum_close": bool(
        np.allclose(theta.sum(axis=1).to_numpy(dtype=float), 1.0, atol=5e-3, rtol=0.0)
    ),
    "z_non_negative": bool((z_target.to_numpy(dtype=float) >= -1e-12).all()),
    "shared_genes_non_empty": bool(len(shared_genes) > 0),
    "orientation_reference_cell_by_gene": bool(reference_raw.shape[0] == len(cell_type_labels)),
    "orientation_mixture_sample_by_gene": bool(mixture_raw.shape[1] == len(shared_genes)),
}

synthetic_regression: dict[str, bool] | None
if MODE == "synthetic":
    assert expected_synthetic is not None
    synthetic_regression = {
        "phi_cell_state_match": bool(
            np.allclose(
                my_prism.phi_cell_state.phi.to_numpy(dtype=float),
                expected_synthetic["phi_cell_state"],
                atol=1e-12,
                rtol=0.0,
            )
        ),
        "phi_cell_type_match": bool(
            np.allclose(
                my_prism.phi_cell_type.phi.to_numpy(dtype=float),
                expected_synthetic["phi_cell_type"],
                atol=1e-12,
                rtol=0.0,
            )
        ),
        "theta_first_type_match": bool(
            np.allclose(
                bp_res.posterior_initial_cell_type.theta.to_numpy(dtype=float),
                expected_synthetic["theta_first_type"],
                atol=5e-3,
                rtol=5e-2,
            )
        ),
        "theta_final_type_match": bool(
            np.allclose(
                theta.to_numpy(dtype=float),
                expected_synthetic["theta_final_type"],
                atol=5e-3,
                rtol=5e-2,
            )
        ),
    }
else:
    synthetic_regression = None

TIMINGS["total"] = round(time.perf_counter() - NOTEBOOK_START, 3)

checks = list(invariants.values())
if synthetic_regression is not None:
    checks.extend(synthetic_regression.values())

validation_report = {
    "status": "ok" if all(checks) else "failed",
    "mode": MODE,
    "key_used": "tumor" if MODE == "real_subset" else "none",
    "target_cell_name": target_cell_name,
    "shapes": {
        "reference_raw": list(reference_raw.shape),
        "mixture_raw": list(mixture_raw.shape),
        "reference_for_prism": list(reference_for_prism.shape),
        "theta": list(theta.shape),
        "z_target": list(z_target.shape),
    },
    "invariants": invariants,
    "synthetic_regression": synthetic_regression,
    "timing_seconds": TIMINGS,
}

marker = "VALIDATION_REPORT_JSON="
print(marker + json.dumps(validation_report, sort_keys=True))
if validation_report["status"] != "ok":
    raise AssertionError("Validation checks failed; inspect validation_report")
