Goal. Build a UMAP embedding from QC-filtered raw counts, then use it as a common reference map for downstream interpretation.

What we check. Visualize dataset structure by Patient_ID, batchID, Tissue_type (N/T for Normal or Tumor sample), and progressively deeper annotation levels (e.g., clTopLevel, clMidwayPr, cl295v11SubFull) to confirm biological consistency and potential batch effects. It is expected that malignant epithelial cells form individual "spots" on UMAPs due to cancer heterogenity. 

What we produce. A set of UMAP figures (saved as PNG if enabled) plus pathway-focused gene overlays (continuous activity for ligand+receptor sets and RGB mixes for top-expressed genes in signal pathways which are essential in organoid cultures (EGF, FGF, WNT, WNT-RSPO, BMP, TGFb); BMP inhibitors such as NOG, GREM1, GREM2 highlighted separately.

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, to_rgb, to_hex
from matplotlib.lines import Line2D
from pathlib import Path
import warnings
from typing import Iterable, Optional, Tuple, Dict, List


In [None]:
# -----------------------------
# Global settings
# -----------------------------

RANDOM_STATE = 123
np.random.seed(RANDOM_STATE)

# Scanpy defaults
sc.settings.verbosity = 2  # errors (0), warnings (1), info (2), hints (3)
sc.settings.set_figure_params(
    dpi=100,
    frameon=True,
    fontsize=10,
)

# Matplotlib defaults
plt.rcParams.update({
    "figure.figsize": (6, 4),
    "axes.grid": False,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
})

# Cleaner notebook output
warnings.filterwarnings("once", category=UserWarning)
warnings.filterwarnings("once", category=RuntimeWarning)


In [None]:
## Input / Output paths and embedding parameters

# -----------------------------
# Project directories (edit this one line)
# -----------------------------
PROJECT_DIR = Path("")
QC_OUT_DIR  = PROJECT_DIR / "outputs"   # where QC notebook saved .h5ad files

# -----------------------------
# Input AnnData files
# -----------------------------
ADATA_PATH        = QC_OUT_DIR / "GSE178341_filtered.h5ad"
ADATA_NORMAL_PATH = QC_OUT_DIR / "GSE178341_normal.h5ad"
ADATA_TUMOR_PATH  = QC_OUT_DIR / "GSE178341_tumor.h5ad"

# -----------------------------
# Output directories for this notebook
# -----------------------------
OUT_DIR = PROJECT_DIR / "outputs_umap"
FIG_DIR = OUT_DIR / "figures"

OUT_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Save figures to disk (True) or only show inline (False)
SAVE_FIGS = True


# -----------------------------#############################
# Embedding parameters
# -----------------------------
N_TOP_GENES = 4000
N_PCS = 50
N_NEIGHBORS = 30
MIN_DIST = 0.3              # set None to use Scanpy default
UMAP_RANDOM_STATE = RANDOM_STATE


# -----------------------------
# Quick sanity print
# -----------------------------
print("Input AnnData:")
print("  full   :", ADATA_PATH)
print("  normal :", ADATA_NORMAL_PATH)
print("  tumor  :", ADATA_TUMOR_PATH)
print()
print("Output dir:", OUT_DIR.resolve())
print("Save figures:", SAVE_FIGS)
print()
print("Embedding parameters:")
print(f"  N_TOP_GENES   = {N_TOP_GENES}")
print(f"  N_PCS         = {N_PCS}")
print(f"  N_NEIGHBORS   = {N_NEIGHBORS}")
print(f"  MIN_DIST      = {MIN_DIST}")
print(f"  RANDOM_STATE  = {UMAP_RANDOM_STATE}")


In [None]:
## Load QC-filtered AnnData and sanity checks

# -----------------------------
# Load AnnData
# -----------------------------
adata = sc.read_h5ad(ADATA_PATH)

print("AnnData loaded.")
print(f"Shape: {adata.n_obs:,} cells × {adata.n_vars:,} genes")


# -----------------------------
# Check required metadata columns
# -----------------------------
required_obs_cols = [
    "Patient_ID",
    "Tissue_type",
    "batchID",        # or adjust if your batch column has a different name
    "clTopLevel",
]

missing_cols = [c for c in required_obs_cols if c not in adata.obs.columns]

if missing_cols:
    raise KeyError(
        "Missing required columns in adata.obs:\n"
        + "\n".join(f"  - {c}" for c in missing_cols)
    )

print("\nRequired obs columns present:")
for c in required_obs_cols:
    print(f"  ✓ {c}")


# -----------------------------
# Optional: report additional annotation levels
# -----------------------------
candidate_ann_cols = [
    "clMidwayPr",
    "cl295v11SubFull",
]

present_extra = [c for c in candidate_ann_cols if c in adata.obs.columns]
if present_extra:
    print("\nAdditional annotation columns detected:")
    for c in present_extra:
        print(f"  - {c}")
else:
    print("\nNo additional mid/deep annotation columns detected.")


# -----------------------------
# Check embedding slots
# -----------------------------
print("\nobsm keys:", list(adata.obsm.keys()))

if "X_umap" in adata.obsm:
    print("UMAP embedding already present in adata.obsm['X_umap'].")
    print("   This notebook will recompute UMAP unless you explicitly skip that step.")
else:
    print("✓ No UMAP embedding found (will be computed later).")


In [None]:
## Dataset-specific: column mapping for plotting (EDIT HERE for another dataset)

# -----------------------------
# Core metadata columns
# -----------------------------
PATIENT_COL = "Patient_ID"
TISSUE_COL  = "Tissue_type"
BATCH_COL   = "batchID"     # change if your dataset uses a different batch/sample column

# -----------------------------
# Annotation levels to visualize on UMAP (from coarse to fine)
# -----------------------------
ANNOT_LEVELS = ["clTopLevel", "clMidwayPr", "cl295v11SubFull"]


# -----------------------------
# Compact summaries (top 10)
# -----------------------------
def _print_top_counts(col, top=10):
    if col not in adata.obs.columns:
        print(f"\n[WARN] Column '{col}' not found in adata.obs")
        return
    vc = adata.obs[col].value_counts(dropna=False)
    print(f"\nValue counts: {col} (top {min(top, len(vc))} / {len(vc)})")
    display(vc.head(top))

_print_top_counts(PATIENT_COL, top=10)
_print_top_counts(TISSUE_COL, top=10)
_print_top_counts(BATCH_COL, top=10)

print("\nAnnotation columns present:")
for c in ANNOT_LEVELS:
    print(f"  - {c}: {'YES' if c in adata.obs.columns else 'NO'}")


In [None]:
## Prep for embedding (matrix types + minimal stabilization)

# -----------------------------
# 1) Ensure CSR + float32 
# -----------------------------
if sparse.issparse(adata.X):
    if not isinstance(adata.X, sparse.csr_matrix):
        adata.X = adata.X.tocsr()
    if adata.X.dtype != np.float32:
        adata.X = adata.X.astype(np.float32)
else:
    adata.X = sparse.csr_matrix(adata.X.astype(np.float32))

print(f"X ready for embedding: {type(adata.X).__name__}, dtype={adata.X.dtype}")

# -----------------------------
# 2) Minimal stabilization filters
# -----------------------------
n0_cells, n0_genes = adata.n_obs, adata.n_vars

sc.pp.filter_cells(adata, min_counts=1)
sc.pp.filter_genes(adata, min_cells=3)

print(f"After basic filters:")
print(f"  cells: {n0_cells:,}  {adata.n_obs:,}")
print(f"  genes: {n0_genes:,}  {adata.n_vars:,}")

# -----------------------------
# 3) Drop zero-variance genes (safety guard)
# -----------------------------
if sparse.issparse(adata.X):
    means = np.asarray(adata.X.mean(axis=0)).ravel()
    ex2 = np.asarray(adata.X.power(2).mean(axis=0)).ravel()
    vars_ = ex2 - means**2
else:
    means = adata.X.mean(axis=0)
    vars_  = adata.X.var(axis=0)

nonzero_var = vars_ > 0
n_removed = int((~nonzero_var).sum())

if n_removed > 0:
    adata = adata[:, nonzero_var].copy()
    print(f"Removed {n_removed:,} zero-variance genes.")
else:
    print("No zero-variance genes detected.")

# -----------------------------
# 4) IMPORTANT: adata.X is still raw counts here
# -----------------------------
# We will later normalize/log-transform adata.X for embedding.
# Keep this in mind for gene overlays: they should read from raw counts
# (e.g., adata.layers['counts']) and apply TP10K/log1p on the fly.
print("\nNote: adata.X is treated as RAW COUNTS at this stage.")


In [None]:
# -----------------------------
# Save raw expression for gene overlays
# -----------------------------
# IMPORTANT:
# - adata.X currently contains raw counts (after QC & minimal stabilization)
# - embedding will overwrite adata.X with log-normalized HVG
# - adata.raw freezes the FULL gene set for downstream gene overlays

# Keep explicit raw-counts layer (optional, but convenient)
if "counts" not in adata.layers:
    adata.layers["counts"] = adata.X.copy()

# Freeze full gene expression matrix BEFORE HVG selection
adata.raw = adata.copy()

print(
    f"Raw expression frozen in adata.raw: "
    f"{adata.raw.n_obs:,} cells × {adata.raw.n_vars:,} genes"
)


In [None]:
## HVG selection (simple, single method)

# -----------------------------
# Compute highly variable genes
# -----------------------------
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=int(N_TOP_GENES),
    flavor="seurat_v3",
    batch_key=None,
    inplace=True,
)

