In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns



In [41]:
adata_path = '../data/spatial_single_cell_KS_adata.h5ad'
adata = sc.read_h5ad(adata_path)

In [42]:
from __future__ import annotations

from typing import Iterable, Optional, Tuple, List

import numpy as np
import pandas as pd
import scipy.sparse as sp
from sklearn.preprocessing import normalize

import squidpy as sq


def _get_X(adata, layer: Optional[str] = None, use_raw: bool = False):
    if use_raw:
        if adata.raw is None:
            raise ValueError("use_raw=True but adata.raw is None.")
        return adata.raw.X
    if layer is None:
        return adata.X
    if layer not in adata.layers:
        raise KeyError(f"layer='{layer}' not in adata.layers")
    return adata.layers[layer]


def _dense(X):
    return X.toarray() if sp.issparse(X) else np.asarray(X)


def compute_cluster_centroids(
    adata,
    cluster_key: str,
    *,
    layer: Optional[str] = None,
    use_raw: bool = False,
    gene_subset: Optional[Iterable[str]] = None,
    min_cells_per_cluster: int = 20,
    l2_normalize: bool = True,
) -> Tuple[np.ndarray, List[str], np.ndarray]:
    """
    Reference-free cluster programs: mean expression per cluster.
    Returns (centroids, cluster_names, gene_idx).
    """
    if cluster_key not in adata.obs:
        raise KeyError(f"cluster_key='{cluster_key}' not found in adata.obs")

    X = _get_X(adata, layer=layer, use_raw=use_raw)

    # Handle gene subsetting with use_raw
    if use_raw:
        var_names = adata.raw.var_names
    else:
        var_names = adata.var_names

    if gene_subset is None:
        gene_idx = np.arange(X.shape[1], dtype=int)
    else:
        gene_subset = list(gene_subset)
        missing = [g for g in gene_subset if g not in var_names]
        if missing:
            raise KeyError(f"{len(missing)} genes missing, e.g. {missing[:5]}")
        gene_idx = var_names.get_indexer(gene_subset).astype(int)

    clusters = adata.obs[cluster_key].astype(str).values
    names_all = sorted(pd.unique(clusters))

    centroids = []
    names = []
    for c in names_all:
        idx = np.where(clusters == c)[0]
        if idx.size < min_cells_per_cluster:
            continue
        # Ensure correct shape: flatten to 1D
        mu = _dense(X[idx, :][:, gene_idx]).mean(axis=0).ravel()
        centroids.append(mu)
        names.append(c)

    if len(centroids) == 0:
        raise ValueError("No clusters passed min_cells_per_cluster.")

    C = np.vstack(centroids)
    if l2_normalize:
        C = normalize(C, norm="l2", axis=1)
    return C, names, gene_idx


