# Load Packages

In [None]:
import spatialdata as sd
from spatialdata_io import xenium
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import squidpy as sq


from pathlib import Path
from spatialdata import SpatialData, read_zarr

In [None]:
sc.settings.verbosity = 3  
sc.settings.set_figure_params(
    dpi=200,       
    dpi_save=400,   
    fontsize=16,    
    facecolor="white",
    vector_friendly=True,
)

# Load Xenium

In [173]:
import anndata as ad
adata = sc.read_h5ad("adata_JMT_ID8_umap_batch_corrected.h5ad")

In [None]:
adata

In [None]:
import pandas as pd
tcol = "Treatment"

is_none = adata.obs[tcol].astype(str).str.strip().str.lower().eq("none")
n_none = int(is_none.sum())
print(f"Cells with Treatment == 'None': {n_none} / {adata.n_obs}")

# remove the out of bound cells
adata = adata[~is_none].copy()

# clean up categories / remove none
if pd.api.types.is_categorical_dtype(adata.obs[tcol]):
    adata.obs[tcol] = adata.obs[tcol].cat.remove_unused_categories()

print(f"Remaining cells: {adata.n_obs}")

In [176]:
# Formatting filesand preprocessing
clustering_params={'normalization_target_sum':100,
'min_counts_x_cell':40,
'min_genes_x_cell':15,
'scale':False,
'clustering_alg':'louvain',
'resolutions':[0.2,0.5,1.1, 1.5],
'n_neighbors':15,'umap_min_dist':0.1,
'n_pcs':0}

In [None]:
for r in clustering_params['resolutions']:
        sc.tl.dendrogram(adata,
                         groupby=clustering_params['clustering_alg']+'_'+str(r), 
                        use_rep='X_scVI')

# Basic Broad Annotations

In [44]:
marker_sets = {
    'Tumor epithelial': ['Krt8', 'Krt18', 'Krt19', 'Krt7', "Muc1", "mCherry-tdTomato"],
    'Fibroblasts': ['Col1a1', 'Col3a1', 'Acta2', "Lrrc15"],
    'Endothelial': ['Cdh5', 'Pecam1', 'Plvap', 'Flt1', 'Vegfa', 'Kdr'],
    'Macrophages': ['C1qa', 'C1qb', 'C1qc', 'Cd68', 'Aif1', 'Csf1r', 'Marco', 'Mrc1', 'Trem2'],
    'T cells': ['Cd3d', 'Cd3e', 'Cd4', 'Cd8a', 'Trac', 'Tcf7'],
    'Tregs': ['Foxp3', 'Il2ra', 'Tigit', 'Ctla4'],
    'Exhausted T cells': ['Tox', 'Havcr2', 'Tigit'],
    'NK cells': ['Nkg7', 'Klrb1c', 'Prf1', 'Gzmb', 'Ifng'],
    'B cells': ['Cd79b', 'Ms4a1'],
    'Plasma cells': ['Igkc', 'Iglc3', 'Xbp1'],
    'Dendritic cells': ['Clec9a', 'Xcr1', 'Irf8', 'Cd86'],
    'gD T cells': ["Trac", "Il7r"],
    'pDCs': ['Irf7']
}

marker_summary = pd.DataFrame([
    {"Cell type": cell_type, "Markers": ", ".join(genes)}
    for cell_type, genes in marker_sets.items()
])

In [None]:
sc.pl.dotplot(
    adata,
    var_names=marker_sets,
    groupby="louvain_2",
    standard_scale="var",
    dendrogram=True,
    swap_axes=False,
    figsize=(10, 10)
)

In [177]:
adata.obs['annotated'] = adata.obs['louvain_2'].cat.add_categories(['Cancer Cells', 
                                                                 'Stromal Cells',
                                                                 'T and NK Cells',
                                                                 'B Cells',
                                                                 'Myeloid Cells',
                                                                 'Plasma Cells',
                                                                 'to_discard'])
adata.obs.loc[adata.obs['louvain_2'].isin(["15", "21", "30", "38"]), 'annotated'] = 'Cancer Cells'
adata.obs.loc[adata.obs['louvain_2'].isin(["12", "19", "35", "37",
                                       "32", "11", "31", "9", "22",
                                       "34"]), 'annotated'] = 'Stromal Cells'
adata.obs.loc[adata.obs['louvain_2'].isin(["27", "36", "33", "28",
                                       "18", "16", "13", "4",
                                       "17", "3", "20", "26",
                                       "10", "25", "23"]), 'annotated'] = 'Myeloid Cells'
adata.obs.loc[adata.obs['louvain_2'].isin(["29", "7", "1"]), 'annotated'] = 'T and NK Cells'
adata.obs.loc[adata.obs['louvain_2'].isin(["6", "5", "14", "24",
                                       "8", "0"]), 'annotated'] = 'B Cells'
adata.obs.loc[adata.obs['louvain_2'].isin(["2"]), 'annotated'] = "Plasma Cells"

adata = adata[~adata.obs['annotated'].isin(['to_discard'])].copy()
adata.obs['annotated'] = adata.obs['annotated'].astype('category')

# adata = adata[~(adata.obs["annotated"] == "to_discard")].copy()

In [None]:
plt.rcParams['figure.facecolor'] = 'white'
sc.pl.umap(adata,color=["annotated", "Muc1", "Cd8a", "C1qc", "Acta2", "Col1a1", "Pparg"],
               size=1,legend_loc="on data",
               legend_fontsize=8,cmap="Reds",
               legend_fontoutline=1,ncols=3, 
               show=True,frameon=False)

In [None]:
sc.tl.dendrogram(adata,groupby='annotated', use_rep='X_scVI')

In [None]:
marker_sets = {
    'Tumor epithelial': ['Krt8', 'Krt18', 'Krt19', 'Krt7'],
    'Fibroblasts': ['Col1a1', 'Col3a1', 'Acta2', "Lrrc15"],
    'Endothelial': ['Cdh5', 'Pecam1', 'Plvap', 'Flt1', 'Vegfa', 'Kdr'],
    'Macrophages': ['C1qa', 'C1qb', 'C1qc', 'Cd68', 'Aif1', 'Csf1r', 'Marco', 'Mrc1', 'Trem2'],
    'T cells': ['Cd3d', 'Cd3e', 'Cd4', 'Cd8a', 'Trac', 'Tcf7'],
    'Tregs': ['Foxp3', 'Il2ra', 'Tigit', 'Ctla4'],
    'Exhausted T cells': ['Tox', 'Havcr2', 'Tigit'],
    'NK cells': ['Nkg7', 'Klrb1c', 'Prf1', 'Gzmb', 'Ifng'],
    'B cells': ['Cd79b', 'Ms4a1'],
    'Plasma cells': ['Igkc', 'Iglc3', 'Xbp1'],
    'Dendritic cells': ['Clec9a', 'Xcr1', 'Irf8', 'Cd86'],
    'gD T cells': ["Trac", "Il7r"],
    'pDCs': ['Irf7']
}

# Then:
sc.pl.dotplot(
    adata,
    var_names=marker_sets,
    groupby="annotated",
    standard_scale="var",
    dendrogram=True,
    swap_axes=False,
    figsize=(10, 3)
)

# Find juxtaposed cells (doublets) 

### Initial pass at identifying confident vs. juxtaposed cells -  later refined again on a cell-type basis.

In [None]:
import numpy as np, pandas as pd
from pandas.api.types import CategoricalDtype
from scipy import sparse

def _ensure_cat(s, cats):
    if not isinstance(s.dtype, CategoricalDtype):
        s = s.astype('category')
    add = [c for c in cats if c not in s.cat.categories]
    return s.cat.add_categories(add) if add else s

def _robust_z(X, eps=1e-9):
    med = np.median(X, axis=0)
    mad = np.median(np.abs(X - med), axis=0)
    return (X - med) / (1.4826 * (mad + eps))

def _dense_slice(X, cols):
    return X[:, cols].toarray() if sparse.issparse(X) else np.asarray(X[:, cols], dtype=float)

def detect_juxta_quantile(
    adata,
    marker_sets: dict,
    layer: str | None = None,          
    z_thresh: float = 2.7,
    q: float = 0.99,                   
    thresholds_override: dict | None = None,  
    allowed_pairs: set | None = None,  
    required_any_of: set | None = None, 
    require_neighbor_support: bool = True,
    neighbors_obsp_key: str = "connectivities",
    min_neighbor_support: int = 2,
    write_to_obs: bool = True,
    obs_pair_key: str = "juxtaposed_pair",
    obs_flag_key: str = "qc_flag",
):
    # Anchors and mapping
    anchor_map = {
        "Epithelial": ["Tumor epithelial","Epithelial","Cancer Cells"],
        "Endothelial": ["Endothelial"],
        "Fibroblast": ["Fibroblasts","Stromal Cells","CAF","Pericytes"],
        "Myeloid": ["Macrophages","Dendritic cells","Monocytes","Neutrophils","Myeloid Cells"],
        "T/NK": ["T cells","NK cells","T and NK Cells","TNK"],
        "B/Plasma": ["B cells","Plasma cells","B Cells","Plasma Cells"],
    }
    marker_sets = {k:[g for g in v if g not in noisy] for k,v in marker_sets.items()}

    var_lower = {g.lower(): i for i, g in enumerate(adata.var_names)}
    def idx_for(names):
        idx=[]
        for g in names:
            j = var_lower.get(g.lower()); 
            if j is not None: idx.append(j)
        return sorted(set(idx))

    anchor_to_idx = {}
    for anc, keys in anchor_map.items():
        genes=[]
        for key in keys:
            if key in marker_sets:
                genes += marker_sets[key]
        anchor_to_idx[anc] = idx_for(genes)

    anchors = list(anchor_to_idx.keys())
    all_idx = sorted({j for L in anchor_to_idx.values() for j in L})
    if not all_idx:
        raise ValueError("No marker genes found for any anchor in var_names.")

    X = adata.layers[layer] if layer is not None else adata.X
    Xsel = _dense_slice(X, all_idx)
    Zsel = _robust_z(Xsel)
    idx_map = {gi:k for k,gi in enumerate(all_idx)}

    counts = {}
    for anc in anchors:
        cols = [idx_map[i] for i in anchor_to_idx[anc]]
        if not cols:
            counts[anc] = np.zeros(adata.n_obs, dtype=int)
            continue
        c = (Zsel[:, cols] > z_thresh).sum(axis=1).astype(int)
        counts[anc] = c
    counts_df = pd.DataFrame(counts, index=adata.obs_names)

    th = counts_df.quantile(q).round().astype(int).to_dict()
    if thresholds_override:
        for k,v in thresholds_override.items():
            if k in th: th[k] = int(v)

    pos = pd.DataFrame({a: (counts_df[a].to_numpy() >= th[a]) for a in anchors}, index=adata.obs_names)
    npos = pos.sum(axis=1).to_numpy()

    cand = np.where(npos == 2)[0]
    pair = np.array([None]*adata.n_obs, dtype=object)
    pos_bool = {a: pos[a].to_numpy() for a in anchors}

    if allowed_pairs is None:
        allowed_pairs = {
            ("Epithelial","Myeloid"),
            ("Epithelial","T/NK"),
            ("Epithelial","Endothelial"),
            ("Endothelial","Myeloid"),
            ("Endothelial","Fibroblast"),
            ("Fibroblast","Myeloid"),
            ("Endothelial","T/NK"),
            ("Myeloid", "T/NK")
        }
    allowed_pairs = {tuple(sorted(p)) for p in allowed_pairs}

    for i in cand:
        on_anc = [a for a in anchors if pos_bool[a][i]]
        A, B = sorted(on_anc)
        if required_any_of and not (A in required_any_of or B in required_any_of):
            continue
        if (A,B) in allowed_pairs:
            pair[i] = f"{A}|{B}"

    is_juxta = pd.Series([p is not None for p in pair], index=adata.obs_names)

    # neighbor support
    if require_neighbor_support and neighbors_obsp_key in adata.obsp:
        G = adata.obsp[neighbors_obsp_key]; 
        G = G.tocsr() if sparse.issparse(G) else G
        keep = np.ones(adata.n_obs, dtype=bool)
        for i in np.where(is_juxta.to_numpy())[0]:
            A,B = pair[i].split("|")
            nbrs = G.indices[G.indptr[i]:G.indptr[i+1]] if sparse.issparse(G) else np.where(G[i] > 0)[0]
            okA = pos_bool[A][nbrs].sum() >= min_neighbor_support
            okB = pos_bool[B][nbrs].sum() >= min_neighbor_support
            if not (okA and okB):
                keep[i] = False
        drop = (~keep) & is_juxta.to_numpy()
        is_juxta.iloc[drop] = False
        for j in np.where(drop)[0]: pair[j] = None

    out = pd.concat([
        counts_df.add_prefix("count_"),
        pos.add_prefix("pos_"),
        pd.Series(pair, index=adata.obs_names, dtype="string", name="pair"),
        pd.Series(is_juxta.values, index=adata.obs_names, name="is_juxtaposed")
    ], axis=1)
    out.attrs["anchor_thresholds"] = th

    if write_to_obs:
        adata.obs[obs_pair_key] = pd.Series(pd.Categorical([pd.NA]*adata.n_obs), index=adata.obs_names)
        new_pairs = sorted(x for x in out["pair"].dropna().unique())
        if new_pairs:
            adata.obs[obs_pair_key] = adata.obs[obs_pair_key].cat.add_categories(new_pairs)
            m = out["pair"].notna().to_numpy()
            adata.obs.loc[m, obs_pair_key] = out.loc[m, "pair"].astype("string").to_numpy()
            adata.obs[obs_pair_key] = adata.obs[obs_pair_key].cat.remove_unused_categories()

        if obs_flag_key not in adata.obs:
            adata.obs[obs_flag_key] = pd.Series("confident", index=adata.obs_names, dtype="category")
        elif not isinstance(adata.obs[obs_flag_key].dtype, CategoricalDtype):
            adata.obs[obs_flag_key] = adata.obs[obs_flag_key].astype("category")
        if "juxtaposed" not in adata.obs[obs_flag_key].cat.categories:
            adata.obs[obs_flag_key] = adata.obs[obs_flag_key].cat.add_categories(["juxtaposed"])

        was_juxta = adata.obs[obs_flag_key].astype(str).eq("juxtaposed").to_numpy()
        now_juxta = out["is_juxtaposed"].to_numpy()
        clear_mask = was_juxta & (~now_juxta)
        adata.obs.loc[clear_mask, obs_flag_key] = "confident"
        set_mask = (~was_juxta) & now_juxta & adata.obs[obs_flag_key].astype(str).eq("confident").to_numpy()
        adata.obs.loc[set_mask, obs_flag_key] = "juxtaposed"
        adata.obs[obs_flag_key] = adata.obs[obs_flag_key].cat.remove_unused_categories()

    return out


In [8]:
marker_sets = {
    'Tumor epithelial': ['Krt8', 'Krt18', 'Krt19', 'Krt7', "Muc1"],
    'Fibroblasts': ['Col1a1', 'Col3a1', 'Acta2', "Lrrc15"],
    'Endothelial': ['Cdh5', 'Pecam1', 'Plvap'],
    'Macrophages': ['C1qa', 'C1qb','Arg1', 'Mertk', 'Marco', 'Mrc1', 'Trem2'],
    'TNK': ['Cd4', 'Cd8a', 'Nkg7', 'Klrb1c', 'Prf1', 'Gzmb'],
    'B cells': ['Cd79b', 'Ms4a1'],
    'Plasma cells': ['Igkc', 'Iglc3', 'Xbp1'],
    'Dendritic cells': ['Clec9a', 'Xcr1', 'Irf8', 'Cd86']
}