if "highly_variable" not in adata.var.columns:
    raise RuntimeError("HVG selection failed: adata.var['highly_variable'] not found.")

n_hvg = int(adata.var["highly_variable"].sum())
print(f"HVG selected: {n_hvg:,} genes (target: {int(N_TOP_GENES):,})")

if n_hvg == 0:
    raise RuntimeError("No HVGs selected - check input matrix and preprocessing.")

# -----------------------------
# Subset AnnData to HVGs
# -----------------------------
adata = adata[:, adata.var["highly_variable"]].copy()
print(f"After HVG subset: {adata.n_obs:,} cells × {adata.n_vars:,} HVGs")


In [None]:
## Normalize + log1p for embedding (HVGs only)

# -----------------------------
# Library-size normalization (TP10K)
# -----------------------------
sc.pp.normalize_total(
    adata,
    target_sum=1e4,
    inplace=True,
)

# -----------------------------
# Log-transform
# -----------------------------
sc.pp.log1p(adata)

print("Normalization and log1p completed.")
print("adata.X now contains log-normalized HVG expression (for embedding only).")


In [None]:
## PCA  neighbors  UMAP

# -----------------------------
# PCA
# -----------------------------
sc.tl.pca(
    adata,
    n_comps=int(N_PCS),
    svd_solver="arpack",
    random_state=RANDOM_STATE,
)

print(f"PCA computed with {int(N_PCS)} components.")

