In [None]:
import io
import tempfile
from anndata import AnnData
import muon as mu
import numpy as np
import requests
import os
import scanpy as sc
import scvi
import seaborn as sns
import torch
import pandas as pd
import sys
import scrublet as scr
import skimage
import pybiomart
from bioservices import BioMart
import rdata
import matplotlib.pyplot as plt
from adjustText import adjust_text
from scipy.stats import beta
import tqdm
import time
import gc
import polars as pl
import pyarrow

In [None]:
import re

# Path to your GTF file
gtf_file = "extended_annotations.gtf"

with open(gtf_file) as f:
    for line in f:
        if line.startswith("#"):
            continue
        fields = line.strip().split("\t")
        if len(fields) < 9:
            continue

        chrom, source, feature, start, end, score, strand, frame, attrs = fields

        # Only keep transcript features
        if feature != "transcript":
            continue

        # Extract transcript_id
        m = re.search(r'transcript_id "([^"]+)"', attrs)
        if not m:
            continue
        tid = m.group(1)

        # Only BambuTx transcripts
        if tid.startswith("BambuTx"):
            print(f"{tid}\t{chrom}:{start}-{end}({strand})")

In [None]:
import re
import pandas as pd

# ---------- your target list ----------
target = [
    "BambuTx21","BambuTx24","BambuTx27","BambuTx137","BambuTx7","BambuTx25","BambuTx26",
    "BambuTx32","BambuTx33","BambuTx30","BambuTx31","BambuTx138","BambuTx9","BambuTx151",
    "BambuTx36","BambuTx139","BambuTx38","BambuTx39","BambuTx40","BambuTx41","BambuTx42",
    "BambuTx37","BambuTx43","BambuTx44","BambuTx50","BambuTx47","BambuTx48","BambuTx49",
    "BambuTx134","BambuTx51","BambuTx52","BambuTx54","BambuTx56","BambuTx12","BambuTx152",
    "BambuTx57","BambuTx1","BambuTx59","BambuTx58","BambuTx143","BambuTx144","BambuTx63",
    "BambuTx61","BambuTx62","BambuTx161","BambuTx3","BambuTx67","BambuTx153","BambuTx71",
    "BambuTx154","BambuTx155","BambuTx5","BambuTx75","BambuTx77","BambuTx80","BambuTx81",
    "BambuTx146","BambuTx162","BambuTx83","BambuTx89","BambuTx157","BambuTx90","BambuTx163",
    "BambuTx147","BambuTx158","BambuTx99","BambuTx100","BambuTx105","BambuTx93","BambuTx148",
    "BambuTx165","BambuTx109","BambuTx110","BambuTx111","BambuTx115","BambuTx11","BambuTx166",
    "BambuTx113","BambuTx119","BambuTx120","BambuTx117","BambuTx13","BambuTx127","BambuTx128",
    "BambuTx149","BambuTx159","BambuTx129","BambuTx130","BambuTx131","BambuTx132","BambuTx136",
    "BambuTx18","BambuTx135","BambuTx150","BambuTx14","BambuTx15","BambuTx16","BambuTx19",
    "BambuTx106","BambuTx107","BambuTx145","BambuTx82","BambuTx95","BambuTx94","BambuTx84",
    "BambuTx22","BambuTx70","BambuTx160","BambuTx141","BambuTx86","BambuTx88","BambuTx103",
    "BambuTx122","BambuTx72","BambuTx121","BambuTx28","BambuTx29","BambuTx156","BambuTx66",
    "BambuTx79","BambuTx76","BambuTx97","BambuTx101","BambuTx102","BambuTx74","BambuTx123",
    "BambuTx17","BambuTx73","BambuTx85","BambuTx2","BambuTx142","BambuTx108","BambuTx164",
    "BambuTx87","BambuTx20","BambuTx92","BambuTx112","BambuTx34","BambuTx140","BambuTx45",
    "BambuTx46","BambuTx53","BambuTx60","BambuTx55","BambuTx91","BambuTx116","BambuTx114",
    "BambuTx10","BambuTx8","BambuTx23","BambuTx64","BambuTx69","BambuTx35","BambuTx65",
    "BambuTx133","BambuTx118","BambuTx6","BambuTx126","BambuTx68","BambuTx124","BambuTx98",
    "BambuTx78","BambuTx104","BambuTx125","BambuTx4","BambuTx96"
]
target_set = set(target)

bt_rx = re.compile(r"\bBambuTx\d+\b")

def _series_to_strings(s: pd.Series):
    """Coerce any dtype (including categorical) to strings safely."""
    return s.dropna().astype(str)

def collect_bambutx_from_anndata_like(adata_like, label=""):
    """Return a set of BambuTx IDs found in var/obs *indexes* and *columns*."""
    found = set()
    sources = {}

    # 1) from indexes
    for name, idx in (("var_names", getattr(adata_like, "var_names", None)),
                      ("obs_names", getattr(adata_like, "obs_names", None))):
        if idx is not None:
            hits = set(bt_rx.findall("\n".join(map(str, idx))))
            if hits:
                found |= hits
                sources[f"{label}.{name}"] = len(hits)

    # 2) from dataframes (.var, .obs), scanning all columns (string/categorical)
    for df_name in ("var", "obs"):
        df = getattr(adata_like, df_name, None)
        if df is None: 
            continue
        for col in df.columns:
            s = df[col]
            if s.dtype.kind in ("O", "U", "S") or str(s.dtype).startswith(("string", "category")):
                text = "\n".join(_series_to_strings(s))
                hits = set(bt_rx.findall(text))
                if hits:
                    found |= hits
                    sources[f"{label}.{df_name}.{col}"] = len(hits)

    return found, sources

def collect_from_object(obj, label):
    """
    Handles AnnData or MuData:
      - If MuData, iterates over obj.mod[...] (e.g., 'rna', 'prot', etc.)
      - If AnnData, collects directly.
    """
    all_found = set()
    all_sources = {}
    if hasattr(obj, "mod"):  # MuData
        for mod_name, ad in obj.mod.items():
            f, src = collect_bambutx_from_anndata_like(ad, label=f"{label}.mod[{mod_name}]")
            all_found |= f
            all_sources.update(src)
    else:  # AnnData
        f, src = collect_bambutx_from_anndata_like(obj, label=label)
        all_found |= f
        all_sources.update(src)
    return all_found, all_sources

# ---------- run over your two datasets ----------
present_i, sources_i = collect_from_object(adata_i_filtered, "iso")

missing = sorted(target_set - present_i, key=lambda x: int(x.replace("BambuTx","")))
present = sorted(target_set & present_i, key=lambda x: int(x.replace("BambuTx","")))

print(f"Found in iso object : {len(present_i)} unique BambuTx (from {len(sources_i)} sources)")
print(f"Of your {len(target_set)} targets: present={len(present)}, missing={len(missing)}\n")

if sources_i:
    print("Top iso sources (up to 10):")
    for k in list(sources_i)[:10]:
        print(f"  {k}: {sources_i[k]} hits")

print("\nMissing list:")
for tx in missing:
    print(tx)

In [None]:
from scanpy import read_h5ad

output_dir = 'Intermediate_Files/Clustering/'

# Load the mdata object from the file
adata_g_filtered = read_h5ad(os.path.join(output_dir, "PBMC_gene_AutoZI_clustered_celltypes_reannotated_AutoZILatent.h5mu"))
adata_i_filtered = read_h5ad(os.path.join(output_dir, "PBMC_iso_AutoZI_clustered_celltypes_reannotated_AutoZILatent.h5mu"))

In [None]:
adata_g_filtered.X = adata_g_filtered.layers["log_denoised"]
adata_i_filtered.X = adata_i_filtered.layers["log_denoised"]

In [None]:
# Function to find matching genes in var_names (combined_ID format)
def find_matching_genes(prefixes, gene_list):
    return [gene for gene in gene_list if any(gene.startswith(prefix) for prefix in prefixes)]

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# ---- Define marker genes
marker_genes = {
    "T cells": ["CD3D:", "CD3E:", "CD3G:"],
    "CD8 Effector": ["CD8A:", "CD8B:", "GATA3:", "KLRB1:", "CCL5:"],
    "CD4 Effector": ["CD4:", "IL2RA:", "GATA3:", "AHR:", "TNF:"],
    "Memory": ["CCR7:", "SELL:", "TCF7:"],
    "Transition": ["ITGAE:", "LEF1:", "TCF7:", "IL7R:", "GATA3:", "IL2RA", "CTLA4", "CD27:"],
    "Natural Killer": ["GZMB:", "KLRF1:", "NCAM1:", "ITGAM:", "IL2RB:"],
    "B cell": ["CD22:", "CD79A:", "MS4A1:", "CD19:"],
    "Monocyte-derived": ["FCGR2A:", "CLEC7A:", "CD33:", "LILRB4:"],
    "Megakaryocyte": ["ITGA2B:", "MPL:", "GP1BA:"]
}

In [None]:
### Full cell-type marker isoform dotplot (not abbreviated)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Config
marker_groups_to_include = ["T cells", "Natural Killer", "B cell", "Monocyte-derived", "Megakaryocyte"]
celltype_order = ["T cells", "NK cells", "B cells", "Monocyte-derived", "Megakaryocytes"]

## Build isoform → group mapping & collect isoforms present
gene_to_group = {}
matched_isoforms = []
for group, genes in marker_genes.items():
    if group not in marker_groups_to_include:
        continue
    matches = find_matching_genes(genes, adata_i_filtered.var_names)  # isoform IDs like "GENE:ENSG:ENST"
    matched_isoforms.extend(matches)
    for iso in matches:
        gene_to_group[iso] = group

if not matched_isoforms:
    raise ValueError("No marker isoforms found for the requested groups.")

## Extract Expression matrix (cells × isoforms)
X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()
iso_expr = pd.DataFrame(
    X,
    index=adata_i_filtered.obs["gen_cell_type"].astype(str),
    columns=matched_isoforms
)

## Per-cluster metrics on isoforms
# mean expression per cluster
avg_expr_iso = iso_expr.groupby(iso_expr.index, observed=True).mean()
# % expressing (>0) per cluster
pct_expr_iso = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

# keep clusters in desired order
present = [ct for ct in celltype_order if ct in avg_expr_iso.index]
avg_expr_iso = avg_expr_iso.reindex(present)
pct_expr_iso = pct_expr_iso.reindex(present)

