In [None]:
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns
import polars as pl
import mudata as mu
from scanpy import read_h5ad
import os

In [None]:
from scanpy import read_h5ad
import os
# ---- Load MuData object with cell-type annotations ----
output_dir = 'Intermediate_Files/Clustering/'

adata_i_filtered = read_h5ad(os.path.join(output_dir, "PBMC_iso_AutoZI_clustered_celltypes_reannotated_AutoZILatent.h5mu"))

In [None]:
# Quick sanity check: print the first 10 feature names to confirm expected "SYMBOL:ENSG:ENST" formatting.
try:
    print("rna.var_names head:", adata_i_filtered.var_names[:10].tolist())
except Exception:
    pass

In [None]:
# Check structure of adata and identify unnecessary parts
# Only raw data, cell-type labels, and color assignments for cell-types are needed
adata_i_filtered

In [None]:
## Remove unneeded parts to slim object to clear up memory

# remove unnecessary parts from .obs
obs_to_drop = [
    '0.04_log_AutoZI', '0.06_log_AutoZI', '0.1_log_AutoZI', '0.14_log_AutoZI', '0.16_log_AutoZI',
    '0.2_log_AutoZI', '0.24_log_AutoZI', '0.26_log_AutoZI', '0.3_log_AutoZI',
    'TCell_Combined', 'NKCell_Combined', 'Monocyte_Combined',
    'BCell_Combined', 'MK_Combined'
]
adata_i_filtered.obs.drop(columns=[c for c in obs_to_drop if c in adata_i_filtered.obs.columns], inplace=True)

# remove unnecessary parts from .uns
uns_to_drop = [
    '0.04_log_AutoZI', '0.06_log_AutoZI', '0.06_log_AutoZI_colors', '0.14_log_AutoZI',
    '0.16_log_AutoZI', '0.1_log_AutoZI', '0.24_log_AutoZI', '0.26_log_AutoZI', '0.2_log_AutoZI',
    '0.34_log_AutoZI', '0.36_log_AutoZI', '0.3_log_AutoZI', 
]
for key in uns_to_drop:
    if key in adata_i_filtered.uns:
        del adata_i_filtered.uns[key]

# remove unnecessary parts from .obsm 
obsm_to_drop = ['X_pca', 'X_umap']
for key in obsm_to_drop:
    if key in adata_i_filtered.obsm:
        del adata_i_filtered.obsm[key]

# remove unnecessary layers 
layers_to_drop = ['denoised', 'log_denoised']
for key in layers_to_drop:
    if key in adata_i_filtered.layers:
        del adata_i_filtered.layers[key]

print(adata_i_filtered)

In [None]:
# Load in unique counts matrices that have undergone filtering from Step 0

pbmc1 = pl.read_csv(
    "InitialFiltering/PBMC_patient0_JUNE_16_2025_bambu_quant_PBMC1_combined_uniqueCounts_transcript.filtered_transposed_expression_matrix.txt",
    separator="\t"
)
pbmc2 = pl.read_csv(
    "InitialFiltering/PBMC_patient0_JUNE_16_2025_bambu_quant_PBMC2_combined_uniqueCounts_transcript.filtered_transposed_expression_matrix.txt",
    separator="\t"
)

In [None]:
import polars as pl
from collections import Counter, defaultdict

## ENSG -> symbol (from adata.var_names: Symbol:ENSG:ENST)
# Build a frequency map of symbols observed per ENSG using your var_names convention.
ensg_to_symbol = defaultdict(Counter)
for v in map(str, adata_i_filtered.var_names):
    parts = v.split(":")
    if len(parts) == 3:
        sym, ensg, _ = parts
        if ensg.startswith("ENSG"):
            ensg_to_symbol[ensg][sym] += 1

# For each ENSG, pick the most frequent symbol as the “primary” one
ensg_primary_symbol = {ensg: cnt.most_common(1)[0][0] for ensg, cnt in ensg_to_symbol.items()}

## marker ENSGs (from adata.var_names using your marker symbols) 
marker_set = {
    "CD3D","CD3E","CD3G",  # General T cell markers 
    "CD8A","CD8B","GATA3","KLRB1","CCL5",    # Effector CD8 T cell markers
    "CD4","IL2RA","AHR","TNF",   # Effector CD4 T cell markers
    "CCR7","SELL","TCF7",   # Memory T cell markers
    "ITGAE","LEF1","CTLA4","IL7R","CD27",   # Transition T cell markers
    "GZMB","KLRF1","NCAM1","ITGAM","IL2RB",   # Natural Killer cell markers
    "CD22","CD79A","MS4A1","CD19",   # B cell markers
    "FCGR2A","CLEC7A","CD33","LILRB4",   # Monocyte-derived cell markers
    "GP1BA","MPL","ITGA2B"   # Megakaryocyte markers
}

## Translate marker symbols into the set of corresponding ENSG IDs present in your dataset.
marker_ensg = set()
for v in map(str, adata_i_filtered.var_names):
    parts = v.split(":")
    if len(parts) == 3:
        sym, ensg, _ = parts
        if sym in marker_set and ensg.startswith("ENSG"):
            marker_ensg.add(ensg)
print(f"[info] marker ENSGs resolved: {len(marker_ensg)}")

## Find ENSGs that have at least one BambuTx in PBMC1 or PBMC2
def ensgs_with_bambu(df: pl.DataFrame) -> set[str]:
    out = set()
    for c in df.columns[1:]:
        if "|" not in c:
            continue
        left, right = c.split("|", 1)
        # Bambu transcript on a known gene looks like: BambuTx###|ENSG###########
        if left.startswith("BambuTx") and right.startswith("ENSG"):
            out.add(right)
    return out

# Union across PBMC1 and PBMC2 to get all genes with any novel Bambu transcript.
ensg_with_novel_1 = ensgs_with_bambu(pbmc1)
ensg_with_novel_2 = ensgs_with_bambu(pbmc2)
ensg_with_novel_any = ensg_with_novel_1 | ensg_with_novel_2
print(f"[info] ENSGs with ≥1 novel (BambuTx) transcript across PBMC1/2: {len(ensg_with_novel_any)}")

In [None]:
# Show shape and first few rows/columns to confirm structure and naming conventions
print("PBMC1 shape:", pbmc1.shape)
print("PBMC2 shape:", pbmc2.shape)

print("\nPBMC1 head:")
print(pbmc1.head())

print("\nPBMC1 column names:")
print(pbmc1.columns)

print("\nPBMC2 head:")
print(pbmc2.head())

print("\nPBMC2 column names:")
print(pbmc2.columns)

In [None]:
## Build keep lists and rename maps; KEEP RULES:
#  1) Keep ALL Bambu (novel) columns.
#  2) Keep any column whose right token ENSG is in marker_ensg.
#  3) Keep any column whose right token ENSG is in ensg_with_novel_any (ALL transcripts of genes with a novel).
def build_keep_and_rename(df: pl.DataFrame):
    cols = df.columns
    cell_col = cols[0]
    keep = [cell_col]
    kept_marker, kept_bambu, kept_all_from_novel = [], [], []
    rename_map = {}

    for c in cols[1:]:
        name = str(c)
        if "|" in name:
            left, right = name.split("|", 1)
        else:
            left, right = name, ""

        is_bambu = ("Bambu" in left) or ("Bambu" in right)

        # Keep BambuTx, add gene symbol, and restructure combined_id
        if is_bambu:
            keep.append(c)
            kept_bambu.append(c)
            if right.startswith("ENSG") and right in ensg_primary_symbol:
                rename_map[c] = f"{left}|{right}|{ensg_primary_symbol[right]}"
            continue

        # Keep marker genes and restructure combined_id
        if right in marker_ensg:
            keep.append(c)
            kept_marker.append(c)
            if right in ensg_primary_symbol:
                rename_map[c] = f"{left}|{right}|{ensg_primary_symbol[right]}"
            continue

        # Keep transcripts woth ≥1 novel (BambuTx) transcript and restructure combined_id
        if right in ensg_with_novel_any:
            keep.append(c)
            kept_all_from_novel.append(c)
            if right in ensg_primary_symbol:
                rename_map[c] = f"{left}|{right}|{ensg_primary_symbol[right]}"

    return keep, kept_marker, kept_bambu, kept_all_from_novel, rename_map

In [None]:

keep1, kept_marker1, kept_bambu1, kept_all_from_novel1, rename1 = build_keep_and_rename(pbmc1)
keep2, kept_marker2, kept_bambu2, kept_all_from_novel2, rename2 = build_keep_and_rename(pbmc2)

print(f"[info] PBMC1 kept: marker={len(kept_marker1)}  bambu={len(kept_bambu1)}  all_from_novel={len(kept_all_from_novel1)}  total={len(keep1)-1}")
print(f"[info] PBMC2 kept: marker={len(kept_marker2)}  bambu={len(kept_bambu2)}  all_from_novel={len(kept_all_from_novel2)}  total={len(keep2)-1}")

In [None]:
# Apply keep-list to each matrix, and rename columns to append the primary SYMBOL
pbmc1_filt = pbmc1.select(keep1).rename(rename1) if rename1 else pbmc1.select(keep1)
pbmc2_filt = pbmc2.select(keep2).rename(rename2) if rename2 else pbmc2.select(keep2)