# -----------------------------
# Nearest neighbors graph
# -----------------------------
sc.pp.neighbors(
    adata,
    n_neighbors=int(N_NEIGHBORS),
    n_pcs=int(N_PCS),
    random_state=RANDOM_STATE,
)

print(f"Neighbors graph computed (n_neighbors={int(N_NEIGHBORS)}, n_pcs={int(N_PCS)}).")

# -----------------------------
# UMAP embedding
# -----------------------------
umap_kwargs = {"random_state": int(UMAP_RANDOM_STATE)}
if MIN_DIST is not None:
    umap_kwargs["min_dist"] = float(MIN_DIST)

sc.tl.umap(adata, **umap_kwargs)

print("UMAP embedding computed.")
print("UMAP coordinates available in adata.obsm['X_umap'].")

# -----------------------------
# Optional: save AnnData with embedding
# -----------------------------
# In most cases this is NOT required, since embedding can be recomputed fast.
# Uncomment only if you want to persist the embedding explicitly.
#
# OUT_UMAP_PATH = OUT_DIR / "GSE178341_umap.h5ad"
# adata.write(OUT_UMAP_PATH)
# print(f"AnnData with UMAP saved to: {OUT_UMAP_PATH.resolve()}")


In [None]:
## UMAP sanity check: patient / batch / tissue

# -----------------------------
# Sanity plots
# -----------------------------
umap_cols = [PATIENT_COL, TISSUE_COL]

for col in umap_cols:
    if col not in adata.obs.columns:
        print(f"[WARN] Column '{col}' not found in adata.obs - skipping.")
        continue

    sc.pl.umap(
        adata,
        color=col,
        frameon=True,
        show=not SAVE_FIGS,
    )

    if SAVE_FIGS:
        fname = f"umap_{col}.png"
        plt.savefig(FIG_DIR / fname)
        plt.show()

print("UMAP sanity check completed.")
print("Inspect plots for strong patient / batch / tissue-driven separation.")


In [None]:
## UMAP: cell-type annotations on the full dataset (reference view)

# -----------------------------
# Coarse annotation (reference)
# -----------------------------
if "clTopLevel" in adata.obs.columns:
    sc.pl.umap(
        adata,
        color="clTopLevel",
        frameon=True,
        legend_loc="on data",
        show=not SAVE_FIGS,
    )
    if SAVE_FIGS:
        plt.savefig(FIG_DIR / "umap_clTopLevel.png")
        plt.show()
else:
    print("[WARN] 'clTopLevel' not found in adata.obs - skipping.")

# -----------------------------
# Deeper annotation levels
# -----------------------------
for col in ANNOT_LEVELS:
    if col == "clTopLevel":
        continue
    if col not in adata.obs.columns:
        print(f"[WARN] Annotation column '{col}' not found - skipping.")
        continue

    sc.pl.umap(
        adata,
        color=col,
        frameon=True,
        show=not SAVE_FIGS,
    )

    if SAVE_FIGS:
        fname = f"umap_{col}.png"
        plt.savefig(FIG_DIR / fname)
        plt.show()

print("UMAP annotation overview completed.")
print("These plots serve as a reference map of cell identities on the global embedding.")


In [None]:
## Subset views: Normal vs Tumor on the same embedding 

# -----------------------------
# Create tissue-specific views (no new files, same embedding)
# -----------------------------
if TISSUE_COL not in adata.obs.columns:
    raise KeyError(f"Tissue column '{TISSUE_COL}' not found in adata.obs")

adata_N = adata[adata.obs[TISSUE_COL] == "N"]
adata_T = adata[adata.obs[TISSUE_COL] == "T"]

print(f"Normal cells (N): {adata_N.n_obs:,}")
print(f"Tumor  cells (T): {adata_T.n_obs:,}")

# -----------------------------
# Helper: UMAP plot with explicit title
# -----------------------------
def _umap_with_title(ad, color, title):
    sc.pl.umap(
        ad,
        color=color,
        frameon=True,
        show=not SAVE_FIGS,
    )
    plt.title(title)
    if SAVE_FIGS:
        safe = str(color).replace("/", "_")
        fname = f"umap_{title.replace(' ', '_')}_{safe}.png"
        plt.savefig(FIG_DIR / fname)
        plt.show()

# -----------------------------
# UMAP: Normal vs Tumor (coarse annotation)
# -----------------------------
if "clTopLevel" in adata.obs.columns:
    _umap_with_title(
        adata_N,
        color="clTopLevel",
        title="Normal tissue (N) - clTopLevel",
    )
    _umap_with_title(
        adata_T,
        color="clTopLevel",
        title="Tumor tissue (T) - clTopLevel",
    )
else:
    print("[WARN] 'clTopLevel' not found - skipping coarse annotation.")

# -----------------------------
# Optional: deeper annotation levels
# -----------------------------
for col in ANNOT_LEVELS:
    if col == "clTopLevel":
        continue

    if col not in adata.obs.columns:
        continue

    _umap_with_title(
        adata_N,
        color=col,
        title=f"Normal tissue (N) - {col}",
    )
    _umap_with_title(
        adata_T,
        color=col,
        title=f"Tumor tissue (T) - {col}",
    )

print("Normal vs Tumor UMAP views completed (shared embedding, labeled).")


In [None]:
## (Optional) Separate embeddings: UMAP within Normal and within Tumor

# NOTE:
# This step is OPTIONAL and should be used only if you want to explore
# structure WITHIN Normal or Tumor separately.
# IMPORTANT: these embeddings live in different coordinate systems
# and MUST NOT be directly compared to the global UMAP.