## Sort isoforms by (group_rank, gene, isoform)
isoform_to_gene = {iso: iso.split(":")[0] for iso in avg_expr_iso.columns}
meta = pd.DataFrame({
    "isoform": avg_expr_iso.columns,
    "gene": [isoform_to_gene[i] for i in avg_expr_iso.columns],
    "group": [gene_to_group.get(i, "") for i in avg_expr_iso.columns],
})
meta["group_rank"] = meta["group"].apply(lambda g: marker_groups_to_include.index(g) if g in marker_groups_to_include else 999)
meta = meta.sort_values(["group_rank", "gene", "isoform"])
ordered_isoforms = meta["isoform"].tolist()

# reorder matrices
avg_expr_iso = avg_expr_iso[ordered_isoforms]
pct_expr_iso = pct_expr_iso[ordered_isoforms]

## Mean-center per isoform (white = average across clusters for that isoform)
avg_centered = avg_expr_iso - avg_expr_iso.mean(axis=0)

# short x labels: GENE:last 6 of transcript id (if present)
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        return f"{gene}:{enst[-6:]}"
    return iso
xlabels = [short_label(i) for i in ordered_isoforms]

## Long format for plotting
avg_long = avg_centered.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="avg_c")
pct_long = pct_expr_iso.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="pct")
plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"])

# enforce display order (seaborn draws bottom→top; we’ll invert)
plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
plot_long["iso"] = pd.Categorical(plot_long["iso"], categories=ordered_isoforms, ordered=True)
plot_long = plot_long.sort_values(["cell_type", "iso"])

## Create Dot plot
fig, ax = plt.subplots(figsize=(max(12, 0.42*len(ordered_isoforms)), max(2.6, 0.45*len(present))))
size_min, size_max = 18, 380

# robust symmetric color range around 0
rng = float(np.nanpercentile(np.abs(avg_centered.values), 99)) or 1.0
norm = TwoSlopeNorm(vcenter=0, vmin=-rng, vmax=rng)

sns.scatterplot(
    data=plot_long,
    x="iso", y="cell_type",
    hue="avg_c", size="pct",
    palette="coolwarm", hue_norm=norm,
    sizes=(size_min, size_max),
    edgecolor="black", linewidth=0.3,
    legend=False, ax=ax
)

# put first category at top; padding for big dots
# enforce display order exactly as in celltype_order
plot_long["cell_type"] = pd.Categorical(
    plot_long["cell_type"], 
    categories=present,  # same order as celltype_order subset
    ordered=True
)
ax.margins(y=0.12, x=0)
ax.set_xlim(-0.5, len(ordered_isoforms) - 0.5)

# swap tick labels to the short ones (same order as categories)
ax.set_xticklabels(xlabels, rotation=90, ha="right", fontsize=7)

# Right side: colorbar + % expressing legend 
divider = make_axes_locatable(ax)

# Colorbar
cax = divider.append_axes("right", size="3%", pad=0.18)
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, cax=cax)
cbar.set_label("Mean-centered expression", rotation=90, labelpad=8)

# Size legend under colorbar
lax = divider.append_axes("right", size="10%", pad=0.28)
lax.set_axis_off(); lax.set_xlim(0, 1); lax.set_ylim(0, 1)
levels = [5, 25, 50, 75, 100]  # keep consistent across plots
lax.text(0.5, 0.98, "% expressing", ha="center", va="top", fontsize=10)
ypos = np.linspace(0.80, 0.20, len(levels))
for y, lv in zip(ypos, levels):
    sz = np.interp(lv, [0, 100], [size_min, size_max])  # legend independent of data range
    lax.plot([0.35], [y], 'o', ms=np.sqrt(sz), markeredgecolor='black', markerfacecolor='lightgray')
    lax.text(0.60, y, f"{lv}%", va="center", ha="left", fontsize=10)

# Cosmetics
ax.set_xlabel(""); ax.set_ylabel("")
for lab in ax.get_xticklabels():
    lab.set_rotation(45); lab.set_ha('right')
plt.yticks(rotation=0)

plt.tight_layout(rect=(0.02, 0, 0.86, 1))  # slimmer left, reserve right panel
plt.savefig("Intermediate_Files/Paper_Figs/Markers/celltype_marker_isoform_dotplot.pdf",
            dpi=600, transparent=True, bbox_inches="tight")
plt.show()

In [None]:
### Abbreviated cell-type marker dotplot with heatmap (Fig 3c)

import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Config
marker_groups_to_include = ["T cells", "Natural Killer", "B cell", "Monocyte-derived", "Megakaryocyte"]
celltype_order = ["T cells", "NK cells", "B cells", "Monocyte-derived", "Megakaryocytes"]

## Only include these transcripts, in this exact order
# These represent the 3 most highly expressed transcripts, by raw counts
# As well as any marker isoforms for unexpected groups
desired_transcripts = {
    "CD3D": ["ENST00000300692", "ENST00000392884", "ENST00000526561"], # T cell marker
    "CD3E": ["ENST00000361763", "ENST00000528600", "ENST00000526146"], # T cell marker
    "CD3G": ["ENST00000532917", "ENST00000527777", "ENST00000292144", "ENST00000392883"], # T cell marker
    "GZMB": ["ENST00000216341", "ENST00000415355", "ENST00000554242", "ENST00000526004"], # NK cell marker
    "IL2RB": ["ENST00000698890", "ENST00000698902", "ENST00000216223"], # NK cell marker
    "NCAM1": ["ENST00000316851", "ENST00000533073", "ENST00000621128"], # NK cell marker
    "ITGAM": ["ENST00000648685", "ENST00000544665", "ENST00000561838"], # NK cell marker
    "KLRF1": ["ENST00000617889", "ENST00000617793", "ENST00000279545"], # NK cell marker
    "CD19": ["ENST00000538922", "ENST00000324662", "ENST00000565089"], # B cell marker
    "CD22": ["ENST00000085219", "ENST00000596492", "ENST00000536635"], # B cell marker
    "CD79A": ["ENST00000221972", "ENST00000444740", "ENST00000597454"], # B cell marker
    "MS4A1": ["ENST00000389939", "ENST00000345732", "ENST00000674194"], # B cell marker
    "CD33": ["ENST00000262262", "ENST00000600557", "ENST00000598473"], # Monocyte-derived cell marker
    "CLEC7A": ["ENST00000304084", "ENST00000310002", "ENST00000534609"], # Monocyte-derived cell marker
    "FCGR2A": ["ENST00000699279", "ENST00000271450"], # Monocyte-derived cell marker
    "LILRB4": ["ENST00000695418"], # Monocyte-derived cell marker
    "GP1BA": ["ENST00000329125"], # Megakaryocyte marker
    "ITGA2B": ["ENST00000262407", "ENST00000648408", "ENST00000587295"], # Megakaryocyte marker
    "MPL": ["ENST00000372470"], # Megakaryocyte marker
}

## Build whitelist from adata varnames (expects "GENE:ENSG:ENST")
varnames = pd.Index(adata_i_filtered.var_names.astype(str))
whitelist_isoforms = []
missing = []

def find_iso_ids(gene, enst):
    # IDs look like "GENE:ENSG:ENST". Accept any ENSG.
    return varnames[(varnames.str.startswith(f"{gene}:")) & (varnames.str.endswith(f":{enst}"))].tolist()

for gene, ensts in desired_transcripts.items():
    for enst in ensts:
        hits = find_iso_ids(gene, enst)
        if hits:
            whitelist_isoforms.extend(hits)  # if multiple, keep stable order
        else:
            missing.append(f"{gene}:{enst}")

if not whitelist_isoforms:
    raise ValueError("None of the requested transcripts were found in adata_i_filtered.var_names.")

if missing:
    print("Warning — these requested transcripts were not found and will be skipped:")
    for m in missing:
        print("  ", m)

matched_isoforms = whitelist_isoforms



## Extract Expression matrix (cells × isoforms)
X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()

iso_expr = pd.DataFrame(
    X,
    index=adata_i_filtered.obs["gen_cell_type"].astype(str),
    columns=matched_isoforms
)



## Per-cluster metrics on isoforms
avg_expr_iso = iso_expr.groupby(iso_expr.index, observed=True).mean()
pct_expr_iso = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

# Keep clusters in desired order
present = [ct for ct in celltype_order if ct in avg_expr_iso.index]
avg_expr_iso = avg_expr_iso.reindex(present)
pct_expr_iso = pct_expr_iso.reindex(present)

# Fixed isoform order = whitelist order present in matrix
ordered_isoforms = [i for i in matched_isoforms if i in avg_expr_iso.columns]

# Reorder matrices
avg_expr_iso = avg_expr_iso[ordered_isoforms]
pct_expr_iso = pct_expr_iso[ordered_isoforms]



## Mean-center per isoform (white = average across clusters for that isoform)
avg_centered = avg_expr_iso - avg_expr_iso.mean(axis=0)

# Short x labels: GENE:last 6 of transcript id (if present)
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        return f"{gene}:{enst[-6:]}"
    return iso

xlabels = [short_label(i) for i in ordered_isoforms]



## Long format for plotting
avg_long = avg_centered.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="avg_c")
pct_long = pct_expr_iso.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="pct")
plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"])

# Enforce display order
plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
plot_long["iso"] = pd.Categorical(plot_long["iso"], categories=ordered_isoforms, ordered=True)
plot_long = plot_long.sort_values(["cell_type", "iso"])



## Create Dot plot
fig, ax = plt.subplots(
    figsize=(max(12, 0.42 * len(ordered_isoforms)), max(2.6, 0.45 * len(present)))
)
size_min, size_max = 18, 380

# Robust symmetric color range around 0
vals = avg_centered.values.astype(float)
if np.all((~np.isfinite(vals)) | (vals == 0)):
    rng = 1.0
else:
    rng = float(np.nanpercentile(np.abs(vals[np.isfinite(vals)]), 99))
    if not np.isfinite(rng) or rng == 0:
        rng = 1.0

norm = TwoSlopeNorm(vcenter=0, vmin=-rng, vmax=rng)

sns.scatterplot(
    data=plot_long,
    x="iso", y="cell_type",
    hue="avg_c", size="pct",
    palette="coolwarm", hue_norm=norm,
    sizes=(size_min, size_max),
    edgecolor="black", linewidth=0.3,
    legend=False, ax=ax
)

# Layout tweaks
plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
ax.margins(y=0.12, x=0)
ax.set_xlim(-0.5, len(ordered_isoforms) - 0.5)