def compute_secondary_program(
    adata,
    cluster_key: str,
    *,
    layer: Optional[str] = None,
    use_raw: bool = False,
    gene_subset: Optional[Iterable[str]] = None,
    min_cells_per_cluster: int = 20,
    prefix: str = "spill",
) -> None:
    """
    Adds:
      - {prefix}_secondary_cluster
      - {prefix}_secondary_strength (sec_sim - prim_sim)
    """
    C, cluster_names, gene_idx = compute_cluster_centroids(
        adata,
        cluster_key=cluster_key,
        layer=layer,
        use_raw=use_raw,
        gene_subset=gene_subset,
        min_cells_per_cluster=min_cells_per_cluster,
        l2_normalize=True,
    )

    # Get X consistently with how centroids were computed
    raw_X = _get_X(adata, layer=layer, use_raw=use_raw)
    X = _dense(raw_X[:, gene_idx])
    X = normalize(X, norm="l2", axis=1)

    sims = X @ C.T  # (n_cells, n_clusters)

    col_index = {c: j for j, c in enumerate(cluster_names)}
    primary = adata.obs[cluster_key].astype(str).values
    prim_idx = np.array([col_index.get(c, -1) for c in primary], dtype=int)

    # Initialize with NaN for cells whose primary cluster didn't pass filter
    prim_sim = np.full(adata.n_obs, np.nan, dtype=float)
    ok = prim_idx >= 0
    prim_sim[ok] = sims[np.where(ok)[0], prim_idx[ok]]

    # Handle edge case: only one cluster passed filter
    if len(cluster_names) < 2:
        adata.obs[f"{prefix}_secondary_cluster"] = pd.Categorical(
            [np.nan] * adata.n_obs
        )
        adata.obs[f"{prefix}_secondary_strength"] = np.full(adata.n_obs, np.nan)
        return

    sims_masked = sims.copy()
    for i in np.where(ok)[0]:
        sims_masked[i, prim_idx[i]] = -np.inf

    # For cells not in valid clusters, mask all similarities
    sims_masked[~ok, :] = -np.inf

    sec_idx = np.argmax(sims_masked, axis=1)
    sec_sim = sims_masked[np.arange(adata.n_obs), sec_idx]

    # Handle cells where all similarities are -inf
    valid_sec = np.isfinite(sec_sim)
    sec_cluster = np.array([None] * adata.n_obs, dtype=object)
    sec_cluster[valid_sec] = np.array(cluster_names, dtype=object)[sec_idx[valid_sec]]

    sec_strength = np.full(adata.n_obs, np.nan, dtype=float)
    valid_both = valid_sec & ok
    sec_strength[valid_both] = sec_sim[valid_both] - prim_sim[valid_both]

    adata.obs[f"{prefix}_primary_cluster"] = adata.obs[cluster_key].astype(str)
    adata.obs[f"{prefix}_primary_strength"] = prim_sim
    adata.obs[f"{prefix}_secondary_cluster"] = pd.Categorical(sec_cluster, categories=cluster_names)
    adata.obs[f"{prefix}_secondary_similarity"] = sec_sim
    adata.obs[f"{prefix}_secondary_strength"] = sec_strength


def _row_normalize_sparse(W: sp.spmatrix) -> sp.csr_matrix:
    W = W.tocsr()
    rs = np.asarray(W.sum(axis=1)).ravel()
    rs[rs == 0] = 1.0
    inv = 1.0 / rs
    Dinv = sp.diags(inv)
    return (Dinv @ W).tocsr()