In [None]:
thr = {"Epithelial":3, "Endothelial":3, "Fibroblast":3, "Myeloid":6, "T/NK":3, "B/Plasma":4} # number of genes per cell type to say that the cell belongs to that lineage.

out = detect_juxta_quantile(
    adata,
    marker_sets=marker_sets,
    layer=None,                 
    z_thresh=1.5,
    q=0.95,                    
    thresholds_override=thr,
    allowed_pairs=None,         
   # required_any_of={"Epithelial","Endothelial"}, 
    require_neighbor_support=False,
    neighbors_obsp_key="connectivities",
    min_neighbor_support=2,
    write_to_obs=True
)

print(out.attrs["anchor_thresholds"])
adata.obs[['juxtaposed_pair','qc_flag']].value_counts(dropna=False).head(20)


### Functions to later find juxtaposaed cells on a cell type specific basis (and to archive old juxtaposed tags)

In [None]:
import re
import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import CategoricalDtype

def _split_pair(pair_str: str):
    toks = [t.strip() for t in re.split(r"\s*[-–—|~]\s*", str(pair_str).strip()) if t.strip()]
    return (toks[0], toks[1]) if len(toks) == 2 else (str(pair_str), None)

def _canon(name: str, markers: dict):
    n = re.sub(r"\s+", " ", str(name)).strip()         
    if n in markers:                                   
        return n
    lower_keys = {k.lower(): k for k in markers.keys()}
    if n.lower() in lower_keys:
        return lower_keys[n.lower()]
    key = _alias.get(n.lower())
    if key in markers:
        return key
    alt = (n[:-1] if n.endswith("s") else n + "s")
    if alt in markers:
        return alt
    if alt.lower() in lower_keys:
        return lower_keys[alt.lower()]
    return None

_alias = {
    # Cancer / epithelial
    "epithelial": "Cancer Cells", "cancer": "Cancer Cells", "cancer cells": "Cancer Cells",

    # B cells
    "b cell": "B Cells", "b cells": "B Cells", "b-cells": "B Cells",

    # T/NK
    "t/nk": "T and NK Cells", "tnk": "T and NK Cells", "t and nk cells": "T and NK Cells",

    # DCs
    "dc": "DCs", "dcs": "DCs",

    # Myeloid
    "myeloid": "Macrophages",
    "macrophage": "Macrophages", "macrophages": "Macrophages",
    "monocyte": "Monocyte", "monocytes": "Monocyte",
    "neutrophil": "Neutrophils", "neutrophils": "Neutrophils",

    # Endothelial
    "endothelium": "Endothelial Cells", "endothelial": "Endothelial Cells",
    "endothelial cells": "Endothelial Cells", "endothelial cell": "Endothelial Cells",
    "endothelial cells ": "Endothelial Cells",

    # Fibroblast
    "fibroblast": "Fibroblast", "fibroblasts": "Fibroblast",
}

_display_map = {
    "Cancer Cells":    "Cancer Cells",
    "Macrophages":     "Macrophages",
    "Monocyte":        "Monocytes",
    "DCs":             "Dendritic Cells",
    "Neutrophils":     "Neutrophils",
    "B Cells":         "B Cells",
    "T and NK Cells":  "T and NK Cells",
    "Endothelial Cells":     "Endothelial Cells",
    "Fibroblast":      "Fibroblasts",
}

_BROAD_LEVELS = ["Cancer Cells","Myeloid Cells","Stromal Cells","T and NK Cells","B Cells","Plasma Cells","Unlabeled"]
_broad_from_display = {
    "Cancer Cells":    "Cancer Cells",
    "Macrophages":     "Myeloid Cells",
    "Monocytes":       "Myeloid Cells",
    "Dendritic Cells": "Myeloid Cells",
    "Neutrophils":     "Myeloid Cells",
    "T and NK Cells":  "T and NK Cells",
    "B Cells":         "B Cells",
    "Endothelial Cells":     "Stromal Cells",
    "Fibroblasts":     "Stromal Cells",
}

def normalize_pairs_in_place(adata, pair_col="juxtaposed_pair"):
    if pair_col not in adata.obs: return
    s = adata.obs[pair_col].astype("string")

    def _norm_one(p):
        if pd.isna(p): return p
        parts = [t.strip() for t in re.split(r"\s*[-–—|~]\s*", p) if t.strip()]
        if len(parts) != 2: 
            return p  # leave triads etc. to your triad logic
        A, B = parts
        a = _canon(A, lineage_markers); b = _canon(B, lineage_markers)
        if not a or not b: 
            return p
        return f"{_display_map.get(a,a)} - {_display_map.get(b,b)}"

    adata.obs[pair_col] = s.map(_norm_one)

QC_LEVELS = ["confident","doublet","juxtaposed","ambiguous","low_quality","to_discard"]

def _ensure_cat(series, cats=None):
    if not isinstance(series.dtype, CategoricalDtype):
        series = series.astype("category")
    if cats:
        add = [c for c in cats if c not in series.cat.categories]
        if add:
            series = series.cat.add_categories(add)
    return series

def _canon(name: str, markers: dict):
    n = str(name).strip()
    for cand in (n, n.title(), n.upper()):
        if cand in markers:
            return cand
    key = _alias.get(n.lower())
    return key if key in markers else None

def _split_pair(pair_str: str):
    toks = [t.strip() for t in re.split(r"\s*[-~|]\s*", str(pair_str).strip()) if t.strip()]
    return (toks[0], toks[1]) if len(toks) == 2 else (str(pair_str), None)

def _sorted_pair_label(pair_str: str) -> str:
    a, b = _split_pair(pair_str)
    if b is None:
        return str(pair_str).strip()
    return " - ".join(sorted([a, b], key=lambda x: x.lower()))

def _present_vars(var_names, genes):
    lut = {g.lower(): g for g in var_names}
    keep = [lut[g.lower()] for g in genes if g.lower() in lut]
    miss = [g for g in genes if g.lower() not in lut]
    return keep, miss

def _counts_per_cell(adata, genes, threshold=0.0, use_raw=None):
    if use_raw is None:
        use_raw = adata.raw is not None
    vnames = adata.raw.var_names if (use_raw and adata.raw is not None) else adata.var_names
    keep, _ = _present_vars(vnames, genes)
    if not keep:
        return pd.Series(0, index=adata.obs_names, dtype=int)
    X = (adata.raw[:, keep].X if (use_raw and adata.raw is not None) else adata[:, keep].X)
    if sp.issparse(X):
        hits = (X > threshold).astype(np.int8).sum(axis=1).A1
    else:
        X = np.asarray(X)
        hits = (X > threshold).sum(axis=1)
    return pd.Series(hits.astype(int), index=adata.obs_names)

def _safe(name: str) -> str:
    return re.sub(r"\W+", "_", name)

def archive_legacy_juxtaposition(adata, pair_col="juxtaposed_pair", qc_flag_col="qc_flag", suffix="_tentative"):
    """
    Move any old/automatic juxtaposition outputs to '*_tentative' columns,
    demote their qc_flag=='juxtaposed' back to 'confident', and clear curated fields.
    """
    for col in [pair_col, "juxta_call", "juxta_primary"]:
        if col in adata.obs and f"{col}{suffix}" not in adata.obs:
            ser = adata.obs[col]
            ser = ser.astype("string") if not isinstance(ser.dtype, CategoricalDtype) else ser.astype("string")
            adata.obs[f"{col}{suffix}"] = ser

    if qc_flag_col in adata.obs:
        adata.obs[qc_flag_col] = _ensure_cat(adata.obs[qc_flag_col], QC_LEVELS)
        tent = adata.obs.get(f"{pair_col}{suffix}", pd.Series(pd.NA, index=adata.obs_names))
        demote = adata.obs[qc_flag_col].astype(str).eq("juxtaposed") & tent.notna()
        adata.obs.loc[demote, qc_flag_col] = "confident"

    adata.obs[pair_col]         = pd.Series(pd.NA, index=adata.obs_names, dtype="string")
    adata.obs["juxta_call"]     = pd.Series(pd.NA, index=adata.obs_names, dtype="string")
    adata.obs["juxta_primary"]  = pd.Series(pd.NA, index=adata.obs_names, dtype="string")

    if "juxta_params" in adata.uns:
        adata.uns.setdefault("juxta_archive", []).append(adata.uns["juxta_params"])
        del adata.uns["juxta_params"]

def juxta_by_gene_counts(
    adata,
    markers: dict,
    pair_col: str = "juxtaposed_pair",
    qc_flag_col: str = "qc_flag",
    subtype_col: str = "subtype",
    initial_broad_col: str = "initial_broad",   # <- NEW: update broad when single-lineage
    cluster_key: str = "louvain",
    min_genes: int = 2,
    expr_threshold: float = 0.0,
    use_raw: bool | None = None,
):
    """
    Require >= min_genes expressed from each lineage. Policy:
      both    -> juxta_call = '<A> - <B>' (alphabetical), qc='juxtaposed', subtype='<A> - <B>'
      A_only  -> juxta_call = display(A), qc='ambiguous', subtype=display(A), initial_broad set via broad map
      B_only  -> juxta_call = display(B), qc='ambiguous', subtype=display(B), initial_broad set via broad map
      neither -> qc='to_discard', subtype='to_discard'
    """
    if qc_flag_col not in adata.obs:
        adata.obs[qc_flag_col] = pd.Series("confident", index=adata.obs_names, dtype="category")
    adata.obs[qc_flag_col] = _ensure_cat(adata.obs[qc_flag_col], QC_LEVELS)

    if pair_col not in adata.obs:
        adata.obs[pair_col] = pd.Series(pd.NA, index=adata.obs_names, dtype="string")
    else:
        adata.obs[pair_col] = adata.obs[pair_col].astype("string")

    if subtype_col not in adata.obs:
        adata.obs[subtype_col] = pd.Series(pd.Categorical([pd.NA]*adata.n_obs), index=adata.obs_names)
    elif not isinstance(adata.obs[subtype_col].dtype, CategoricalDtype):
        adata.obs[subtype_col] = adata.obs[subtype_col].astype("category")

    if initial_broad_col not in adata.obs:
        adata.obs[initial_broad_col] = pd.Series("Unlabeled", index=adata.obs_names, dtype="category")
    adata.obs[initial_broad_col] = _ensure_cat(adata.obs[initial_broad_col], _BROAD_LEVELS)

    for col in ["juxta_call","juxta_primary"]:
        if col not in adata.obs:
            adata.obs[col] = pd.Series(pd.NA, index=adata.obs_names, dtype="string")
        else:
            adata.obs[col] = adata.obs[col].astype("string")

    jmask = adata.obs[qc_flag_col].astype(str).eq("juxtaposed") & adata.obs[pair_col].notna()
    if jmask.sum() == 0:
        print("[INFO] No juxtaposed cells found.")
        return pd.DataFrame(), {}

    pairs = adata.obs.loc[jmask, pair_col].astype(str).unique().tolist()
    used, bad_pairs, parsed_pairs = set(), [], []
    for p in pairs:
        A_raw, B_raw = _split_pair(p)
        A = _canon(A_raw, markers); B = _canon(B_raw, markers)
        if not A or not B:
            bad_pairs.append(p); continue
        used.update([A, B]); parsed_pairs.append((p, A, B))
    if bad_pairs:
        print("[WARN] Missing marker sets for:", ", ".join(bad_pairs))

    missing_genes, counts = {}, {}
    vnames = adata.raw.var_names if (use_raw and adata.raw is not None) else adata.var_names
    for lin in used:
        keep, miss = _present_vars(vnames, markers[lin])
        missing_genes[lin] = miss
        counts[lin] = _counts_per_cell(adata, markers[lin], expr_threshold, use_raw)

    rows = []
    for pair_str, A, B in parsed_pairs:
        pmask = jmask & adata.obs[pair_col].astype(str).eq(pair_str)
        if pmask.sum() == 0: 
            continue

        a_ct = counts[A].loc[pmask.index][pmask]
        b_ct = counts[B].loc[pmask.index][pmask]
        Apos, Bpos = (a_ct >= min_genes), (b_ct >= min_genes)

        both, A_only, B_only, neither = (Apos & Bpos), (Apos & ~Bpos), (Bpos & ~Apos), (~Apos & ~Bpos)

        primary = np.where(a_ct.values >= b_ct.values, A, B)
        adata.obs.loc[pmask, "juxta_primary"] = primary

        pair_sorted = _sorted_pair_label(pair_str)
        needed = [pair_sorted, _display_map.get(A, A), _display_map.get(B, B), "to_discard"]
        adata.obs[subtype_col] = _ensure_cat(adata.obs[subtype_col], needed)

        # BOTH: keep qc 'juxtaposed', set subtype to pair label
        if both.any():
            adata.obs.loc[pmask & both, "juxta_call"] = pair_sorted
            adata.obs.loc[pmask & both, subtype_col]  = pair_sorted

        def _update_initial_broad(mask, display_label):
            broad = _broad_from_display.get(display_label)
            if broad is not None:
                adata.obs[initial_broad_col] = _ensure_cat(adata.obs[initial_broad_col], _BROAD_LEVELS)
                adata.obs.loc[pmask & mask, initial_broad_col] = broad

        # SINGLE -> ambiguous + set initial_broad
        if A_only.any():
            labA = _display_map.get(A, A)
            adata.obs.loc[pmask & A_only, "juxta_call"] = labA
            adata.obs.loc[pmask & A_only, subtype_col]  = labA
            adata.obs.loc[pmask & A_only, qc_flag_col]  = "ambiguous"
            _update_initial_broad(A_only, labA)

        if B_only.any():
            labB = _display_map.get(B, B)
            adata.obs.loc[pmask & B_only, "juxta_call"] = labB
            adata.obs.loc[pmask & B_only, subtype_col]  = labB
            adata.obs.loc[pmask & B_only, qc_flag_col]  = "ambiguous"
            _update_initial_broad(B_only, labB)

        # NEITHER -> to_discard (leave initial_broad unchanged)
        if neither.any():
            adata.obs.loc[pmask & neither, "juxta_call"] = "neither"
            adata.obs.loc[pmask & neither, subtype_col]  = "to_discard"
            adata.obs.loc[pmask & neither, qc_flag_col]  = "to_discard"

        adata.obs.loc[pmask, f"juxta_{_safe(A)}_n"] = a_ct.values
        adata.obs.loc[pmask, f"juxta_{_safe(B)}_n"] = b_ct.values

        rows.append(pd.DataFrame({
            "pair":    pair_sorted,
            "cluster": adata.obs.loc[pmask, cluster_key].astype(str).values,
            "juxta_call": adata.obs.loc[pmask, "juxta_call"].astype(str).values
        }))

    if rows:
        df = pd.concat(rows, axis=0, ignore_index=True)
        summary = df.value_counts(["pair","cluster","juxta_call"]).rename("n").reset_index()
        totals  = df.value_counts(["pair","cluster"]).rename("N").reset_index()
        out = summary.merge(totals, on=["pair","cluster"])
        out["frac"] = out["n"] / out["N"]
        adata.uns["juxta_params"] = dict(min_genes=min_genes, expr_threshold=expr_threshold, used_lineages=sorted(used))
        if isinstance(adata.obs[subtype_col].dtype, CategoricalDtype):
            adata.obs[subtype_col] = adata.obs[subtype_col].cat.remove_unused_categories()
        adata.obs[qc_flag_col] = _ensure_cat(adata.obs[qc_flag_col], QC_LEVELS)
        if isinstance(adata.obs[initial_broad_col].dtype, CategoricalDtype):
            adata.obs[initial_broad_col] = adata.obs[initial_broad_col].cat.remove_unused_categories()
        return out.sort_values(["pair","cluster","juxta_call"]).reset_index(drop=True), missing_genes
    else:
        return pd.DataFrame(), missing_genes