# Swap tick labels to the short ones (same order as categories)
ax.set_xticks(range(len(ordered_isoforms)))
ax.set_xticklabels(xlabels, rotation=90, ha="right", fontsize=7)



## Right side: colorbar + % expressing legend
divider = make_axes_locatable(ax)

# Colorbar
cax = divider.append_axes("right", size="3%", pad=0.18)
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, cax=cax)
cbar.set_label("Mean-centered expression", rotation=90, labelpad=8)

# Size legend under colorbar
lax = divider.append_axes("right", size="10%", pad=0.28)
lax.set_axis_off(); lax.set_xlim(0, 1); lax.set_ylim(0, 1)
levels = [5, 25, 50, 75, 100]  # keep consistent across plots
lax.text(0.5, 0.98, "% expressing", ha="center", va="top", fontsize=10)
ypos = np.linspace(0.80, 0.20, len(levels))
for y, lv in zip(ypos, levels):
    sz = np.interp(lv, [0, 100], [size_min, size_max])  # legend independent of data range
    lax.plot([0.35], [y], 'o', ms=np.sqrt(sz), markeredgecolor='black', markerfacecolor='lightgray')
    lax.text(0.60, y, f"{lv}%", va="center", ha="left", fontsize=10)

# Cosmetics
ax.set_xlabel(""); ax.set_ylabel("")
for lab in ax.get_xticklabels():
    lab.set_rotation(90); lab.set_ha('right')
plt.yticks(rotation=0)



## Save / show
outpath = "Intermediate_Files/Paper_Figs/Markers/celltype_marker_isoform_dotplot_abbreviated.pdf"
os.makedirs(os.path.dirname(outpath), exist_ok=True)
plt.tight_layout(rect=(0.02, 0, 0.86, 1))  # slimmer left, reserve right panel
plt.savefig(outpath, dpi=600, transparent=True, bbox_inches="tight")
plt.show()
print(f"Saved to: {outpath}")

In [None]:
### Full sub-cell-type marker isoform dotplot (not abbreviated)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

# --- Config ---
marker_groups_to_include = ["CD4 Effector", "CD8 Effector", "Memory", "Transition"]
celltype_order = [
    "Effector CD4 T cells", "Effector CD8 T cells", "Memory T cells", "Effector-Memory Transition T cells"
]

# --- Build isoform → group mapping & collect isoforms present ---
gene_to_group = {}
matched_isoforms = []
for group, genes in marker_genes.items():
    if group not in marker_groups_to_include:
        continue
    matches = find_matching_genes(genes, adata_i_filtered.var_names)  # isoform IDs like "GENE:ENSG:ENST"
    matched_isoforms.extend(matches)
    for iso in matches:
        gene_to_group[iso] = group

if not matched_isoforms:
    raise ValueError("No marker isoforms found for the requested groups.")

# Deduplicate isoforms while preserving order
matched_isoforms = list(dict.fromkeys(matched_isoforms))

# --- Expression matrix (cells × isoforms) ---
X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()

# IMPORTANT: subtypes live in sub_cell_type (matches celltype_order labels)
iso_expr = pd.DataFrame(
    X,
    index=adata_i_filtered.obs["sub_cell_type"].astype(str),
    columns=matched_isoforms
)

# --- Per-cluster metrics on isoforms ---
avg_expr_iso = iso_expr.groupby(iso_expr.index, observed=True).mean()
pct_expr_iso = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

# --- keep clusters in desired order ---
present = [ct for ct in celltype_order if ct in avg_expr_iso.index]
if not present:
    raise ValueError("None of the requested subtypes in celltype_order are present in sub_cell_type.")
avg_expr_iso = avg_expr_iso.reindex(present)
pct_expr_iso = pct_expr_iso.reindex(present)

# --- Sort isoforms by (group_rank, gene, isoform) ---
isoform_to_gene = {iso: iso.split(":")[0] for iso in avg_expr_iso.columns}
meta = pd.DataFrame({
    "isoform": avg_expr_iso.columns,
    "gene": [isoform_to_gene[i] for i in avg_expr_iso.columns],
    "group": [gene_to_group.get(i, "") for i in avg_expr_iso.columns],
})
meta["group_rank"] = meta["group"].apply(
    lambda g: marker_groups_to_include.index(g) if g in marker_groups_to_include else 999
)
meta = meta.sort_values(["group_rank", "gene", "isoform"])

# Build ordered list from what actually exists (unique!)
ordered_isoforms = meta["isoform"].astype(str)
ordered_isoforms = pd.Index(ordered_isoforms).unique().tolist()

# Reorder matrices (safe even if some groups overlap)
avg_expr_iso = avg_expr_iso[ordered_isoforms]
pct_expr_iso = pct_expr_iso[ordered_isoforms]

# --- Mean-center per isoform (white = average across clusters for that isoform) ---
avg_centered = avg_expr_iso - avg_expr_iso.mean(axis=0)

# short x labels: GENE:last 6 of transcript id (if present)
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        return f"{gene}:{enst[-6:]}"
    return iso
xlabels = [short_label(i) for i in ordered_isoforms]

# --- Long format for plotting ---
avg_long = avg_centered.reset_index(names="cell_type").melt(
    id_vars="cell_type", var_name="iso", value_name="avg_c"
)
pct_long = pct_expr_iso.reset_index(names="cell_type").melt(
    id_vars="cell_type", var_name="iso", value_name="pct"
)
plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"], how="inner")

# Ensure types/uniqueness BEFORE categoricals (prevents empty merges & warnings)
plot_long["iso"] = plot_long["iso"].astype(str)
ordered_isoforms = pd.Index([str(i) for i in ordered_isoforms]).unique().tolist()

# Enforce display order
plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
plot_long["iso"]       = pd.Categorical(plot_long["iso"], categories=ordered_isoforms, ordered=True)

# IMPORTANT: drop rows with categories not present (avoids empty means)
plot_long = plot_long.dropna(subset=["cell_type", "iso"]).sort_values(["cell_type", "iso"])

# --- Dot plot ---
fig, ax = plt.subplots(figsize=(max(12, 0.42*len(ordered_isoforms)), max(2.5, 0.45*len(present))))
size_min, size_max = 18, 380

# robust symmetric color range around 0
if plot_long["avg_c"].notna().any():
    rng = float(np.nanpercentile(np.abs(avg_centered.values), 99)) or 1.0
else:
    rng = 1.0  # fallback
norm = TwoSlopeNorm(vcenter=0, vmin=-rng, vmax=rng)

sns.scatterplot(
    data=plot_long,
    x="iso", y="cell_type",
    hue="avg_c", size="pct",
    palette="coolwarm", hue_norm=norm,
    sizes=(size_min, size_max),
    edgecolor="black", linewidth=0.3,
    legend=False, ax=ax
)

# layout
ax.margins(y=0.12, x=0)
ax.set_xlim(-0.5, len(ordered_isoforms) - 0.5)

# x ticks: set ticks first, then labels (short form)
ax.set_xticks(np.arange(len(ordered_isoforms)))
ax.set_xticklabels(xlabels, rotation=90, ha="right", fontsize=7)

# --- Right side: colorbar + % expressing legend ---
divider = make_axes_locatable(ax)

cax = divider.append_axes("right", size="1%", pad=0.0)
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, cax=cax)
cbar.set_label("Mean-centered expression", rotation=90, labelpad=8)

lax = divider.append_axes("right", size="10%", pad=0.28)
lax.set_axis_off(); lax.set_xlim(0, 1); lax.set_ylim(0, 1)
levels = [5, 25, 50, 75, 100]  # fixed legend scale
lax.text(0.5, 0.98, "% expressing", ha="center", va="top", fontsize=10)
ypos = np.linspace(0.80, 0.20, len(levels))
for y, lv in zip(ypos, levels):
    sz = np.interp(lv, [0, 100], [size_min, size_max])
    lax.plot([0.35], [y], 'o', ms=np.sqrt(sz), markeredgecolor='black', markerfacecolor='lightgray')
    lax.text(0.60, y, f"{lv}%", va="center", ha="left", fontsize=10)

# Cosmetics
ax.set_xlabel(""); ax.set_ylabel("")
plt.yticks(rotation=0)

plt.tight_layout(rect=(0, 0, 0.88, 1))
out = "Intermediate_Files/Paper_Figs/Markers/subcelltype_marker_isoform_dotplot.pdf"
fig.savefig(out, dpi=600, transparent=True, bbox_inches="tight")
plt.show()
print(f"[DONE] Saved: {out}")

In [None]:
### Dotplots for all isoforms in general cell-type marker genes (Supplemental Figure S3)

import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Config
marker_groups_to_include = ["T cells", "Natural Killer", "B cell", "Monocyte-derived", "Megakaryocyte"]
celltype_order           = ["T cells", "NK cells", "B cells", "Monocyte-derived", "Megakaryocytes"]
OUTDIR = "Intermediate_Files/Paper_Figs/Markers/dotplots_by_group"
os.makedirs(OUTDIR, exist_ok=True)

## Collect isoforms present & group mapping
gene_to_group = {}
matched_isoforms = []
for group, genes in marker_genes.items():
    if group not in marker_groups_to_include:
        continue
    m = find_matching_genes(genes, adata_i_filtered.var_names)  # isoforms like "GENE:ENSG:ENST"
    matched_isoforms.extend(m)
    for iso in m:
        gene_to_group[iso] = group

if not matched_isoforms:
    raise ValueError("No marker isoforms found for the requested groups.")

## Extract Expression matrix (cells × isoforms)
X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()

iso_expr = pd.DataFrame(
    X,
    index=adata_i_filtered.obs["gen_cell_type"].astype(str),
    columns=matched_isoforms
)

# Per-cluster metrics
avg_expr_all = iso_expr.groupby(iso_expr.index, observed=True).mean()
pct_expr_all = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

# Keep clusters in desired order
present = [ct for ct in celltype_order if ct in avg_expr_all.index]
avg_expr_all = avg_expr_all.reindex(present)
pct_expr_all = pct_expr_all.reindex(present)

# Short label helper: "GENE:...:ENST000123" -> "GENE:000123"
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        return f"{gene}:{enst[-6:]}"
    return iso