def compute_within_condition_umap(
    adata_in,
    label,
    n_top_genes=N_TOP_GENES,
    n_pcs=N_PCS,
    n_neighbors=N_NEIGHBORS,
    min_dist=MIN_DIST,
    random_state=RANDOM_STATE,
):
    """
    Compute a standalone UMAP embedding for a subset of cells.
    Returns a new AnnData object with its own PCA/UMAP.
    """
    ad = adata_in.copy()

    # -----------------------------
    # HVG selection (standalone)
    # -----------------------------
    sc.pp.highly_variable_genes(
        ad,
        flavor="seurat_v3",
        n_top_genes=int(n_top_genes),
        batch_key=None,
        inplace=True,
    )
    ad = ad[:, ad.var["highly_variable"]].copy()

    # -----------------------------
    # Normalize + log1p
    # -----------------------------
    sc.pp.normalize_total(ad, target_sum=1e4)
    sc.pp.log1p(ad)

    # -----------------------------
    # PCA  neighbors  UMAP
    # -----------------------------
    sc.tl.pca(
        ad,
        n_comps=int(n_pcs),
        svd_solver="arpack",
        random_state=int(random_state),
    )

    sc.pp.neighbors(
        ad,
        n_neighbors=int(n_neighbors),
        n_pcs=int(n_pcs),
        random_state=int(random_state),
    )

    umap_kwargs = {"random_state": int(random_state)}
    if min_dist is not None:
        umap_kwargs["min_dist"] = float(min_dist)

    sc.tl.umap(ad, **umap_kwargs)

    print(f"Within-condition UMAP computed for: {label} ({ad.n_obs:,} cells)")
    return ad


# -----------------------------
# Run separate embeddings (only if needed)
# -----------------------------
RUN_SEPARATE_UMAPS = False  # set True if you want within-condition UMAPs

if RUN_SEPARATE_UMAPS:
    adata_N_umap = compute_within_condition_umap(adata_N, label="Normal (N)")
    adata_T_umap = compute_within_condition_umap(adata_T, label="Tumor (T)")

    # Example visualization (coarse annotation)
    if "clTopLevel" in adata_N_umap.obs.columns:
        sc.pl.umap(
            adata_N_umap,
            color="clTopLevel",
            frameon=True,
            title="Normal tissue (N) - within-condition UMAP",
            show=not SAVE_FIGS,
        )
        if SAVE_FIGS:
            plt.savefig(FIG_DIR / "umap_within_normal_clTopLevel.png")
            plt.show()

    if "clTopLevel" in adata_T_umap.obs.columns:
        sc.pl.umap(
            adata_T_umap,
            color="clTopLevel",
            frameon=True,
            title="Tumor tissue (T) - within-condition UMAP",
            show=not SAVE_FIGS,
        )
        if SAVE_FIGS:
            plt.savefig(FIG_DIR / "umap_within_tumor_clTopLevel.png")
            plt.show()

else:
    print("Separate within-condition UMAPs skipped (RUN_SEPARATE_UMAPS=False).")


In [None]:
## Gene sets: pathways (ligands / receptors / inhibitors) essential for organoid cultivation. 

# -----------------------------
# Define pathway gene sets
# -----------------------------
# NOTE:
# All gene symbols are expected to be HGNC-style (uppercase).
# This block is the single point of truth for pathway definitions.

PATHWAYS = {
    "EGF": {
        "L": [
            "EGF", "HBEGF", "AREG", "EREG", "EPGN",
            "TGFA", "BTC",
            "NRG1", "NRG2", "NRG3", "NRG4",
        ],
        "R": ["EGFR", "ERBB2", "ERBB3", "ERBB4"],
    },
    "FGF": {
        "L": [f"FGF{i}" for i in range(1, 11)] + [f"FGF{i}" for i in range(16, 24)],
        "R": ["FGFR1", "FGFR2", "FGFR3", "FGFR4", "HSPG"],
    },
    "WNT": {
        "L": (
            [f"WNT{i}" for i in range(1, 12)]
            + ["WNT2B", "WNT3A", "WNT5B", "WNT7B", "WNT8B", "WNT9B", "WNT10B", "WNT16"]
        ),
        "R": [f"FZD{i}" for i in range(1, 11)] + ["LRP5", "LRP6"],
    },
    "RSPO": {
        "L": ["RSPO1", "RSPO2", "RSPO3", "RSPO4"],
        "R": ["LGR4", "LGR5", "LGR6"],
    },
    "BMP": {
        "L": (
            [f"BMP{i}" for i in range(2, 16)]
            + ["BMP3B", "BMP8A", "BMP8B", "GDF2", "GDF5", "GDF6", "GDF7"]
        ),
        "R": ["ACVR1", "BMPR1A", "BMPR1B", "BMPR2", "ACVR2A", "ACVR2B"],
        "I": ["NOG", "GREM1", "GREM2"],  # inhibitors
    },
    "TGFB": {
        "L": ["TGFB1", "TGFB2", "TGFB3"],
        "R": ["TGFBR1", "TGFBR2", "TGFBR3"],
    },
}

print("Defined signaling pathways:")
for k, v in PATHWAYS.items():
    parts = ", ".join(v.keys())
    print(f"  - {k}: {parts}")