def neighborhood_consistency_squidpy(
    adata,
    *,
    cluster_key: str,
    spatial_key: str = "spatial",
    n_neighs: int = 30,
    coord_type: str = "generic",
    prefix: str = "spill",
    layer: Optional[str] = None,
    use_raw: bool = False,
    gene_subset: Optional[Iterable[str]] = None,
    min_cells_per_cluster: int = 20,
    n_permutations: int = 0,
    random_state: int = 0,
) -> None:
    """
    Builds Squidpy spatial graph and computes neighborhood-consistency spillover tests.

    Requires:
      - adata.obsm[spatial_key] exists (or use squidpy's standard spatial metadata)

    Produces in adata.obs:
      - {prefix}_secondary_cluster
      - {prefix}_secondary_strength
      - {prefix}_neigh_secondary_enrichment
      - {prefix}_neigh_match_rate
      - optional p-values if n_permutations>0:
        {prefix}_p_enrichment, {prefix}_p_match
    """
    # Validate spatial coordinates exist
    if spatial_key not in adata.obsm:
        raise KeyError(f"spatial_key='{spatial_key}' not found in adata.obsm")

    # 1) Secondary program direction (reference-free)
    compute_secondary_program(
        adata,
        cluster_key=cluster_key,
        layer=layer,
        use_raw=use_raw,
        gene_subset=gene_subset,
        min_cells_per_cluster=min_cells_per_cluster,
        prefix=prefix,
    )

    # 2) Build spatial graph once (fast) and reuse
    sq.gr.spatial_neighbors(
        adata,
        spatial_key=spatial_key,
        n_neighs=n_neighs,
        coord_type=coord_type,
    )
    # Squidpy stores connectivities in adata.obsp['spatial_connectivities']
    if "spatial_connectivities" not in adata.obsp:
        raise KeyError("Expected adata.obsp['spatial_connectivities'] after sq.gr.spatial_neighbors.")
    W = adata.obsp["spatial_connectivities"].tocsr()
    Wn = _row_normalize_sparse(W)

    # 3) Neighbor enrichment: neighbors' PRIMARY labels == cell's secondary cluster
    primary = adata.obs[cluster_key].astype(str).values
    secondary = adata.obs[f"{prefix}_secondary_cluster"].astype(str).values

    # Build one-hot for primary clusters
    prim_cats = pd.Categorical(primary)
    prim_codes = prim_cats.codes
    nC = len(prim_cats.categories)
    P = sp.csr_matrix(
        (np.ones(adata.n_obs), (np.arange(adata.n_obs), prim_codes)),
        shape=(adata.n_obs, nC)
    )

    # Neighbor composition per cell in primary-label space
    neigh_comp = Wn @ P  # (n_cells, n_clusters)

    # For each cell, pick the column corresponding to its secondary label
    cat_to_col = {c: j for j, c in enumerate(prim_cats.categories)}
    sec_cols = np.array([cat_to_col.get(s, -1) for s in secondary], dtype=int)

    enrich = np.full(adata.n_obs, np.nan, dtype=float)
    ok = sec_cols >= 0
    if ok.any():
        # Convert sparse result to dense for indexing
        enrich_vals = np.asarray(neigh_comp[np.where(ok)[0], sec_cols[ok]])
        # Handle both 1D and 2D array returns
        enrich[ok] = enrich_vals.ravel()

    adata.obs[f"{prefix}_neigh_secondary_enrichment"] = enrich

    # 4) Neighbor match rate: neighbors share the same *secondary direction*
    # Filter out invalid secondary assignments (NaN, empty string)
    valid_secondary = (secondary != "nan") & (secondary != "") & pd.notna(secondary)
    
    if valid_secondary.sum() > 0:
        sec_filtered = secondary[valid_secondary]
        sec_cats = pd.Categorical(sec_filtered)
        sec_codes_full = np.full(adata.n_obs, -1, dtype=int)
        
        # Map secondary categories to codes
        sec_cat_to_code = {c: j for j, c in enumerate(sec_cats.categories)}
        for i in np.where(valid_secondary)[0]:
            sec_codes_full[i] = sec_cat_to_code.get(secondary[i], -1)
        
        nS = len(sec_cats.categories)
        # Only include valid cells in the one-hot matrix
        valid_idx = np.where(sec_codes_full >= 0)[0]
        S = sp.csr_matrix(
            (np.ones(len(valid_idx)), (valid_idx, sec_codes_full[valid_idx])),
            shape=(adata.n_obs, nS)
        )
        neigh_sec_comp = Wn @ S
        
        match = np.full(adata.n_obs, np.nan, dtype=float)
        for i in valid_idx:
            match[i] = neigh_sec_comp[i, sec_codes_full[i]]
    else:
        match = np.full(adata.n_obs, np.nan, dtype=float)
        sec_codes_full = np.full(adata.n_obs, -1, dtype=int)
        nS = 0

    adata.obs[f"{prefix}_neigh_match_rate"] = match

    # 5) Optional permutation p-values (vectorized; still costs O(B * nnz(W)))
    if n_permutations and n_permutations > 0:
        rng = np.random.default_rng(random_state)

        # enrichment null: shuffle primary labels
        enrich_null = np.zeros((n_permutations, adata.n_obs), dtype=float)
        base_codes = prim_codes.copy()

        for b in range(n_permutations):
            perm = base_codes.copy()
            rng.shuffle(perm)
            Pp = sp.csr_matrix(
                (np.ones(adata.n_obs), (np.arange(adata.n_obs), perm)),
                shape=(adata.n_obs, nC)
            )
            comp_p = Wn @ Pp
            tmp = np.full(adata.n_obs, np.nan, dtype=float)
            ok_idx = np.where(ok)[0]
            if len(ok_idx) > 0:
                tmp[ok] = np.asarray(comp_p[ok_idx, sec_cols[ok]]).ravel()
            enrich_null[b] = tmp

        # Vectorized p-value computation
        p_enrich = np.full(adata.n_obs, np.nan, dtype=float)
        finite_mask = np.isfinite(enrich)
        if finite_mask.any():
            # Count how many null values >= observed (vectorized)
            counts = (enrich_null[:, finite_mask] >= enrich[finite_mask]).sum(axis=0)
            p_enrich[finite_mask] = (counts + 1.0) / (n_permutations + 1.0)
        adata.obs[f"{prefix}_p_enrichment"] = p_enrich

        # match null: shuffle secondary directions (only among valid cells)
        if nS > 0 and valid_secondary.sum() > 1:
            match_null = np.zeros((n_permutations, adata.n_obs), dtype=float)
            valid_codes = sec_codes_full[sec_codes_full >= 0]
            valid_idx = np.where(sec_codes_full >= 0)[0]
            
            for b in range(n_permutations):
                perm_codes = valid_codes.copy()
                rng.shuffle(perm_codes)
                
                # Rebuild sparse matrix with permuted codes
                Sp = sp.csr_matrix(
                    (np.ones(len(valid_idx)), (valid_idx, perm_codes)),
                    shape=(adata.n_obs, nS)
                )
                comp_p = Wn @ Sp
                
                tmp = np.full(adata.n_obs, np.nan, dtype=float)
                for j, i in enumerate(valid_idx):
                    tmp[i] = comp_p[i, perm_codes[j]]
                match_null[b] = tmp

            p_match = np.full(adata.n_obs, np.nan, dtype=float)
            finite_match = np.isfinite(match)
            if finite_match.any():
                counts = (match_null[:, finite_match] >= match[finite_match]).sum(axis=0)
                p_match[finite_match] = (counts + 1.0) / (n_permutations + 1.0)
        else:
            p_match = np.full(adata.n_obs, np.nan, dtype=float)
        
        adata.obs[f"{prefix}_p_match"] = p_match