## Create One dot plot per marker group
for grp in marker_groups_to_include:
    iso_in_grp = [iso for iso in avg_expr_all.columns if gene_to_group.get(iso) == grp]
    if not iso_in_grp:
        print(f"[{grp}] no isoforms found, skipping.")
        continue

    # order columns by (gene, isoform)
    iso_to_gene = {iso: iso.split(":")[0] for iso in iso_in_grp}
    iso_in_grp = sorted(iso_in_grp, key=lambda x: (iso_to_gene[x], x))

    # subset & center
    avg_expr = avg_expr_all[iso_in_grp]
    pct_expr = pct_expr_all[iso_in_grp]
    avg_centered = avg_expr - avg_expr.mean(axis=0)

    # long format
    avg_long = avg_centered.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="avg_c")
    pct_long = pct_expr.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="pct")
    plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"])
    plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
    plot_long["iso"]       = pd.Categorical(plot_long["iso"], categories=iso_in_grp, ordered=True)
    plot_long = plot_long.sort_values(["cell_type", "iso"])

    # color scaling (per-group), symmetric & robust
    rng = float(np.nanpercentile(np.abs(avg_centered.values), 100)) or 1.0
    norm = TwoSlopeNorm(vcenter=0, vmin=-rng, vmax=rng)

    # figure size proportional to #isoforms
    fig, ax = plt.subplots(figsize=(max(10.0, 0.55*len(iso_in_grp)), max(3.0, 0.55*len(present))))
    size_min, size_max = 18, 380

    sns.scatterplot(
        data=plot_long,
        x="iso", y="cell_type",
        hue="avg_c", size="pct",
        palette="coolwarm", hue_norm=norm,
        sizes=(size_min, size_max),
        edgecolor="black", linewidth=0.3,
        legend=False, ax=ax
    )

    # Layout: top row first, no x-gutters, vertical padding
    # enforce display order exactly as in celltype_order
    plot_long["cell_type"] = pd.Categorical(
        plot_long["cell_type"], 
        categories=present,  # same order as celltype_order subset
        ordered=True
    )

    ax.margins(y=0.12, x=0)
    ax.set_xlim(-0.5, len(iso_in_grp) - 0.5)

    # Fixed ticks + 45° labels (auto-thin if many)
    ax.set_xticks(np.arange(len(iso_in_grp)))             
    xt = [short_label(i) for i in iso_in_grp]
    max_visible = 90
    step = max(1, int(np.ceil(len(xt) / max_visible)))
    xt = [lab if (k % step == 0) else "" for k, lab in enumerate(xt)]
    ax.set_xticklabels(xt, rotation=90, ha="right", fontsize=6)

    ax.set_xlabel(""); ax.set_ylabel("")

    # Right-side panels: colorbar + size legend
    divider = make_axes_locatable(ax)

    # Colorbar
    cax = divider.append_axes("right", size="3%", pad=0.1)
    sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm); sm.set_array([])
    cbar = fig.colorbar(sm, cax=cax)
    cbar.set_label("Mean-centered expression", rotation=270, labelpad=9)

    # Size legend
    lax = divider.append_axes("right", size="10%", pad=0.70)
    lax.set_axis_off(); lax.set_xlim(0, 1); lax.set_ylim(0, 1)
    levels = [5, 25, 50, 75, 100]  # fixed legend scale
    lax.text(0.5, 0.99, "% expressing", ha="center", va="top", fontsize=9)
    ypos = np.linspace(0.80, 0.20, len(levels))
    for y, lv in zip(ypos, levels):
        sz = np.interp(lv, [0, 100], [size_min, size_max])  # legend independent of data
        lax.plot([0.35], [y], 'o', ms=np.sqrt(sz), markeredgecolor='black', markerfacecolor='lightgray')
        lax.text(0.60, y, f"{lv}%", va="center", ha="left", fontsize=9)

    # Reserve margin for side panels; keep left tight
    plt.tight_layout(rect=(0.02, 0, 0.80, 1))

    out = os.path.join(OUTDIR, f"celltype_marker_isoform_dotplot__{grp.replace(' ', '_')}.pdf")
    fig.savefig(out, dpi=600, transparent=True, bbox_inches="tight")
    plt.show()
    print(f"Saved: {out}")

In [None]:
### Dotplots for all isoforms in sub-T cell marker genes (Supplemental Figure S6a-d)

import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Helper functions
def find_matching_genes(prefixes, gene_list):
    return [gene for gene in gene_list if any(gene.startswith(prefix) for prefix in prefixes)]

marker_genes = {
    "T cells": ["CD3D:", "CD3E:", "CD3G:"],
    "CD8 Effector": ["CD8A:", "CD8B:", "GATA3:", "KLRB1:", "CCL5:"],
    "CD4 Effector": ["CD4:", "IL2RA:", "GATA3:", "AHR:", "TNF:"],
    "Memory": ["CCR7:", "SELL:", "TCF7:"],
    "Transition": ["ITGAE:", "LEF1:", "TCF7:", "IL7R:", "GATA3:", "IL2RA", "CTLA4", "CD27:"],
}

# normalize prefixes (ensure trailing ":")
normalized_marker_genes = {
    grp: [g if g.endswith(":") else f"{g}:" for g in lst]
    for grp, lst in marker_genes.items()
}

## Config (T cell subgroups)
marker_groups_to_include = ["T cells", "CD4 Effector", "CD8 Effector", "Memory", "Transition"]
tcell_order = ["Effector CD4 T cells", "Effector CD4 T cells", "Memory T cells", "Effector-Memory Transition T cells"]
OUTDIR = "Intermediate_Files/Paper_Figs/Markers/dotplots_by_Tcell_subgroups"
os.makedirs(OUTDIR, exist_ok=True)

## Collect union of isoforms and build matrix
matched_isoforms = []
for grp in marker_groups_to_include:
    matched_isoforms.extend(find_matching_genes(normalized_marker_genes[grp], adata_i_filtered.var_names))
# unique, keep order
matched_isoforms = list(dict.fromkeys(matched_isoforms))

if not matched_isoforms:
    raise ValueError("No marker isoforms found for the requested groups.")

X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()

iso_expr = pd.DataFrame(
    X,
    index=adata_i_filtered.obs["sub_cell_type"].astype(str),
    columns=matched_isoforms
)

# Per-subtype metrics
avg_expr_all = iso_expr.groupby(iso_expr.index, observed=True).mean()
pct_expr_all = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

# Keep subtypes in desired order
present = [ct for ct in tcell_order if ct in avg_expr_all.index]
if not present:
    raise ValueError("None of tcell_order present in sub_cell_type.")
avg_expr_all = avg_expr_all.reindex(present)
pct_expr_all = pct_expr_all.reindex(present)

# Short label helper
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        return f"{gene}:{enst[-6:]}"
    return iso

## One dot plot per marker group (forced genes, all isoforms)
for grp in marker_groups_to_include:
    prefixes = normalized_marker_genes.get(grp, [])
    if not prefixes:
        print(f"[{grp}] no marker list; skipping panel.")
        continue

    # find isoforms for this group directly from prefixes
    grp_iso_all = [c for c in avg_expr_all.columns if any(c.startswith(p) for p in prefixes)]

    # collect ALL isoforms for each listed gene, in the exact gene order provided
    iso_in_grp = []
    for prefix in prefixes:
        gene_sym = prefix[:-1]
        g_isos = [iso for iso in grp_iso_all if iso.split(":")[0] == gene_sym]
        if not g_isos:
            print(f"[WARN] {grp}: no isoforms found for required gene '{gene_sym}'")
        # stable order by transcript id token if present
        g_isos = sorted(g_isos, key=lambda s: s.split(":")[-1])
        iso_in_grp.extend(g_isos)

    # dedupe while preserving order (prevents "Categorical categories must be unique")
    iso_in_grp = list(dict.fromkeys(iso_in_grp))
    if not iso_in_grp:
        print(f"[{grp}] no isoforms available for any required genes; skipping.")
        continue

    # subset & center
    avg_expr = avg_expr_all[iso_in_grp]
    pct_expr = pct_expr_all[iso_in_grp]
    avg_centered = avg_expr - avg_expr.mean(axis=0)

    # long format
    avg_long = avg_centered.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="avg_c")
    pct_long = pct_expr.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="pct")
    plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"])

    # enforce top→bottom order; NO invert_yaxis later
    plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
    plot_long["iso"]       = pd.Categorical(plot_long["iso"], categories=iso_in_grp, ordered=True)
    plot_long = plot_long.sort_values(["cell_type", "iso"])

    # robust symmetric color range
    rng = float(np.nanpercentile(np.abs(avg_centered.values), 100)) or 1.0
    norm = TwoSlopeNorm(vcenter=0, vmin=-rng, vmax=rng)

    # figure size proportional to isoforms
    fig, ax = plt.subplots(figsize=(max(10.0, 0.65 * len(iso_in_grp)), max(2.0, 0.6 * len(present))))
    size_min, size_max = 18, 380

    sns.scatterplot(
        data=plot_long,
        x="iso", y="cell_type",
        hue="avg_c", size="pct",
        palette="coolwarm", hue_norm=norm,
        sizes=(size_min, size_max),
        edgecolor="black", linewidth=0.3,
        legend=False, ax=ax
    )

    # layout
    ax.margins(y=0.15, x=0.3)
    ax.set_xlim(-0.5, len(iso_in_grp) - 0.5)

    # x tick labels
    ax.set_xticks(np.arange(len(iso_in_grp)))
    xt = [short_label(i) for i in iso_in_grp]
    max_visible = 90
    step = max(1, int(np.ceil(len(xt) / max_visible)))
    xt = [lab if (k % step == 0) else "" for k, lab in enumerate(xt)]
    ax.set_xticklabels(xt, rotation=90, ha="right", fontsize=6)

    ax.set_xlabel(""); ax.set_ylabel("")

    # right-side colorbar + size legend
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3%", pad=0.1)
    sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm); sm.set_array([])
    cbar = fig.colorbar(sm, cax=cax)
    cbar.set_label("Mean-centered expression", rotation=90, labelpad=9)

    lax = divider.append_axes("right", size="10%", pad=0.70)
    lax.set_axis_off(); lax.set_xlim(0, 1); lax.set_ylim(0, 1)
    levels = [5, 25, 50, 75, 100]
    lax.text(0.5, 0.99, "% expressing", ha="center", va="top", fontsize=9)
    ypos = np.linspace(0.80, 0.20, len(levels))
    for y, lv in zip(ypos, levels):
        sz = np.interp(lv, [0, 100], [size_min, size_max])
        lax.plot([0.35], [y], 'o', ms=np.sqrt(sz), markeredgecolor='black', markerfacecolor='lightgray')
        lax.text(0.60, y, f"{lv}%", va="center", ha="left", fontsize=9)

    plt.tight_layout(rect=(0.02, 0, 0.80, 1))
    out = os.path.join(OUTDIR, f"Tcell_subtype_isoform_dotplot__{grp.replace(' ', '_')}.pdf")
    fig.savefig(out, dpi=600, transparent=True, bbox_inches="tight")
    plt.show()
    print(f"Saved: {out}")