# -----------------------------
# Optional: check gene presence in the dataset
# -----------------------------
def check_pathway_genes(adata, pathways, verbose=True):
    """
    Check which pathway genes are present in the dataset.
    Uses adata.raw.var_names (FULL gene set).
    """
    if adata.raw is None:
        raise RuntimeError("adata.raw is None - raw expression was not saved before HVG.")

    var_names = adata.raw.var_names.str.upper()
    present = {}

    for pname, blocks in pathways.items():
        present[pname] = {}
        for role, genes in blocks.items():
            genes = [g.upper() for g in genes]
            found = [g for g in genes if g in var_names.values]
            missing = [g for g in genes if g not in var_names.values]

            present[pname][role] = {
                "found": found,
                "missing": missing,
            }

            if verbose:
                print(f"[{pname}:{role}] {len(found)}/{len(genes)} genes present")
                if missing:
                    print("   missing:", ", ".join(missing))

    return present


# Run presence check (recommended once, for sanity)
pathway_gene_presence = check_pathway_genes(adata, PATHWAYS, verbose=True)

print("Pathway gene presence check completed.")


In [None]:
# ============================================================
# Utility: gene overlays on UMAP (continuous + RGB mix)
# ============================================================
# IMPORTANT DESIGN CHOICE:
# - ALL overlays read expression from adata.raw (full gene set)
# - adata.X may contain HVG-only log-normalized data for embedding
# - adata.raw must be set BEFORE HVG selection
#
# Scaling for overlays:
#   TP10K -> log1p -> p99 clip (robust, interpretable)
#
# These helpers DO NOT mutate adata.
# ============================================================


# -----------------------------
# Internal helpers
# -----------------------------
def _get_matrix_and_varnames(adata, layer: str):
    """
    Return (X, var_names) for expression source.
    For overlays, 'raw' and 'counts' are treated identically.
    """
    layer = str(layer)

    if layer in ("raw", "counts"):
        if adata.raw is None:
            raise RuntimeError("adata.raw is None - full expression was not saved.")
        return adata.raw.X, adata.raw.var_names

    if layer == "X":
        return adata.X, adata.var_names

    raise ValueError(
        f"Unsupported layer='{layer}'. "
        f"Use 'raw' (recommended) or 'X'."
    )


def _match_genes_case_insensitive(var_names, genes: Iterable[str]):
    """
    Case-insensitive matching of requested genes.
    Returns (found_as_in_var_names, missing_requested).
    """
    genes = [str(g).strip() for g in genes]
    vmap = pd.Series(var_names.values, index=var_names.str.upper())

    found, missing = [], []
    for g in genes:
        hit = vmap.get(g.upper())
        if hit is None:
            missing.append(g)
        else:
            found.append(hit)
    return found, missing


def _require_umap(adata):
    if "X_umap" not in adata.obsm:
        raise RuntimeError("UMAP coordinates not found: adata.obsm['X_umap']")
    return adata.obsm["X_umap"][:, 0], adata.obsm["X_umap"][:, 1]


# ============================================================
# (A) Continuous overlay
# ============================================================
def umap_gene_expression_continuous(
    adata,
    genes: Iterable[str],
    layer: str = "raw",
    aggregate: str = "sum",            # sum | mean | max
    target_sum: float = 1e4,           # TP10K
    use_log1p: bool = True,
    clip_p99: bool = True,
    cmap: str = "cividis",
    size: float = 6,
    alpha: float = 0.9,
    bg_zero: Optional[str] = None,
    title: Optional[str] = None,
    colorbar_label: str = "log1p(TP10K)",
    save_path: Optional[Path] = None,
    show: bool = True,
):
    """
    Continuous UMAP overlay for one gene or aggregated gene set.
    Reads expression from adata.raw by default.
    """
    X, var_names = _get_matrix_and_varnames(adata, layer)
    x_umap, y_umap = _require_umap(adata)

    found, missing = _match_genes_case_insensitive(var_names, genes)
    if missing:
        print("Missing genes:", ", ".join(missing))
    if not found:
        raise ValueError("None of the requested genes were found.")

    idxs = [var_names.get_loc(g) for g in found]

    # --- extract raw counts ---
    if issparse(X):
        Xsub = X[:, idxs]
        if Xsub.shape[1] == 1:
            raw = np.asarray(Xsub.sum(axis=1)).ravel()
        else:
            if aggregate == "mean":
                raw = np.asarray(Xsub.mean(axis=1)).ravel()
            elif aggregate == "max":
                raw = np.asarray(Xsub.max(axis=1)).ravel()
            else:
                raw = np.asarray(Xsub.sum(axis=1)).ravel()
        lib = np.asarray(X.sum(axis=1)).ravel()
    else:
        Xsub = np.asarray(X)[:, idxs]
        if Xsub.shape[1] == 1:
            raw = Xsub[:, 0]
        else:
            if aggregate == "mean":
                raw = Xsub.mean(axis=1)
            elif aggregate == "max":
                raw = Xsub.max(axis=1)
            else:
                raw = Xsub.sum(axis=1)
        lib = np.asarray(X).sum(axis=1)

    lib = np.clip(lib, 1.0, None)

    tp = raw * (target_sum / lib)
    vals = np.log1p(tp) if use_log1p else tp

    pos = vals > 0
    if clip_p99 and np.any(pos):
        vmax = float(np.percentile(vals[pos], 99))
    else:
        vmax = float(vals[pos].max()) if np.any(pos) else 1.0
    vmax = max(vmax, 1e-9)

    norm = Normalize(vmin=0.0, vmax=vmax, clip=True)
    colors = plt.get_cmap(cmap)(norm(vals))

    fig, ax = plt.subplots(figsize=(8.6, 7))

    if bg_zero is not None and np.any(~pos):
        ax.scatter(x_umap[~pos], y_umap[~pos], s=size, c=bg_zero, lw=0, alpha=alpha)

    ax.scatter(x_umap, y_umap, s=size, c=colors, lw=0, alpha=alpha)
    ax.axis("off")

    auto_title = f"{', '.join(found)} - continuous overlay ({aggregate})"
    ax.set_title(title or auto_title)

    sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label(colorbar_label, rotation=90)

    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=400, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close(fig)

    return fig, ax, raw