# Save AnnDatas for Annotation

### Create essential obs for later 

In [178]:
from pandas.api.types import CategoricalDtype

QC_LEVELS = ["confident","doublet","juxtaposed","ambiguous","low_quality","to_discard"]

def ensure_cat(s, cats):
    if not isinstance(s.dtype, CategoricalDtype):
        s = s.astype('category')
    to_add = [c for c in cats if c not in s.cat.categories]
    return s.cat.add_categories(to_add) if to_add else s

adata.obs['initial_broad'] = adata.obs.get('annotated', 'Unlabeled').astype('category')

if 'subtype' not in adata.obs:
    adata.obs['subtype'] = pd.Series(pd.Categorical([pd.NA]*adata.n_obs), index=adata.obs_names)
if 'subtype_granular' not in adata.obs:
    adata.obs['subtype_granular'] = pd.Series(pd.Categorical([pd.NA]*adata.n_obs), index=adata.obs_names)

adata.obs['qc_flag'] = adata.obs.get('qc_flag', pd.Series('confident', index=adata.obs_names, dtype='category'))
adata.obs['qc_flag'] = ensure_cat(adata.obs['qc_flag'], QC_LEVELS)

adata.obs['juxtaposed_pair'] = adata.obs.get('juxtaposed_pair', pd.Series(pd.NA, index=adata.obs_names, dtype='string'))
adata.obs['annot_source']    = adata.obs.get('annot_source',    pd.Series(pd.NA, index=adata.obs_names, dtype='string'))


In [179]:
if "X_scVI" not in adata.obsm_keys():
    raise ValueError("obsm['X_scVI'] not found in `adata`. Please compute or specify a different use_rep.")

In [11]:
groups = [
    "Cancer Cells",
    "Stromal Cells",
    "T and NK Cells",
    "B Cells",
    "Myeloid Cells",
    "Plasma Cells",
]
fname_map = {
    "Cancer Cells":  "adata_JMT_cancer_to_annot.h5ad",
    "Stromal Cells": "adata_JMT_stromal_to_annot.h5ad",
    "T and NK Cells":"adata_JMT_tnk_to_annot.h5ad",
    "B Cells":       "adata_JMT_bcells_to_annot.h5ad",
    "Myeloid Cells": "adata_JMT_myeloid_to_annot.h5ad",
    "Plasma Cells":  "adata_JMT_plasma_to_annot.h5ad",
}

In [17]:
from pandas.api.types import is_string_dtype, is_categorical_dtype

def sanitize_nullable_strings_inplace(adata):
    for attr in ("obs", "var"):
        df = getattr(adata, attr)
        for col in df.columns:
            s = df[col]
            if is_string_dtype(s.dtype):
                df[col] = s.astype(object)
            elif is_categorical_dtype(s):
                cats = s.cat.categories
                if hasattr(cats, "dtype") and is_string_dtype(cats.dtype):
                    df[col] = s.astype(object).astype("category")
    return adata


def process_subset(adata_full, group_label, use_rep="X_scVI",
                   louvain_resolutions=(0.5, 1.0, 1.5, 2.0, 2.5),
                   umap_min_dist=0.3, random_state=0):
    mask = adata_full.obs["annotated"].astype("string").eq(group_label).to_numpy()
    n = int(mask.sum())
    if n == 0:
        print(f"[skip] {group_label}: no cells found.")
        return None, None
    sub = adata_full[mask].copy()
    print(f"[{group_label}] n={n}")
    sc.pp.neighbors(sub, use_rep=use_rep)
    sc.tl.umap(sub, min_dist=umap_min_dist, random_state=random_state)

    for res in louvain_resolutions:
        sc.tl.louvain(sub, resolution=float(res), key_added=f"louvain_r{res}", random_state=random_state)

    primary_key = f"louvain_r{max(louvain_resolutions)}"
    sub.obs["louvain"] = sub.obs[primary_key].copy()
    if isinstance(sub.obs["louvain"].dtype, pd.CategoricalDtype):
        sub.obs["louvain"] = sub.obs["louvain"].cat.remove_unused_categories()

    sc.tl.rank_genes_groups(sub, groupby="louvain", method="wilcoxon")
    sc.tl.dendrogram(sub, groupby="louvain", use_rep=use_rep)

    sanitize_nullable_strings_inplace(sub)

    out = fname_map.get(group_label, f"adata_{group_label.replace(' ','_').lower()}_to_annot.h5ad")
    sub.write_h5ad(out)
    print(f"[{group_label}] saved -> {out}")
    return sub, out


In [None]:
results = {}
for g in groups:
    sub, fn = process_subset(adata, g)
    results[g] = {"adata": sub, "file": fn}

print("\nSummary:")
for g, info in results.items():
    if info["adata"] is None:
        print(f" - {g}: skipped (no cells)")
    else:
        print(f" - {g}: {info['adata'].n_obs} cells -> {info['file']}")

# Annotate Myeloid Cells

In [None]:
fname_map["Myeloid Cells"]

In [205]:
myeloid_adata = sc.read_h5ad(fname_map["Myeloid Cells"])

In [None]:
plt.rcParams['figure.facecolor'] = 'white'
sc.pl.umap(myeloid_adata,color=["louvain", "Treatment", "Csf1r", "C1qb", "Marco", "Cd79b", "Plvap","Cd8a",
                               "Clec9a", "Xcr1", "Irf8", "Arg1", "Spp1", "Mrc1", "Folr2", "Trem2", "Il10",
                                "Il1b", "Krt18", "Krt19", "Tgfb1", "Il6", "Ccr7",  "Acta2", "Col1a1", 
                                "Hpgd"],
               size=1,legend_loc="on data", cmap="Reds",
               legend_fontsize=8,vmax="p99",
               legend_fontoutline=1,ncols=4, 
               show=True,frameon=False)

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
from pandas.api.types import CategoricalDtype

A = myeloid_adata           

mono_core   = ["Ly6c2","Ccr2","Sell","Cd14","Cxcr4","Cx3cr1"]
macro_core  = ["C1qa","C1qb","C1qc","Mertk","Mrc1","Trem2","Folr2","Marco","Lyve1",
               "Igf1","Lgals3","Lrp1","Cd68","Mpeg1","Csf1r","Sirpa"]
dc_core     = ["Xcr1","Clec9a","Batf3","Wdfy4","Itgax","Clec10a","Sirpa"]  # cDC1 + general DC cues
neutro_core = ["Csf3r","Cxcr2","Pglyrp1","Sell"]  # tight neutrophil set from your panel

def keep_present(glist): 
    return [g for g in glist if g in A.var_names]

mono_core   = keep_present(mono_core)
macro_core  = keep_present(macro_core)
dc_core     = keep_present(dc_core)
neutro_core = keep_present(neutro_core)

print("mono_core:", mono_core)
print("macro_core:", macro_core)
print("dc_core:", dc_core)
print("neutro_core:", neutro_core)

sc.tl.score_genes(A, mono_core,   score_name="score_mono_core")
sc.tl.score_genes(A, macro_core,  score_name="score_macro_core")
sc.tl.score_genes(A, dc_core,     score_name="score_dc_core")
sc.tl.score_genes(A, neutro_core, score_name="score_neutro_core")

def call_per_cell(row, margin=0.1):
    scores = {
        "Monocyte-like":   row.get("score_mono_core",   -np.inf),
        "Macrophage-like": row.get("score_macro_core",  -np.inf),
        "DC-like":         row.get("score_dc_core",     -np.inf),
        "Neutrophil-like": row.get("score_neutro_core", -np.inf),
    }
    # rank by score
    items = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
    if items[0][1] - items[1][1] < margin:
        return items[0][0] + " (weak)"
    return items[0][0]

A.obs["myeloid_lineage_call"] = pd.Categorical(
    [call_per_cell(row) for _, row in A.obs[score_cols].iterrows()]
)

g = (
    A.obs
    .groupby(louvain_key)[score_cols]
    .mean()
    .assign(delta_mono_macro=lambda df: df.get("score_mono_core", 0) - df.get("score_macro_core", 0))
)

def cluster_call(row, thr=0.05):
    s = {
        "Monocyte-like":   row.get("score_mono_core",   -np.inf),
        "Macrophage-like": row.get("score_macro_core",  -np.inf),
        "DC-like":         row.get("score_dc_core",     -np.inf),
        "Neutrophil-like": row.get("score_neutro_core", -np.inf),
    }
    top = sorted(s.items(), key=lambda kv: kv[1], reverse=True)
    return top[0][0] if (top[0][1] - top[1][1]) >= thr else top[0][0] + " (weak)"

g["predicted_lineage"] = g.apply(cluster_call, axis=1)

print("\n=== Cluster summary (means) ===")
display_cols = [c for c in ["predicted_lineage","score_mono_core","score_macro_core","score_dc_core","score_neutro_core","delta_mono_macro"] if c in g.columns]
print(g[display_cols].sort_index().to_string())

cluster_map = g["predicted_lineage"].to_dict()
print("\nCluster → lineage call:")
for k, v in cluster_map.items():
    print(f"  {k}: {v}")

var_sets = {}
if mono_core:   var_sets["Mono core"]   = mono_core
if macro_core:  var_sets["Mac core"]    = macro_core
if dc_core:     var_sets["DC core"]     = dc_core
if neutro_core: var_sets["Neutro core"] = neutro_core

if var_sets:
    sc.pl.dotplot(
        A, var_names=var_sets, groupby=louvain_key,
        standard_scale="var", swap_axes=True, show=False
    )

In [238]:
archive_legacy_juxtaposition(myeloid_adata)

In [None]:
myeloid_adata.obs['subtype'] = myeloid_adata.obs['louvain'].cat.add_categories(['Cancer Cells', 
                                                                 'Dendritic Cells',
                                                                'Neutrophils',
                                                                'Macrophages',
                                                                'Monocytes',
                                                                 'to_discard'])

myeloid_adata.obs.loc[myeloid_adata.obs['louvain'].isin(["33", "29"]), 'subtype'] = 'Cancer Cells'
myeloid_adata.obs.loc[myeloid_adata.obs['louvain'].isin(["15", "30"]), 'subtype'] = 'Monocytes'
myeloid_adata.obs.loc[myeloid_adata.obs['louvain'].isin(["1"]), 'subtype'] = 'Dendritic Cells'
myeloid_adata.obs.loc[myeloid_adata.obs['louvain'].isin(["26"]), 'subtype'] = 'Neutrophils'
myeloid_adata.obs.loc[myeloid_adata.obs['louvain'].isin(["7", "10", "2", "27", "13",
                                                        "19", "23", "21", "0", "8",
                                                        "25", "11", "9", "6", "4", "32",
                                                        "16", "3", "31", "5", "28"]), 'subtype'] = 'Macrophages'

myeloid_adata.obs.loc[myeloid_adata.obs['louvain'].isin(["17", "22", "29"]), 'subtype'] = 'to_discard'

myeloid_adata.obs["juxtaposed_pair"] = myeloid_adata.obs["juxtaposed_pair"].astype("string")
lv = myeloid_adata.obs["louvain"].astype(str)

def mark_pair(tags, pair):
    mask = lv.isin(tags)
    myeloid_adata.obs.loc[mask, "qc_flag"] = "juxtaposed"
    myeloid_adata.obs.loc[mask, "juxtaposed_pair"] = pair  

mark_pair(["14","18"], "Cancer Cells - Macrophages")
mark_pair(["19","23"], "B Cells - Macrophages")
mark_pair(["12","20"], "T and NK Cells - Macrophages")
mark_pair(["24"],      "Endothelial Cells - Monocyte")
normalize_pairs_in_place(myeloid_adata, "juxtaposed_pair")


In [None]:
from pandas.api.types import CategoricalDtype
if isinstance(myeloid_adata.obs.get("juxtaposed_pair", pd.Series([], dtype="string")).dtype, CategoricalDtype):
    myeloid_adata.obs["juxtaposed_pair"] = myeloid_adata.obs["juxtaposed_pair"].astype("string")

    # ---------------------------------------
lineage_markers = {
    "Cancer Cells": ["Epcam","Krt8","Krt18","Krt19","Muc1","Krt7"],
    "Macrophages":   ["Csf1r","Mrc1","Mpeg1","Cd68","Mertk","C1qb","C1qa","C1qc","Trem2","Folr2","Marco","Igf1","Lgals3","Lrp1","Sirpa"],
    "Monocyte":     ["Ly6c2","Ccr2","Sell","Cd14","Cxcr4","Cx3cr1"],
    "B Cells":      ["Ms4a1","Cd79b","Igkc","Iglc3"],
    "DCs":          ["Xcr1","Clec9a","Batf3","Wdfy4","Itgax","Clec10a","Sirpa"],
    "T and NK Cells":          ["Cd3d","Cd3e","Cd8a","Nkg7","Trac","Cd4"],
    "Tregs":        ["Cd4", "Ctla4", "Foxp3", "Tigit"],
    "Neutrophils":  ["Csf3r","Cxcr2","Pglyrp1","Sell"],
    "Endothelial Cells":  ["Pecam1","Kdr","Cdh5","Plvap"],
    "Fibroblast":   ["Col1a1","Col1a2","Col3a1","Col5a1","Col5a2","Col6a1","Pdgfra","Pdgfrb","Postn","Fn1","Thy1","Vim"],
}

summary, missing = juxta_by_gene_counts(
    myeloid_adata,
    lineage_markers,
    pair_col="juxtaposed_pair",
    qc_flag_col="qc_flag",
    subtype_col="subtype",       
    cluster_key="louvain",
    min_genes=3,
    expr_threshold=0.1,
    use_raw=True                 
)

print(summary)
if any(len(v)>0 for v in missing.values()):
    print("Missing genes (ignored):")
    for lin, miss in missing.items():
        if miss:
            print(f"  {lin}: {', '.join(miss)}")


In [245]:
import pandas as pd
from pandas.api.types import CategoricalDtype

QC_LEVELS = ["confident","doublet","juxtaposed","ambiguous","low_quality","to_discard"]