In [None]:
### Abbreviated isoform Dot plots for sub-T cell isoform markers (Figure 5c)
## Transition markers are excluded in this plot

import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Config (T cell subtypes)
celltype_order = [
    "Effector CD4 T cells",
    "Effector CD4 T cells",
    "Memory T cells",
    "Effector-Memory Transition T cells",
]

# Exact gene display order you requested
gene_display_order = ["GATA3","CD4","IL2RA","AHR","TNF","CD8A","CD8B","CCL5","KLRB1","CCR7","SELL","TCF7"]

# Include ALL transcripts in your figure (full ENST IDs that match the shown tails)
desired_transcripts = {
    "GATA3": ["ENST00000379328", "ENST00000346208", "ENST00000461472"],
    "IL2RA": ["ENST00000379959", "ENST00000649218", "ENST00000379954"],
    "CD4":   ["ENST00000011653", "ENST00000544344"],
    "AHR":   ["ENST00000242057", "ENST00000481944"],
    "TNF":   ["ENST00000449264"],
    "CD8A":  ["ENST00000283635", "ENST00000352580", "ENST00000699436"],
    "CD8B":  ["ENST00000390655", "ENST00000393761", "ENST00000393759", "ENST00000431506"],
    "CCL5":  ["ENST00000605140", "ENST00000651122"],
    "KLRB1": ["ENST00000229402"],
    "CCR7":  ["ENST00000579344", "ENST00000246657", "ENST00000578085"],
    "SELL":  ["ENST00000650983", "ENST00000236147", "ENST00000463108"],
    "TCF7":  ["ENST00000524342", "ENST00000395023", "ENST00000378560", "ENST00000520958"],
}

## Build whitelist from adata varnames (expects "GENE:ENSG:ENST")

varnames = pd.Index(adata_i_filtered.var_names.astype(str))
whitelist_isoforms, missing = [], []

def find_iso_ids(gene, enst):
    # IDs look like "GENE:ENSG:ENST". Accept any ENSG.
    return varnames[(varnames.str.startswith(f"{gene}:")) & (varnames.str.endswith(f":{enst}"))].tolist()

# Enforce your GENE order first, then transcript order within each gene
for gene in gene_display_order:
    for enst in desired_transcripts.get(gene, []):
        hits = find_iso_ids(gene, enst)
        if hits:
            whitelist_isoforms.extend(hits)   # preserve requested order
        else:
            missing.append(f"{gene}:{enst}")

if not whitelist_isoforms:
    raise ValueError("None of the requested transcripts were found in adata_i_filtered.var_names.")

if missing:
    print("Warning — these requested transcripts were not found and will be skipped:")
    for m in missing:
        print("  ", m)

matched_isoforms = whitelist_isoforms

## Extract Expression matrix (cells × isoforms); INDEX by sub_cell_type
ct_series = adata_i_filtered.obs["sub_cell_type"].astype(str).copy()
X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()

iso_expr = pd.DataFrame(X, index=ct_series.values, columns=matched_isoforms)


## Per-subcluster metrics on isoforms
avg_expr_iso = iso_expr.groupby(iso_expr.index, observed=True).mean()
pct_expr_iso = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

present = [ct for ct in celltype_order if ct in avg_expr_iso.index]
if not present:
    raise ValueError("None of the expected T-cell subclusters are present in sub_cell_type.")

avg_expr_iso = avg_expr_iso.reindex(present)
pct_expr_iso = pct_expr_iso.reindex(present)

# Fixed isoform order = whitelist order present in matrix (already in your gene order)
ordered_isoforms = [i for i in matched_isoforms if i in avg_expr_iso.columns]
avg_expr_iso = avg_expr_iso[ordered_isoforms]
pct_expr_iso = pct_expr_iso[ordered_isoforms]


## Mean-center per isoform (white = average across subclusters)
avg_centered = avg_expr_iso - avg_expr_iso.mean(axis=0)

# Short x labels: GENE:last 6 of ENST
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        tail = enst[-6:] if len(enst) >= 6 else enst
        return f"{gene}:{tail}"
    return iso

xlabels = [short_label(i) for i in ordered_isoforms]



## Long format for plotting
avg_long = avg_centered.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="avg_c")
pct_long = pct_expr_iso.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="pct")
plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"])
plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
plot_long["iso"] = pd.Categorical(plot_long["iso"], categories=ordered_isoforms, ordered=True)
plot_long = plot_long.sort_values(["cell_type", "iso"])

## Create Dot plot
# More horizontal space per isoform: bump width multiplier from 0.42 -> 0.60
# Less vertical space between cell types: drop per-row height 0.45 -> 0.32 and shrink y-margins

fig, ax = plt.subplots(
    figsize=(max(12, 0.60 * len(ordered_isoforms)), max(2.4, 0.32 * len(present)))
)

size_min, size_max = 18, 380

vals = avg_centered.values.astype(float)
if np.all((~np.isfinite(vals)) | (vals == 0)):
    rng = 1.0
else:
    rng = float(np.nanpercentile(np.abs(vals[np.isfinite(vals)]), 99))
    if not np.isfinite(rng) or rng == 0:
        rng = 1.0
norm = TwoSlopeNorm(vcenter=0, vmin=-rng, vmax=rng)

sns.scatterplot(
    data=plot_long,
    x="iso", y="cell_type",
    hue="avg_c", size="pct",
    palette="coolwarm", hue_norm=norm,
    sizes=(size_min, size_max),
    edgecolor="black", linewidth=0.3,
    legend=False, ax=ax
)

ax.margins(y=0.03, x=0.02)                     
ax.set_xlim(-0.5, len(ordered_isoforms) - 0.5)
ax.set_ylim(-0.5, len(present) - 0.5)

ax.set_xticks(range(len(ordered_isoforms)))
ax.set_xticklabels(xlabels, rotation=90, ha="right", fontsize=7)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=9)


## Right side: colorbar + % expressing legend
divider = make_axes_locatable(ax)

cax = divider.append_axes("right", size="3%", pad=0.12)  
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, cax=cax)
cbar.set_label("Mean-centered expression", rotation=90, labelpad=8)

lax = divider.append_axes("right", size="9%", pad=0.20)   
lax.set_axis_off(); lax.set_xlim(0, 1); lax.set_ylim(0, 1)
levels = [5, 25, 50, 75, 100]
lax.text(0.5, 0.98, "% expressing", ha="center", va="top", fontsize=10)
ypos = np.linspace(0.80, 0.20, len(levels))
for y, lv in zip(ypos, levels):
    sz = np.interp(lv, [0, 100], [size_min, size_max])
    lax.plot([0.32], [y], 'o', ms=np.sqrt(sz), markeredgecolor='black', markerfacecolor='lightgray')
    lax.text(0.55, y, f"{lv}%", va="center", ha="left", fontsize=10)

ax.set_xlabel(""); ax.set_ylabel("")
plt.yticks(rotation=0)


### Save / show
outpath = "Intermediate_Files/Paper_Figs/Markers/subcelltype_TCells_marker_isoform_dotplot_abbreviated.pdf"
os.makedirs(os.path.dirname(outpath), exist_ok=True)

# Give right-side guides room, but keep plot compact vertically
plt.tight_layout(rect=(0.02, 0.0, 0.90, 1.0))  
plt.savefig(outpath, dpi=600, transparent=True, bbox_inches="tight")
plt.show()
print(f"Saved to: {outpath}")

In [None]:
### Abbreviated isoform Dot plots for transition T cell isoform markers (Figure 5d)

import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Config (T cell subtypes)
celltype_order = [
    "Effector CD4 T cells",
    "Effector CD8 T cells",
    "Memory T cells",
    "Effector-Memory Transition T cells",
]

# Exact gene display order you requested (append new genes at the end)
gene_display_order = [
                      "CTLA4","CD27","LEF1","ITGAE"]

desired_transcripts = {
    # The next four will be filled dynamically from your matrix:
    "CTLA4": ["ENST00000648405"],
    "CD27":  ["ENST00000266557", "ENST00000541233"],
    "LEF1":  ["ENST00000265165", "ENST00000379951", "ENST00000505328"],
    "ITGAE": ["ENST00000572179", "ENST00000570360", "ENST00000263087"],
}


## Build whitelist from adata varnames (expects "GENE:ENSG:ENST")
varnames = pd.Index(adata_i_filtered.var_names.astype(str))
whitelist_isoforms, missing = [], []

def find_iso_ids(gene, enst):
    # IDs look like "GENE:ENSG:ENST". Accept any ENSG.
    return varnames[(varnames.str.startswith(f"{gene}:")) & (varnames.str.endswith(f":{enst}"))].tolist()

# Enforce your GENE order first, then transcript order within each gene
for gene in gene_display_order:
    for enst in desired_transcripts.get(gene, []):
        hits = find_iso_ids(gene, enst)
        if hits:
            whitelist_isoforms.extend(hits)   # preserve requested order
        else:
            missing.append(f"{gene}:{enst}")

if not whitelist_isoforms:
    raise ValueError("None of the requested transcripts were found in adata_i_filtered.var_names.")

if missing:
    print("Warning — these requested transcripts were not found and will be skipped:")
    for m in missing:
        print("  ", m)

matched_isoforms = whitelist_isoforms



## Expression matrix (cells × isoforms); enforce Y order now
# Make sure sub_cell_type follows the specified order
present = [ct for ct in celltype_order if ct in adata_i_filtered.obs["sub_cell_type"].unique().tolist()]
if not present:
    raise ValueError("None of the expected T-cell subclusters are present in sub_cell_type.")

ct_series = adata_i_filtered.obs["sub_cell_type"].astype(str)
ct_series = pd.Categorical(ct_series, categories=present, ordered=True)

X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()
iso_expr = pd.DataFrame(X, index=ct_series, columns=matched_isoforms)



## Per-subcluster metrics on isoforms
avg_expr_iso = iso_expr.groupby(iso_expr.index, observed=True).mean()
pct_expr_iso = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