In [None]:
# Harmonize for plotting/concat
def harmonize_columns(df_left: pl.DataFrame, df_right: pl.DataFrame):
    id_col = df_left.columns[0] # Assume first column is the cell/ID column
    left_cols  = set(df_left.columns[1:])
    right_cols = set(df_right.columns[1:])
    union_cols = sorted(left_cols | right_cols)
    def reindex(df: pl.DataFrame) -> pl.DataFrame:
        missing = [c for c in union_cols if c not in df.columns]
        if missing:
            df = df.hstack([pl.Series(c, [0]*df.height) for c in missing])
        return df.select([id_col] + union_cols)
    return reindex(df_left), reindex(df_right)

pbmc1_filt, pbmc2_filt = harmonize_columns(pbmc1_filt, pbmc2_filt)

# After this, pbmc1_filt and pbmc2_filt should have identical feature columns (same names, same order), enabling safe concat/compare.

In [None]:
# Quick previews to ensure matching formatting
print("\n[preview] PBMC1 first 8 kept marker columns:")
print([c for c in pbmc1_filt.columns[1:] if "|ENSG" in c and c.split('|')[1] in marker_ensg][:8])

print("\n[preview] PBMC1 first 8 kept from genes-with-novel (non-Bambu):")
print([c for c in pbmc1_filt.columns[1:]
       if "|ENSG" in c and c.split('|')[1] in ensg_with_novel_any and not c.startswith("BambuTx")][:8])

print("\n[preview] PBMC1 first 8 Bambu columns:")
print([c for c in pbmc1_filt.columns[1:] if c.startswith("BambuTx")][:8])

In [None]:
# Save 
pbmc1_filt.write_csv("Intermediate_Files/Unique_Count_Analyses/PBMC1_uniquecounts_filtered_marker_Bambu_withSymbols.txt", separator="\t")
pbmc2_filt.write_csv("Intermediate_Files/Unique_Count_Analyses/PBMC2_uniquecounts_filtered_marker_Bambu_withSymbols.txt", separator="\t")

In [None]:
# Free up memory from unneeded objects
import gc
del pbmc1
del pbmc2
gc.collect()

In [None]:
import polars as pl

pbmc1_filt = pl.read_csv(
    "Intermediate_Files/Unique_Count_Analyses/PBMC1_uniquecounts_filtered_marker_Bambu_withSymbols.txt",
    separator="\t"
)

pbmc2_filt = pl.read_csv(
    "Intermediate_Files/Unique_Count_Analyses/PBMC2_uniquecounts_filtered_marker_Bambu_withSymbols.txt",
    separator="\t"
)

# Check to confirm structure and contents after round-trip save/load.
print(pbmc1_filt.head())
print(pbmc2_filt.head())

In [None]:
# Inputs: pbmc1_filt, pbmc2_filt (Polars DataFrames), already harmonized as shown.
import re
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns

## Function to identify and extract BambuTx columns (left token "BambuTx..." before the first '|')
def get_bambu_cols(df: pl.DataFrame) -> list[str]:
    cols = []
    for c in df.columns[1:]:  # skip first col (CellID)
        s = str(c)
        left = s.split("|", 1)[0]
        if left.startswith("BambuTx"):
            cols.append(c)
    return cols

#Collect BambuTx columns present in either dataset
bambu_cols_1 = set(get_bambu_cols(pbmc1_filt))
bambu_cols_2 = set(get_bambu_cols(pbmc2_filt))
bambu_cols   = sorted(bambu_cols_1 | bambu_cols_2)

print(f"[info] BambuTx columns: PBMC1={len(bambu_cols_1)}  PBMC2={len(bambu_cols_2)}  union={len(bambu_cols)}")

if not bambu_cols:
    raise ValueError("No BambuTx columns found in the filtered unique-count matrices.")