def merge_subset_back(
    master,
    subset,
    cols=("subtype","subtype_granular","qc_flag","juxtaposed_pair","juxta_call"),
    cat_cols=("subtype","subtype_granular","qc_flag"),
    str_cols=("juxtaposed_pair","juxta_call"),
):
    idx = master.obs_names.intersection(subset.obs_names)
    if len(idx) == 0:
        print("[WARN] No overlapping cells between master and subset.")
        return master

    for col in cols:
        if col not in subset.obs:
            if col in cat_cols:
                subset.obs[col] = pd.Series(pd.Categorical([pd.NA]*subset.n_obs), index=subset.obs_names)
            else:
                subset.obs[col] = pd.Series(pd.NA, index=subset.obs_names, dtype="string")

    for col in cols:
        if col not in master.obs:
            if col in cat_cols:
                master.obs[col] = pd.Series(pd.Categorical([pd.NA]*master.n_obs), index=master.obs_names)
            else:
                master.obs[col] = pd.Series(pd.NA, index=master.obs_names, dtype="string")

    if "qc_flag" in cat_cols and "qc_flag" in master.obs:
        if not isinstance(master.obs["qc_flag"].dtype, CategoricalDtype):
            master.obs["qc_flag"] = master.obs["qc_flag"].astype("category")
        need = [c for c in QC_LEVELS if c not in master.obs["qc_flag"].cat.categories]
        if need:
            master.obs["qc_flag"] = master.obs["qc_flag"].cat.add_categories(need)

    for col in cols:
        if col in str_cols:
            master.obs[col] = master.obs[col].astype("string")
            vals = subset.obs.loc[idx, col].astype("string")
            master.obs.loc[idx, col] = vals
            continue

        if not isinstance(master.obs[col].dtype, CategoricalDtype):
            master.obs[col] = master.obs[col].astype("category")
        if not isinstance(subset.obs[col].dtype, CategoricalDtype):
            subset.obs[col] = subset.obs[col].astype("string").astype("category")

        new_cats = [c for c in subset.obs[col].cat.categories if c not in master.obs[col].cat.categories]
        if new_cats:
            master.obs[col] = master.obs[col].cat.add_categories(new_cats)

        master.obs.loc[idx, col] = subset.obs.loc[idx, col].astype(master.obs[col].dtype)

    return master


In [243]:
adata = merge_subset_back(adata, myeloid_adata)

# Annotate Macrophages

In [25]:
mac_adata = myeloid_adata[myeloid_adata.obs['subtype'].isin(["Macrophages"])].copy()

In [None]:
import inspect

louvain_resolutions = [0.5, 1.0, 1.5, 2.0, 2.5]

sc.pp.neighbors(mac_adata, use_rep='X_scVI')
sc.tl.umap(mac_adata, min_dist=0.3, random_state=0)

rgg_keys = {}
for res in louvain_resolutions:
    lv_key = f"louvain_r{res}"
    sc.tl.louvain(mac_adata, resolution=float(res), key_added=lv_key, random_state=0)
    mac_adata.obs[lv_key] = mac_adata.obs[lv_key].astype('category').cat.remove_unused_categories()

    rkey = f"rgg_{lv_key}"
    sc.tl.rank_genes_groups(mac_adata, groupby=lv_key, method="wilcoxon", key_added=rkey)
    rgg_keys[res] = rkey

    dkey = f"dendrogram_{lv_key}"
    dend_kwargs = dict(groupby=lv_key, use_rep="X_scVI")
    if "key" in inspect.signature(sc.tl.dendrogram).parameters:
        sc.tl.dendrogram(mac_adata, key=dkey, **dend_kwargs)
    else:
        sc.tl.dendrogram(mac_adata, key_added=dkey, **dend_kwargs)

print("Done. Rank-genes keys:", rgg_keys)


In [None]:
louvain_resolutions = [3.0]
for res in louvain_resolutions:
    lv_key = f"louvain_r{res}"
    sc.tl.louvain(mac_adata, resolution=float(res), key_added=lv_key, random_state=0)
    mac_adata.obs[lv_key] = mac_adata.obs[lv_key].astype('category').cat.remove_unused_categories()

    rkey = f"rgg_{lv_key}"
    sc.tl.rank_genes_groups(mac_adata, groupby=lv_key, method="wilcoxon", key_added=rkey)
    rgg_keys[res] = rkey

    dkey = f"dendrogram_{lv_key}"
    dend_kwargs = dict(groupby=lv_key, use_rep="X_scVI")
    if "key" in inspect.signature(sc.tl.dendrogram).parameters:
        sc.tl.dendrogram(mac_adata, key=dkey, **dend_kwargs)
    else:
        sc.tl.dendrogram(mac_adata, key_added=dkey, **dend_kwargs)

print("Done. Rank-genes keys:", rgg_keys)


In [None]:
mac_adata.X

In [82]:
Mac_annotations = {'Cycling M': ['Mki67', 'Top2a', 'Stmn1'],
                   'Clearing M': ['Ccl5', 'Cd3d'],
                   'Il1b': ['Ifng', 'Il1b', 'Il1a','Il1rn', 'Cxcl2', 'Ccl3', 'Ccl4', 'Areg', 'Tnf'],
                   'M2 Col1a1': ['Col1a1', 'Col1a2', 'Col3a1', 'Il33'],
                   'M2 Mmp9': ['Ckb', 'Mmp9'],
                   'MHC2': ['H2-Ab1'],
                   'M2 MARCO': ['Marco', 'Spp1'],
                   'M2 Selenop': ['Ccl3'],
                   'M2 Cxcl10': ['Cxcl10', 'Isg15', 'Irf7', 'Cxcl9', 'Cd274'],
                   'M1 S100a8': ['Vcan', 'Il1b'],    
                   'Complement': ['C1qa', 'C1qb', 'C1qc','Arg1'],
                   'Trem2': ['Trem2', 'Cd9', 'Mrc1', 'Il10'],
                   'Ifit3': ['Ccl6', 'Ly6c2', 'Ccr2'],
                   'Folr2': ['Folr2'],
                   'Mertk': ['Mertk', 'Axl'],
                   'Spp1': ['Spp1', 'Vegfa', 'Thbs1']}  

modules = {
    "Cycling": ["Mki67","Top2a","Stmn1"],
    "MHCII_AP": ["H2-Ab1"],
    "Complement_TRM": ["C1qa","C1qb","C1qc"],
    "Efferocytosis_Scav": ["Trem2","Mertk","Axl","Marco","Folr2","Mrc1","Hmox1","Igf1"],
    "Checkpoint_IL10_TGFb": ["Cd274","Il10","Tgfb1","Ido1","Arg1"],
    "ProInflammatory": ["Il1b","Il6","Tnf","Cd86","Cd80","Cd40","Cybb","Cd14"],
    "TypeI_IFN": ["Cxcl10","Cxcl9","Isg15","Irf7"],
    "Monocyte_like": ["Ly6c2","Ccr2","Ccl6"],
    "ECM_Remodel": ["Col1a1","Col1a2","Col3a1","Mmp9","Il33"],
    "PPARg_Lipid": ["Pparg","Mrc1","Mertk","Marco"],
    "Tcell_Contam": ["Cd3d","Cd8a", "Cd4"]
}


In [None]:
mac_state_dict = {
    "Immunosuppressive Macrophages": [
        "Arg1", "Cd274", "Folr2", "Hmox1", "Il10", "Igf1", 
        "Mertk", "Mrc1", "Tgfb1", "Trem2", "Il4"
    ],
    "Pro-inflammatory Macrophages": [
        "Cd14", "Cd40", "Cd80", "Cd86", "Cybb", "Epas1", 
        "Il1b", "Il6", "Ccr7"
    ],
    "Myeloid Regulatory (Immunosuppression)": [
        "Arg1", "Cd274", "Il4", "Il10"
    ],
    "Myeloid Regulatory (Type I Inflammation)": [
        "Il1b", "Il6", "Tnf"
    ],
    'Il12_response' : ['Cxcl10','Ccl12', 'Iigp1', 'Serpina3g', 'Cxcl9', 'Gbp7', 'Ifi47',
                 'Fam26f', 'Pnp', 'Serpina3f', 'Ifi203', 'Irf1', 'Irgm1', 'Ccl2', 'Ifi204',
                 'Ifi211', 'Gbp2', 'Irf7', 'Stat1', 'Igtp', 'Themis2', 'Ifit2', 'Gbp5',
                 'Zbp1', 'Socs1', 'Eif2ak2', 'Ifit1','Irgm2']
}

In [None]:
def present_genes(mac_adata, genes):
    lower_map = {g.lower(): g for g in mac_adata.var_names}
    keep, missing = [], []
    for g in genes:
        gl = g.lower()
        if gl in lower_map:
            keep.append(lower_map[gl])
        else:
            missing.append(g)
    return keep, missing

use_raw = mac_adata.raw is not None  
score_cols, missing_report = [], {}

for label, genes in mac_state_dict.items():
    keep, missing = present_genes(mac_adata, genes)
    if missing:
        missing_report[label] = missing
    if len(keep) == 0:
        print(f"[WARN] No genes present for {label}; skipping.")
        continue
    sc.tl.score_genes(
        mac_adata,
        gene_list=keep,
        ctrl_size=50,             
        score_name=f"{label}_score",
        use_raw=use_raw,
        random_state=0
    )
    score_cols.append(f"{label}_score")

if missing_report:
    print("Missing genes (ignored):")
    for k, v in missing_report.items():
        print(f"  {k}: {', '.join(v)}")

In [None]:
use_raw = mac_adata.raw is not None
score_cols, missing_report = [], {}

for label, genes in mac_state_dict.items():
    keep, missing = present_genes(mac_adata, genes)
    if missing:
        missing_report[label] = missing
    if len(keep) == 0:
        print(f"[WARN] No genes present for {label}; skipping.")
        continue

    sc.tl.score_genes(
        mac_adata,
        gene_list=keep,
        ctrl_size=50,
        score_name=f"{label}_score",
        use_raw=use_raw,
        random_state=0
    )
    score_cols.append(f"{label}_score")

if missing_report:
    print("Missing genes (ignored):")
    for k, v in missing_report.items():
        print(f"  {k}: {', '.join(v)}")

sc.pl.umap(mac_adata, 
           color=score_cols, 
           ncols=3, 
           size=2,
           color_map="coolwarm", 
           vmax="p99.5",
           vmin=None, 
           frameon=False)


In [None]:
sc.pl.dotplot(mac_adata,var_names=Mac_annotations, 
              groupby="louvain_r1.5", standard_scale='var',
              dendrogram=True, cmap="Reds", swap_axes=False)

In [None]:
sc.pl.dotplot(mac_adata,var_names=modules, 
              groupby="louvain_r1.5", 
              standard_scale='var',
              dendrogram=True, cmap="Reds", swap_axes=False)

In [None]:
sc.pl.umap(mac_adata,color=["louvain_r1.0","louvain_r1.5", "louvain_r2.0", "louvain_r2.5", "louvain_r3.0",
                            "qc_flag", "juxtaposed_pair",
                            'Treatment',
                            'Il1b', 'Ifng',
                            "Ccr2",  "Spp1", "Il1rn",
                            'Mki67', 'Top2a',
                            'Il18bp', 'Il18',
                            'Arg1', 'Cxcl10', 'Cxcl9', 'Ccl5', 'Il4', 'Il10', 'Ido1', 'Tgfb1', 'Marco',
                            'Col1a1', 
                            'Axl', 'Mertk', 'Tgfbr2', "Krt18",
                            'Folr2', 'Mrc1', 'Trem2'],
               size=2,legend_loc="on data",
               legend_fontsize=8,cmap="Reds",
           #vmax="p99",
               legend_fontoutline=1,ncols=4, 
               show=True,frameon=False)

In [None]:
print("Overlap cells master↔mac:", len(adata.obs_names.intersection(mac_adata.obs_names)))

gran_cats = [
    'Folr2 High Immunosuppressive Macrophages', 
    'Folr2 Low Immunosuppressive Macrophages',
    'Folr2- Myeloid Regulatory Immunosuppressive Macrophages',
    'Pro-inflammatory Macrophages',
    'Type 1 Inflammation Myeloid Regulatory Macrophages',
    'Spp1 Macrophages',
    'to_discard',
]
mac_adata.obs['subtype_granular'] = pd.Series(
    pd.Categorical([pd.NA]*mac_adata.n_obs, categories=gran_cats),
    index=mac_adata.obs_names
)

lv = mac_adata.obs['louvain_r1.5'].astype(str)
mac_adata.obs.loc[lv.isin(["13","5"]), 'subtype_granular'] = 'Type 1 Inflammation Myeloid Regulatory Macrophages'
mac_adata.obs.loc[lv.isin(["16"]),     'subtype_granular'] = 'Folr2 High Immunosuppressive Macrophages'
mac_adata.obs.loc[lv.isin(["8"]),      'subtype_granular'] = 'Folr2 Low Immunosuppressive Macrophages'
mac_adata.obs.loc[lv.isin(["14"]),     'subtype_granular'] = 'Spp1 Macrophages' # too few cells, must discard
mac_adata.obs.loc[lv.isin(["7","6","19"]), 'subtype_granular'] = 'Folr2- Myeloid Regulatory Immunosuppressive Macrophages'
mac_adata.obs.loc[lv.isin(["12","17","9","11","3","2","0","15","18","1","10","4"]), 'subtype_granular'] = 'Pro-inflammatory Macrophages'

In [None]:
A = adata  

cols = ["initial_broad", "subtype", "subtype_granular", "qc_flag"]

mask_discard = np.zeros(A.n_obs, dtype=bool)
for c in cols:
    if c in A.obs:
        s = A.obs[c].astype("string")
        mask_discard |= s.eq("to_discard").fillna(False).to_numpy().astype(bool)

print(f"[filter] dropping {int(mask_discard.sum())} / {A.n_obs} cells (to_discard in any of {cols})")

A = A[~mask_discard].copy()

for c in cols:
    if c in A.obs and isinstance(A.obs[c].dtype, CategoricalDtype):
        A.obs[c] = A.obs[c].cat.remove_unused_categories()
    A.uns.pop(f"{c}_colors", None)

# keep result
adata = A


In [None]:
bmap_norm = {
    "cancer cells":        "Cancer Cells",
    "endothelial":         "Stromal Cells",
    "endothelial cells":   "Stromal Cells",
    "fibroblast":          "Stromal Cells",
    "fibroblasts":         "Stromal Cells",
    "b cells":             "B Cells",
    "b cell":              "B Cells",
    "t and nk cells":      "T and NK Cells",
    "TNK":                 "T and NK Cells",
    "tnk":                 "T and NK Cells",
    "t/nk":                "T and NK Cells",
}

broad_levels = ["Cancer Cells","Myeloid Cells","Stromal Cells","T and NK Cells","B Cells","Plasma Cells","Unlabeled"]

if "initial_broad" not in adata.obs:
    adata.obs["initial_broad"] = pd.Series("Unlabeled", index=adata.obs_names, dtype="category")
if not isinstance(adata.obs["initial_broad"].dtype, CategoricalDtype):
    adata.obs["initial_broad"] = adata.obs["initial_broad"].astype("category")
need = [c for c in broad_levels if c not in adata.obs["initial_broad"].cat.categories]
if need:
    adata.obs["initial_broad"] = adata.obs["initial_broad"].cat.add_categories(need)

sub_norm = adata.obs.get("subtype", pd.Series(pd.NA, index=adata.obs_names, dtype="string")).astype("string").str.strip().str.lower()
target_broad = sub_norm.map(bmap_norm)  

is_myeloid_now = adata.obs["initial_broad"].astype(str).eq("Myeloid Cells")
fix_mask = target_broad.notna() & is_myeloid_now

new_cats = [c for c in target_broad[fix_mask].dropna().unique() if c not in adata.obs["initial_broad"].cat.categories]
if new_cats:
    adata.obs["initial_broad"] = adata.obs["initial_broad"].cat.add_categories(list(new_cats))