# Fixed isoform order = whitelist order present in matrix (already GENE→ENST order)
ordered_isoforms = [i for i in matched_isoforms if i in avg_expr_iso.columns]
avg_expr_iso = avg_expr_iso.loc[present, ordered_isoforms]
pct_expr_iso = pct_expr_iso.loc[present, ordered_isoforms]



## Mean-center per isoform (white = average across subclusters)
avg_centered = avg_expr_iso - avg_expr_iso.mean(axis=0)

# Short x labels: GENE:last 6 of ENST
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        tail = enst[-6:] if len(enst) >= 6 else enst
        return f"{gene}:{tail}"
    return iso

xlabels = [short_label(i) for i in ordered_isoforms]

## Long format for plotting
avg_long = avg_centered.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="avg_c")
pct_long = pct_expr_iso.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="pct")
plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"])

plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
plot_long["iso"] = pd.Categorical(plot_long["iso"], categories=ordered_isoforms, ordered=True)
plot_long = plot_long.sort_values(["cell_type", "iso"], kind="stable")



## Create Dot plot
fig, ax = plt.subplots(
    figsize=(max(8, 0.28 * len(ordered_isoforms)), max(2.4, 0.32 * len(present)))
)

size_min, size_max = 18, 380

sns.scatterplot(
    data=plot_long,
    x="iso", y="cell_type",
    hue="avg_c", size="pct",
    palette="coolwarm", hue_norm=norm,
    sizes=(size_min, size_max),
    edgecolor="black", linewidth=0.3,
    legend=False, ax=ax
)

# Lock axes and remove horizontal padding
ax.set_xlim(-0.5, len(ordered_isoforms) - 0.5)
ax.set_ylim(-0.5, len(present) - 0.5)
ax.margins(y=0.03, x=0.00)

ax.set_xticks(range(len(ordered_isoforms)))
ax.set_xticklabels(xlabels, rotation=90, ha="right", fontsize=7)
ax.set_yticklabels(present, fontsize=9)

# Clear auto labels from seaborn
ax.set_xlabel("")
ax.set_ylabel("")




## Side panels placed explicitly in figure coords

# Reserve space on the right for *both* panels
right_pad = 0.24                       
plt.tight_layout(rect=(0.06, 0.02, 1.0 - right_pad, 0.98)) 

# Get the axis position after tight_layout
fig.canvas.draw_idle()
pos = ax.get_position()

# Sizes / gaps (figure coordinates
gap_plot_to_cbar = 0.010
gap_cbar_to_legend = 0.035             
cbar_w = 0.018
legend_w = right_pad - cbar_w - gap_plot_to_cbar - gap_cbar_to_legend

# Colorbar axis 
cax = fig.add_axes([pos.x1 + gap_plot_to_cbar, pos.y0, cbar_w, pos.height])

# % expressing legend axis to the RIGHT of the colorbar
lax = fig.add_axes([pos.x1 + gap_plot_to_cbar + cbar_w + gap_cbar_to_legend,
                    pos.y0, legend_w, pos.height])
lax.set_axis_off()
lax.set_xlim(0, 1); lax.set_ylim(0, 1)

# Colorbar
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, cax=cax, orientation="vertical")
cbar.ax.tick_params(labelsize=8)
# keep the label compact so it doesn't intrude into the legend column
cbar.set_label("Mean-centered expression", rotation=90, labelpad=6)

# % expressing legend (use area units to match main dots)
levels = [5, 25, 50, 75, 100]
legend_scale = 0.60                    
s_vals = np.interp(levels, [0, 100], [size_min, size_max]) * legend_scale

lax.text(0.52, 0.965, "% expressing", ha="center", va="top", fontsize=9)
ypos = np.linspace(0.80, 0.20, len(levels))
lax.scatter([0.32]*len(levels), ypos, s=s_vals, facecolor="lightgray",
            edgecolor="black", linewidth=0.6, zorder=2)
for y, lv in zip(ypos, levels):
    lax.text(0.52, y, f"{lv}%", va="center", ha="left", fontsize=9)


## Save
outpath = "Intermediate_Files/Paper_Figs/Markers/subcelltype_transitionOnly_TCells_marker_isoform_dotplot_abbreviated.pdf"
os.makedirs(os.path.dirname(outpath), exist_ok=True)
fig.savefig(outpath, dpi=600, transparent=True, bbox_inches="tight")

In [None]:
### Individual Gene Dot plots for sub-T cell isoform markers (Supplemental Figure 6e)

import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Helper functions & markers
def find_matching_genes(prefixes, gene_list):
    return [gene for gene in gene_list if any(gene.startswith(prefix) for prefix in prefixes)]

marker_genes = {
    "T cells": ["CD3D:", "CD3E:", "CD3G:"],
    "CD4 Effector": ["CD8A:", "CD8B:", "GATA3:", "KLRB1:", "CCL5:"],
    "CD4 Effector": ["CD4:", "IL2RA:", "GATA3:", "AHR:", "TNF:"],
    "Memory": ["CCR7:", "SELL:", "TCF7:"],
    "Transition": ["ITGAE:", "LEF1:", "TCF7:", "IL7R:", "GATA3:", "IL2RA", "CTLA4", "CD27:"],
}

# normalize prefixes (ensure trailing ":")
normalized_marker_genes = {
    grp: [g if g.endswith(":") else f"{g}:" for g in lst]
    for grp, lst in marker_genes.items()
}

## Config (T cell subgroups)
marker_groups_to_include = ["T cells", "CD4 Effector", "CD8 Effector", "Memory", "Transition"]
tcell_order = ["Effector CD4 T cells", "Effector CD8 T cells", "Memory T cells", "Effector-Memory Transition T cells"]
OUTDIR = "Intermediate_Files/Paper_Figs/Markers/dotplots_by_Tcell_subgroups"
os.makedirs(OUTDIR, exist_ok=True)

## Collect union of isoforms and build matrix
matched_isoforms = []
for grp in marker_groups_to_include:
    matched_isoforms.extend(find_matching_genes(normalized_marker_genes[grp], adata_i_filtered.var_names))
# unique, keep order
matched_isoforms = list(dict.fromkeys(matched_isoforms))

if not matched_isoforms:
    raise ValueError("No marker isoforms found for the requested groups.")

X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()

iso_expr = pd.DataFrame(
    X,
    index=adata_i_filtered.obs["sub_cell_type"].astype(str),
    columns=matched_isoforms
)

# Per-subtype metrics
avg_expr_all = iso_expr.groupby(iso_expr.index, observed=True).mean()
pct_expr_all = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

# Keep subtypes in desired order
present = [ct for ct in tcell_order if ct in avg_expr_all.index]
if not present:
    raise ValueError("None of tcell_order present in sub_cell_type.")
avg_expr_all = avg_expr_all.reindex(present)
pct_expr_all = pct_expr_all.reindex(present)

# Short label helper
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        return f"{gene}:{enst[-6:]}"
    return iso

## One dot plot per gene (all isoforms for that gene)
for grp in marker_groups_to_include:
    prefixes = normalized_marker_genes.get(grp, [])
    if not prefixes:
        print(f"[{grp}] no marker list; skipping panel.")
        continue

    # Loop through each gene prefix individually
    for prefix in prefixes:
        gene_sym = prefix[:-1]  # remove trailing ":"
        # get isoforms for this gene from the already matched isoforms
        g_isos = [iso for iso in avg_expr_all.columns if iso.split(":")[0] == gene_sym]

        if not g_isos:
            print(f"[WARN] No isoforms found for gene '{gene_sym}' in group '{grp}'")
            continue

        # stable order by transcript ID
        g_isos = sorted(g_isos, key=lambda s: s.split(":")[-1])

        # subset & center
        avg_expr = avg_expr_all[g_isos]
        pct_expr = pct_expr_all[g_isos]
        avg_centered = avg_expr - avg_expr.mean(axis=0)

        # long format
        avg_long = avg_centered.reset_index(names="cell_type").melt(
            id_vars="cell_type", var_name="iso", value_name="avg_c"
        )
        pct_long = pct_expr.reset_index(names="cell_type").melt(
            id_vars="cell_type", var_name="iso", value_name="pct"
        )
        plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"])

        # enforce order
        plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
        plot_long["iso"]       = pd.Categorical(plot_long["iso"], categories=g_isos, ordered=True)
        plot_long = plot_long.sort_values(["cell_type", "iso"])

        # robust symmetric color range
        rng = float(np.nanpercentile(np.abs(avg_centered.values), 100)) or 1.0
        norm = TwoSlopeNorm(vcenter=0, vmin=-rng, vmax=rng)

        # figure size proportional to isoforms
        fig, ax = plt.subplots(
            figsize=(max(8.0, 0.65 * len(g_isos)), max(2.0, 0.6 * len(present)))
        )
        size_min, size_max = 18, 380

        sns.scatterplot(
            data=plot_long,
            x="iso", y="cell_type",
            hue="avg_c", size="pct",
            palette="coolwarm", hue_norm=norm,
            sizes=(size_min, size_max),
            edgecolor="black", linewidth=0.3,
            legend=False, ax=ax
        )

        # layout
        ax.margins(y=0.15, x=0.3)
        ax.set_xlim(-0.5, len(g_isos) - 0.5)

        # x tick labels
        ax.set_xticks(np.arange(len(g_isos)))
        xt = [short_label(i) for i in g_isos]
        max_visible = 90
        step = max(1, int(np.ceil(len(xt) / max_visible)))
        xt = [lab if (k % step == 0) else "" for k, lab in enumerate(xt)]
        ax.set_xticklabels(xt, rotation=90, ha="right", fontsize=6)

        ax.set_xlabel(""); ax.set_ylabel("")

        # right-side colorbar + size legend
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="3%", pad=0.1)
        sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm); sm.set_array([])
        cbar = fig.colorbar(sm, cax=cax)
        cbar.set_label("Mean-centered expression", rotation=90, labelpad=9)

        lax = divider.append_axes("right", size="10%", pad=0.70)
        lax.set_axis_off(); lax.set_xlim(0, 1); lax.set_ylim(0, 1)
        levels = [5, 25, 50, 75, 100]
        lax.text(0.5, 0.99, "% expressing", ha="center", va="top", fontsize=9)
        ypos = np.linspace(0.80, 0.20, len(levels))
        for y, lv in zip(ypos, levels):
            sz = np.interp(lv, [0, 100], [size_min, size_max])
            lax.plot([0.35], [y], 'o', ms=np.sqrt(sz),
                     markeredgecolor='black', markerfacecolor='lightgray')
            lax.text(0.60, y, f"{lv}%", va="center", ha="left", fontsize=9)

        plt.tight_layout(rect=(0.02, 0, 0.80, 1))

        # Save with group+gene name
        out = os.path.join(
            OUTDIR, f"Tcell_subtype_isoform_dotplot__{grp.replace(' ', '_')}__{gene_sym}.pdf"
        )
        fig.savefig(out, dpi=600, transparent=True, bbox_inches="tight")
        plt.close(fig)  # close to avoid too many open figures
        print(f"Saved: {out}")