# ============================================================
# (B) RGB mix overlay
# ============================================================
def umap_multi_gene_color_mix(
    adata,
    genes: Iterable[str],
    layer: str = "raw",
    thr_umi: float = 1.0,
    target_sum: float = 1e4,
    use_log_scale: bool = True,
    clip_p99: bool = True,
    size: float = 8,
    bg_color: str = "#d9d9d9",
    palette: Optional[Dict[str, str]] = None,
    alpha_zero: float = 1.0,
    alpha_pos_min: float = 0.4,
    alpha_pos_max: float = 0.9,
    title: Optional[str] = None,
    show_strength_bar: bool = True,
    strength_cmap: str = "Greys",
    strength_label: str = "TP10K (p99)",
    save_path: Optional[Path] = None,
    show: bool = True,
):
    """
    RGB-mix UMAP overlay for small gene sets (3-5 genes).
    Expression is read from adata.raw.
    """
    X, var_names = _get_matrix_and_varnames(adata, layer)
    x_umap, y_umap = _require_umap(adata)

    found, missing = _match_genes_case_insensitive(var_names, genes)
    if missing:
        print("Missing genes:", ", ".join(missing))
    if not found:
        raise ValueError("None of the requested genes were found.")

    if palette is None:
        tableau10 = [
            "#ff0000", "#1f77b4", "#2ca02c", "#ff7f0e", "#9467bd",
            "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf",
        ]
        palette = {g: tableau10[i % len(tableau10)] for i, g in enumerate(found)}
    else:
        palette = {g: palette.get(g, "#7f7f7f") for g in found}

    idxs = [var_names.get_loc(g) for g in found]

    if issparse(X):
        expr = X[:, idxs].toarray()
        lib = np.asarray(X.sum(axis=1)).ravel()
    else:
        expr = np.asarray(X)[:, idxs]
        lib = np.asarray(X).sum(axis=1)

    lib = np.clip(lib, 1.0, None)
    expr_tp = expr * (target_sum / lib[:, None])

    expr_thr = np.where(expr_tp >= thr_umi, expr_tp, 0.0)
    total = expr_thr.sum(axis=1)
    pos = total > 0

    with np.errstate(divide="ignore", invalid="ignore"):
        weights = expr_thr / total[:, None]
    weights[~np.isfinite(weights)] = 0.0

    rgb = np.zeros((expr_thr.shape[0], 3))
    for j, g in enumerate(found):
        rgb += weights[:, [j]] * np.array(to_rgb(palette[g]))[None, :]
    rgb = np.clip(rgb, 0.0, 1.0)

    colors = np.array([to_hex(c) for c in rgb])
    colors[~pos] = bg_color

    if clip_p99 and np.any(pos):
        t99 = float(np.percentile(total[pos], 99))
    else:
        t99 = float(total[pos].max()) if np.any(pos) else 1.0
    t99 = max(t99, 1e-9)

    if use_log_scale:
        strength = np.zeros_like(total)
        strength[pos] = np.log1p(total[pos]) / np.log1p(t99)
    else:
        strength = total / t99
    strength = np.clip(strength, 0.0, 1.0)

    alphas = np.where(
        pos,
        alpha_pos_min + (alpha_pos_max - alpha_pos_min) * strength,
        alpha_zero,
    )

    fig, ax = plt.subplots(figsize=(8.6, 7))
    ax.scatter(x_umap, y_umap, s=size, c=colors, lw=0, alpha=alphas)
    ax.axis("off")

    auto_title = f"{', '.join(found)} - RGB mix"
    ax.set_title(title or auto_title)

    handles = [
        Line2D([0], [0], marker="o", linestyle="", markersize=8,
               markerfacecolor=palette[g], markeredgecolor="none", label=g)
        for g in found
    ]
    handles.append(
        Line2D([0], [0], marker="o", linestyle="", markersize=8,
               markerfacecolor=bg_color, markeredgecolor="none", label="none")
    )
    ax.legend(handles=handles, frameon=False, loc="center right", bbox_to_anchor=(-0.03, 0.5))

    if show_strength_bar:
        sm = plt.cm.ScalarMappable(norm=Normalize(vmin=0, vmax=1), cmap=strength_cmap)
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label(strength_label, rotation=90)

    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=400, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close(fig)

    return fig, ax, total


In [None]:
# ============================================================
# Continuous overlays: pathway activity (L vs R separately)
# ============================================================
# Reads expression from adata.raw
# Plots ligands and receptors on separate continuous maps
# ============================================================

def _plot_block_continuous(adata, pathway_name, role, genes, label_suffix=""):
    if not genes:
        return

    role_name = {"L": "Ligands", "R": "Receptors", "I": "Inhibitors"}.get(role, role)
    title = f"{pathway_name} - {role_name}{label_suffix}"

    umap_gene_expression_continuous(
        adata,
        genes=genes,
        layer="raw",
        aggregate="sum",
        cmap="cividis",
        size=2,
        alpha=0.9,
        bg_zero="#d9d9d9",
        title=title,
        colorbar_label="log1p(TP10K)",
        save_path=None,
        show=True,
    )