n_fix = int(fix_mask.sum())
if n_fix:
    adata.obs.loc[fix_mask, "initial_broad"] = target_broad[fix_mask].astype("string").values

adata.obs["initial_broad"] = adata.obs["initial_broad"].cat.remove_unused_categories()
adata.uns.pop("initial_broad_colors", None)  # clear stale palette

print(f"[initial_broad] updated {n_fix} cells from 'Myeloid Cells' based on subtype.")


# T and NK Cells

In [47]:
tnk_adata = adata[adata.obs['initial_broad'].isin(["T and NK Cells"])].copy()

In [None]:
louvain_resolutions = [0.5, 1.0, 1.5, 2.0, 2.5]

sc.pp.neighbors(tnk_adata, use_rep='X_scVI')
sc.tl.umap(tnk_adata, min_dist=0.3, random_state=0)

rgg_keys = {}
for res in louvain_resolutions:
    lv_key = f"louvain_r{res}"
    sc.tl.louvain(tnk_adata, resolution=float(res), key_added=lv_key, random_state=0)
    tnk_adata.obs[lv_key] = tnk_adata.obs[lv_key].astype('category').cat.remove_unused_categories()

    rkey = f"rgg_{lv_key}"
    sc.tl.rank_genes_groups(tnk_adata, groupby=lv_key, method="wilcoxon", key_added=rkey)
    rgg_keys[res] = rkey

    dkey = f"dendrogram_{lv_key}"
    dend_kwargs = dict(groupby=lv_key, use_rep="X_scVI")
    if "key" in inspect.signature(sc.tl.dendrogram).parameters:
        sc.tl.dendrogram(tnk_adata, key=dkey, **dend_kwargs)
    else:
        sc.tl.dendrogram(tnk_adata, key_added=dkey, **dend_kwargs)

print("Done. Rank-genes keys:", rgg_keys)


In [None]:
sc.pl.umap(tnk_adata, color=["louvain_r0.5", "louvain_r1.0", "louvain_r1.5", 
                             "louvain_r2.0", 'Cd3d',
                             'Treatment',  'Col1a1',
                            'Cd8a', 'Cd4', 'Foxp3', 
                            "Ccr7", "Tcf7", "Sell", "Tigit", "Ctla4", "Gzmb", "Ifng",
                            'Nkg7', "Cd69", "Cd38", "Tnf", "Prf1", "Krt18"],
           size=2,legend_loc="on data",
           legend_fontsize=8,cmap="Reds",
           vmax="p99.9",
           legend_fontoutline=1,ncols=4, 
           show=True,frameon=False)

In [None]:
tnk_dict = {
    "CD8 T Cells": ["Cd8a"],
    "CD4 T Cells": ["Cd4"],
    "Naive/Memory-like T": ["Ccr7", "Tcf7", "Sell", "Il7r", "Cd44", "Itgae"],
    "Regulatory T Cells (Treg)": ["Foxp3", "Il2ra", "Ctla4", "Tnfrsf18"],
    'Activation Treg': ['Il10'],
    "Cytotoxic T Cells": ["Prf1", "Gzmb", "Nkg7", "Ccl5"],
    "Exhausted/Dysfunctional T": ["Tigit", "Ctla4", "Havcr2", "Tox", "Foxp3"],
    "Th1": ["Tbx21", "Ifng", "Cxcr3", "Stat4", "Il18r1", "Il18rap"],
    "Th2": ["Gata3", "Il4", "Il5"],
    "Activation T Cell": ['Cd69', 'Il2ra', 'Cd38', 'Ifng', 'Gzmb', 'Prf1'],
    'ISGs': ['Isg15'],
    "Tissue-Resident Memory T (Trm)": ["Itgae", "Cxcr6", "Bhlhe41"],
    "Proliferating T": ["Mki67", "Top2a"],
    "NK Cells": ["Nkg7", "Klrk1", "Klrb1c", "Prf1", "Gzmb"]
}

gene_set = {k: keep_from_panel(v) for k,v in tnk_dict.items()}

for name, genes in gene_set.items():
    sc.tl.score_genes(tnk_adata, genes, score_name=f"{name}_score")
    
sc.pl.umap(tnk_adata, color=[
    "CD8 T Cells_score","CD4 T Cells_score","Naive/Memory-like T_score","Regulatory T Cells (Treg)_score",
    "Activation T Cell_score", "Exhausted/Dysfunctional T_score", "ISGs_score", "NK Cells_score", "Tissue-Resident Memory T (Trm)_score"
], vmax='p99', cmap="coolwarm", vcenter=0)


In [254]:
tnk_adata.obs["qc_flag"] = tnk_adata.obs["qc_flag"].astype("category").cat.set_categories(list(dict.fromkeys(list(tnk_adata.obs["qc_flag"].astype("category").cat.categories) + ["juxtaposed"])))
tnk_adata.obs['subtype'] = tnk_adata.obs['louvain_r2.0'].cat.add_categories(['CD8 T cells',
                                                                                'CD4 T cells',
                                                                                "CD4 Naive T cells",
                                                                                "Tregs",
                                                                                "CD8 Memory T cells",
                                                                                "CD8 Activated T cells",
                                                                                'NK Cells',
                                                                                'Dendritic Cells',
                                                                                'to_discard'])
tnk_adata.obs['initial_broad'] = tnk_adata.obs['louvain_r2.0'].cat.add_categories(['Myeloid Cells'])
tnk_adata.obs.loc[tnk_adata.obs['louvain_r2.0'].isin(["1"]), 'initial_broad'] = 'Myeloid Cells'
tnk_adata.obs.loc[tnk_adata.obs['louvain_r2.0'].isin(["17"]), 'subtype'] = 'NK Cells'
tnk_adata.obs.loc[tnk_adata.obs['louvain_r2.0'].isin(["11"]), 'subtype'] = 'Dendritic Cells'
tnk_adata.obs.loc[tnk_adata.obs['louvain_r2.0'].isin(["1"]), 'subtype'] = 'Tregs'
tnk_adata.obs.loc[tnk_adata.obs['louvain_r2.0'].isin(["5"]), 'subtype'] = 'CD4 T cells'
tnk_adata.obs.loc[tnk_adata.obs['louvain_r2.0'].isin(["6", "3"]), 'subtype'] = 'CD8 Memory T cells'
tnk_adata.obs.loc[tnk_adata.obs['louvain_r2.0'].isin(["0", "2", "4", "7",
                                                     "9", "10", "12", "13", "14",
                                                     "15", "16", "18", "20", 
                                                     ]), 'subtype'] = 'CD8 Activated T cells'


tnk_adata.obs.loc[tnk_adata.obs['louvain_r2.0'].isin(["19", "8"]), 'subtype'] = 'to_discard' # contam

tnk_adata.obs["juxtaposed_pair"] = tnk_adata.obs["juxtaposed_pair"].astype("string")
lv = tnk_adata.obs["louvain_r2.0"].astype(str)

def mark_pair(tags, pair):
    mask = lv.isin(tags)
    tnk_adata.obs.loc[mask, "qc_flag"] = "juxtaposed"
    tnk_adata.obs.loc[mask, "juxtaposed_pair"] = pair  # now safe

mark_pair(["15"], "Endothelial Cells - T and NK Cells")




In [None]:
summary, missing = juxta_by_gene_counts(
    tnk_adata,
    lineage_markers,
    pair_col="juxtaposed_pair",
    qc_flag_col="qc_flag",
    subtype_col="subtype",      
    cluster_key="louvain_r2.0",
    min_genes=3,
    expr_threshold=0.1,
    use_raw=True                
)

print(summary)
if any(len(v)>0 for v in missing.values()):
    print("Missing genes (ignored):")
    for lin, miss in missing.items():
        if miss:
            print(f"  {lin}: {', '.join(miss)}")


In [260]:
adata = merge_subset_back(adata, tnk_adata)

# Stromal Cells

In [55]:
stromal_adata = sc.read_h5ad(fname_map["Stromal Cells"])

In [None]:
louvain_resolutions = [0.5, 1.0, 1.5, 2.0, 2.5]

sc.pp.neighbors(stromal_adata, use_rep='X_scVI')
sc.tl.umap(stromal_adata, min_dist=0.3, random_state=0)

rgg_keys = {}
for res in louvain_resolutions:
    lv_key = f"louvain_r{res}"
    sc.tl.louvain(stromal_adata, resolution=float(res), key_added=lv_key, random_state=0)
    stromal_adata.obs[lv_key] = stromal_adata.obs[lv_key].astype('category').cat.remove_unused_categories()

    rkey = f"rgg_{lv_key}"
    sc.tl.rank_genes_groups(stromal_adata, groupby=lv_key, method="wilcoxon", key_added=rkey)
    rgg_keys[res] = rkey

    dkey = f"dendrogram_{lv_key}"
    dend_kwargs = dict(groupby=lv_key, use_rep="X_scVI")
    if "key" in inspect.signature(sc.tl.dendrogram).parameters:
        sc.tl.dendrogram(stromal_adata, key=dkey, **dend_kwargs)
    else:
        sc.tl.dendrogram(stromal_adata, key_added=dkey, **dend_kwargs)

print("Done. Rank-genes keys:", rgg_keys)


In [None]:
def stromal_dict_from_panel(panel_genes):
    P = {g.upper(): g for g in panel_genes}

    G = {
      "Endothelial Cells": [
        {"PECAM1","CD31","PECAM-1","Pecam1"},
        {"FLT1","VEGFR1","Flt1"},
        {"KDR","VEGFR2","FLK1","Kdr"},
        {"PLVAP","Plvap"},
        {"VWF","Vwf"}, {"CD34","Cd34"}, {"CLDN5","Cldn5"}, {"ICAM2","Icam2"}, {"ESAM","Esam"},
      ],
      "Pericytes": [
        {"PDGFRB","Pdgfrb"}, {"RGS5","Rgs5"}, {"CSPG4","NG2","Cspg4"}, {"MCAM","CD146","Mcam"},
        {"ACTA2","Acta2"}, {"MYH11","Myh11"}, {"TAGLN","Tagln"},
      ],
      "Fibroblasts": [
        {"LUM","Lum"}, {"DCN","Dcn"}, {"COL1A1","Col1a1"}, {"COL1A2","Col1a2"}, {"COL3A1","Col3a1"},
        {"PDGFRA","Pdgfra"}, {"THY1","CD90","Thy1"},
      ],
      "iCAF": [
        {"IL6","Il6"}, {"CXCL12","Cxcl12"}, {"CXCL1","Cxcl1"}, {"CXCL2","Cxcl2"},
        {"FAP","Fap"}, {"C3"}, {"C7"}, {"LIF","Lif"}, {"PTGDS","Ptgds"}, {"PDPN","Pdpn"},
      ],
      "myCAF": [
        {"ACTA2","Acta2"}, {"TAGLN","Tagln"}, {"MYH11","Myh11"}, {"FAP", "Fap"}, {"TGFB", "Tgfb1"},
        {"COL10A1","Col10a1"}, {"COL11A1","Col11a1"}, {"MYLK","Mylk"},
      ],
      "ASCs": [
        {"APOD","Apod"}, {"CFD","Cfd"}, {"MGP","Mgp"}, {"APOE","Apoe"}, {"ADH1B","Adh1b"},
          {"LEPTIN", "Lep"}, {"FABP4", "Fabp4"}, {"ADIPOQ", "Adipoq"}, {"PPARG", "Pparg"},
          {"SLC7A10", "Slc7a10"}
      ],
      "apCAF": [
        {"HLA-DRA","HLA-DR","H2-Aa"}, {"HLA-DRB1","H2-Ab1"}, {"CD74","Cd74"}, {"CIITA","Ciita"},
      ],
    }

    out = {k: [] for k in G}
    for cell_type, alias_sets in G.items():
        for aliases in alias_sets:
            hit = next((P[a.upper()] for a in aliases if a.upper() in P), None)
            if hit:
                out[cell_type].append(hit)
    return out

panel_genes = set(adata.var_names)  # or your list of panel genes
stromal_markers = stromal_dict_from_panel(panel_genes)
stromal_markers

In [None]:
res = 1.0
lv  = f"louvain_r{res}"
rgg = f"rgg_{lv}" 

sc.pl.dotplot(
    stromal_adata,
    var_names=stromal_markers,
    groupby=lv,
    standard_scale="var",
    dendrogram=True,
    swap_axes=False,
    figsize=(12, 7)
)

In [262]:
archive_legacy_juxtaposition(stromal_adata)
stromal_adata.obs['subtype'] = stromal_adata.obs['louvain_r1.0'].cat.add_categories(['Fibroblasts',
                                                                                'MyoCAFs',
                                                                                'iCAFs',
                                                                                'Endothelial Cells',
                                                                                'Adipocytes',
                                                                                     'Macrophages',
                                                                                'to_discard'])

stromal_adata.obs['initial_broad'] = stromal_adata.obs['louvain_r1.0'].cat.add_categories(['Cancer Cells',
                                                                                           'Myeloid Cells',
                                                                                'to_discard'])

stromal_adata.obs['subtype_granular'] = stromal_adata.obs['louvain_r1.0'].cat.add_categories(['Folr2+ Immunosuppressive Macrophages',
                                                                                'to_discard'])


stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["10"]), 'initial_broad'] = 'Cancer Cells'
stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["17"]), 'initial_broad'] = 'Myeloid Cells'
stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["17"]), 'subtype'] = 'Macrophages'
stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["17"]), 'subtype_granular'] = 'Folr2+ Immunosuppressive Macrophages'
stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["13"]), 'subtype'] = 'Adipocytes'
stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["16", "1", "0", "9", "14"]), 'subtype'] = 'to_discard'

stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["5", "6",
                                                             "7"]), 'subtype'] = 'Fibroblasts'
stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["12"]), 'subtype'] = 'MyoCAFs'
stromal_adata.obs.loc[stromal_adata.obs['louvain_r1.0'].isin(["8", "3", "11", "2", "4",
                                                             "15"]), 'subtype'] = 'Endothelial Cells'



stromal_adata.obs["juxtaposed_pair"] = stromal_adata.obs["juxtaposed_pair"].astype("string")
lv = stromal_adata.obs["louvain_r1.0"].astype(str)

def mark_pair(tags, pair):
    mask = lv.isin(tags)
    stromal_adata.obs.loc[mask, "qc_flag"] = "juxtaposed"
    stromal_adata.obs.loc[mask, "juxtaposed_pair"] = pair  

mark_pair(["10"], "Cancer Cells - Endothelial Cells")
mark_pair(["0"], "Fibroblast - Macrophage")
mark_pair(["8"], "Endothelial Cells - T and NK Cells")
mark_pair(["6"], "Cancer Cells - Fibroblasts")
mark_pair(["1"], "Endothelial - Fibroblast - Macrophage")

### Juxta function for TRIPLETS (where applicable)