def neighborhood_consistency_by_sample(
    adata,
    *,
    sample_key: str = "sample_id",
    cluster_key: str,
    spatial_key: str = "spatial",
    n_neighs: int = 30,
    coord_type: str = "generic",
    prefix: str = "spill",
    layer: Optional[str] = None,
    use_raw: bool = False,
    gene_subset: Optional[Iterable[str]] = None,
    min_cells_per_cluster: int = 20,
    n_permutations: int = 0,
    random_state: int = 0,
    verbose: bool = True,
) -> None:
    """
    Run neighborhood_consistency_squidpy separately for each sample.
    
    This avoids building spatial graphs across sample boundaries,
    which would create spurious neighbors between unrelated tissues.
    """
    if sample_key not in adata.obs:
        raise KeyError(f"sample_key='{sample_key}' not found in adata.obs")
    
    samples = adata.obs[sample_key].unique()
    
    # Initialize output columns as object/float (not Categorical to avoid category mismatch)
    output_cols = [
        f"{prefix}_primary_cluster",
        f"{prefix}_primary_strength", 
        f"{prefix}_secondary_cluster",
        f"{prefix}_secondary_similarity",
        f"{prefix}_secondary_strength",
        f"{prefix}_neigh_secondary_enrichment",
        f"{prefix}_neigh_match_rate",
    ]
    if n_permutations > 0:
        output_cols.extend([f"{prefix}_p_enrichment", f"{prefix}_p_match"])
    
    for col in output_cols:
        if col.endswith("_cluster"):
            adata.obs[col] = pd.Series([None] * adata.n_obs, index=adata.obs.index, dtype=object)
        else:
            adata.obs[col] = pd.Series(np.nan, index=adata.obs.index, dtype=float)
    
    for i, sample in enumerate(samples):
        if verbose:
            print(f"Processing sample {i+1}/{len(samples)}: {sample}")
        
        mask = adata.obs[sample_key] == sample
        adata_sub = adata[mask].copy()
        
        if adata_sub.n_obs < min_cells_per_cluster:
            if verbose:
                print(f"  Skipping: only {adata_sub.n_obs} cells")
            continue
        
        try:
            neighborhood_consistency_squidpy(
                adata_sub,
                cluster_key=cluster_key,
                spatial_key=spatial_key,
                n_neighs=n_neighs,
                coord_type=coord_type,
                prefix=prefix,
                layer=layer,
                use_raw=use_raw,
                gene_subset=gene_subset,
                min_cells_per_cluster=min_cells_per_cluster,
                n_permutations=n_permutations,
                random_state=random_state,
            )
            
            # Copy results back to main adata
            idx = adata.obs.index[mask]
            for col in output_cols:
                if col in adata_sub.obs:
                    # Convert to numpy array to avoid Categorical issues
                    values = adata_sub.obs[col].values
                    if hasattr(values, 'to_numpy'):
                        values = values.to_numpy()
                    # Use numpy array assignment via .loc
                    adata.obs.loc[idx, col] = np.asarray(values)
                    
        except Exception as e:
            if verbose:
                print(f"  Error: {e}")
            continue
    
    # Convert cluster columns to Categorical at the end
    if verbose:
        print("Done.")
    
    for col in output_cols:
        if col.endswith("_cluster"):
            adata.obs[col] = pd.Categorical(adata.obs[col])