def _plot_pathway_lr(adata, pname, blocks, label_suffix=""):
    _plot_block_continuous(adata, pname, "L", blocks.get("L", []), label_suffix)
    _plot_block_continuous(adata, pname, "R", blocks.get("R", []), label_suffix)

    if pname == "BMP" and "I" in blocks:
        _plot_block_continuous(adata, pname, "I", blocks.get("I", []), label_suffix)


# -----------------------------
# Global UMAP
# -----------------------------
print("=== Continuous pathway overlays: GLOBAL UMAP (L vs R) ===")
for pname, blocks in PATHWAYS.items():
    _plot_pathway_lr(adata, pname, blocks, label_suffix=" (global)")


# -----------------------------
# Normal vs Tumor
# -----------------------------
if "Tissue_type" in adata.obs.columns:
    adata_N = adata[adata.obs["Tissue_type"] == "N"].copy()
    adata_T = adata[adata.obs["Tissue_type"] == "T"].copy()

    print("\n=== Continuous pathway overlays: NORMAL (N) ===")
    for pname, blocks in PATHWAYS.items():
        _plot_pathway_lr(adata_N, pname, blocks, label_suffix=" (N)")

    print("\n=== Continuous pathway overlays: TUMOR (T) ===")
    for pname, blocks in PATHWAYS.items():
        _plot_pathway_lr(adata_T, pname, blocks, label_suffix=" (T)")

print("Continuous pathway overlays completed.")


In [None]:
# ============================================================
# RGB overlays: top-5 ligands and top-5 receptors (separately)
# ============================================================
# Goal:
#   Show "who carries the pathway" via RGB composition plots.
#
# Design:
#   - Work separately inside Normal (N) and Tumor (T) (same global embedding coords).
#   - For each pathway:
#       * pick top-5 ligands by TOTAL raw counts across the subset (adata.raw)
#       * pick top-5 receptors by TOTAL raw counts across the subset (adata.raw)
#       * plot RGB mix for L and RGB mix for R
#   - BMP inhibitors: fixed set ["NOG","GREM1","GREM2"] as RGB mix (no top-k)
#
# Requirements:
#   - adata.obsm["X_umap"] exists
#   - adata.raw exists (full gene set)
#   - functions from Cell 15 exist: umap_multi_gene_color_mix
# ============================================================

def _topk_by_total_counts_from_raw(adata_sub, genes, k=5):
    """
    Rank genes by total raw counts across all cells in adata_sub (adata_sub.raw.X).
    Returns a list of up to k genes (as in var_names casing).
    """
    if adata_sub.raw is None:
        raise RuntimeError("adata.raw is None - required for top-k ranking.")

    genes = list(genes) if genes is not None else []
    if len(genes) == 0:
        return []

    X = adata_sub.raw.X
    v = adata_sub.raw.var_names

    found, missing = _match_genes_case_insensitive(v, genes)
    if missing:
        print("Missing genes:", ", ".join(missing))
    if not found:
        return []

    idxs = [v.get_loc(g) for g in found]

    if issparse(X):
        totals = np.asarray(X[:, idxs].sum(axis=0)).ravel()
    else:
        totals = np.asarray(X)[:, idxs].sum(axis=0)

    order = np.argsort(totals)[::-1]
    top = [found[i] for i in order[: min(k, len(found))]]
    return top


def _rgb_for_pathway_block(
    adata_sub,
    pname,
    role,
    genes,
    k=5,
    thr_tp10k=1.0,
    title_suffix="",
):
    """
    role: 'L' or 'R' (top-k) or 'I' (fixed list, no top-k)
    """
    role_name = {"L": "Ligands", "R": "Receptors", "I": "Inhibitors"}.get(role, role)

    if role in ("L", "R"):
        top = _topk_by_total_counts_from_raw(adata_sub, genes, k=k)
        if not top:
            print(f"[{pname}:{role}] No genes found for RGB overlay.")
            return
        title = f"{pname} - {role_name} (top {len(top)} by total counts){title_suffix}"
        umap_multi_gene_color_mix(
            adata_sub,
            genes=top,
            layer="raw",          # ALWAYS raw for overlays
            thr_umi=float(thr_tp10k),
            size=2,
            title=title,
            strength_label="TP10K (p99)",
            show=True,
            save_path=None,
        )
    else:
        # fixed list (BMP inhibitors)
        if not genes:
            return
        # Keep only present genes (avoid hard crash if something absent)
        found, missing = _match_genes_case_insensitive(adata_sub.raw.var_names, genes)
        if missing:
            print(f"[{pname}:I] Missing inhibitors:", ", ".join(missing))
        if not found:
            print(f"[{pname}:I] No inhibitor genes found.")
            return

        title = f"{pname} - {role_name}{title_suffix}"
        umap_multi_gene_color_mix(
            adata_sub,
            genes=found,
            layer="raw",
            thr_umi=float(thr_tp10k),
            size=2,
            title=title,
            strength_label="TP10K (p99)",
            show=True,
            save_path=None,
        )