In [None]:
# NEW JUXTA FUNCTION TO HANDLE 3+ JUXTAPOSED CELL TYPES
def juxta_by_gene_counts(
    adata,
    markers: dict,
    pair_col: str = "juxtaposed_pair",
    qc_flag_col: str = "qc_flag",
    subtype_col: str = "subtype",
    initial_broad_col: str = "initial_broad",
    cluster_key: str = "louvain",
    min_genes: int = 2,
    expr_threshold: float = 0.0,
    use_raw: bool | None = None,
    display_map: dict | None = None,
    broad_map: dict | None = None,
    alias: dict | None = None,
    qc_levels: list | None = None,
    broad_levels: list | None = None,
):
    import re
    import numpy as np
    import pandas as pd
    import scipy.sparse as sp
    from pandas.api.types import CategoricalDtype

    if qc_levels is None:
        qc_levels = ["confident","doublet","juxtaposed","ambiguous","low_quality","to_discard"]
    if broad_levels is None:
        broad_levels = ["Cancer Cells","Myeloid Cells","Stromal Cells","T and NK Cells","B Cells","Plasma Cells","Unlabeled"]
    if display_map is None:
        display_map = {k: k for k in markers.keys()}
        display_map.update({
            "Macrophage":"Macrophages",
            "Monocyte":"Monocytes",
            "DCs":"Dendritic Cells",
            "TNK":"T and NK Cells",
            "B cells":"B Cells",
            "Endothelial Cells":"Endothelial Cells",
            "Fibroblast":"Fibroblasts",
        })
    if broad_map is None:
        broad_map = {
            "Cancer Cells":"Cancer Cells",
            "Macrophages":"Myeloid Cells",
            "Monocytes":"Myeloid Cells",
            "Dendritic Cells":"Myeloid Cells",
            "Neutrophils":"Myeloid Cells",
            "B cells":"B Cells",
            "T and NK Cells":"T and NK Cells",
            "Tregs":"T and NK Cells",
            "Endothelial":"Stromal Cells",
            "Fibroblasts":"Stromal Cells",
        }
    if alias is None:
        alias = {k.lower(): k for k in markers.keys()}
        alias.update({
            "cancer cells":"Cancer Cells","epithelial":"Cancer Cells","cancer":"Cancer Cells",
            "b cell":"B cells","b cells":"B cells","b-cells":"B cells",
            "tnk":"TNK","t/nk":"TNK","t and nk cells":"TNK",
            "dc":"DCs","dcs":"DCs",
            "macrophage":"Macrophage","macrophages":"Macrophage","myeloid":"Macrophage",
            "monocytes":"Monocyte","monocyte":"Monocyte",
            "neutrophils":"Neutrophils","neutrophil":"Neutrophils",
            "endothelium":"Endothelial Cells","endothelial cell":"Endothelial Cells"
            ,"endothelial cells":"Endothelial Cells","endothelial":"Endothelial Cells",
            "fibroblast":"Fibroblast","fibroblasts":"Fibroblast",
        })

    def _ensure_cat(series, cats=None):
        if not isinstance(series.dtype, CategoricalDtype):
            series = series.astype("category")
        if cats:
            add = [c for c in cats if c not in series.cat.categories]
            if add:
                series = series.cat.add_categories(add)
        return series

    def _safe(name: str) -> str:
        return re.sub(r"\W+", "_", str(name))

    def _canon(tok: str) -> str | None:
        t = str(tok).strip()
        for cand in (t, t.title(), t.upper()):
            if cand in markers:
                return cand
        return alias.get(t.lower())

    def _split_lineages(pair_str: str):
        # supports A-B, A|B|C, A ~ B 
        return [t.strip() for t in re.split(r"\s*[-~|]\s*", str(pair_str).strip()) if t.strip()]

    def _present_vars(vnames, genes):
        lut = {g.lower(): g for g in vnames}
        keep = [lut[g.lower()] for g in genes if g.lower() in lut]
        miss = [g for g in genes if g.lower() not in lut]
        return keep, miss

    def _counts_per_cell(adata, genes, threshold=0.0, use_raw=None):
        if use_raw is None:
            use_raw = adata.raw is not None
        vnames = adata.raw.var_names if (use_raw and adata.raw is not None) else adata.var_names
        keep, _ = _present_vars(vnames, genes)
        if not keep:
            return pd.Series(0, index=adata.obs_names, dtype=int)
        X = (adata.raw[:, keep].X if (use_raw and adata.raw is not None) else adata[:, keep].X)
        if sp.issparse(X):
            hits = (X > threshold).astype(np.int8).sum(axis=1).A1
        else:
            X = np.asarray(X)
            hits = (X > threshold).sum(axis=1)
        return pd.Series(hits.astype(int), index=adata.obs_names)

    if qc_flag_col not in adata.obs:
        adata.obs[qc_flag_col] = pd.Series("confident", index=adata.obs_names, dtype="category")
    adata.obs[qc_flag_col] = _ensure_cat(adata.obs[qc_flag_col], qc_levels)

    if pair_col not in adata.obs:
        adata.obs[pair_col] = pd.Series(pd.NA, index=adata.obs_names, dtype="string")
    else:
        adata.obs[pair_col] = adata.obs[pair_col].astype("string")

    if subtype_col not in adata.obs:
        adata.obs[subtype_col] = pd.Series(pd.Categorical([pd.NA]*adata.n_obs), index=adata.obs_names)
    elif not isinstance(adata.obs[subtype_col].dtype, CategoricalDtype):
        adata.obs[subtype_col] = adata.obs[subtype_col].astype("category")

    if initial_broad_col not in adata.obs:
        adata.obs[initial_broad_col] = pd.Series("Unlabeled", index=adata.obs_names, dtype="category")
    adata.obs[initial_broad_col] = _ensure_cat(adata.obs[initial_broad_col], broad_levels)

    for col in ["juxta_call","juxta_primary"]:
        if col not in adata.obs:
            adata.obs[col] = pd.Series(pd.NA, index=adata.obs_names, dtype="string")
        else:
            adata.obs[col] = adata.obs[col].astype("string")

    jmask = adata.obs[qc_flag_col].astype(str).eq("juxtaposed") & adata.obs[pair_col].notna()
    if jmask.sum() == 0:
        print("[INFO] No juxtaposed cells found.")
        return pd.DataFrame(), {}

    pairs_raw = adata.obs.loc[jmask, pair_col].astype(str).unique().tolist()
    parsed, used, bad = [], set(), []
    for p in pairs_raw:
        toks = _split_lineages(p)
        L = [c for c in (_canon(t) for t in toks) if c]
        if len(L) >= 2:
            parsed.append((p, L))
            used.update(L)
        else:
            bad.append(p)
    if bad:
        print("[WARN] Missing marker sets for:", ", ".join(bad))

    vnames = adata.raw.var_names if (use_raw and adata.raw is not None) else adata.var_names
    counts, missing_genes = {}, {}
    for lin in used:
        keep, miss = _present_vars(vnames, markers[lin])
        missing_genes[lin] = miss
        counts[lin] = _counts_per_cell(adata, markers[lin], expr_threshold, use_raw)

    rows = []
    for pair_str, L in parsed:
        pmask = jmask & adata.obs[pair_col].astype(str).eq(pair_str)
        if pmask.sum() == 0:
            continue
        idx = adata.obs.index[pmask]

        C = pd.DataFrame({lin: counts[lin].loc[idx].values for lin in L}, index=idx)
        POS = C.ge(min_genes)
        kpos = POS.sum(axis=1)

        adata.obs.loc[idx, "juxta_call"] = "neither"
        if isinstance(adata.obs[subtype_col].dtype, CategoricalDtype) and "to_discard" not in adata.obs[subtype_col].cat.categories:
            adata.obs[subtype_col] = adata.obs[subtype_col].cat.add_categories(["to_discard"])
        adata.obs.loc[idx, subtype_col] = "to_discard"
        adata.obs.loc[idx, qc_flag_col] = adata.obs.loc[idx, qc_flag_col].where(adata.obs.loc[idx, qc_flag_col].ne("juxtaposed"), "to_discard")
        for lin in L:
            adata.obs.loc[idx, f"juxta_{_safe(lin)}_n"] = C[lin].values

        m2 = kpos.ge(2)
        if m2.any():
            def _top2_labels(row):
                top = row.nlargest(2)
                keys = sorted(top.index.tolist()[:2], key=lambda k: display_map.get(k, k).lower())
                return " - ".join([display_map.get(keys[0], keys[0]), display_map.get(keys[1], keys[1])]), keys[0]
            pair_labels = []
            primaries = []
            for i, r in C.loc[m2].iterrows():
                lab, prim = _top2_labels(r)
                pair_labels.append(lab); primaries.append(prim)
            pair_labels = pd.Series(pair_labels, index=C.loc[m2].index)

            need_pairs = pd.unique(pair_labels.dropna().astype(str)).tolist()
            adata.obs[subtype_col] = _ensure_cat(adata.obs[subtype_col], need_pairs)

            id2 = m2.index[m2]
            adata.obs.loc[id2, "juxta_call"] = pair_labels.astype(str).to_numpy(object)
            adata.obs.loc[id2, subtype_col]  = pair_labels.astype(str).to_numpy(object)
            adata.obs.loc[id2, qc_flag_col]  = "juxtaposed"  # keep as juxtaposed
            adata.obs.loc[id2, "juxta_primary"] = np.array(primaries, dtype=object)

        m1 = kpos.eq(1)
        if m1.any():
            lin1 = POS.loc[m1].idxmax(axis=1)  
            disp = lin1.map(lambda x: display_map.get(x, x))

            need_single = pd.unique(disp.dropna().astype(str)).tolist()
            adata.obs[subtype_col] = _ensure_cat(adata.obs[subtype_col], need_single)

            id1 = m1.index[m1]
            adata.obs.loc[id1, "juxta_call"] = disp.astype(str).to_numpy(object)
            adata.obs.loc[id1, subtype_col]  = disp.astype(str).to_numpy(object)
            adata.obs.loc[id1, qc_flag_col]  = "ambiguous"
            adata.obs.loc[id1, "juxta_primary"] = lin1.astype(str).to_numpy(object)

            broad_vals = disp.map(lambda d: broad_map.get(d, pd.NA)).astype("string")
            need_b = [b for b in pd.unique(broad_vals.dropna()) if b not in adata.obs[initial_broad_col].cat.categories]
            if need_b:
                adata.obs[initial_broad_col] = adata.obs[initial_broad_col].cat.add_categories(need_b)
            adata.obs.loc[id1, initial_broad_col] = broad_vals.values

        idn = kpos.index[kpos.eq(0)]
        if len(idn):
            adata.obs.loc[idn, "juxta_call"] = "neither"
            adata.obs.loc[idn, subtype_col]  = "to_discard"
            adata.obs.loc[idn, qc_flag_col]  = "to_discard"

        rows.append(pd.DataFrame({
            "pair":    pair_str,
            "cluster": adata.obs.loc[idx, cluster_key].astype(str).values,
            "juxta_call": adata.obs.loc[idx, "juxta_call"].astype(str).values
        }))

    if rows:
        df = pd.concat(rows, axis=0, ignore_index=True)
        summary = df.value_counts(["pair","cluster","juxta_call"]).rename("n").reset_index()
        totals  = df.value_counts(["pair","cluster"]).rename("N").reset_index()
        out = summary.merge(totals, on=["pair","cluster"])
        out["frac"] = out["n"] / out["N"]

        adata.uns["juxta_params"] = dict(
            min_genes=min_genes, expr_threshold=expr_threshold,
            used_lineages=sorted(list(used))
        )

        if isinstance(adata.obs[subtype_col].dtype, CategoricalDtype):
            adata.obs[subtype_col] = adata.obs[subtype_col].cat.remove_unused_categories()
        if isinstance(adata.obs[initial_broad_col].dtype, CategoricalDtype):
            adata.obs[initial_broad_col] = adata.obs[initial_broad_col].cat.remove_unused_categories()
        adata.obs[qc_flag_col] = _ensure_cat(adata.obs[qc_flag_col], qc_levels)

        return out.sort_values(["pair","cluster","juxta_call"]).reset_index(drop=True), missing_genes
    else:
        return pd.DataFrame(), missing_genes


In [265]:
lineage_markers = {
    "Cancer Cells": ["Epcam","Krt8","Krt18","Krt19","Muc1","Krt7"],
    "Macrophage":   ["Csf1r","Mrc1","Mpeg1","Cd68","Mertk","C1qb","C1qa","C1qc","Trem2","Folr2","Marco","Igf1","Lgals3","Lrp1","Sirpa"],
    "Monocyte":     ["Ly6c2","Ccr2","Sell","Cd14","Cxcr4","Cx3cr1"],
    "B cells":      ["Ms4a1","Cd79b","Igkc","Iglc3"],
    "DCs":          ["Xcr1","Clec9a","Batf3","Wdfy4","Itgax","Clec10a","Sirpa"],
    "TNK":          ['Prf1', 'Nkg7', 'Gzmb', 'Cd8a', 'Trac', 'Cd3d'],
    "Tregs":        ["Cd4", "Ctla4", "Foxp3", "Tigit"],
    "Neutrophils":  ["Csf3r","Cxcr2","Pglyrp1","Sell"],
    "Endothelial Cells":  ["Pecam1","Kdr","Cdh5","Plvap"],
    "Fibroblast":   ["Col1a1","Col1a2","Col3a1","Col5a1","Col5a2","Col6a1","Pdgfra","Pdgfrb","Postn","Fn1","Thy1","Vim"],
}



summary, missing = juxta_by_gene_counts(
    stromal_adata,
    lineage_markers,
    pair_col="juxtaposed_pair",
    qc_flag_col="qc_flag",
    subtype_col="subtype",
    initial_broad_col="initial_broad",
    cluster_key="louvain_r1.0",
    min_genes=3,
    expr_threshold=0.1,
    use_raw=True
)


In [267]:
adata = merge_subset_back(adata, stromal_adata)

In [None]:
cols = [c for c in ("initial_broad","subtype","subtype_granular", "qc_flag") if c in adata.obs]
to_discard = pd.Series(False, index=adata.obs_names)
for c in cols:
    to_discard |= adata.obs[c].astype("string").fillna("").eq("to_discard")

n_before = adata.n_obs
adata = adata[~to_discard.to_numpy()].copy()
print(f"[filter] dropped {int(to_discard.sum())}/{n_before} cells (to_discard in any of {cols})")


In [297]:
import pandas as pd
from pandas.api.types import CategoricalDtype, is_extension_array_dtype

def sanitize_obs_for_write(adata):
    obs = adata.obs

    for col in obs.columns:
        if is_extension_array_dtype(obs[col].dtype) and str(obs[col].dtype) == "string":
            obs[col] = obs[col].astype(object)

    for col in obs.columns:
        if isinstance(obs[col].dtype, CategoricalDtype):
            cats = obs[col].cat.categories
            if str(cats.dtype) == "string":
                obs[col] = obs[col].cat.rename_categories(lambda x: str(x))
            obs[col] = obs[col].cat.remove_unused_categories()

sanitize_obs_for_write(adata)
adata.write_h5ad("JMT_annot_v1.h5ad")

In [2]:
adata = sc.read_h5ad("JMT_annot_v1.h5ad")

# Cancer Cells

In [208]:
cancer_adata = adata[adata.obs['subtype_granular'].isin(["Cancer Cells"])].copy()

In [None]:
import inspect
louvain_resolutions = [0.5, 1.0, 1.5, 2.0, 2.5]

sc.pp.neighbors(cancer_adata, use_rep='X_scVI')
sc.tl.umap(cancer_adata, min_dist=0.3, random_state=0)