In [None]:
### Abbreviated isoform Dot plots for general T cell markers (CD3D, CD3E, CD3G)
## across sub-T cell clusters (Figure 5e)

import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Config (T cell subtypes)
celltype_order = [
    "Effector CD4 T cells",
    "Effector CD8 T cells",
    "Memory T cells",
    "Effector-Memory Transition T cells",
]

# Exact gene display order you requested
gene_display_order = ["CD3D","CD3E","CD3G"]

# Include ALL transcripts in your figure (full ENST IDs that match the shown tails)
desired_transcripts = {
    "CD3D": ["ENST00000300692", "ENST00000392884", "ENST00000526561", "ENST00000695667", "ENST00000695666"],
    "CD3E": ["ENST00000361763", "ENST00000528600", "ENST00000526146"],
    "CD3G": ["ENST00000532917", "ENST00000527777", "ENST00000292144", "ENST00000392883"],
}


## Build whitelist from adata varnames (expects "GENE:ENSG:ENST")
varnames = pd.Index(adata_i_filtered.var_names.astype(str))
whitelist_isoforms, missing = [], []

def find_iso_ids(gene, enst):
    # IDs look like "GENE:ENSG:ENST". Accept any ENSG.
    return varnames[(varnames.str.startswith(f"{gene}:")) & (varnames.str.endswith(f":{enst}"))].tolist()

# Enforce your GENE order first, then transcript order within each gene
for gene in gene_display_order:
    for enst in desired_transcripts.get(gene, []):
        hits = find_iso_ids(gene, enst)
        if hits:
            whitelist_isoforms.extend(hits)   # preserve requested order
        else:
            missing.append(f"{gene}:{enst}")

if not whitelist_isoforms:
    raise ValueError("None of the requested transcripts were found in adata_i_filtered.var_names.")

if missing:
    print("Warning — these requested transcripts were not found and will be skipped:")
    for m in missing:
        print("  ", m)

matched_isoforms = whitelist_isoforms


## Extract Expression matrix (cells × isoforms); INDEX by sub_cell_type
ct_series = adata_i_filtered.obs["sub_cell_type"].astype(str).copy()
X = adata_i_filtered[:, matched_isoforms].X
if hasattr(X, "toarray"):
    X = X.toarray()

iso_expr = pd.DataFrame(X, index=ct_series.values, columns=matched_isoforms)


## Per-subcluster metrics on isoforms
avg_expr_iso = iso_expr.groupby(iso_expr.index, observed=True).mean()
pct_expr_iso = (iso_expr > 0).astype(int).groupby(iso_expr.index, observed=True).mean() * 100

present = [ct for ct in celltype_order if ct in avg_expr_iso.index]
if not present:
    raise ValueError("None of the expected T-cell subclusters are present in sub_cell_type.")

avg_expr_iso = avg_expr_iso.reindex(present)
pct_expr_iso = pct_expr_iso.reindex(present)

# Fixed isoform order = whitelist order present in matrix (already in your gene order)
ordered_isoforms = [i for i in matched_isoforms if i in avg_expr_iso.columns]
avg_expr_iso = avg_expr_iso[ordered_isoforms]
pct_expr_iso = pct_expr_iso[ordered_isoforms]


## Mean-center per isoform (white = average across subclusters)
avg_centered = avg_expr_iso - avg_expr_iso.mean(axis=0)

# Short x labels: GENE:last 6 of ENST
def short_label(iso):
    parts = iso.split(":")
    if len(parts) == 3:
        gene, _, enst = parts
        tail = enst[-6:] if len(enst) >= 6 else enst
        return f"{gene}:{tail}"
    return iso

xlabels = [short_label(i) for i in ordered_isoforms]


## Long format for plotting
avg_long = avg_centered.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="avg_c")
pct_long = pct_expr_iso.reset_index(names="cell_type").melt(id_vars="cell_type", var_name="iso", value_name="pct")
plot_long = avg_long.merge(pct_long, on=["cell_type", "iso"])
plot_long["cell_type"] = pd.Categorical(plot_long["cell_type"], categories=present, ordered=True)
plot_long["iso"] = pd.Categorical(plot_long["iso"], categories=ordered_isoforms, ordered=True)
plot_long = plot_long.sort_values(["cell_type", "iso"])


## Dot plot
fig, ax = plt.subplots(
    figsize=(max(8, 0.28 * len(ordered_isoforms)),   
             max(2.4, 0.32 * len(present)))          
)

size_min, size_max = 18, 380

sns.scatterplot(
    data=plot_long,
    x="iso", y="cell_type",
    hue="avg_c", size="pct",
    palette="coolwarm", hue_norm=norm,
    sizes=(size_min, size_max),
    size_norm=(0, 100),               
    edgecolor="black", linewidth=0.3,
    legend=False, ax=ax
)

# Lock axes and remove extra horizontal padding 
ax.set_xlim(-0.5, len(ordered_isoforms) - 0.5)
ax.set_ylim(-0.5, len(present) - 0.5)
ax.margins(y=0.03, x=0.00)            # <- zero x-margins
ax.set_xticks(range(len(ordered_isoforms)))
ax.set_xticklabels(xlabels, rotation=90, ha="right", fontsize=7)
ax.set_yticklabels(present, fontsize=9)
ax.set_xlabel(""); ax.set_ylabel("")

# Right side formatting
right_pad = 0.24                                       
plt.tight_layout(rect=(0.06, 0.02, 1.0 - right_pad, 0.98)) 
fig.canvas.draw_idle()
pos = ax.get_position()

gap_plot_to_cbar   = 0.010
gap_cbar_to_legend = 0.035
cbar_w   = 0.018
legend_w = right_pad - cbar_w - gap_plot_to_cbar - gap_cbar_to_legend

# Colorbar
cax = fig.add_axes([pos.x1 + gap_plot_to_cbar, pos.y0, cbar_w, pos.height])
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)  
sm.set_array([])
cbar = fig.colorbar(sm, cax=cax, orientation="vertical")
cbar.ax.tick_params(labelsize=8)
cbar.set_label("Mean-centered expression", rotation=90, labelpad=6)

# % expressing legend (sizes match size_norm 0–100)
lax = fig.add_axes([pos.x1 + gap_plot_to_cbar + cbar_w + gap_cbar_to_legend,
                    pos.y0, legend_w, pos.height])
lax.set_axis_off(); lax.set_xlim(0, 1); lax.set_ylim(0, 1)
levels = [5, 25, 50, 75, 100]
s_vals = np.interp(levels, [0, 100], [size_min, size_max]) * 0.60  
lax.text(0.52, 0.965, "% expressing", ha="center", va="top", fontsize=9)
ypos = np.linspace(0.80, 0.20, len(levels))
lax.scatter([0.32]*len(levels), ypos, s=s_vals, facecolor="lightgray",
            edgecolor="black", linewidth=0.6, zorder=2)
for y, lv in zip(ypos, levels):
    lax.text(0.52, y, f"{lv}%", va="center", ha="left", fontsize=9)

# Save 
outpath = "Intermediate_Files/Paper_Figs/Markers/subcelltype_TCells_marker_CD3_isoform_dotplot_abbreviated.pdf"
os.makedirs(os.path.dirname(outpath), exist_ok=True)
fig.savefig(outpath, dpi=600, transparent=True, bbox_inches="tight")
plt.show()
print(f"Saved to: {outpath}")

In [None]:
### Prep for heatmaps of novel isoforms only, split by known and new

## STRICT SPLIT: ENSG/BambuTx vs BambuGene/BambuTx
import re
import numpy as np
import pandas as pd
import scipy.sparse as sp

## Parse transcript->gene_id from your extended annotations (Bambu adds gene_id) -----
ext_gtf_path = "extended_annotations.gtf"  # your file with Bambu transcripts
ext = pd.read_csv(
    ext_gtf_path, sep="\t", comment="#", header=None, dtype=str,
    names=["chrom","source","feature","start","end","score","strand","frame","attribute"]
)
ext = ext[ext["feature"] == "transcript"].copy()
ext["transcript_id"] = ext["attribute"].str.extract(r'transcript_id "([^"]+)"', expand=False)
ext["gene_id"]       = ext["attribute"].str.extract(r'gene_id "([^"]+)"', expand=False)
tx_to_gene_id = (
    ext.dropna(subset=["transcript_id","gene_id"])
       .drop_duplicates("transcript_id")
       .set_index("transcript_id")["gene_id"]
       .to_dict()
)

## Get X as CSR (or keep dense) and the transcript list 
X = adata_i_filtered.X
is_sparse = sp.isspmatrix(X)
if is_sparse:
    if not sp.isspmatrix_csr(X):
        X = X.tocsr()
else:
    X = np.asarray(X, dtype=np.float32, order="C")

transcripts = adata_i_filtered.var.index.astype(str).to_numpy()

# cell-type info
cell_types_cat = adata_i_filtered.obs["sub_cell_type"].astype("category")
ct_codes = cell_types_cat.cat.codes.to_numpy()
ct_names = cell_types_cat.cat.categories.to_numpy()
n_cells, n_iso = X.shape
n_ct = len(ct_names)

## Build *exact* masks
# (a) BambuTx isoforms only
is_bambu_tx = pd.Series(transcripts, dtype="string").str.startswith("BambuTx").to_numpy()

# (b) Map transcript -> gene_id, then gene class
gene_ids = pd.Series(transcripts).map(tx_to_gene_id)         # may be NaN for some
gene_ids = gene_ids.fillna("")
gene_is_ensg   = gene_ids.str.startswith("ENSG").to_numpy()
gene_is_bambu  = gene_ids.str.startswith("BambuGene").to_numpy()

# (c) Final, strict subsets
mask_ENSG_BambuTx    = is_bambu_tx & gene_is_ensg
mask_BambuGene_BambuTx = is_bambu_tx & gene_is_bambu