## Ensure missing Bambu columns exist as zeros
def ensure_cols(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    missing = [c for c in cols if c not in df.columns]
    if missing:
        df = df.hstack([pl.Series(c, [0]*df.height) for c in missing])
    return df.select([df.columns[0]] + cols)  # keep CellID + desired columns

pbmc1b = ensure_cols(pbmc1_filt, bambu_cols)
pbmc2b = ensure_cols(pbmc2_filt, bambu_cols)

## Convert the selected Bambu columns to numeric
def to_numeric(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    return df.with_columns([pl.col(cols).cast(pl.Int64, strict=False)])
pbmc1b = to_numeric(pbmc1b, bambu_cols)
pbmc2b = to_numeric(pbmc2b, bambu_cols)

In [None]:
# New isoform cellular prevalence from unique-counts (PBMC1 + PBMC2)
## Count per-transcript cells (value > 0) in each dataset, then sum across PBMC1+PBMC2
def per_transcript_cell_counts(df: pl.DataFrame, cols: list[str]) -> dict[str, int]:
    # sum of (col > 0) per column; Polars sums booleans as integers
    out = df.select([ (pl.col(c) > 0).sum().alias(c) for c in cols ])
    return out.to_dicts()[0]

counts1 = per_transcript_cell_counts(pbmc1b, bambu_cols)
counts2 = per_transcript_cell_counts(pbmc2b, bambu_cols)

## Combine counts (same transcript columns)
counts_total = np.array([int(counts1.get(c, 0)) + int(counts2.get(c, 0)) for c in bambu_cols], dtype=int)

# Summary stats: Prevalence of novel isoforms across cells.
n_transcripts = counts_total.size
median_cells  = int(np.median(counts_total))
min_cells     = int(counts_total.min())
max_cells     = int(counts_total.max())
print(f"[info] Novel transcripts (BambuTx): {n_transcripts}")
print(f"[info] Cells per transcript: median={median_cells}, range={min_cells}–{max_cells}")

In [None]:
# --- All novel isoforms (BambuTx): number of cells expressing each, sorted descending ---

# Build display labels (prefer SYMBOL if present, else ENSG/BambuGene)
def pretty_label(colname: str) -> str:
    parts = str(colname).split("|")
    left = parts[0] if parts else colname
    if len(parts) >= 3 and parts[2]:
        right = parts[2]                   # SYMBOL
    elif len(parts) >= 2 and parts[1]:
        right = parts[1]                   # ENSG or BambuGene
    else:
        right = ""
    return f"{left} | {right}" if right else left

labels = bambu_cols
disp_labels = [pretty_label(c) for c in labels]
counts_total = np.array([int(counts1.get(c, 0)) + int(counts2.get(c, 0)) for c in labels], dtype=int)

# Table sorted by prevalence (most → least) for easier plotting and inspection.
df_all = pd.DataFrame({
    "Transcript": labels,
    "Display": disp_labels,
    "CellsExpressing": counts_total
}).sort_values("CellsExpressing", ascending=False).reset_index(drop=True)

print(f"[info] Novel transcripts (BambuTx): {len(df_all)}")
print(df_all.head(10).to_string(index=False))

In [None]:
## Bar Plot: Plot ALL transcripts (horizontal bar chart), most → least

# Auto height so labels are readable; cap fontsize accordingly.
rows = len(df_all)
fig_h = max(4, min(0.28 * rows, 18))  # scale height; cap to 18 inches
plt.figure(figsize=(8, fig_h))
ax = sns.barplot(
    data=df_all,
    x="CellsExpressing",
    y="Display",
    color="#377eb8",
    edgecolor="black"
)

# Bar-end labels (compact)
for p in ax.patches:
    w = p.get_width()
    y = p.get_y() + p.get_height()/2
    ax.text(w + max(5, 0.01*w), y, f"{int(w):,}", va="center", ha="left", fontsize=8)

ax.set_xlabel("Cells Expressing Transcript")
ax.set_ylabel("")
ax.set_title("Novel Isoforms (BambuTx): Cells Expressing Each (PBMC1+PBMC2)")

# Aesthetics
sns.despine(left=True, bottom=True)
ax.tick_params(axis='y', labelsize=8)
ax.tick_params(axis='x', labelsize=9)
plt.tight_layout()
plt.savefig("Intermediate_Files/Paper_Figs/uniquecounts_bambu_transcripts_cells_expressed_sorted_all.pdf",
            dpi=600, transparent=True, bbox_inches="tight")
plt.show()

# Optional: also save the table
df_all.to_csv("Intermediate_Files/Unique_Count_Analyses/bambu_transcripts_cells_expressed_sorted_all.tsv",
              sep="\t", index=False)

In [None]:
# --- Stacked bar for one gene (CD8A): total counts by cell type ---
# --- Stacked bar for CD8A: TOTAL COUNTS by cell type (from counts layer) ---
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from scipy import sparse

adata = adata_i_filtered  # your filtered AnnData
gene_symbol_target = "CD8A"
ensembl_fallbacks = {"CD8A": ["ENSG00000153563"]}  # add more if you wish

# --------------- Utility: get counts matrix ----------------
X = adata.layers["counts"] if "counts" in adata.layers else adata.X
is_sparse = sparse.issparse(X)

# --------------- Resolve features that belong to the gene ---------------
def _series_ci_equals(s: pd.Series, target: str) -> pd.Series:
    return s.astype(str).str.upper() == target.upper()

def _series_ci_contains(s: pd.Series, target: str) -> pd.Series:
    return s.astype(str).str.upper().str.contains(target.upper(), regex=False)

def find_feature_indices_for_gene(adata, gene_symbol: str) -> np.ndarray:
    var = adata.var.copy()

    # 1) exact match in var_names
    exact_idx = np.flatnonzero(pd.Series(adata.var_names).str.upper().values == gene_symbol.upper())
    if exact_idx.size > 0:
        return exact_idx

    # 2) exact match in common gene symbol columns
    symbol_cols = [c for c in var.columns if "symbol" in c.lower() or "gene_name" in c.lower() or c.lower() == "gene"]
    for col in symbol_cols:
        m = _series_ci_equals(var[col], gene_symbol)
        if m.any():
            return np.flatnonzero(m.values)

    # 3) Ensembl fallbacks (if provided)
    for ens in ensembl_fallbacks.get(gene_symbol, []):
        # try any column that looks like ensembl id
        ensg_cols = [c for c in var.columns if "ensembl" in c.lower() or c.lower() == "gene_id" or "gene_id" in c.lower()]
        for col in ensg_cols:
            m = _series_ci_equals(var[col], ens)
            if m.any():
                return np.flatnonzero(m.values)

    # 4) heuristic substring search (common in isoform-level matrices)
    #    e.g., var_names like "ENSG...|CD8A|ENST..." or columns with "CD8A-201"
    #    prefer columns where it appears in a gene_name-like field
    for col in symbol_cols:
        m = _series_ci_contains(var[col], gene_symbol)
        if m.any():
            return np.flatnonzero(m.values)

    # fallback: search var_names substrings
    sub_idx = np.flatnonzero(pd.Series(adata.var_names).str.upper().str.contains(gene_symbol.upper(), regex=False).values)
    return sub_idx  # may be empty

idxs = find_feature_indices_for_gene(adata, gene_symbol_target)
if idxs.size == 0:
    candidates = (
        pd.DataFrame({"var_name": adata.var_names})
        .assign(
            gene_name=adata.var.get("gene_name", pd.Series([""] * adata.n_vars)).astype(str),
            symbol=adata.var.get("symbol", pd.Series([""] * adata.n_vars)).astype(str),
        )
    )
    mask_cand = (
        candidates["var_name"].str.upper().str.contains(gene_symbol_target)
        | candidates["gene_name"].str.upper().str.contains(gene_symbol_target)
        | candidates["symbol"].str.upper().str.contains(gene_symbol_target)
    )
    hits = candidates.loc[mask_cand].head(25)
    raise ValueError(
        f"Could not resolve features for '{gene_symbol_target}'. "
        f"Top candidates (showing up to 25):\n{hits}"
    )

# Build per-transcript counts
# label each matching feature
var = adata.var.copy()
label_cols = [c for c in ["transcript_id", "tx_id", "transcript", "isoform_id",
                          "feature_id", "name", "id"] if c in var.columns]
def _label_for(j):
    # prefer transcript_id if available; else var_name
    for c in label_cols:
        v = str(var.iloc[j][c])
        if v and v != "None":
            return v
    return str(adata.var_names[j])

tx_indices = np.asarray(idxs)  # from your finder
tx_labels  = [ _label_for(j) for j in tx_indices ]

# grab counts for just those transcripts
X = adata.layers["counts"] if "counts" in adata.layers else adata.X
from scipy import sparse
is_sparse = sparse.issparse(X)

if is_sparse:
    sub = X[:, tx_indices].toarray()     # (cells × transcripts)
else:
    sub = np.asarray(X[:, tx_indices])   # (cells × transcripts)

# dataframe with cell types
if "gen_cell_type" not in adata.obs.columns:
    raise KeyError("adata.obs['gen_cell_type'] is required for grouping.")
cell_types = adata.obs["gen_cell_type"].astype(str).to_numpy()

# aggregate TOTAL counts by (transcript, cell_type)
import pandas as pd
df = pd.DataFrame(sub, columns=tx_labels)
df["gen_cell_type"] = cell_types
by_tx_ct = df.groupby("gen_cell_type", observed=True).sum(numeric_only=True)  # sums per CT
# Now we want rows=transcripts, cols=cell types, so transpose:
wide = by_tx_ct.T  # index=transcript, columns=cell types

# sort transcripts by total counts (descending)
wide["__Total__"] = wide.sum(axis=1)
wide = wide.sort_values("__Total__", ascending=False)

# (optional) show only top N transcripts
TOP_N = 30
wide = wide.head(TOP_N)

# Colors & column order
if 'cell_type_colors' in globals() and isinstance(cell_type_colors, dict):
    ct_order = [ct for ct in cell_type_colors if ct in wide.columns]
    # add any extra CTs not in the palette at the end
    ct_order += [ct for ct in wide.columns if ct not in ct_order and ct != "__Total__"]
    colors = [cell_type_colors.get(ct, "#999999") for ct in ct_order]
else:
    ct_order = [c for c in wide.columns if c != "__Total__"]
    colors = None

# final matrix for plotting
plot_mat = wide[ct_order]
totals = wide["__Total__"].to_numpy()
tx_names = wide.index.tolist()

In [None]:
### Stacked bar plot (one per transcript) ---------------
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import numpy as np

rows = plot_mat.shape[0]
fig_h = max(4, min(0.35 * rows, 18))
fig, ax = plt.subplots(figsize=(10, fig_h))

y = np.arange(rows)
left = np.zeros(rows, dtype=float)
handles = []

for i, ct in enumerate(ct_order):
    vals = plot_mat[ct].to_numpy()
    c = (colors[i] if colors is not None else None)
    bar = ax.barh(y, vals, left=left, color=c, edgecolor="black", label=ct)
    left = left + vals
    if colors is not None:
        handles.append(Patch(facecolor=c, edgecolor="black", label=ct))

ax.set_yticks(y)
ax.set_yticklabels(tx_names)
ax.invert_yaxis()
ax.set_xlabel("Total counts")
ax.set_ylabel("")
ax.set_title(f"{gene_symbol_target}: transcript-level total counts by cell type (counts layer)")

# write total at ends
for i, t in enumerate(totals):
    ax.text(t + max(5, 0.01*t), i, f"{int(t):,}", va="center", ha="left", fontsize=8)

# legend
if colors is not None:
    ax.legend(handles=handles, title="Cell type", bbox_to_anchor=(1.01, 1), loc="upper left", frameon=True)
else:
    ax.legend(title="Cell type", bbox_to_anchor=(1.01, 1), loc="upper left", frameon=True)

# clean look
for spine in ("left", "right", "top"):
    ax.spines[spine].set_visible(False)

plt.tight_layout()
plt.savefig("Intermediate_Files/Paper_Figs/CD8A_transcripts_total_counts_by_celltype_countsLayer.pdf",
            dpi=600, transparent=True, bbox_inches="tight")
plt.show()

In [None]:
### CD8A transcripts: TOTAL UNIQUE COUNTS by cell type (PBMC1+PBMC2)
# Cell types are pulled from adata_i_filtered.obs['gen_cell_type']
import polars as pl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns

## CONFIG 
ct_order = ["TCells", "NK Cells", "BCells", "Monocyte-derived", "Megakaryocyte"]
target_symbol = "CD8A"

def get_colors(order):
    if 'cell_type_colors' in globals():
        return [cell_type_colors.get(ct, cell_type_colors.get(ct.rstrip('s') + 's', "#999999")) for ct in order]
    return sns.color_palette("tab10", n_colors=len(order))
colors = get_colors(ct_order)

## FUNCTIONS
def parse_tokens(colname: str):
    parts = str(colname).split("|")
    left   = parts[0] if len(parts) >= 1 else str(colname)   # ENST... or BambuTx...
    ensg   = parts[1] if len(parts) >= 2 else ""             # ENSG... or BambuGene...
    symbol = parts[2] if len(parts) >= 3 else ""             # SYMBOL (preferred)
    return left, ensg, symbol

def pretty_label(colname: str) -> str:
    left, ensg, symbol = parse_tokens(colname)
    return f"{left} | {symbol}" if symbol else (f"{left} | {ensg}" if ensg else left)

def select_cols_cd8a(df: pl.DataFrame) -> list[str]:
    keep = []
    for c in df.columns[1:]:  # skip CellID
        _, _, sym = parse_tokens(c)
        if sym == target_symbol:
            keep.append(c)
    return keep

def ensure_cols(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    id_col = df.columns[0]
    df = df.with_columns(pl.col(id_col).cast(pl.Utf8))
    missing = [c for c in cols if c not in df.columns]
    if missing:
        df = df.hstack([pl.Series(name=c, values=pl.repeat(0, df.height, dtype=pl.Int64)) for c in missing])
    return df.select([id_col] + cols)

# CELL-TYPE TABLE FROM AnnData
# Pull gen_cell_type from adata_i_filtered to a small Polars table for joining
adata_i_filtered.obs.index = adata_i_filtered.obs.index.astype(str)
obs_ct = adata_i_filtered.obs["gen_cell_type"]
ct_table = pl.DataFrame({
    "CellID": obs_ct.index.to_series().astype(str).tolist(),
    "cell_type": obs_ct.values.tolist()
})

# Expect pbmc1_filt / pbmc2_filt in memory (first col = CellID, other cols = unique-counts per transcript)
cols1 = select_cols_cd8a(pbmc1_filt)
cols2 = select_cols_cd8a(pbmc2_filt)
target_cols = sorted(set(cols1) | set(cols2))
if not target_cols:
    sample_hits = [c for c in pbmc1_filt.columns[1:] if "CD8" in str(c)][:20]
    print("[debug] No CD8A transcripts found. Example CD8-like columns:", sample_hits)
    raise ValueError("No transcripts with SYMBOL 'CD8A' found. Expect columns like 'ENST...|ENSG...|CD8A'.")

pbmc1_sel = ensure_cols(pbmc1_filt, target_cols)
pbmc2_sel = ensure_cols(pbmc2_filt, target_cols)

def attach_celltype(df: pl.DataFrame, ct_table: pl.DataFrame) -> pl.DataFrame:
    id_col = df.columns[0]
    return (
        df.rename({id_col: "CellID"}).with_columns(pl.col("CellID").cast(pl.Utf8))
          .join(ct_table, on="CellID", how="left")
          .filter(pl.col("cell_type").is_not_null())
    )

pbmc1c = attach_celltype(pbmc1_sel, ct_table)
pbmc2c = attach_celltype(pbmc2_sel, ct_table)

def per_tx_ct_sum(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    """Sum UNIQUE COUNTS (not presence) per (Transcript, cell_type)."""
    long = df.unpivot(
        index=["CellID", "cell_type"],
        on=cols,
        variable_name="Transcript",
        value_name="val"
    )
    return (
        long
        .group_by(["Transcript", "cell_type"])
        .agg(pl.col("val").sum().alias("total_unique_counts"))
    )

sum1 = per_tx_ct_sum(pbmc1c, target_cols)
sum2 = per_tx_ct_sum(pbmc2c, target_cols)
sum_all = (
    pl.concat([sum1, sum2])
    .group_by(["Transcript", "cell_type"])
    .agg(pl.col("total_unique_counts").sum().alias("total_unique_counts"))
)

# Pivot -> rows=Transcript, cols=cell_type (TOTAL UNIQUE COUNTS)
wide = (
    sum_all
    .pivot(index="Transcript", columns="cell_type", values="total_unique_counts")
    .fill_null(0)
)
# Ensure all ct columns present & ordered
for ct in ct_order:
    if ct not in wide.columns:
        wide = wide.with_columns(pl.lit(0).alias(ct))
wide = wide.select(["Transcript"] + ct_order)

# To pandas for labeling / sorting
wide_pd = wide.to_pandas().set_index("Transcript")
wide_pd["Display"] = [pretty_label(t) for t in wide_pd.index]
wide_pd["Total"] = wide_pd[ct_order].sum(axis=1)
wide_pd = wide_pd.sort_values("Total", ascending=False)

In [None]:
### Bar plot: stacked unique counts per transcript by cell type
rows = len(wide_pd)
fig_h = max(4, min(0.35 * rows, 18))
fig, ax = plt.subplots(figsize=(10, fig_h))

y = np.arange(rows, dtype=float)
left = np.zeros(rows, dtype=float)

for color, ct in zip(colors, ct_order):
    vals = wide_pd[ct].to_numpy()
    ax.barh(y, vals, left=left, color=color, edgecolor="black", label=ct)
    left += vals

ax.set_yticks(y)
ax.set_yticklabels(wide_pd["Display"].values)
ax.invert_yaxis()
ax.set_xlabel("Total unique counts")
ax.set_ylabel("")
ax.set_title("CD8A transcripts: total unique counts by cell type (PBMC1 + PBMC2)")

# total labels at bar ends
for i, total in enumerate(wide_pd["Total"].to_numpy()):
    ax.text(total + max(5, 0.01*total), i, f"{int(total):,}", va="center", ha="left", fontsize=8)

legend_handles = [Patch(facecolor=col, edgecolor="black", label=ct) for col, ct in zip(colors, ct_order)]
ax.legend(handles=legend_handles, title="Cell type", bbox_to_anchor=(1.01, 1), loc="upper left", frameon=True)

sns.despine(left=True, bottom=True)
plt.tight_layout()
plt.savefig("Intermediate_Files/Unique_Count_Analyses/CD8A_uniquecounts_transcripts_stacked_counts.pdf",
            dpi=600, transparent=True, bbox_inches="tight")
plt.show()

# Also export the numeric table of counts
out_counts = wide_pd.loc[:, ["Display"] + ct_order + ["Total"]].reset_index(drop=False).rename(columns={"index": "Transcript"})
out_counts.to_csv(
    "Intermediate_Files/Unique_Count_Analyses/CD8A_uniquecounts_transcripts_counts_table.tsv",
    sep="\t", index=False
)

In [None]:
# Bulk-by-cell-type bar charts of UNIQUE COUNTS per transcript, split by gene 
# One PDF per gene (all its transcripts in one panel), plus a combined multi-page PDF.

import os
import re
import unicodedata
import numpy as np
import polars as pl
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

## CONFIG / UTILS 
OUT_DIR_PER_GENE = "Intermediate_Files/Unique_Count_Analyses/plots/bulk_by_celltype_per_gene"
OUT_COMBINED_PDF = "Intermediate_Files/Unique_Count_Analyses/plots/bulk_by_celltype_all_genes.pdf"
os.makedirs(OUT_DIR_PER_GENE, exist_ok=True)

# Safe filename
def safe_name(s: str) -> str:
    s = str(s)
    s = unicodedata.normalize("NFKD", s)
    s = re.sub(r"[^\w\-(). ]+", "_", s)
    s = re.sub(r"\s+", "_", s).strip("_")
    return s[:200]  # avoid super long names

# Parse tokens from column name "LEFT|GENEID|SYMBOL?" (handles Bambu/ENST)
def parse_tokens(colname: str):
    parts = str(colname).split("|")
    left = parts[0] if parts else colname                     # transcript token (BambuTx### or ENST###)
    gene_token = parts[1] if len(parts) >= 2 else ""          # ENSG##### or BambuGene#####
    symbol = parts[2] if len(parts) >= 3 and parts[2] else "" # preferred display if present
    return left, gene_token, symbol

def gene_display(colname: str) -> str:
    left, gene_token, symbol = parse_tokens(colname)
    if symbol:
        return f"{symbol} ({gene_token})"
    return gene_token if gene_token else left  # fallback

def tx_display(colname: str) -> str:
    left, gene_token, symbol = parse_tokens(colname)
    # Keep concise, but add symbol if helpful
    return left if not symbol else f"{left}\n{symbol}"

## PREP DATA
# 1) Identify all transcript columns (skip the first 'CellID' column)
def transcript_columns(df: pl.DataFrame) -> list[str]:
    return [c for c in df.columns[1:]]  # you've already filtered these matrices

# Ensure both matrices share the same columns (fill missing with 0)
def ensure_cols(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    id_col = df.columns[0]
    missing = [c for c in cols if c not in df.columns]
    if missing:
        df = df.hstack([pl.Series(c, [0]*df.height) for c in missing])
    return df.select([id_col] + cols)

cols1 = transcript_columns(pbmc1_filt)
cols2 = transcript_columns(pbmc2_filt)
all_cols = sorted(set(cols1) | set(cols2))
pbmc1b = ensure_cols(pbmc1_filt, all_cols)
pbmc2b = ensure_cols(pbmc2_filt, all_cols)

# 2) Attach cell type to each cell row from adata_i_filtered
celltype_map = adata_i_filtered.obs["gen_cell_type"].astype(str).to_dict()

def attach_celltype(df: pl.DataFrame) -> pl.DataFrame:
    df = df.rename({df.columns[0]: "CellID"})
    keys = list(celltype_map.keys())
    vals = [celltype_map.get(k, None) for k in keys]
    ct_df = pl.DataFrame({"CellID": keys, "cell_type": vals})
    return df.join(ct_df, on="CellID", how="left").filter(pl.col("cell_type").is_not_null())

pbmc1c = attach_celltype(pbmc1b)
pbmc2c = attach_celltype(pbmc2b)

# 3) Sum UNIQUE COUNTS per transcript × cell type (PBMC1 + PBMC2)
def per_tx_ct_unique_counts(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    long = df.unpivot(
        index=["CellID", "cell_type"],
        on=cols,
        variable_name="Transcript",
        value_name="uc"   # per-cell unique counts
    )
    return (
        long
        .group_by(["Transcript", "cell_type"])
        .agg(pl.col("uc").sum().alias("unique_counts"))
    )

counts1 = per_tx_ct_unique_counts(pbmc1c, all_cols)
counts2 = per_tx_ct_unique_counts(pbmc2c, all_cols)
counts_all = (
    pl.concat([counts1, counts2])
    .group_by(["Transcript", "cell_type"])
    .agg(pl.col("unique_counts").sum().alias("unique_counts"))
)

df_counts = counts_all.to_pandas()

# Keep your preferred cell-type order/colors if defined
if 'cell_type_order' in globals():
    ct_order = [ct for ct in cell_type_order if ct in set(df_counts["cell_type"])]
else:
    ct_order = sorted(df_counts["cell_type"].unique())

## Cell-type order
# Force your desired order (do NOT filter by what’s present)
ct_order = ["TCells", "NK Cells", "BCells", "Monocyte-derived", "Megakaryocyte"]

# Colors for that order
if 'cell_type_colors' in globals():
    palette = {ct: cell_type_colors.get(ct, "#999999") for ct in ct_order}
else:
    pal = sns.color_palette("tab10", n_colors=len(ct_order))
    palette = {ct: pal[i] for i, ct in enumerate(ct_order)}

# Map Transcript -> Gene group
tx_to_gene = {t: gene_display(t) for t in all_cols}
df_counts["Gene"] = df_counts["Transcript"].map(tx_to_gene)
df_counts["TranscriptLabel"] = df_counts["Transcript"].map(tx_display)

# Total per gene (no longer used for ordering; keep if you still want the stats)
gene_totals = (
    df_counts.groupby("Gene")["unique_counts"]
             .sum()
)

# --- Robust alphabetical ordering by preferred gene symbol ---
# Sort by:
#   1) non-Bambu before BambuGene (optional: keep if you prefer this behavior)
#   2) normalized, case-insensitive "symbol" (text before the '(' )
#   3) natural sort for any digits
import unicodedata

def _normalize_ascii(s: str) -> str:
    s = unicodedata.normalize("NFKD", s)
    s = s.encode("ascii", "ignore").decode("ascii")
    return s.strip().lower()

def _natural_parts(s: str):
    # Split into text and integer chunks for natural ordering (e.g., gene2 < gene10)
    parts = re.split(r"(\d+)", s)
    return [int(p) if p.isdigit() else p for p in parts]

def _symbol_for_sort(g: str) -> str:
    # If "SYMBOL (ENSG...)" use SYMBOL; else use full string
    m = re.match(r"\s*([^(]+)\s*\(", str(g))
    base = m.group(1) if m else str(g)
    return _normalize_ascii(base)

def _sort_key(g: str):
    sym = _symbol_for_sort(g)
    parts = _natural_parts(sym)
    is_bambu = str(g).lower().startswith("bambugene")
    # Put non-Bambu first, then by natural parts, then full normalized name as tiebreaker
    return (is_bambu, parts, sym)

genes_sorted = sorted(list(gene_totals.index), key=_sort_key)

# (optional) quick sanity peek
print("First 20, sorted:", genes_sorted[:20])

## PLOTTING (PER GENE)

from matplotlib.backends.backend_pdf import PdfPages
saved_pages = []

with PdfPages(OUT_COMBINED_PDF) as bigpdf:
    for gene in genes_sorted:
        sub = df_counts[df_counts["Gene"] == gene].copy()
        if sub.empty:
            continue

        # transcript order by descending unique counts
        tx_order = (
            sub.groupby("TranscriptLabel")["unique_counts"]
               .sum().sort_values(ascending=False).index.tolist()
        )
        if not tx_order:
            continue

        # --- COMPLETE TO FULL GRID (TranscriptLabel × cell_type) with zeros ---
        sub_grid = (
            sub.groupby(["TranscriptLabel", "cell_type"], as_index=False)["unique_counts"]
               .sum()
        )

        # Build full index across ALL requested categories (even if zero)
        full_idx = pd.MultiIndex.from_product([tx_order, ct_order],
                                              names=["TranscriptLabel", "cell_type"])
        sub_grid = (
            sub_grid.set_index(["TranscriptLabel", "cell_type"])
                    .reindex(full_idx, fill_value=0)
                    .reset_index()
        )

        # Categorical ordering for seaborn
        sub_grid["TranscriptLabel"] = pd.Categorical(sub_grid["TranscriptLabel"],
                                                     categories=tx_order, ordered=True)
        sub_grid["cell_type"] = pd.Categorical(sub_grid["cell_type"],
                                               categories=ct_order, ordered=True)

        # figure size scales with # transcripts
        n_tx = len(tx_order)
        fig_w = max(6, min(0.4 * n_tx + 2, 18))
        fig_h = 4
        fig, ax = plt.subplots(figsize=(fig_w, fig_h))

        sns.barplot(
            data=sub_grid,
            x="TranscriptLabel", y="unique_counts", hue="cell_type",
            order=tx_order, hue_order=ct_order,
            palette=palette, edgecolor="black",
            estimator=sum, errorbar=None, ax=ax  # explicit identity-like behavior
        )

        ax.set_xlabel("Transcripts")
        ax.set_ylabel("Unique Counts (sum over cells)")
        ax.set_title(gene)

        # Rotate + right-align x labels
        for lbl in ax.get_xticklabels():
            lbl.set_rotation(45)
            lbl.set_horizontalalignment("right")

        sns.despine(ax=ax, top=True, right=True)
        ax.legend(title="Cell type", bbox_to_anchor=(1.01, 1), loc="upper left", frameon=True)
        plt.tight_layout()

        # Save individual page (vector) and append to combined
        out_path = os.path.join(OUT_DIR_PER_GENE, f"{safe_name(gene)}.pdf")
        fig.savefig(out_path, dpi=600, transparent=True, bbox_inches="tight")
        saved_pages.append(out_path)
        bigpdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)

print(f"[info] Wrote {len(saved_pages)} per-gene PDFs to {OUT_DIR_PER_GENE}")
print(f"[info] Combined multi-page PDF: {OUT_COMBINED_PDF}")

In [None]:
### Bulk-by-cell-type bar charts of RAW COUNTS per transcript, split by gene
# One PDF per gene (all its transcripts in one panel), plus a combined multi-page PDF.

import os, re, unicodedata
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from scipy import sparse
import polars as pl
from collections import Counter, defaultdict

## CONFIG / UTILS

OUT_DIR_PER_GENE = "Intermediate_Files/Unique_Count_Analyses/plots/bulk_by_celltype_per_gene_RAWCOUNTS"
OUT_COMBINED_PDF = "Intermediate_Files/Unique_Count_Analyses/plots/bulk_by_celltype_all_genes_RAWCOUNTS.pdf"
os.makedirs(OUT_DIR_PER_GENE, exist_ok=True)

SCALE_PER = None  # keep None for raw totals

def safe_name(s: str) -> str:
    s = str(s)
    s = unicodedata.normalize("NFKD", s)
    s = re.sub(r"[^\w\-(). ]+", "_", s)
    s = re.sub(r"\s+", "_", s).strip("_")
    return s[:200]

ct_order = ["TCells", "NK Cells", "BCells", "Monocyte-derived", "Megakaryocyte"]
if 'cell_type_colors' in globals():
    palette = {ct: cell_type_colors.get(ct, "#999999") for ct in ct_order}
else:
    pal = sns.color_palette("tab10", n_colors=len(ct_order))
    palette = {ct: pal[i] for i, ct in enumerate(ct_order)}

## MARKER & NOVEL GENE SETS

# ENSG -> symbol map
ensg_to_symbol = defaultdict(Counter)
for v in map(str, adata_i_filtered.var_names):
    parts = v.split(":")
    if len(parts) == 3:
        sym, ensg, _ = parts
        if ensg.startswith("ENSG"):
            ensg_to_symbol[ensg][sym] += 1
ensg_primary_symbol = {ensg: cnt.most_common(1)[0][0] for ensg, cnt in ensg_to_symbol.items()}

# marker symbols
marker_set = {
    "CD3D","CD3E","CD3G", # T cell markers
    "CD8A","CD8B","GATA3","KLRB1","CCL5",   # CD8 Effector T cell markers
    "CD4","IL2RA","AHR","TNF",   #CD4 Effector T cell markers
    "CCR7","SELL","TCF7",    # Memory T cell markers
    "ITGAE","LEF1","CTLA4","IL7R","CD27",   # Transition T cell markers
    "GZMB","KLRF1","NCAM1","ITGAM","IL2RB",   # NK Cell markers
    "CD22","CD79A","MS4A1","CD19",     # B cell markers
    "FCGR2A","CLEC7A","CD33","LILRB4",   # Monocyte-derived cell markers
    "GP1BA","MPL","ITGA2B"    # Megakaryocyte markers
}
marker_ensg = set()
for v in map(str, adata_i_filtered.var_names):
    parts = v.split(":")
    if len(parts) == 3:
        sym, ensg, _ = parts
        if sym in marker_set and ensg.startswith("ENSG"):
            marker_ensg.add(ensg)
print(f"[info] marker ENSGs resolved: {len(marker_ensg)}")

# find ENSGs with ≥1 novel (BambuTx) transcript
def ensgs_with_bambu(df: pl.DataFrame) -> set[str]:
    out = set()
    for c in df.columns[1:]:
        if "|" not in c:
            continue
        left, right = c.split("|", 1)
        if left.startswith("BambuTx") and right.startswith("ENSG"):
            out.add(right)
    return out

ensg_with_novel_any = ensgs_with_bambu(pbmc1) | ensgs_with_bambu(pbmc2)
print(f"[info] ENSGs with ≥1 novel transcript: {len(ensg_with_novel_any)}")

keep_ensgs = marker_ensg | ensg_with_novel_any

## DATA PREP

adata = adata_i_filtered
X = adata.layers["counts"] if "counts" in adata.layers else adata.X
X = X.tocsr() if sparse.issparse(X) else np.asarray(X)

ct_series = adata.obs["gen_cell_type"].astype(str)
ct_codes = pd.Categorical(ct_series, categories=ct_order, ordered=True).codes
valid = ct_codes >= 0
n_cells, n_vars, n_ct = adata.n_obs, adata.n_vars, len(ct_order)

rows = np.flatnonzero(valid)
cols = ct_codes[valid]
data = np.ones(rows.shape[0], dtype=np.int8)
G = sparse.csr_matrix((data, (rows, cols)), shape=(n_cells, n_ct))

if sparse.issparse(X):
    X_csr = X.tocsr()
    counts_fc = (G.T @ X_csr).T.toarray()
else:
    counts_fc = X.T @ G.toarray()

var = adata.var.copy()
var_names = pd.Index(adata.var_names).astype(str)

# gene_id (ENSG)
if "gene_id" in var.columns:
    gene_id = var["gene_id"].astype(str).values
elif "ensembl_gene_id" in var.columns:
    gene_id = var["ensembl_gene_id"].astype(str).values
else:
    def _gene_from_varname(v):
        p = str(v).split(":")
        return p[1] if len(p) >= 2 else v
    gene_id = np.array([_gene_from_varname(v) for v in var_names], dtype=object)

# gene_symbol
if "gene_name" in var.columns:
    gene_symbol = var["gene_name"].astype(str).values
elif "symbol" in var.columns:
    gene_symbol = var["symbol"].astype(str).values
else:
    def _sym_from_varname(v):
        p = str(v).split(":")
        return p[0] if len(p) >= 1 else ""
    gene_symbol = np.array([_sym_from_varname(v) for v in var_names], dtype=object)

# transcript label
label_cols = [c for c in ["transcript_id","tx_id","transcript","isoform_id","feature_id","name","id"] if c in var.columns]
def _tx_label(i):
    for c in label_cols:
        v = str(var.iloc[i][c])
        if v and v != "None": return v
    return str(var_names[i])
tx_label = np.array([_tx_label(i) for i in range(n_vars)], dtype=object)

tx_label_display = np.where(gene_symbol != "",
                            [f"{tx_label[i]}\n{gene_symbol[i]}" for i in range(n_vars)],
                            tx_label)

gene_display = np.where(gene_symbol != "",
                        [f"{gene_symbol[i]} ({gene_id[i]})" if gene_id[i] else gene_symbol[i] for i in range(n_vars)],
                        gene_id)

counts_df = pd.DataFrame(counts_fc, columns=ct_order)
counts_df.insert(0, "TranscriptLabel", tx_label_display.tolist())
counts_df.insert(1, "Gene", gene_display.tolist())
counts_df.insert(2, "ENSG", gene_id.tolist())

df_counts = counts_df.melt(
    id_vars=["TranscriptLabel","Gene","ENSG"],
    var_name="cell_type",
    value_name="raw_counts"
)

# filter to only marker or novel ENSGs
df_counts = df_counts[df_counts["ENSG"].isin(keep_ensgs)]

## SORTING

def _normalize_ascii(s: str) -> str:
    return unicodedata.normalize("NFKD", s).encode("ascii","ignore").decode("ascii").strip().lower()
def _nat_parts(s: str):
    parts = re.split(r"(\d+)", s)
    return [int(p) if p.isdigit() else p for p in parts]
def _sym_for_sort(g: str) -> str:
    m = re.match(r"\s*([^(]+)\s*\(", str(g))
    return _normalize_ascii(m.group(1) if m else g)
def _sort_key(g: str):
    sym = _sym_for_sort(g)
    return (str(g).lower().startswith("bambugene"), _nat_parts(sym), sym)

genes_sorted = sorted(df_counts["Gene"].dropna().unique().tolist(), key=_sort_key)

## PLOTTING

saved_pages = []
ylab = "Total counts" if SCALE_PER is None else f"Counts per {SCALE_PER:,} cells"

with PdfPages(OUT_COMBINED_PDF) as bigpdf:
    for gene in genes_sorted:
        sub = df_counts[df_counts["Gene"] == gene]
        if sub.empty:
            continue

        txs = sorted(sub["TranscriptLabel"].unique().tolist())
        full_idx = pd.MultiIndex.from_product([txs, ct_order], names=["TranscriptLabel","cell_type"])
        sub_grid = (sub.groupby(["TranscriptLabel","cell_type"], observed=True)["raw_counts"]
                        .sum()
                        .reindex(full_idx, fill_value=0)
                        .reset_index())

        tx_order = (sub_grid.groupby("TranscriptLabel", observed=True)["raw_counts"]
                               .sum().sort_values(ascending=False).index.tolist())

        sub_grid["TranscriptLabel"] = pd.Categorical(sub_grid["TranscriptLabel"], categories=tx_order, ordered=True)
        sub_grid["cell_type"] = pd.Categorical(sub_grid["cell_type"], categories=ct_order, ordered=True)

        n_tx = len(tx_order)
        fig_w = max(6, min(0.4 * n_tx + 2, 18))
        fig_h = max(4, min(0.28 * n_tx + 2, 12))
        fig, ax = plt.subplots(figsize=(fig_w, fig_h), constrained_layout=True)

        bottoms = np.zeros(len(tx_order), dtype=float)
        for ct in ct_order:
            vals = sub_grid.loc[sub_grid["cell_type"] == ct, ["TranscriptLabel","raw_counts"]]
            vals = vals.set_index("TranscriptLabel").reindex(tx_order)["raw_counts"].to_numpy()
            ax.bar(tx_order, vals, bottom=bottoms,
                   color=palette.get(ct, "#999999"), edgecolor="black", label=ct)
            bottoms += vals

        ax.set_xlabel("Transcripts")
        ax.set_ylabel(ylab)
        ax.set_title(gene)

        ax.tick_params(axis="x", labelsize=8, rotation=40)
        for lbl in ax.get_xticklabels():
            lbl.set_horizontalalignment("right")

        sns.despine(ax=ax, top=True, right=True)
        ax.legend(title="Cell type", bbox_to_anchor=(1.01, 1), loc="upper left", frameon=True)

        out_path = os.path.join(OUT_DIR_PER_GENE, f"{safe_name(gene)}.pdf")
        fig.savefig(out_path, dpi=600, transparent=True, bbox_inches="tight")
        saved_pages.append(out_path)
        bigpdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)

print(f"[info] Wrote {len(saved_pages)} per-gene PDFs to {OUT_DIR_PER_GENE}")
print(f"[info] Combined multi-page PDF: {OUT_COMBINED_PDF}")

In [None]:
### Bulk-by-sub-cell-type bar charts of UNIQUE COUNTS per transcript, split by gene 
# One PDF per gene (all its transcripts in one panel), plus a combined multi-page PDF.

import os
import re
import unicodedata
import numpy as np
import polars as pl
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

## CONFIG / UTILS
OUT_DIR_PER_GENE = "Intermediate_Files/Unique_Count_Analyses/plots/bulk_by_SUBcelltype_per_gene"
OUT_COMBINED_PDF = "Intermediate_Files/Unique_Count_Analyses/plots/bulk_by_SUBcelltype_all_genes.pdf"
os.makedirs(OUT_DIR_PER_GENE, exist_ok=True)

# Desired sub-cell-type order (exact)
sub_ct_order = [
    "Effector CD4 TCells",
    "Cytotoxic TCells",
    "Memory TCells",
    "Effector-Memory Transition TCells",
    "NK Cells",
    "B Cells",
    "Monocyte-derived",
    "Megakaryocytes",
]

def safe_name(s: str) -> str:
    s = str(s)
    s = unicodedata.normalize("NFKD", s)
    s = re.sub(r"[^\w\-(). ]+", "_", s)
    s = re.sub(r"\s+", "_", s).strip("_")
    return s[:200]

def parse_tokens(colname: str):
    parts = str(colname).split("|")
    left = parts[0] if parts else colname
    gene_token = parts[1] if len(parts) >= 2 else ""
    symbol = parts[2] if len(parts) >= 3 and parts[2] else ""
    return left, gene_token, symbol

def gene_display(colname: str) -> str:
    left, gene_token, symbol = parse_tokens(colname)
    if symbol:
        return f"{symbol} ({gene_token})"
    return gene_token if gene_token else left

def tx_display(colname: str) -> str:
    left, gene_token, symbol = parse_tokens(colname)
    return left if not symbol else f"{left}\n{symbol}"

# ============================ PREP DATA ============================

def transcript_columns(df: pl.DataFrame) -> list[str]:
    return [c for c in df.columns[1:]]

def ensure_cols(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    id_col = df.columns[0]
    missing = [c for c in cols if c not in df.columns]
    if missing:
        df = df.hstack([pl.Series(c, [0]*df.height) for c in missing])
    return df.select([id_col] + cols)

cols1 = transcript_columns(pbmc1_filt)
cols2 = transcript_columns(pbmc2_filt)
all_cols = sorted(set(cols1) | set(cols2))
pbmc1b = ensure_cols(pbmc1_filt, all_cols)
pbmc2b = ensure_cols(pbmc2_filt, all_cols)

# Attach *sub_cell_type* to each cell row
subtype_series = adata_i_filtered.obs["sub_cell_type"].astype(str)
subtype_map = subtype_series.to_dict()

def attach_subtype(df: pl.DataFrame) -> pl.DataFrame:
    df = df.rename({df.columns[0]: "CellID"})
    keys = list(subtype_map.keys())
    vals = [subtype_map.get(k, None) for k in keys]
    ct_df = pl.DataFrame({"CellID": keys, "sub_cell_type": vals})
    return df.join(ct_df, on="CellID", how="left").filter(pl.col("sub_cell_type").is_not_null())

pbmc1c = attach_subtype(pbmc1b)
pbmc2c = attach_subtype(pbmc2b)

# Sum UNIQUE COUNTS per transcript × sub-cell-type (PBMC1 + PBMC2)
def per_tx_subct_unique_counts(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    long = df.unpivot(
        index=["CellID", "sub_cell_type"],
        on=cols,
        variable_name="Transcript",
        value_name="uc"
    )
    return (
        long
        .group_by(["Transcript", "sub_cell_type"])
        .agg(pl.col("uc").sum().alias("unique_counts"))
    )

counts1 = per_tx_subct_unique_counts(pbmc1c, all_cols)
counts2 = per_tx_subct_unique_counts(pbmc2c, all_cols)
counts_all = (
    pl.concat([counts1, counts2])
    .group_by(["Transcript", "sub_cell_type"])
    .agg(pl.col("unique_counts").sum().alias("unique_counts"))
)

df_counts = counts_all.to_pandas()

# Force your exact order and keep all categories (missing ones will be filled with 0 later)
ct_order = sub_ct_order[:]

# Colors: use sub_cell_type_colors if present; else build one
if 'sub_cell_type_colors' in globals():
    palette = {ct: sub_cell_type_colors.get(ct, "#999999") for ct in ct_order}
else:
    pal = sns.color_palette("tab20", n_colors=len(ct_order))
    palette = {ct: pal[i] for i, ct in enumerate(ct_order)}

# Map Transcript -> Gene group
tx_to_gene = {t: gene_display(t) for t in all_cols}
df_counts["Gene"] = df_counts["Transcript"].map(tx_to_gene)
df_counts["TranscriptLabel"] = df_counts["Transcript"].map(tx_display)

# --- Robust alphabetical ordering by preferred gene symbol ---
def _normalize_ascii(s: str) -> str:
    s = unicodedata.normalize("NFKD", s)
    s = s.encode("ascii", "ignore").decode("ascii")
    return s.strip().lower()

def _natural_parts(s: str):
    parts = re.split(r"(\d+)", s)
    return [int(p) if p.isdigit() else p for p in parts]

def _symbol_for_sort(g: str) -> str:
    m = re.match(r"\s*([^(]+)\s*\(", str(g))
    base = m.group(1) if m else str(g)
    return _normalize_ascii(base)

def _sort_key(g: str):
    sym = _symbol_for_sort(g)
    parts = _natural_parts(sym)
    is_bambu = str(g).lower().startswith("bambugene")
    return (is_bambu, parts, sym)

genes_sorted = sorted(df_counts["Gene"].unique(), key=_sort_key)

## PLOTTING (PER GENE)

saved_pages = []

with PdfPages(OUT_COMBINED_PDF) as bigpdf:
    for gene in genes_sorted:
        sub = df_counts[df_counts["Gene"] == gene].copy()
        if sub.empty:
            continue

        # transcript order by descending unique counts
        tx_order = (
            sub.groupby("TranscriptLabel")["unique_counts"]
               .sum().sort_values(ascending=False).index.tolist()
        )
        if not tx_order:
            continue

        # complete to full grid (TranscriptLabel × sub_cell_type) with zeros
        sub_grid = (
            sub.groupby(["TranscriptLabel", "sub_cell_type"], as_index=False)["unique_counts"]
               .sum()
        )
        full_idx = pd.MultiIndex.from_product([tx_order, ct_order],
                                              names=["TranscriptLabel", "sub_cell_type"])
        sub_grid = (
            sub_grid.set_index(["TranscriptLabel", "sub_cell_type"])
                    .reindex(full_idx, fill_value=0)
                    .reset_index()
        )

        sub_grid["TranscriptLabel"] = pd.Categorical(sub_grid["TranscriptLabel"],
                                                     categories=tx_order, ordered=True)
        sub_grid["sub_cell_type"] = pd.Categorical(sub_grid["sub_cell_type"],
                                                   categories=ct_order, ordered=True)

        n_tx = len(tx_order)
        fig_w = max(6, min(0.4 * n_tx + 2, 18))
        fig_h = 4
        fig, ax = plt.subplots(figsize=(fig_w, fig_h))

        sns.barplot(
            data=sub_grid,
            x="TranscriptLabel", y="unique_counts", hue="sub_cell_type",
            order=tx_order, hue_order=ct_order,
            palette=palette, edgecolor="black",
            estimator=sum, errorbar=None, ax=ax
        )

        ax.set_xlabel("Transcripts")
        ax.set_ylabel("Unique Counts (sum over cells)")
        ax.set_title(gene)

        for lbl in ax.get_xticklabels():
            lbl.set_rotation(45)
            lbl.set_horizontalalignment("right")

        sns.despine(ax=ax, top=True, right=True)
        ax.legend(title="Sub cell type", bbox_to_anchor=(1.01, 1), loc="upper left", frameon=True)
        plt.tight_layout()

        out_path = os.path.join(OUT_DIR_PER_GENE, f"{safe_name(gene)}.pdf")
        fig.savefig(out_path, dpi=600, transparent=True, bbox_inches="tight")
        saved_pages.append(out_path)
        bigpdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)

print(f"[info] Wrote {len(saved_pages)} per-gene PDFs to {OUT_DIR_PER_GENE}")
print(f"[info] Combined multi-page PDF: {OUT_COMBINED_PDF}")

In [None]:
### Bulk-by-sub-cell-type bar charts of RAW COUNTS per transcript, split by gene (FILTERED)
# Uses pbmc1_filt / pbmc2_filt to:
#   (1) detect ENSGs with >1 novel BambuTx, and
#   (2) backfill ENSG for Bambu-only features in adata.var.
# Pages are grouped by ENSG (GeneKey). Titles show SYMBOL (ENSG).
# Keeps ONLY: marker ENSGs OR ENSGs with >1 novel (BambuTx). RAW totals (no normalization).

import os, re, unicodedata
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from scipy import sparse
import polars as pl
from collections import Counter, defaultdict

## CONFIG / UTILS

OUT_DIR_PER_GENE = "Intermediate_Files/Unique_Count_Analyses/plots/bulk_by_subcelltype_per_gene_RAWCOUNTS"
OUT_COMBINED_PDF = "Intermediate_Files/Unique_Count_Analyses/plots/bulk_by_subcelltype_all_genes_RAWCOUNTS.pdf"
os.makedirs(OUT_DIR_PER_GENE, exist_ok=True)

SCALE_PER = None  # RAW totals

def safe_name(s: str) -> str:
    s = str(s)
    s = unicodedata.normalize("NFKD", s)
    s = re.sub(r"[^\w\-(). ]+", "_", s)
    s = re.sub(r"\s+", "_", s).strip("_")
    return s[:200]

# ---- Robust token parser: split on both '|' and ':' and detect by prefix ----
def split_tokens(name: str):
    s = str(name)
    s2 = s.replace("|", ":")
    return [p.strip() for p in s2.split(":") if p.strip()]

def parse_tokens_robust(name: str):
    """Return (tx_id, ensg, symbol_guess) by scanning tokens."""
    parts = split_tokens(name)
    tx_id = ""
    ensg  = ""
    symbol_guess = ""
    for p in parts:
        if not tx_id and (p.startswith("BambuTx") or p.startswith("ENST")):
            tx_id = p
        if not ensg and p.startswith("ENSG"):
            ensg = p
    for p in parts:
        if p.startswith(("BambuTx","ENST","ENSG","BambuGene")):
            continue
        if 1 <= len(p) <= 20 and " " not in p:
            symbol_guess = p
            break
    if not tx_id:
        tx_id = parts[0] if parts else str(name)
    return tx_id, ensg, symbol_guess

def tx_display_from_name(name: str) -> str:
    tx, ensg, sym = parse_tokens_robust(name)
    return tx if not sym else f"{tx}\n{sym}"

def _ensg_core(x):
    x = str(x) if x is not None else ""
    x = x.split("|", 1)[0].strip()  # drop trailing |SYMBOL if any
    x = x.split(".", 1)[0].strip()  # drop version suffix like .13
    return x if x.startswith("ENSG") else ""

# normalize sub-cell-type (plug aliases if needed)
def norm_subct(s: str) -> str:
    return str(s).strip()

# palette and order for sub-cell-types
if 'sub_cell_type_colors' in globals():
    subct_order = list(sub_cell_type_colors.keys())
    palette = {ct: sub_cell_type_colors.get(ct, "#999999") for ct in subct_order}
else:
    subct_order = sorted(adata_i_filtered.obs["sub_cell_type"].astype(str).unique())
    pal = sns.color_palette("tab20", n_colors=len(subct_order))
    palette = {ct: pal[i] for i, ct in enumerate(subct_order)}

### MARKER & NOVEL GENE SETS

# ENSG -> symbol (best-effort) from var_names
ensg_to_symbol = defaultdict(Counter)
for v in map(str, adata_i_filtered.var_names):
    parts = split_tokens(v)
    sym = ""
    ensg = ""
    for p in parts:
        if not ensg and p.startswith("ENSG"): ensg = p
        if not sym and not p.startswith(("ENSG","ENST","BambuTx","BambuGene")): sym = p
    if ensg and sym:
        ensg_to_symbol[ensg][sym] += 1
ensg_primary_symbol = {e: cnt.most_common(1)[0][0] for e, cnt in ensg_to_symbol.items()}

marker_set = {
    "CD3D","CD3E","CD3G",
    "CD8A","CD8B","GATA3","KLRB1","CCL5",
    "CD4","IL2RA","AHR","TNF",
    "CCR7","SELL","TCF7",
    "ITGAE","LEF1","CTLA4","IL7R","CD27",
    "GZMB","KLRF1","NCAM1","ITGAM","IL2RB",
    "CD22","CD79A","MS4A1","CD19",
    "FCGR2A","CLEC7A","CD33","LILRB4",
    "GP1BA","MPL","ITGA2B"
}

# Marker ENSGs from var_names and var columns
marker_ensg = set()
for v in map(str, adata_i_filtered.var_names):
    parts = split_tokens(v)
    sym = ""
    ensg = ""
    for p in parts:
        if not ensg and p.startswith("ENSG"): ensg = p
        if not sym and not p.startswith(("ENSG","ENST","BambuTx","BambuGene")): sym = p
    if ensg and sym in marker_set:
        marker_ensg.add(ensg)
if {"gene_id","gene_name"}.issubset(adata_i_filtered.var.columns):
    var_sub = adata_i_filtered.var[["gene_id","gene_name"]].astype(str)
    add_ensg = set(var_sub.loc[var_sub["gene_name"].isin(marker_set), "gene_id"])
    marker_ensg |= {e for e in add_ensg if str(e).startswith("ENSG")}

# Build BambuTx -> ENSG and count novels (>1) from pbmc1_filt & pbmc2_filt
def build_bambu_maps_and_counts(*dfs: pl.DataFrame):
    tx_to_ensg = {}
    ensg_ctr = Counter()
    for df in dfs:
        for c in df.columns[1:]:
            if "|" not in c: 
                continue
            left, right = c.split("|", 1)  # left=BambuTx### or ENST..., right='ENSG...|SYMBOL' or 'ENSG...'
            if not left.startswith("BambuTx"):
                continue
            ensg = _ensg_core(right)
            if ensg:
                tx_to_ensg[left] = ensg
                ensg_ctr[ensg] += 1
    return tx_to_ensg, ensg_ctr

# NOTE: use pbmc1_filt and pbmc2_filt (not pbmc1 / pbmc2)
txid_to_ensg, novel_ctr = build_bambu_maps_and_counts(pbmc1_filt, pbmc2_filt)
ensg_with_gt1_novel = {e for e, k in novel_ctr.items() if k > 1}

keep_ensgs = marker_ensg | ensg_with_gt1_novel
print(f"[info] marker ENSGs: {len(marker_ensg)} | ENSGs with >1 novel: {len(ensg_with_gt1_novel)} | total keep: {len(keep_ensgs)}")
print(f"[info] example BambuTx→ENSG: {list(txid_to_ensg.items())[:5]}")

### DATA PREP

adata = adata_i_filtered
X = adata.layers["counts"] if "counts" in adata.layers else adata.X
X = X.tocsr() if sparse.issparse(X) else np.asarray(X)

subct_series = adata.obs["sub_cell_type"].astype(str).map(norm_subct)
codes = pd.Categorical(subct_series, categories=subct_order, ordered=True).codes
valid = codes >= 0
n_cells, n_vars, n_ct = adata.n_obs, adata.n_vars, len(subct_order)

rows = np.flatnonzero(valid)
cols = codes[valid]
data = np.ones(rows.shape[0], dtype=np.int8)
G = sparse.csr_matrix((data, (rows, cols)), shape=(n_cells, n_ct))

if sparse.issparse(X):
    X_csr = X.tocsr()
    counts_fc = (G.T @ X_csr).T.toarray()   # (features × sub-CT)
else:
    counts_fc = X.T @ G.toarray()

## LABELS & BACKFILL

var = adata.var.copy()
var_names = pd.Index(adata.var_names).astype(str)

# Robust tokens from var_names
tx_ids_from_names = np.array([parse_tokens_robust(v)[0] for v in var_names], dtype=object)
ensg_from_names   = np.array([parse_tokens_robust(v)[1] for v in var_names], dtype=object)
sym_from_names    = np.array([parse_tokens_robust(v)[2] for v in var_names], dtype=object)

# ENSG from var if present; else from names; then backfill from Bambu map; sanitize to core
if "gene_id" in var.columns:
    ensg_arr = var["gene_id"].astype(str).values
elif "ensembl_gene_id" in var.columns:
    ensg_arr = var["ensembl_gene_id"].astype(str).values
else:
    ensg_arr = ensg_from_names

ensg_filled = []
for i in range(len(ensg_arr)):
    g = _ensg_core(ensg_arr[i])
    if g:
        ensg_filled.append(g)
    else:
        # extract BambuTx if embedded (e.g., "BambuGene123:BambuTx45")
        txid = str(tx_ids_from_names[i])
        if ":" in txid and "BambuTx" in txid:
            btoks = [p for p in split_tokens(txid) if p.startswith("BambuTx")]
            txid_key = btoks[0] if btoks else txid
        else:
            txid_key = txid
        ensg_filled.append(_ensg_core(txid_to_ensg.get(txid_key, "")))

ensg_filled = np.array(ensg_filled, dtype=object)

# symbol for display (prefer var; else robust name parse; else ENSG->primary-symbol map)
if "gene_name" in var.columns:
    gene_symbol = var["gene_name"].astype(str).values
elif "symbol" in var.columns:
    gene_symbol = var["symbol"].astype(str).values
else:
    gene_symbol = sym_from_names

tx_label = np.array([tx_display_from_name(v) for v in var_names], dtype=object)
gene_display = []
for i in range(len(ensg_filled)):
    gid = ensg_filled[i] if isinstance(ensg_filled[i], str) else ""
    sym_guess = ensg_primary_symbol.get(gid, "")
    sym = sym_guess or (str(gene_symbol[i]) if gene_symbol is not None else "")
    gene_display.append(f"{sym} ({gid})" if sym and gid else (sym or gid or str(var_names[i])))

## LONG TABLE & FILTER

counts_df = pd.DataFrame(counts_fc, columns=subct_order)
counts_df.insert(0, "Transcript", var_names.tolist())
counts_df.insert(1, "TxID", tx_ids_from_names.tolist())
counts_df.insert(2, "TranscriptLabel", tx_label.tolist())
counts_df.insert(3, "GeneKey", ensg_filled.tolist())       # ENSG (backfilled)
counts_df.insert(4, "GeneDisplay", gene_display)

df_counts = counts_df.melt(
    id_vars=["Transcript","TxID","TranscriptLabel","GeneKey","GeneDisplay"],
    var_name="sub_cell_type",
    value_name="raw_counts"
)

# Effective keep = what you want AND what exists in matrix
present_gene_keys = {ek for ek in ensg_filled if isinstance(ek, str) and ek.startswith("ENSG")}
keep_ensgs_eff = { _ensg_core(k) for k in keep_ensgs } & present_gene_keys

missing_from_matrix = sorted(keep_ensgs - keep_ensgs_eff)
if missing_from_matrix:
    miss_disp = [f"{e} ({ensg_primary_symbol.get(e,'?')})" for e in missing_from_matrix[:20]]
    print(f"[info] {len(missing_from_matrix)} keep ENSGs absent from matrix; first few: {miss_disp}")

before = df_counts["GeneKey"].nunique()
df_counts = df_counts[df_counts["GeneKey"].isin(keep_ensgs_eff)].copy()
after = df_counts["GeneKey"].nunique()
print(f"[debug] unique GeneKey before filter: {before} | after effective filter: {after} (expect {len(keep_ensgs_eff)})")

## SORTING

def _normalize_ascii(s: str) -> str:
    s = str(s)
    return unicodedata.normalize("NFKD", s).encode("ascii","ignore").decode("ascii").strip().lower()
def _nat_parts(s: str):
    s = str(s); parts = re.split(r"(\d+)", s)
    return [int(p) if p.isdigit() else p for p in parts]
def _sym_for_sort(g: str) -> str:
    m = re.match(r"\s*([^(]+)\s*\(", str(g))
    return _normalize_ascii(m.group(1) if m else g)
def _sort_key(g: str):
    sym = _sym_for_sort(g)
    return (str(g).lower().startswith("bambugene"), _nat_parts(sym), sym)

ensg_to_display = (df_counts[["GeneKey","GeneDisplay"]]
                   .drop_duplicates()
                   .set_index("GeneKey")["GeneDisplay"]
                   .to_dict())

gene_keys_sorted = sorted(keep_ensgs_eff, key=lambda k: _sort_key(ensg_to_display.get(k, k)))
print(f"[debug] pages to write: {len(gene_keys_sorted)} (should be {len(keep_ensgs_eff)})")

## PLOTTING

saved_pages = []
ylab = "Total counts" if SCALE_PER is None else f"Counts per {SCALE_PER:,} cells"

with PdfPages(OUT_COMBINED_PDF) as bigpdf:
    for gene_key in gene_keys_sorted:
        title = ensg_to_display.get(gene_key, str(gene_key))
        sub = df_counts[df_counts["GeneKey"] == gene_key]
        if sub.empty:
            continue

        txs = sorted(sub["TranscriptLabel"].unique().tolist())
        full_idx = pd.MultiIndex.from_product([txs, subct_order], names=["TranscriptLabel","sub_cell_type"])
        sub_grid = (sub.groupby(["TranscriptLabel","sub_cell_type"], observed=True)["raw_counts"]
                        .sum()
                        .reindex(full_idx, fill_value=0)
                        .reset_index())

        tx_order = (sub.groupby("TranscriptLabel", observed=True)["raw_counts"]
                        .sum().sort_values(ascending=False).index.tolist())

        sub_grid["TranscriptLabel"]  = pd.Categorical(sub_grid["TranscriptLabel"],  categories=tx_order, ordered=True)
        sub_grid["sub_cell_type"]    = pd.Categorical(sub_grid["sub_cell_type"],    categories=subct_order, ordered=True)

        n_tx = len(tx_order)
        fig_w = max(6, min(0.4 * n_tx + 2, 18))
        fig_h = max(4, min(0.28 * n_tx + 2, 12))
        fig, ax = plt.subplots(figsize=(fig_w, fig_h), constrained_layout=True)

        # manual stacked bars by sub-cell-type
        bottoms = np.zeros(len(tx_order), dtype=float)
        for ct in subct_order:
            vals = (sub_grid.loc[sub_grid["sub_cell_type"] == ct, ["TranscriptLabel","raw_counts"]]
                            .set_index("TranscriptLabel")
                            .reindex(tx_order)["raw_counts"].to_numpy())
            ax.bar(tx_order, vals, bottom=bottoms,
                   color=palette.get(ct, "#999999"), edgecolor="black", label=ct)
            bottoms += vals

        ax.set_xlabel("Transcripts")
        ax.set_ylabel(ylab)
        ax.set_title(title)

        ax.tick_params(axis="x", labelsize=8, rotation=40)
        for lbl in ax.get_xticklabels():
            lbl.set_horizontalalignment("right")

        sns.despine(ax=ax, top=True, right=True)
        ax.legend(title="Sub-cell type", bbox_to_anchor=(1.01, 1), loc="upper left", frameon=True)

        out_path = os.path.join(OUT_DIR_PER_GENE, f"{safe_name(title)}.pdf")
        fig.savefig(out_path, dpi=600, transparent=True, bbox_inches="tight")
        saved_pages.append(out_path)
        bigpdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)

print(f"[info] Wrote {len(saved_pages)} per-gene PDFs to {OUT_DIR_PER_GENE}")
print(f"[info] Combined multi-page PDF: {OUT_COMBINED_PDF}")