rgg_keys = {}
for res in louvain_resolutions:
    lv_key = f"louvain_r{res}"
    sc.tl.louvain(cancer_adata, resolution=float(res), key_added=lv_key, random_state=0)
    cancer_adata.obs[lv_key] = cancer_adata.obs[lv_key].astype('category').cat.remove_unused_categories()

    rkey = f"rgg_{lv_key}"
    sc.tl.rank_genes_groups(cancer_adata, groupby=lv_key, method="wilcoxon", key_added=rkey)
    rgg_keys[res] = rkey

    dkey = f"dendrogram_{lv_key}"
    dend_kwargs = dict(groupby=lv_key, use_rep="X_scVI")
    if "key" in inspect.signature(sc.tl.dendrogram).parameters:
        sc.tl.dendrogram(cancer_adata, key=dkey, **dend_kwargs)
    else:
        sc.tl.dendrogram(cancer_adata, key_added=dkey, **dend_kwargs)

print("Done. Rank-genes keys:", rgg_keys)


In [None]:
res = 1.0
lv  = f"louvain_r{res}"
rgg = f"rgg_{lv}" 

if rgg not in cancer_adata.uns:
    sc.tl.rank_genes_groups(cancer_adata, groupby=lv, method="wilcoxon", key_added=rgg)

sc.tl.dendrogram(cancer_adata, groupby="louvain_r2.5", use_rep="X_scVI")
sc.pl.rank_genes_groups_dotplot(
    cancer_adata,
    key=rgg,                 
    groupby=lv,             
    n_genes=20,
    standard_scale="var",
    dendrogram=True,
    swap_axes=False,
    show=False,
)


In [211]:
cancer_adata.obs["juxtaposed_pair"] = cancer_adata.obs["juxtaposed_pair"].astype("string")
lv = cancer_adata.obs["louvain_r1.0"].astype(str)
cancer_adata.obs["qc_flag"] = cancer_adata.obs["qc_flag"].astype("category").cat.set_categories(list(cancer_adata.obs["qc_flag"].cat.categories) + (["juxtaposed"] if "juxtaposed" not in cancer_adata.obs["qc_flag"].cat.categories else []))
def mark_pair(tags, pair):
    mask = lv.isin(tags)
    cancer_adata.obs.loc[mask, "qc_flag"] = "juxtaposed"
    cancer_adata.obs.loc[mask, "juxtaposed_pair"] = pair  # now safe

mark_pair(["3","9"], "Cancer Cells - T and NK Cells")
mark_pair(["8"], "Cancer Cells - Fibroblasts")
mark_pair(["13"], "Cancer Cells - Endothelial Cells")

In [None]:
from pandas.api.types import CategoricalDtype
if isinstance(cancer_adata.obs.get("juxtaposed_pair", pd.Series([], dtype="string")).dtype, CategoricalDtype):
    cancer_adata.obs["juxtaposed_pair"] = cancer_adata.obs["juxtaposed_pair"].astype("string")

lineage_markers = {
    "Cancer Cells": ["Krt18","Krt19","Krt7","Muc1","Krt8", "Epcam"],
    "Macrophage":   ["Csf1r","Mrc1","Mpeg1","Cd68","Mertk","C1qb","C1qa","C1qc","Trem2","Folr2","Marco","Igf1","Lgals3","Lrp1","Sirpa"],
    "Monocyte":     ["Ly6c2","Ccr2","Sell","Cd14","Cxcr4","Cx3cr1"],
    "B cells":      ["Ms4a1","Cd79b","Igkc","Iglc3"],
    "DCs":          ["Xcr1","Clec9a","Batf3","Wdfy4","Itgax","Clec10a","Sirpa"],
    "T and NK Cells":          [ "Gzmb", "Cd8a","Prf1", "Nkg7", "Cd4", "Ctla4", "Foxp3", "Tigit"],
    "Tregs":        ["Cd4", "Ctla4", "Foxp3", "Tigit"],
    "Neutrophils":  ["Csf3r","Cxcr2","Pglyrp1","Sell"],
    "Endothelial":  ["Pecam1","Kdr","Cdh5","Plvap"],
    "Fibroblast":   ["Col1a1","Col1a2","Col3a1","Col5a1","Col5a2","Col6a1","Pdgfra","Pdgfrb","Postn","Fn1","Thy1","Vim"],
}

summary, missing = juxta_by_gene_counts(
    cancer_adata,
    lineage_markers,
    pair_col="juxtaposed_pair",
    qc_flag_col="qc_flag",
    subtype_col="subtype",       # cells also get renamed here
    cluster_key="louvain_r1.0",
    min_genes=2,
    expr_threshold=0.1,
    use_raw=False                 # or False/None, depending on your data
)

print(summary)
if any(len(v)>0 for v in missing.values()):
    print("Missing genes (ignored):")
    for lin, miss in missing.items():
        if miss:
            print(f"  {lin}: {', '.join(miss)}")


In [219]:
cancer_adata.obs['subtype'] = cancer_adata.obs['louvain_r1.0'].cat.add_categories(['Proliferating Epithelial Cancer Cells', 
                                                                                   'Metabolically Reprogrammed Proliferating Cancer Cells',
                                                                                   'Quiescent De-differentiated Cancer Cells',
                                                                                   'Stem-Like Cancer Cells',
                                                                                   'IFN-Responsive Epithelial Cancer Cells',
                                                                                   'EMT Cancer Cells',
                                                                                   'Transitional De-differentiated Cells',
                                                                                   'Angiogenesis-Associated Cancer Cells',
                                                                                   'Juxtaposed Cancer-TNK Cells',
                                                                 'to_discard'])
cancer_adata.obs.loc[cancer_adata.obs['louvain_r1.0'].isin(["7", "2", "0"]), 'subtype'] = 'Proliferating Epithelial Cancer Cells'
cancer_adata.obs.loc[cancer_adata.obs['louvain_r1.0'].isin(["12"]), 'subtype'] = 'Metabolically Reprogrammed Proliferating Cancer Cells'
cancer_adata.obs.loc[cancer_adata.obs['louvain_r1.0'].isin(["1", "4", "6", "5"]), 'subtype'] = 'Stem-Like Cancer Cells'
cancer_adata.obs.loc[cancer_adata.obs['louvain_r1.0'].isin(["11"]), 'subtype'] = 'IFN-Responsive Epithelial Cancer Cells'
cancer_adata.obs.loc[cancer_adata.obs['louvain_r1.0'].isin(["8"]), 'subtype'] = 'EMT Cancer Cells'
cancer_adata.obs.loc[cancer_adata.obs['louvain_r1.0'].isin(["10"]), 'subtype'] = 'Transitional De-differentiated Cells'
cancer_adata.obs.loc[cancer_adata.obs['louvain_r1.0'].isin(["13"]), 'subtype'] = 'Angiogenesis-Associated Cancer Cells'

cancer_adata.obs.loc[cancer_adata.obs['louvain_r1.0'].isin(["3", "9"]), 'subtype'] = 'Juxtaposed Cancer-TNK Cells'


In [None]:
sc.tl.dendrogram(cancer_adata, groupby="subtype", use_rep="X_scVI")
sc.pl.dotplot(cancer_adata, 
             tumor_set,
              groupby="subtype",  
              standard_scale="var",
              dendrogram=True)

In [246]:
adata = merge_subset_back(adata, cancer_adata)

# B and Plasma Cells Cleanup

In [249]:
b_adata = adata[adata.obs['subtype_granular'].isin(["B Cells", "Plasma Cells"])].copy()

In [None]:
import inspect
louvain_resolutions = [0.5, 1.0, 1.5, 2.0, 2.5]

sc.pp.neighbors(b_adata, use_rep='X_scVI')
sc.tl.umap(b_adata, min_dist=0.3, random_state=0)

rgg_keys = {}
for res in louvain_resolutions:
    lv_key = f"louvain_r{res}"
    sc.tl.louvain(b_adata, resolution=float(res), key_added=lv_key, random_state=0)
    b_adata.obs[lv_key] = b_adata.obs[lv_key].astype('category').cat.remove_unused_categories()

    rkey = f"rgg_{lv_key}"
    sc.tl.rank_genes_groups(b_adata, groupby=lv_key, method="wilcoxon", key_added=rkey)
    rgg_keys[res] = rkey

    dkey = f"dendrogram_{lv_key}"
    dend_kwargs = dict(groupby=lv_key, use_rep="X_scVI")
    if "key" in inspect.signature(sc.tl.dendrogram).parameters:
        sc.tl.dendrogram(b_adata, key=dkey, **dend_kwargs)
    else:
        sc.tl.dendrogram(b_adata, key_added=dkey, **dend_kwargs)

print("Done. Rank-genes keys:", rgg_keys)


In [None]:
plt.rcParams['figure.facecolor'] = 'white'
sc.pl.umap(b_adata,
           color=["louvain_0.2", "Cd79b", "Irf4", "pred_celltype","Treatment"],
           size=1,
           #legend_loc="on data",
           legend_fontsize=8,
           cmap="Reds",
           legend_fontoutline=1,ncols=3, 
           show=True,frameon=False)

In [None]:
res = 0.2
lv  = f"louvain_{res}"
rgg = f"rgg_{lv}" 

if rgg not in cancer_adata.uns:
    sc.tl.rank_genes_groups(b_adata, groupby=lv, method="wilcoxon", key_added=rgg)

sc.tl.dendrogram(b_adata, groupby="louvain_0.2", use_rep="X_scVI")
sc.pl.rank_genes_groups_dotplot(
    b_adata,
    key=rgg,                 
    groupby=lv,              
    n_genes=7,
    standard_scale="var",
    dendrogram=True,
    swap_axes=False,
    show=False,
)


In [282]:
b_adata.obs["juxtaposed_pair"] = b_adata.obs["juxtaposed_pair"].astype("string")
lv = b_adata.obs["louvain_0.2"].astype(str)
b_adata.obs["qc_flag"] = b_adata.obs["qc_flag"].astype("category").cat.set_categories(list(b_adata.obs["qc_flag"].cat.categories) + (["juxtaposed"] if "juxtaposed" not in b_adata.obs["qc_flag"].cat.categories else []))
def mark_pair(tags, pair):
    mask = lv.isin(tags)
    b_adata.obs.loc[mask, "qc_flag"] = "juxtaposed"
    b_adata.obs.loc[mask, "juxtaposed_pair"] = pair  # now safe
    
mark_pair(["4"], "B Cells - Fibroblasts")
mark_pair(["3"], "B Cells - Endothelial Cells")
mark_pair(["2"], "B Cells - T and NK Cells")
mark_pair(["0"], "B Cells - Macrophages")

b_adata.obs['subtype'] = b_adata.obs['louvain_0.2'].cat.add_categories(['B Cells',
                                                                             'Plasma Cells',
                                                                 'to_discard'])

b_adata.obs.loc[b_adata.obs['louvain_0.2'].isin(["1"]), 'subtype'] = 'B Cells'
b_adata.obs.loc[b_adata.obs['louvain_0.2'].isin(["5"]), 'subtype'] = 'Plasma Cells'
b_adata.obs.loc[b_adata.obs['louvain_0.2'].isin(["6"]), 'subtype'] = 'to_discard'


In [None]:
from pandas.api.types import CategoricalDtype
if isinstance(b_adata.obs.get("juxtaposed_pair", pd.Series([], dtype="string")).dtype, CategoricalDtype):
    b_adata.obs["juxtaposed_pair"] = b_adata.obs["juxtaposed_pair"].astype("string")

lineage_markers = {
    "Cancer Cells": ["Krt18","Krt19","Krt7","Muc1","Krt8", "Epcam"],
    "Macrophages":   ["Csf1r","Mrc1","Mpeg1","Cd68","Mertk","C1qb","C1qa","C1qc","Trem2","Folr2","Marco","Igf1","Lgals3","Lrp1","Sirpa"],
    "Monocyte":     ["Ly6c2","Ccr2","Sell","Cd14","Cxcr4","Cx3cr1"],
    "B Cells":      ["Ms4a1","Cd79b","Igkc","Iglc3"],
    "DCs":          ["Xcr1","Clec9a","Batf3","Wdfy4","Itgax","Clec10a","Sirpa"],
    "T and NK Cells":          [ "Gzmb", "Cd8a","Prf1", "Nkg7", "Cd4", "Ctla4", "Foxp3", "Tigit"],
    "Tregs":        ["Cd4", "Ctla4", "Foxp3", "Tigit"],
    "Neutrophils":  ["Csf3r","Cxcr2","Pglyrp1","Sell"],
    "Endothelial Cells":  ["Pecam1","Kdr","Cdh5","Plvap"],
    "Fibroblasts":   ["Col1a1","Col1a2","Col3a1","Col5a1","Col5a2","Col6a1","Pdgfra","Pdgfrb","Postn","Fn1","Thy1","Vim"],
}

summary, missing = juxta_by_gene_counts(
    b_adata,
    lineage_markers,
    pair_col="juxtaposed_pair",
    qc_flag_col="qc_flag",
    subtype_col="subtype",       
    cluster_key="louvain_0.2",
    min_genes=2,
    expr_threshold=0.1,
    use_raw=False                
)

print(summary)
if any(len(v)>0 for v in missing.values()):
    print("Missing genes (ignored):")
    for lin, miss in missing.items():
        if miss:
            print(f"  {lin}: {', '.join(miss)}")


In [None]:
qc = b_adata.obs.get("qc_flag")

drop_mask = qc.astype("string").fillna("").eq("to_discard").to_numpy(dtype=bool)

print(f"[filter] dropping {int(drop_mask.sum())} / {cancer_adata.n_obs} cells (qc_flag == 'to_discard')")

b_adata = b_adata[~drop_mask, :].copy()

for col in ("subtype", "qc_flag", "initial_broad", "subtype_granular"):
    if col in b_adata.obs and isinstance(b_adata.obs[col].dtype, CategoricalDtype):
        b_adata.obs[col] = b_adata.obs[col].cat.remove_unused_categories()

In [287]:
adata = merge_subset_back(adata, b_adata)

# Clean up

In [None]:
from pandas.api.types import is_categorical_dtype

if not is_categorical_dtype(adata.obs['subtype_granular']):
    adata.obs['subtype_granular'] = adata.obs['subtype_granular'].astype('category')
if not is_categorical_dtype(adata.obs['subtype']):
    adata.obs['subtype'] = adata.obs['subtype'].astype('category')

mask = adata.obs['subtype_granular'].isin(['Cancer Cells', 'B Cells', 'Plasma Cells'])

needed = pd.Index(
    adata.obs.loc[mask, 'subtype'].astype(str).unique()
).difference(adata.obs['subtype_granular'].cat.categories)

if len(needed):
    adata.obs['subtype_granular'] = adata.obs['subtype_granular'].cat.add_categories(needed)

adata.obs.loc[mask, 'subtype_granular'] = adata.obs.loc[mask, 'subtype'].astype(str).values

adata.obs['subtype_granular'] = adata.obs['subtype_granular'].cat.remove_unused_categories()


In [300]:
def is_juxtaposed(cell_type):
    if pd.isna(cell_type):
        return False
    return ' - ' in str(cell_type) and cell_type != 'to_discard'

for col in ['subtype', 'subtype_granular']:
    adata.obs[col] = adata.obs[col].astype(str)  
    mask_juxta = adata.obs[col].apply(is_juxtaposed)
    adata.obs.loc[mask_juxta, col] = 'Juxtaposed ' + adata.obs.loc[mask_juxta, col]