## Lightweight prevalence filter & column cap (adjust as desired)
MIN_CELLS_NZ = 20
MIN_PREV_ANY = 0.01
MAX_COLS_PER_SUBSET = 300  # None for no cap (risk of huge PDFs)

def compute_global_nz(X):
    if sp.isspmatrix_csr(X):
        return np.asarray(X.getnnz(axis=0)).ravel()
    # dense, chunk to save RAM
    B = 2000
    nz = np.zeros(X.shape[1], dtype=np.int64)
    for start in range(0, X.shape[0], B):
        stop = min(start+B, X.shape[0])
        nz += (X[start:stop, :] != 0).sum(axis=0)
    return nz

global_nz = compute_global_nz(X)

keep_cols_global = np.flatnonzero(global_nz >= MIN_CELLS_NZ)
if keep_cols_global.size == 0:
    print("[split] No isoforms pass global nonzero filter. Skipping.")
    df_known = pd.DataFrame(columns=["cell_type","clean_name","expression"])
    df_novel = pd.DataFrame(columns=["cell_type","clean_name","expression"])
else:
    # keep only globally nonzero columns
    transcripts_k = transcripts[keep_cols_global].astype(str)
    global_nz_k   = global_nz[keep_cols_global]
    Xk            = X[:, keep_cols_global]  # works for CSR or dense

    ## FUNCTIONS
    def agg_by_celltype_sparse(Xsub):
        """Aggregate per cell type (sparse). Returns means, prev arrays."""
        k = Xsub.shape[1]
        means = np.zeros((n_ct, k), dtype=np.float32)
        prev  = np.zeros((n_ct, k), dtype=np.float32)
        for i in range(n_ct):
            sel = (ct_codes == i)
            if not np.any(sel):
                continue
            Xi = Xsub[sel, :]
            c = int(sel.sum())
            means[i, :] = np.asarray(Xi.sum(axis=0)).ravel() / max(c, 1)
            Xi_nz = Xi.sign()
            prev[i, :]  = np.asarray(Xi_nz.sum(axis=0)).ravel() / max(c, 1)
        return means, prev

    def agg_by_celltype_dense(Xsub, ct_codes, n_ct, batch_rows=2000):
        """Aggregate per cell type (dense) in row-chunks. Returns means, prev arrays."""
        n_cells, k = Xsub.shape
        means = np.zeros((n_ct, k), dtype=np.float32)
        prev  = np.zeros((n_ct, k), dtype=np.float32)

        for i in range(n_ct):
            rows_i = np.flatnonzero(ct_codes == i)
            if rows_i.size == 0:
                continue
            s  = np.zeros(k, dtype=np.float64)
            nz = np.zeros(k, dtype=np.int64)
            for start in range(0, rows_i.size, batch_rows):
                sub = rows_i[start:start+batch_rows]
                Xi = Xsub[sub, :]
                s  += Xi.sum(axis=0)
                nz += (Xi != 0).sum(axis=0)
            c = float(rows_i.size)
            means[i, :] = (s / max(c, 1)).astype(np.float32)
            prev[i, :]  = (nz / max(c, 1)).astype(np.float32)
        return means, prev

    def aggregate_subset_by_cols(col_idx):
        """Aggregate per cell type for a chosen subset of columns with a prevalence filter."""
        if col_idx.size == 0:
            return None, None, col_idx
        Xs = Xk[:, col_idx]
        if is_sparse:
            m, p = agg_by_celltype_sparse(Xs)
        else:
            # pass positionally; avoids old 'batch' keyword collisions
            m, p = agg_by_celltype_dense(Xs, ct_codes, n_ct, 2000)
        keep = (p.max(axis=0) >= MIN_PREV_ANY)
        if not np.any(keep):
            return None, None, np.array([], dtype=int)
        return m[:, keep], p[:, keep], col_idx[keep]

    ## STRICT SPLIT BY LABEL PATTERN
    # We want ONLY:
    #   • known set:  "ENSG###########:BambuTx<id>"
    #   • novel set:  "BambuGene<id>:BambuTx<id>"
    tx_series = pd.Series(transcripts_k, dtype=str)
    mask_known_gene_btx = tx_series.str.match(r"ENSG\d{11}:BambuTx\d+$", na=False).to_numpy()
    mask_bambu_gene_btx = tx_series.str.match(r"BambuGene\d+:BambuTx\d+$", na=False).to_numpy()

    def choose_by_mask(mask_bool):
        idx = np.flatnonzero(mask_bool)
        if idx.size == 0:
            return idx
        # rank by global prevalence; cap; then sort
        take = idx if MAX_COLS_PER_SUBSET is None else idx[np.argsort(global_nz_k[idx])[-MAX_COLS_PER_SUBSET:]]
        return np.sort(take)

    cols_known_gene_btx = choose_by_mask(mask_known_gene_btx)
    cols_bambu_gene_btx = choose_by_mask(mask_bambu_gene_btx)

    means_known, prev_known, idx_known = aggregate_subset_by_cols(cols_known_gene_btx)
    means_bambu, prev_bambu, idx_bambu = aggregate_subset_by_cols(cols_bambu_gene_btx)

    def to_long_df(means, col_idx):
        if means is None or col_idx.size == 0:
            return pd.DataFrame(columns=["cell_type","clean_name","expression"])
        tx_sel = transcripts_k[col_idx]
        # gene label is the part before the colon; clean_name = "<gene>:<full_label>"
        gene_label = pd.Series(tx_sel, dtype=str).str.split(":", n=1, expand=True).iloc[:, 0]
        clean = np.where(gene_label.notna(), gene_label.values + ":" + tx_sel, tx_sel)
        small_wide = pd.DataFrame(means, index=ct_names, columns=clean)
        out = small_wide.stack().reset_index()
        out.columns = ["cell_type", "clean_name", "expression"]
        return out

    df_known = to_long_df(means_known, idx_known)   # ENSG:BambuTx only
    df_novel = to_long_df(means_bambu, idx_bambu)   # BambuGene:BambuTx only
    print(f"[split] ENSG/BambuTx rows: {len(df_known)} | BambuGene/BambuTx rows: {len(df_novel)}")

In [None]:
### Heatmaps: Expression of novel isoforms from new and known genes across cell-types and sub-cell-types

## plot heatmap for novel isoforms
import os, numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

FIG_DIR = "Intermediate_Files/Paper_Figs"
os.makedirs(FIG_DIR, exist_ok=True)

def plot_clustermap_with_coexpression(df_subset, filename_pdf, vmax, celltype_order=None):
    """Builds an expression clustermap (rows=cell types, cols=isoforms) and a
    co-expression (isoform-by-isoform) heatmap using the same column order."""
    if df_subset is None or len(df_subset) == 0:
        print(f"[plot] {os.path.basename(filename_pdf)}: empty input — skipping.")
        return

    # mean expression per (cell_type, isoform)
    summary = (df_subset
               .groupby(["cell_type", "clean_name"], observed=True)["expression"]
               .mean()
               .unstack(fill_value=0.0))

    # optional row order
    if celltype_order:
        present = [ct for ct in celltype_order if ct in summary.index]
        summary = summary.reindex(present)

    # safety: drop all-zero columns (corr would be NaN)
    if (summary == 0).all(axis=0).any():
        summary = summary.loc[:, ~(summary == 0).all(axis=0)]
        if summary.shape[1] == 0:
            print(f"[plot] {os.path.basename(filename_pdf)}: all-zero after filter — skipping.")
            return

    # sort isoforms alphabetically before clustering for stability
    summary = summary.reindex(columns=sorted(summary.columns))

    # expression clustermap (cluster columns only)
    col_cluster = summary.shape[1] >= 2
    g = sns.clustermap(
        summary, cmap="coolwarm",
        linewidths=0.4, linecolor='gray',
        vmin=-vmax, vmax=vmax, center=0,
        figsize=(max(len(summary.columns) * 0.3, 3.5), 5.5),
        row_cluster=False, col_cluster=col_cluster,
        cbar_kws={"label": ""},
        dendrogram_ratio=(0.1, 0.5) if col_cluster else (0.0, 0.0),
        cbar_pos=(0.02, 0.52, 0.03, 0.2)
    )
    g.ax_heatmap.set_xlabel(""); g.ax_heatmap.set_ylabel("")
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
    g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0, ha='right')

    # reorder columns by the clustermap dendrogram (if any) for co-expression
    if col_cluster:
        order = [summary.columns[i] for i in g.dendrogram_col.reordered_ind]
        summary_ord = summary[order]
    else:
        summary_ord = summary

    # isoform-by-isoform correlation
    corr = summary_ord.corr().replace([np.inf, -np.inf], np.nan).fillna(0.0)

    # save both
    base = os.path.splitext(filename_pdf)[0]
    g.savefig(base + "_dendrogram.pdf", dpi=600, transparent=True, bbox_inches="tight")
    plt.figure(figsize=(max(len(summary_ord.columns) * 0.3, 3.5), 10))
    sns.heatmap(corr, cmap="coolwarm", vmin=-1, vmax=1, center=0,
                linewidths=0.4, linecolor='gray')
    plt.xticks(rotation=90); plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(base + "_coexpression.pdf", dpi=600, transparent=True, bbox_inches="tight")
    plt.close()
    print(f"[plot] saved: {base}_dendrogram.pdf and {base}_coexpression.pdf")

In [None]:
## calls (each subset already enforced by your split step)
celltype_order = [
    "Effector CD4 T cells", "Effector CD8 T cells", "Memory T cells",
    "Effector-Memory Transition T cells", "NK Cells", "BCells",
    "Monocyte-derived", "Megakaryocytes"
]

# ENSG:BambuTx (known gene / novel isoform)
if len(df_known):
    plot_clustermap_with_coexpression(
        df_known,
        os.path.join(FIG_DIR, "known_gene__BambuTx_subcelltype_heatmap_clustered_coexpression.pdf"),
        vmax=2, celltype_order=celltype_order
    )
else:
    print("[plot] df_known empty — skipping known plots.")

# BambuGene:BambuTx (novel gene / novel isoform)
if len(df_novel):
    plot_clustermap_with_coexpression(
        df_novel,
        os.path.join(FIG_DIR, "bambu_gene__BambuTx_subcelltype_heatmap_clustered_coexpression.pdf"),
        vmax=2, celltype_order=celltype_order
    )
else:
    print("[plot] df_novel empty — skipping novel plots.")