def classify_spillover_vs_hybrid(
    adata,
    *,
    prefix: str = "spill",
    strength_q: float = 0.90,
    enrich_q: float = 0.90,
    p_thresh: float = 0.05,
) -> pd.Series:
    """
    Labels:
      - spillover_likely: strong secondary + neighborhood enrichment (and p<=p_thresh if available)
      - hybrid_candidate: strong secondary but not enriched
      - clean: else
    """
    strength_col = f"{prefix}_secondary_strength"
    enrich_col = f"{prefix}_neigh_secondary_enrichment"
    
    if strength_col not in adata.obs:
        raise KeyError(f"Column '{strength_col}' not found. Run neighborhood_consistency_squidpy first.")
    if enrich_col not in adata.obs:
        raise KeyError(f"Column '{enrich_col}' not found. Run neighborhood_consistency_squidpy first.")
    
    s = adata.obs[strength_col].astype(float).values
    e = adata.obs[enrich_col].astype(float).values

    # Handle all-NaN case
    if np.all(np.isnan(s)) or np.all(np.isnan(e)):
        adata.obs[f"{prefix}_class"] = pd.Categorical(["clean"] * adata.n_obs)
        return adata.obs[f"{prefix}_class"]

    s_thr = np.nanquantile(s, strength_q)
    e_thr = np.nanquantile(e, enrich_q)

    # Use NaN-safe comparisons
    spill = (s >= s_thr) & (e >= e_thr)
    if f"{prefix}_p_enrichment" in adata.obs:
        p = adata.obs[f"{prefix}_p_enrichment"].astype(float).values
        spill = spill & (p <= p_thresh)

    hybrid = (s >= s_thr) & ~spill & ~np.isnan(s)

    out = np.array(["clean"] * adata.n_obs, dtype=object)
    out[spill] = "spillover_likely"
    out[hybrid] = "hybrid_candidate"
    adata.obs[f"{prefix}_class"] = pd.Categorical(out)
    return adata.obs[f"{prefix}_class"]

In [43]:
neighborhood_consistency_by_sample(
    adata,
    sample_key="sample_id",        # column with sample IDs
    cluster_key="broad_cell_types",       # your annotation column
    spatial_key="spatial",
    n_neighs=30,
    n_permutations=0,            # optional, set to 0 to skip
    verbose=True,
)

Processing sample 1/16: KS_TMA_1_0026870
Processing sample 2/16: KS_TMA_2_0026882
Processing sample 3/16: KS_TMA_3_0027198
Processing sample 4/16: KS_TMA_4_0026764
Processing sample 5/16: KS_TMA_5_0026776
Processing sample 6/16: KS_TMA_6_0027092
Processing sample 7/16: KS_TMA_7_0027079
Processing sample 8/16: KS_TMA_8_0027273
Processing sample 9/16: KS_TMA_9_0026831
Processing sample 10/16: KS_TMA_10_0026828
Processing sample 11/16: KS_TMA_11_0026930
Processing sample 12/16: KS_TMA_12_0026888
Processing sample 13/16: KS_TMA_13_0027077
Processing sample 14/16: KS_TMA_14_0027019
Processing sample 15/16: KS_TMA_15_0033811
Processing sample 16/16: KS_TMA_16_0033809
Done.


In [54]:
adata.write_h5ad('../data/spatial_single_cell_KS_adata_spillover.h5ad')

In [61]:
expected_by_chance = 1 / 16  # 0.0625

spillover_mask = (
    (adata.obs["spill_secondary_strength"] > 0) &
    (adata.obs["spill_neigh_secondary_enrichment"] > expected_by_chance * 2) &
    (adata.obs["spill_neigh_match_rate"] > expected_by_chance * 2)
)

n_spillover = spillover_mask.sum()
pct_spillover = 100 * n_spillover / adata.n_obs

print(f"Estimated spillover: {n_spillover:,} cells ({pct_spillover:.1f}%)")

Estimated spillover: 216,529 cells (3.7%)