In [None]:
def set_categorical_order_juxta_last(adata, col_name):
    """
    Set categorical order with juxtaposed cells plotted last
    """
    unique_values = adata.obs[col_name].unique()
    unique_values = [str(x) for x in unique_values if pd.notna(x)]
    
    regular_cells = []
    juxta_cells = []
    
    for cell_type in unique_values:
        if (' - ' in cell_type) or ('Mixed' in cell_type) or ('Juxtaposed' in cell_type):
            juxta_cells.append(cell_type)
        else:
            regular_cells.append(cell_type)
    
    regular_cells.sort()
    juxta_cells.sort()
    
    ordered_categories = regular_cells + juxta_cells
    
    adata.obs[col_name] = pd.Categorical(
        adata.obs[col_name], 
        categories=ordered_categories, 
        ordered=True
    )
    
    print(f"Set categorical order for {col_name}: {len(regular_cells)} regular + {len(juxta_cells)} juxtaposed cell types")
    
    return adata

adata = set_categorical_order_juxta_last(adata, 'subtype')
adata = set_categorical_order_juxta_last(adata, 'subtype_granular')

In [324]:
adata = adata[~((adata.obs['subtype'] == 'to_discard') | (adata.obs['subtype_granular'] == 'to_discard')), :].copy()

In [None]:
sc.pl.umap(
    adata,
    color=["initial_broad", "subtype", "subtype_granular", "juxta_call", "qc_flag"],
    cmap="Reds",
    vmax="p99",
    size=1,
 #   legend_loc="on data",
    legend_fontsize=8,
    legend_fontoutline=1,
    ncols=1,
    frameon=False,
    show=True,
)

# Set colors

In [332]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
from matplotlib.colors import ListedColormap
import seaborn as sns

def create_comprehensive_color_scheme(adata):
    """
    Create a comprehensive color scheme with:
    1. Distinct colors for broad cell types
    2. Shaded variants for subtypes 
    3. Mixed colors for juxtaposed cells
    4. Specific colors for QC flags and treatments
    """
    
    broad_base_colors = {
        'Cancer Cells': '#E31A1C',      # Red
        'T and NK Cells': '#1F78B4',    # Blue  
        'B Cells': '#33A02C',           # Green
        'Myeloid Cells': '#FF7F00',     # Orange
        'Stromal Cells': '#6A3D9A',     # Purple
        'Plasma Cells': '#FB9A99',      # Light Pink
    }
    
    def get_broad_category(cell_type):
        if any(cancer in cell_type for cancer in ['Cancer', 'EMT', 'Stem-Like', 'Proliferating', 'IFN-Responsive', 'Angiogenesis', 'Metabolically', 'Transitional']):
            return 'Cancer Cells'
        elif any(t in cell_type for t in ['CD8', 'CD4', 'T cells', 'NK', 'TNK', 'Tregs']):
            return 'T and NK Cells'
        elif 'B Cells' in cell_type or cell_type.startswith('B '):
            return 'B Cells'
        elif any(m in cell_type for m in ['Macrophages', 'Monocytes', 'Dendritic', 'Neutrophils', 'Myeloid']):
            return 'Myeloid Cells'
        elif any(s in cell_type for s in ['Fibroblasts', 'Endothelial', 'Stromal', 'MyoCAFs', 'Adipocytes']):
            return 'Stromal Cells'
        elif 'Plasma' in cell_type:
            return 'Plasma Cells'
        else:
            return 'Other'
    
    def generate_color_variants(base_color, n_variants, darken_factor=0.3, lighten_factor=0.3):
        """Generate darker and lighter variants of a base color"""
        base_rgb = mcolors.to_rgb(base_color)
        variants = []
        
        if n_variants == 1:
            return [base_color]
        
        for i in range(n_variants):
            if i < n_variants // 2:
                factor = 1 - (darken_factor * (n_variants // 2 - i) / (n_variants // 2))
                variant = tuple(c * factor for c in base_rgb)
            elif i == n_variants // 2:
                variant = base_rgb
            else:
                factor = lighten_factor * (i - n_variants // 2) / (n_variants // 2)
                variant = tuple(min(1.0, c + factor * (1 - c)) for c in base_rgb)
            
            variants.append(mcolors.to_hex(variant))
        
        return variants
    
    def mix_colors(color1, color2, ratio=0.5):
        """Mix two colors with specified ratio"""
        rgb1 = np.array(mcolors.to_rgb(color1))
        rgb2 = np.array(mcolors.to_rgb(color2))
        mixed = ratio * rgb1 + (1 - ratio) * rgb2
        return mcolors.to_hex(mixed)
    
    def get_juxtaposed_color(cell_type):
        """Get mixed color for juxtaposed cell types"""
        if 'Juxtaposed' not in cell_type:
            return None
            
        components = cell_type.replace('Juxtaposed ', '').split(' - ')
        
        if len(components) == 2:
            broad1 = get_broad_category(components[0])
            broad2 = get_broad_category(components[1])
            
            if broad1 in broad_base_colors and broad2 in broad_base_colors:
                return mix_colors(broad_base_colors[broad1], broad_base_colors[broad2])
        
        elif len(components) == 3:
            broad1 = get_broad_category(components[0])
            broad2 = get_broad_category(components[1]) 
            broad3 = get_broad_category(components[2])
            
            if all(b in broad_base_colors for b in [broad1, broad2, broad3]):
                temp_mix = mix_colors(broad_base_colors[broad1], broad_base_colors[broad2])
                return mix_colors(temp_mix, broad_base_colors[broad3], ratio=0.67)
        
        return '#808080'
    
    color_schemes = {}
    
    color_schemes['initial_broad'] = broad_base_colors.copy()
    
    print("Creating subtype color scheme...")
    subtype_colors = {}
    
    subtypes_by_broad = {}
    for subtype in adata.obs['subtype'].unique():
        if pd.notna(subtype):
            broad_cat = get_broad_category(subtype)
            if broad_cat not in subtypes_by_broad:
                subtypes_by_broad[broad_cat] = []
            subtypes_by_broad[broad_cat].append(subtype)
    
    for broad_cat, subtypes in subtypes_by_broad.items():
        regular_subtypes = [s for s in subtypes if 'Juxtaposed' not in s]
        juxta_subtypes = [s for s in subtypes if 'Juxtaposed' in s]
        
        if regular_subtypes and broad_cat in broad_base_colors:
            n_regular = len(regular_subtypes)
            print(f"  {broad_cat}: {n_regular} regular subtypes", end="")
            
            if n_regular <= 3:
                range_info = "(standard range)"
            elif n_regular <= 6:
                range_info = "(medium range)"
            elif n_regular <= 10:
                range_info = "(large range)"
            else:
                range_info = "(extra large range + HSV variation)"
            
            print(f" {range_info}")
            
            variants = generate_color_variants(broad_base_colors[broad_cat], n_regular)
            for subtype, color in zip(sorted(regular_subtypes), variants):
                subtype_colors[subtype] = color
        
        for subtype in juxta_subtypes:
            mixed_color = get_juxtaposed_color(subtype)
            if mixed_color:
                subtype_colors[subtype] = mixed_color
    
    color_schemes['subtype'] = subtype_colors
    
    print("Creating subtype_granular color scheme...")
    granular_colors = {}
    
    granular_by_broad = {}
    for granular in adata.obs['subtype_granular'].unique():
        if pd.notna(granular):
            broad_cat = get_broad_category(granular)
            if broad_cat not in granular_by_broad:
                granular_by_broad[broad_cat] = []
            granular_by_broad[broad_cat].append(granular)
    
    for broad_cat, granulars in granular_by_broad.items():
        regular_granulars = [g for g in granulars if 'Juxtaposed' not in g]
        juxta_granulars = [g for g in granulars if 'Juxtaposed' in g]
        
        if regular_granulars and broad_cat in broad_base_colors:
            variants = generate_color_variants(broad_base_colors[broad_cat], len(regular_granulars))
            for granular, color in zip(sorted(regular_granulars), variants):
                granular_colors[granular] = color
        
        for granular in juxta_granulars:
            mixed_color = get_juxtaposed_color(granular)
            if mixed_color:
                granular_colors[granular] = mixed_color
    
    color_schemes['subtype_granular'] = granular_colors
    
    print("Creating juxta_call color scheme...")
    juxta_call_colors = {}
    
    for juxta_call in adata.obs['juxta_call'].unique():
        if pd.notna(juxta_call):
            if juxta_call in subtype_colors:
                juxta_call_colors[juxta_call] = subtype_colors[juxta_call]
            else:
                if 'Juxtaposed' not in juxta_call and ' - ' in juxta_call:
                    juxtaposed_name = f'Juxtaposed {juxta_call}'
                    if juxtaposed_name in subtype_colors:
                        juxta_call_colors[juxta_call] = subtype_colors[juxtaposed_name]
                    else:
                        mixed_color = get_juxtaposed_color(juxta_call)
                        juxta_call_colors[juxta_call] = mixed_color if mixed_color else '#808080'
                else:
                    juxta_call_colors[juxta_call] = '#808080'
    
    color_schemes['juxta_call'] = juxta_call_colors
    
    qc_colors = {
        'confident': '#D3D3D3',      # Light grey
        'juxtaposed': '#FFA500',     # Orange
        'ambiguous': '#FFD700',      # Gold  
        'doublet': '#DC143C',        # Crimson
        'low_quality': '#8B4513',    # Saddle brown
        'to_discard': '#000000'      # Black
    }
    color_schemes['qc_flag'] = qc_colors
    
    treatment_colors = {
        'Treated': '#2E8B57',        # Sea green 
        'Control': '#4682B4'         # Steel blue 
    }
    color_schemes['Treatment'] = treatment_colors
    
    if 'Region_ID' in adata.obs.columns:
        regions = adata.obs['Region_ID'].unique()
        region_colors = {}
        # Use a colorblind-friendly palette
        palette = sns.color_palette("Set2", len(regions))
        for region, color in zip(sorted(regions), palette):
            if pd.notna(region):
                region_colors[region] = mcolors.to_hex(color)
        color_schemes['Region_ID'] = region_colors
    
    return color_schemes

def apply_color_schemes(adata, color_schemes):
    """Apply color schemes to adata.uns for scanpy"""
    
    for annotation, colors in color_schemes.items():
        color_key = f'{annotation}_colors'
        
        if annotation in adata.obs.columns:
            unique_values = adata.obs[annotation].cat.categories if hasattr(adata.obs[annotation], 'cat') else adata.obs[annotation].unique()
            
            color_list = []
            for value in unique_values:
                if pd.notna(value) and value in colors:
                    color_list.append(colors[value])
                else:
                    color_list.append('#808080')  # Default grey for missing
            
            adata.uns[color_key] = color_list
            print(f"Applied {len(color_list)} colors for {annotation}")
    
    return adata

def plot_color_preview(color_schemes, save_path=None):
    """Create a preview plot of all color schemes"""
    
    n_schemes = len(color_schemes)
    fig, axes = plt.subplots(n_schemes, 1, figsize=(15, 3*n_schemes))
    if n_schemes == 1:
        axes = [axes]
    
    for idx, (scheme_name, colors) in enumerate(color_schemes.items()):
        ax = axes[idx]
        
        items = list(colors.items())
        regular_items = [(k, v) for k, v in items if 'Juxtaposed' not in k]
        juxta_items = [(k, v) for k, v in items if 'Juxtaposed' in k]
        sorted_items = regular_items + juxta_items
        
        y_pos = np.arange(len(sorted_items))
        colors_list = [item[1] for item in sorted_items]
        labels = [item[0] for item in sorted_items]
        
        bars = ax.barh(y_pos, [1]*len(sorted_items), color=colors_list, height=0.8)
        
        ax.set_yticks(y_pos)
        ax.set_yticklabels(labels, fontsize=8)
        ax.set_xlabel('Color')
        ax.set_title(f'{scheme_name.replace("_", " ").title()} Color Scheme', fontweight='bold')
        ax.set_xlim(0, 1)
        
        ax.set_xticks([])
        
        for i, (bar, color) in enumerate(zip(bars, colors_list)):
            ax.text(0.5, i, color, ha='center', va='center', fontsize=6, 
                   color='white' if sum(mcolors.to_rgb(color)) < 1.5 else 'black')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Color preview saved to {save_path}")
    
    plt.show()
    
    return fig

def setup_all_colors(adata, plot_preview=True):
    """Main function to create and apply comprehensive color scheme"""
    
    print("Creating comprehensive color scheme...")
    color_schemes = create_comprehensive_color_scheme(adata)
    
    print("Applying color schemes to adata...")
    adata = apply_color_schemes(adata, color_schemes)
    
    if plot_preview:
        print("Creating color preview plot...")
        plot_color_preview(color_schemes)
    
    print("Color scheme setup complete!")
    print(f"Applied colors for: {list(color_schemes.keys())}")
    
    return adata, color_schemes



In [None]:
adata, color_schemes = setup_all_colors(adata, plot_preview=True)

In [None]:
sc.pl.umap(
    adata,
    color=["initial_broad", "subtype", "subtype_granular", "juxta_call", "qc_flag"],
    cmap="Reds",
    vmax="p99",
    size=1,
    legend_fontsize=8,
    legend_fontoutline=1,
    ncols=1,
    frameon=False,
    show=True,
)

In [None]:
def create_color_reference(adata_confident):
    """Create formatted color reference for easy copying"""
    
    print("="*80)
    print("="*80)
    
    for annotation in ['initial_broad', 'subtype', 'subtype_granular']:
        if f'{annotation}_colors' in adata_confident.uns and annotation in adata_confident.obs.columns:
            
            print(f"\n{annotation.upper().replace('_', ' ')}:")
            print("-" * 50)
            
            if hasattr(adata_confident.obs[annotation], 'cat'):
                categories = adata_confident.obs[annotation].cat.categories
                colors = adata_confident.uns[f'{annotation}_colors']
                
                regular_cats = []
                juxta_cats = []
                
                for i, cat in enumerate(categories):
                    if 'Juxtaposed' in str(cat):
                        juxta_cats.append((cat, colors[i]))
                    else:
                        regular_cats.append((cat, colors[i]))
                
                for cat, color in sorted(regular_cats):
                    print(f"{cat:<45} {color}")
                
                if juxta_cats:
                    print("  --- Juxtaposed Cells ---")
                    for cat, color in sorted(juxta_cats):
                        print(f"{cat:<45} {color}")
    
    print("\n" + "="*80)
    print("COPY FORMAT: Category Name → Hex Color")
    print("="*80)

create_color_reference(adata_confident)

print("\n\nSIMPLE COPY-PASTE FORMAT:")
print("="*50)

for annotation in ['initial_broad', 'subtype', 'subtype_granular']:
    if f'{annotation}_colors' in adata_confident.uns:
        print(f"\n{annotation}:")
        if hasattr(adata_confident.obs[annotation], 'cat'):
            categories = adata_confident.obs[annotation].cat.categories
            colors = adata_confident.uns[f'{annotation}_colors']
            
            for cat, color in zip(categories, colors):
                print(f"{cat}: {color}")

In [None]:
create_color_reference(adata)

print("\n\nSIMPLE COPY-PASTE FORMAT:")
print("="*50)

for annotation in ['initial_broad', 'subtype', 'subtype_granular']:
    if f'{annotation}_colors' in adata.uns:
        print(f"\n{annotation}:")
        if hasattr(adata.obs[annotation], 'cat'):
            categories = adata.obs[annotation].cat.categories
            colors = adata.uns[f'{annotation}_colors']
            
            for cat, color in zip(categories, colors):
                print(f"{cat}: {color}")