def _rgb_top5_lr_for_subset(adata_sub, subset_label, top_k=5, thr_tp10k=1.0):
    print(f"\n=== RGB overlays: {subset_label} (top-{top_k} L and top-{top_k} R) ===")

    for pname, blocks in PATHWAYS.items():
        L = blocks.get("L", [])
        R = blocks.get("R", [])
        I = blocks.get("I", []) if pname == "BMP" else []

        # Ligands (top-k)
        _rgb_for_pathway_block(
            adata_sub,
            pname=pname,
            role="L",
            genes=L,
            k=top_k,
            thr_tp10k=thr_tp10k,
            title_suffix=f" ({subset_label})",
        )

        # Receptors (top-k)
        _rgb_for_pathway_block(
            adata_sub,
            pname=pname,
            role="R",
            genes=R,
            k=top_k,
            thr_tp10k=thr_tp10k,
            title_suffix=f" ({subset_label})",
        )

        # BMP inhibitors (fixed)
        if pname == "BMP" and I:
            _rgb_for_pathway_block(
                adata_sub,
                pname=pname,
                role="I",
                genes=["NOG", "GREM1", "GREM2"],   
                k=top_k,
                thr_tp10k=thr_tp10k,
                title_suffix=f" ({subset_label})",
            )


# -----------------------------
# Run inside Normal (N) and Tumor (T)
# -----------------------------
if "Tissue_type" not in adata.obs.columns:
    raise RuntimeError("Tissue_type column not found in adata.obs")

adata_N = adata[adata.obs["Tissue_type"] == "N"].copy()
adata_T = adata[adata.obs["Tissue_type"] == "T"].copy()

TOP_K = 5
THR_TP10K = 1.0   # threshold in TP10K units (matches overlay function behavior)

_rgb_top5_lr_for_subset(adata_N, subset_label="Normal (N)", top_k=TOP_K, thr_tp10k=THR_TP10K)
_rgb_top5_lr_for_subset(adata_T, subset_label="Tumor (T)",  top_k=TOP_K, thr_tp10k=THR_TP10K)

print("\nRGB top-5 L/R overlays completed.")


In [None]:
# ============================================================
# Summary: what we produced + optional tables preview
# ============================================================
# Goal:
#   - Print a compact, reproducible summary of what this notebook generated.
#   - Optionally compute (but NOT save) small tables:
#       top-5 ligands / receptors per condition per pathway
#
# NOTE:
#   This cell does NOT write any files. It only prints and displays tables.
# ============================================================

def _compute_topk_lr_tables(
    adata_N,
    adata_T,
    pathways,
    top_k=5,
):
    """
    Return two DataFrames:
      - top ligands per pathway per condition
      - top receptors per pathway per condition
    Ranking metric: total raw counts across cells (adata.raw).
    """
    rows_L, rows_R = [], []

    for pname, blocks in pathways.items():
        L = blocks.get("L", [])
        R = blocks.get("R", [])

        topL_N = _topk_by_total_counts_from_raw(adata_N, L, k=top_k)
        topL_T = _topk_by_total_counts_from_raw(adata_T, L, k=top_k)
        topR_N = _topk_by_total_counts_from_raw(adata_N, R, k=top_k)
        topR_T = _topk_by_total_counts_from_raw(adata_T, R, k=top_k)

        rows_L.append({"Pathway": pname, "Normal_topL": ", ".join(topL_N), "Tumor_topL": ", ".join(topL_T)})
        rows_R.append({"Pathway": pname, "Normal_topR": ", ".join(topR_N), "Tumor_topR": ", ".join(topR_T)})

    df_L = pd.DataFrame(rows_L).sort_values("Pathway").reset_index(drop=True)
    df_R = pd.DataFrame(rows_R).sort_values("Pathway").reset_index(drop=True)
    return df_L, df_R


# -----------------------------
# 1) What is available in memory
# -----------------------------
print("=== Notebook summary (no file saving) ===\n")

print("Data objects in memory:")
print(f"  adata (full):   {adata.n_obs:,} cells × {adata.n_vars:,} vars (HVG-only in X)")
print(f"  adata.raw:      {'present' if adata.raw is not None else 'MISSING'}")
if adata.raw is not None:
    print(f"    raw genes:    {adata.raw.n_vars:,}")

print("\nKey columns (obs):")
for c in [PATIENT_COL, TISSUE_COL, BATCH_COL]:
    print(f"  - {c}: {'OK' if c in adata.obs.columns else 'MISSING'}")

print("\nUMAP embedding:")
print(f"  - X_umap present: {'X_umap' in adata.obsm}")

print("\nGenerated visuals (inline in this run):")
print("  - Global UMAP sanity plots: patient / batch / tissue")
print("  - Annotation UMAPs: clTopLevel + deep levels if present")
print("  - Normal vs Tumor views on the same embedding")
print("  - Continuous pathway overlays (Ligands vs Receptors; BMP inhibitors отдельно)")
print("  - RGB overlays: top-5 ligands and top-5 receptors per pathway for N and T")
print("\n(No PNG/CSV were saved - this is notebook-only output.)")


# -----------------------------
# 2) Preview tables 
# -----------------------------
print("\n=== Optional tables (preview only) ===")

if TISSUE_COL in adata.obs.columns:
    adata_N = adata[adata.obs[TISSUE_COL] == "N"].copy()
    adata_T = adata[adata.obs[TISSUE_COL] == "T"].copy()

    if adata.raw is None:
        print("Cannot compute top-5 tables because adata.raw is missing.")
    else:
        TOP_K = 5
        df_topL, df_topR = _compute_topk_lr_tables(adata_N, adata_T, PATHWAYS, top_k=TOP_K)

        print(f"\nTop-{TOP_K} ligands per pathway (by total raw counts):")
        display(df_topL)

        print(f"\nTop-{TOP_K} receptors per pathway (by total raw counts):")
        display(df_topR)
else:
    print(f"Cannot split N/T: '{TISSUE_COL}' column not found.")
