# Load Packages

In [None]:
# followed https://squidpy.readthedocs.io/en/stable/notebooks/tutorials/tutorial_xenium.html

import numpy as np
import pandas as pd
import os

import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import squidpy as sq

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial']

In [None]:
sc.settings.verbosity = 3  
sc.settings.set_figure_params(
    dpi=200,        # on-screen quality
    dpi_save=400,   # when saving
    fontsize=16,    # base font size 
    facecolor="white",
    vector_friendly=True,
)

# Read AnnData

In [None]:
adata = sc.read_h5ad("/Users/bhavyasingh/Downloads/JMT_cleaned_confident_only.h5ad")

In [None]:
adata

# Overall AnnData Analysis

## Export gene list

In [None]:
adata.var["gene_ids"].to_csv("Xenium_panel.csv")

## Color scheme

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
from matplotlib.colors import ListedColormap
import seaborn as sns

def create_comprehensive_color_scheme(adata):
    """
    Create a comprehensive color scheme with:
    1. Distinct colors for broad cell types
    2. Shaded variants for subtypes 
    3. Mixed colors for juxtaposed cells
    4. Specific colors for QC flags and treatments
    """
    
    broad_base_colors = {
        'Cancer Cells': "#CA2225D2",      # Red
        'T and NK Cells': "#2B6E9BC1",    # Blue  
        'B Cells': "#268A1FBB",           # Green
        'Myeloid Cells': "#BA6612BC",     # Orange
        'Stromal Cells': "#551E90B5",     # Purple
        'Plasma Cells': "#D18483C1",      # Light Pink
    }
    
    def get_broad_category(cell_type):
        if any(cancer in cell_type for cancer in ['Cancer', 'EMT', 'Stem-Like', 'Proliferating', 'IFN-Responsive', 'Angiogenesis', 'Metabolically', 'Transitional']):
            return 'Cancer Cells'
        elif any(t in cell_type for t in ['CD8', 'CD4', 'T cells', 'NK', 'TNK', 'Tregs']):
            return 'T and NK Cells'
        elif 'B Cells' in cell_type or cell_type.startswith('B '):
            return 'B Cells'
        elif any(m in cell_type for m in ['Macrophages', 'Macs', 'Monocytes', 'Dendritic', 'Neutrophils', 'Myeloid']):
            return 'Myeloid Cells'
        elif any(s in cell_type for s in ['Fibroblasts', 'Endothelial', 'Stromal', 'MyoCAFs', 'Adipocytes']):
            return 'Stromal Cells'
        elif 'Plasma' in cell_type:
            return 'Plasma Cells'
        else:
            return 'Other'
    
    def generate_color_variants(base_color, n_variants, darken_factor=0.3, lighten_factor=0.3):
        """Generate darker and lighter variants of a base color"""
        base_rgb = mcolors.to_rgb(base_color)
        variants = []
        
        if n_variants == 1:
            return [base_color]
        
        for i in range(n_variants):
            if i < n_variants // 2:
                factor = 1 - (darken_factor * (n_variants // 2 - i) / (n_variants // 2))
                variant = tuple(c * factor for c in base_rgb)
            elif i == n_variants // 2:
                variant = base_rgb
            else:
                factor = lighten_factor * (i - n_variants // 2) / (n_variants // 2)
                variant = tuple(min(1.0, c + factor * (1 - c)) for c in base_rgb)
            
            variants.append(mcolors.to_hex(variant))
        
        return variants
    
    def mix_colors(color1, color2, ratio=0.5):
        """Mix two colors with specified ratio"""
        rgb1 = np.array(mcolors.to_rgb(color1))
        rgb2 = np.array(mcolors.to_rgb(color2))
        mixed = ratio * rgb1 + (1 - ratio) * rgb2
        return mcolors.to_hex(mixed)
    
    def get_juxtaposed_color(cell_type):
        """Get mixed color for juxtaposed cell types"""
        if 'Juxtaposed' not in cell_type:
            return None
            
        components = cell_type.replace('Juxtaposed ', '').split(' - ')
        
        if len(components) == 2:
            broad1 = get_broad_category(components[0])
            broad2 = get_broad_category(components[1])
            
            if broad1 in broad_base_colors and broad2 in broad_base_colors:
                return mix_colors(broad_base_colors[broad1], broad_base_colors[broad2])
        
        elif len(components) == 3:
            broad1 = get_broad_category(components[0])
            broad2 = get_broad_category(components[1]) 
            broad3 = get_broad_category(components[2])
            
            if all(b in broad_base_colors for b in [broad1, broad2, broad3]):
                temp_mix = mix_colors(broad_base_colors[broad1], broad_base_colors[broad2])
                return mix_colors(temp_mix, broad_base_colors[broad3], ratio=0.67)
        
        return '#808080'
    
    color_schemes = {}
    
    color_schemes['initial_broad'] = broad_base_colors.copy()
    
    print("Creating subtype color scheme...")
    subtype_colors = {}
    
    subtypes_by_broad = {}
    for subtype in adata.obs['subtype'].unique():
        if pd.notna(subtype):
            broad_cat = get_broad_category(subtype)
            if broad_cat not in subtypes_by_broad:
                subtypes_by_broad[broad_cat] = []
            subtypes_by_broad[broad_cat].append(subtype)
    
    for broad_cat, subtypes in subtypes_by_broad.items():
        regular_subtypes = [s for s in subtypes if 'Juxtaposed' not in s]
        juxta_subtypes = [s for s in subtypes if 'Juxtaposed' in s]
        
        if regular_subtypes and broad_cat in broad_base_colors:
            n_regular = len(regular_subtypes)
            print(f"  {broad_cat}: {n_regular} regular subtypes", end="")
            
            if n_regular <= 3:
                range_info = "(standard range)"
            elif n_regular <= 6:
                range_info = "(medium range)"
            elif n_regular <= 10:
                range_info = "(large range)"
            else:
                range_info = "(extra large range + HSV variation)"
            
            print(f" {range_info}")
            
            variants = generate_color_variants(broad_base_colors[broad_cat], n_regular)
            for subtype, color in zip(sorted(regular_subtypes), variants):
                subtype_colors[subtype] = color
        
        for subtype in juxta_subtypes:
            mixed_color = get_juxtaposed_color(subtype)
            if mixed_color:
                subtype_colors[subtype] = mixed_color
    
    color_schemes['subtype'] = subtype_colors
    
    print("Creating subtype_granular color scheme...")
    granular_colors = {}
    
    granular_by_broad = {}
    for granular in adata.obs['subtype_granular'].unique():
        if pd.notna(granular):
            broad_cat = get_broad_category(granular)
            if broad_cat not in granular_by_broad:
                granular_by_broad[broad_cat] = []
            granular_by_broad[broad_cat].append(granular)
    
    for broad_cat, granulars in granular_by_broad.items():
        regular_granulars = [g for g in granulars if 'Juxtaposed' not in g]
        juxta_granulars = [g for g in granulars if 'Juxtaposed' in g]
        
        if regular_granulars and broad_cat in broad_base_colors:
            variants = generate_color_variants(broad_base_colors[broad_cat], len(regular_granulars))
            for granular, color in zip(sorted(regular_granulars), variants):
                granular_colors[granular] = color
        
        for granular in juxta_granulars:
            mixed_color = get_juxtaposed_color(granular)
            if mixed_color:
                granular_colors[granular] = mixed_color
    
    color_schemes['subtype_granular'] = granular_colors
    
    print("Creating juxta_call color scheme...")
    juxta_call_colors = {}
    
    for juxta_call in adata.obs['juxta_call'].unique():
        if pd.notna(juxta_call):
            if juxta_call in subtype_colors:
                juxta_call_colors[juxta_call] = subtype_colors[juxta_call]
            else:
                if 'Juxtaposed' not in juxta_call and ' - ' in juxta_call:
                    juxtaposed_name = f'Juxtaposed {juxta_call}'
                    if juxtaposed_name in subtype_colors:
                        juxta_call_colors[juxta_call] = subtype_colors[juxtaposed_name]
                    else:
                        mixed_color = get_juxtaposed_color(juxta_call)
                        juxta_call_colors[juxta_call] = mixed_color if mixed_color else '#808080'
                else:
                    juxta_call_colors[juxta_call] = '#808080'
    
    color_schemes['juxta_call'] = juxta_call_colors
    
    qc_colors = {
        'confident': '#D3D3D3',      # Light grey
        'juxtaposed': '#FFA500',     # Orange
        'ambiguous': '#FFD700',      # Gold  
        'doublet': '#DC143C',        # Crimson
        'low_quality': '#8B4513',    # Saddle brown
        'to_discard': '#000000'      # Black
    }
    color_schemes['qc_flag'] = qc_colors
    
    treatment_colors = {
        "Treated": "#F44646",   
         "Control": "#6E6EE5"    # neutral gray
        # Steel blue 
    }
    
    color_schemes['Treatment'] = treatment_colors
    
    if 'Region_ID' in adata.obs.columns:
        regions = adata.obs['Region_ID'].unique()
        region_colors = {}
        # Use a colorblind-friendly palette
        palette = sns.color_palette("Set2", len(regions))
        for region, color in zip(sorted(regions), palette):
            if pd.notna(region):
                region_colors[region] = mcolors.to_hex(color)
        color_schemes['Region_ID'] = region_colors
    
    return color_schemes

def apply_color_schemes(adata, color_schemes):
    """Apply color schemes to adata.uns for scanpy"""
    
    for annotation, colors in color_schemes.items():
        color_key = f'{annotation}_colors'
        
        if annotation in adata.obs.columns:
            unique_values = adata.obs[annotation].cat.categories if hasattr(adata.obs[annotation], 'cat') else adata.obs[annotation].unique()
            
            color_list = []
            for value in unique_values:
                if pd.notna(value) and value in colors:
                    color_list.append(colors[value])
                else:
                    color_list.append('#808080')  # Default grey for missing
            
            adata.uns[color_key] = color_list
            print(f"Applied {len(color_list)} colors for {annotation}")
    
    return adata

def plot_color_preview(color_schemes, save_path=None):
    """Create a preview plot of all color schemes"""
    
    n_schemes = len(color_schemes)
    fig, axes = plt.subplots(n_schemes, 1, figsize=(15, 3*n_schemes))
    if n_schemes == 1:
        axes = [axes]
    
    for idx, (scheme_name, colors) in enumerate(color_schemes.items()):
        ax = axes[idx]
        
        items = list(colors.items())
        regular_items = [(k, v) for k, v in items if 'Juxtaposed' not in k]
        juxta_items = [(k, v) for k, v in items if 'Juxtaposed' in k]
        sorted_items = regular_items + juxta_items
        
        y_pos = np.arange(len(sorted_items))
        colors_list = [item[1] for item in sorted_items]
        labels = [item[0] for item in sorted_items]
        
        bars = ax.barh(y_pos, [1]*len(sorted_items), color=colors_list, height=0.8)
        
        ax.set_yticks(y_pos)
        ax.set_yticklabels(labels, fontsize=8)
        ax.set_xlabel('Color')
        ax.set_title(f'{scheme_name.replace("_", " ").title()} Color Scheme', fontweight='bold')
        ax.set_xlim(0, 1)
        
        ax.set_xticks([])
        
        for i, (bar, color) in enumerate(zip(bars, colors_list)):
            ax.text(0.5, i, color, ha='center', va='center', fontsize=6, 
                   color='white' if sum(mcolors.to_rgb(color)) < 1.5 else 'black')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Color preview saved to {save_path}")
    
    plt.show()
    
    return fig

def setup_all_colors(adata, plot_preview=True):
    """Main function to create and apply comprehensive color scheme"""
    
    print("Creating comprehensive color scheme...")
    color_schemes = create_comprehensive_color_scheme(adata)
    
    print("Applying color schemes to adata...")
    adata = apply_color_schemes(adata, color_schemes)
    
    if plot_preview:
        print("Creating color preview plot...")
        plot_color_preview(color_schemes)
    
    print("Color scheme setup complete!")
    print(f"Applied colors for: {list(color_schemes.keys())}")
    
    return adata, color_schemes

In [None]:
adata, color_schemes = setup_all_colors(adata, plot_preview=True)

# General Figures

## UMAPs

In [None]:
# Optionally tighten specific elements (not needed if 8 is fine)
plt.rcParams.update({
    "legend.fontsize": 8,
    "axes.titlesize": 12,
    "axes.labelsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
})


sc.pl.umap(
    adata,
    color=["initial_broad", "subtype", "subtype_granular", "Region_ID", "Treatment"],
    show=False,
    legend_loc="right margin",
    legend_fontsize=10,
    legend_fontoutline=0,
    title=["Cell Types in Control and Treated\nOmentum Metastases", "Cell Subtypes", "Granular Subtypes", "Samples", "Treatment"],
    size=2,
    frameon=False,
    ncols=1,
)


In [None]:
sq.pl.spatial_scatter(
    adata,
    library_id="spatial",
    shape=None,
    color=['Treatment', "Region_ID", "initial_broad"],
    wspace=0.5,
    ncols=3,
    size=0.2,
    vmax=1.14,
    frameon=False,
    title=["Control           Treated",
           "Control           Treated",
           "Control           Treated"],
    vcenter=0,
    figsize=(10, 10)
)

## Abundance plots

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

adata_sub = adata
group_col = "initial_broad"

def first_present(df, cols):
    for c in cols:
        if c in df.columns: return c
    return None

def style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype", palette=None, labels=None):
    ax.grid(False)
    for spine in ("top","right"):
        ax.spines[spine].set_visible(False)
    for spine in ("left","bottom"):
        ax.spines[spine].set_linewidth(1.1)
    ax.tick_params(axis="both", which="both", direction="out", length=4, width=1, labelsize=10)
    if y_is_percent:
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
    if legend:
        leg = ax.legend(title=legend_title, bbox_to_anchor=(1.02, 1), loc="upper left",
                        frameon=False, fontsize=9, title_fontsize=9, handlelength=1.2, handletextpad=0.4)
    else:
        ax.legend_.remove()

def get_palette(adata_sub, group_col, cats):
    uns_key = f"{group_col}_colors"
    if uns_key in adata_sub.uns and len(adata_sub.uns[uns_key]) >= len(cats):
        return [dict(zip(cats, adata_sub.uns[uns_key]))[c] for c in cats]
    return None

obs = adata_sub.obs.copy()
sample_col    = first_present(obs, ["sample","sample_id","orig.ident","library","batch","donor","patient"]) or "Region_ID"
condition_col = first_present(obs, ["Treatment","condition","group","stim","status"])  # optional

if group_col not in obs.columns:
    raise ValueError(f"'{group_col}' not in .obs")

if sample_col not in obs.columns:
    obs[sample_col] = "all"

if pd.api.types.is_categorical_dtype(obs[group_col]):
    cats = list(obs[group_col].cat.categories)
else:
    cats = sorted(obs[group_col].dropna().unique().tolist())
palette = get_palette(adata_sub, group_col, cats)

groupers = [sample_col, group_col] + ([condition_col] if condition_col else [])
counts = (obs.groupby(groupers, observed=True).size().reset_index(name="n"))

pivot_sample = counts.pivot_table(index=sample_col, columns=group_col, values="n", fill_value=0).reindex(columns=cats)

# Absolute per-sample
fig, ax = plt.subplots(figsize=(4, 4))
pivot_sample.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.85)
ax.set_ylabel("Number of cells", fontsize=12)
ax.set_xlabel("Control     Treated", fontsize=12)
ax.set_title("Absolute abundance of Cell Types", fontsize=12)
style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype")
fig.tight_layout()
plt.show()

# Relative per-sample
prop_sample = pivot_sample.div(pivot_sample.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)
fig, ax = plt.subplots(figsize=(4, 4))
prop_sample.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.85)
ax.set_ylabel("Relative abundance", fontsize=11)
ax.set_xlabel("")
ax.set_title("Relative abundance per sample (subtype_granular)", fontsize=12)
style_axes(ax, y_is_percent=True, legend=True, legend_title="Subtype")
fig.tight_layout()
plt.show()

if condition_col:
    pivot_cond = (
        counts.groupby([condition_col, group_col])["n"].sum().unstack(fill_value=0).reindex(columns=cats)
    )

    # Absolute by condition
    fig, ax = plt.subplots(figsize=(1, 4))
    pivot_cond.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.7)
    ax.set_ylabel("Number of cells", fontsize=11)
    ax.set_xlabel("")
    ax.set_title("Absolute abundance by treatment", fontsize=12)
    style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype")
    fig.tight_layout()
    plt.show()

    # Relative by condition
    prop_cond = pivot_cond.div(pivot_cond.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)
    fig, ax = plt.subplots(figsize=(1, 4))
    prop_cond.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.7)
    ax.set_ylabel("Relative abundance", fontsize=11)
    ax.set_xlabel("")
    ax.set_title("Relative abundance by treatment", fontsize=12)
    style_axes(ax, y_is_percent=True, legend=True, legend_title="Subtype")
    fig.tight_layout()
    plt.show()

In [None]:
adata_sub = adata
group_col = "subtype_granular"

def first_present(df, cols):
    for c in cols:
        if c in df.columns: return c
    return None

def style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype", palette=None, labels=None):
    ax.grid(False)
    for spine in ("top","right"):
        ax.spines[spine].set_visible(False)
    for spine in ("left","bottom"):
        ax.spines[spine].set_linewidth(1.1)
    ax.tick_params(axis="both", which="both", direction="out", length=4, width=1, labelsize=10)
    if y_is_percent:
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
    if legend:
        leg = ax.legend(title=legend_title, bbox_to_anchor=(1.02, 1), loc="upper left",
                        frameon=False, fontsize=9, title_fontsize=9, handlelength=1.2, handletextpad=0.4)
    else:
        ax.legend_.remove()

def get_palette(adata_sub, group_col, cats):
    uns_key = f"{group_col}_colors"
    if uns_key in adata_sub.uns and len(adata_sub.uns[uns_key]) >= len(cats):
        return [dict(zip(cats, adata_sub.uns[uns_key]))[c] for c in cats]
    return None

obs = adata_sub.obs.copy()
sample_col    = first_present(obs, ["sample","sample_id","orig.ident","library","batch","donor","patient"]) or "Region_ID"
condition_col = first_present(obs, ["Treatment","condition","group","stim","status"])  # optional

if group_col not in obs.columns:
    raise ValueError(f"'{group_col}' not in .obs")

if sample_col not in obs.columns:
    obs[sample_col] = "all"

if pd.api.types.is_categorical_dtype(obs[group_col]):
    cats = list(obs[group_col].cat.categories)
else:
    cats = sorted(obs[group_col].dropna().unique().tolist())
palette = get_palette(adata_sub, group_col, cats)

groupers = [sample_col, group_col] + ([condition_col] if condition_col else [])
counts = (obs.groupby(groupers, observed=True).size().reset_index(name="n"))

pivot_sample = counts.pivot_table(index=sample_col, columns=group_col, values="n", fill_value=0).reindex(columns=cats)

fig, ax = plt.subplots(figsize=(4, 4))
pivot_sample.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.85)
ax.set_ylabel("Number of cells", fontsize=11)
ax.set_xlabel("Control     Treated", fontsize=11)
ax.set_title("Absolute abundance of Cell Types", fontsize=12)
style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype")
fig.tight_layout()
plt.show()

prop_sample = pivot_sample.div(pivot_sample.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)
fig, ax = plt.subplots(figsize=(4, 4))
prop_sample.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.85)
ax.set_ylabel("Relative abundance", fontsize=11)
ax.set_xlabel("")
ax.set_title("Relative abundance per sample (subtype_granular)", fontsize=12)
style_axes(ax, y_is_percent=True, legend=True, legend_title="Subtype")
fig.tight_layout()
plt.show()

if condition_col:
    pivot_cond = (
        counts.groupby([condition_col, group_col])["n"].sum().unstack(fill_value=0).reindex(columns=cats)
    )

    fig, ax = plt.subplots(figsize=(1, 4))
    pivot_cond.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.7)
    ax.set_ylabel("Number of cells", fontsize=11)
    ax.set_xlabel("")
    ax.set_title("Absolute abundance by treatment", fontsize=12)
    style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype")
    fig.tight_layout()
    plt.show()

    prop_cond = pivot_cond.div(pivot_cond.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)
    fig, ax = plt.subplots(figsize=(1, 4))
    prop_cond.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.7)
    ax.set_ylabel("Relative abundance", fontsize=11)
    ax.set_xlabel("")
    ax.set_title("Relative abundance by treatment", fontsize=12)
    style_axes(ax, y_is_percent=True, legend=True, legend_title="Subtype")
    fig.tight_layout()
    plt.show()

## Main cell type dotplot

In [None]:
marker_set_nice = {
    'B cells': ['Cd79b', 'Ms4a1'],
    'Plasma cells': ['Igkc', 'Iglc3', 'Xbp1'],
    'Myeloid Cells': ['C1qa', 'C1qb', 'C1qc', 'Cd68', 'Aif1', 'Csf1r', 'Marco', 'Mrc1', 'Trem2', 'Clec9a', 'Xcr1', 'Cd86'],
    'T cells': ['Cd3d', 'Cd3e', 'Cd4', 'Cd8a', 'Trac', 'Tcf7'],
    'Tregs': ['Foxp3', 'Il2ra', 'Tigit', 'Ctla4'],
    'Exhausted T cells': ['Tox','Tigit'],
    'NK cells': ['Nkg7', 'Klrb1c', 'Prf1', 'Gzmb', 'Ifng'],
    'Tumor epithelial': ['Krt8', 'Krt18', 'Krt19', 'Krt7'],
    'Fibroblasts': ['Col1a1', 'Col3a1', 'Acta2', "Lrrc15"],
    'Endothelial': ['Cdh5', 'Pecam1', 'Plvap', 'Flt1', 'Kdr']
    }

# Then:
sc.pl.dotplot(
    adata,
    var_names=marker_set_nice,
    groupby="initial_broad",
    standard_scale="var",
    dendrogram=True,
    swap_axes=False,
    figsize=(13, 2)
)

# Create Mac Data

In [None]:
myeloid_adata = adata[adata.obs['initial_broad'].isin(["Myeloid Cells"])].copy()
mac_adata = myeloid_adata[myeloid_adata.obs['subtype'].isin(["Macrophages"])].copy()

In [None]:
# neighbors/UMAP once
sc.pp.neighbors(mac_adata, use_rep='X_scVI')
sc.tl.umap(mac_adata, min_dist=0.3, random_state=0)

# Macrophage Figures

## Create dicts

In [None]:
mac_state_dict = {
    "Immunosuppressive Macrophages": [
        "Arg1", "Cd274", "Folr2", "Hmox1", "Il10", "Igf1", 
        "Mertk", "Mrc1", "Tgfb1", "Trem2", "Il4"
    ],
    "Pro-inflammatory Macrophages": [
        "Cd14", "Cd40", "Cd80", "Cd86", "Cybb", "Epas1", 
        "Il1b", "Il6", "Ccr7"
    ],
    "Myeloid Regulatory (Immunosuppression)": [
        "Arg1", "Cd274", "Il4", "Il10"
    ],
    "Myeloid Regulatory (Type I Inflammation)": [ # Il1b subtype
        "Il1b", "Il6", "Tnf"
    ],
    'Il12_response' : ['Cxcl10','Ccl12', 'Iigp1', 'Serpina3g', 'Cxcl9', 'Gbp7', 'Ifi47',
                 'Fam26f', 'Pnp', 'Serpina3f', 'Ifi203', 'Irf1', 'Irgm1', 'Ccl2', 'Ifi204',
                 'Ifi211', 'Gbp2', 'Irf7', 'Stat1', 'Igtp', 'Themis2', 'Ifit2', 'Gbp5',
                 'Zbp1', 'Socs1', 'Eif2ak2', 'Ifit1','Irgm2']
}

In [None]:
mac_state_dict_2 = {
    "Immunosuppressive Macrophages": [
        "Arg1", "Cd274", "Folr2", "Hmox1", "Il10", "Igf1", 
        "Mertk", "Mrc1", "Tgfb1", "Trem2", "Il4"
    ],
    "Pro-inflammatory Macrophages": [
        "Cd14", "Cd40", "Cd80", "Cd86", "Cybb", "Epas1", 
        "Il1b", "Il6", "Ccr7"
    ],
    "Myeloid Regulatory (Immunosuppression)": [
        "Arg1", "Cd274", "Il4", "Il10"
    ],
    "Myeloid Regulatory (Type I Inflammation)": [ # Il1b subtype
        "Il1b", "Il6", "Tnf"
    ],
    'Il12 response' : ['Cxcl10', 'Cxcl9', 'Ccl2','Irf7', 'Stat1']
}

In [None]:
import matplotlib.pyplot as plt
import scanpy as sc

mpl.rcParams["axes.grid"] = True
mpl.rcParams["axes.spines.top"] = True
mpl.rcParams["axes.spines.right"] = True

dp = sc.pl.dotplot(
    mac_adata,
    var_names={
        "I": ["Folr2","Igf1", "Mrc1", "Tgfb1", "Trem2", "Il10",  "Il4"],
        "II": ["Arg1", "Cd274"],
        'III' : ["Il1b", "Tnf"],
        "IV": ['Cxcl10', 'Cxcl9', 'Stat1',  'Ccl2'],
        "V": ["Cd14", "Cd40", "Cd80", "Cd86", "Cybb", "Epas1"],
        "VI": ['Il18', 'Ccl5', "Il18bp", "Ctsd", "Il15ra","Aif1", "Tlr1", "Vcam1",
               "Echs1", "H2-Ab1", 'H2-K1', 'B2m', 'Tap1', 'Psmb8', 'Psmb9']
    },
    groupby="subtype_granular",
    standard_scale="var",
    cmap="Reds",
    swap_axes=False,
    dendrogram=True,
    return_fig=False
)
# Move legend to desired position




## Scores

In [None]:
def present_genes(mac_adata, genes):
    lower_map = {g.lower(): g for g in mac_adata.var_names}
    keep, missing = [], []
    for g in genes:
        gl = g.lower()
        if gl in lower_map:
            keep.append(lower_map[gl])
        else:
            missing.append(g)
    return keep, missing

use_raw = mac_adata.raw is not None 
score_cols, missing_report = [], {}

for label, genes in mac_state_dict.items():
    keep, missing = present_genes(mac_adata, genes)
    if missing:
        missing_report[label] = missing
    if len(keep) == 0:
        print(f"[WARN] No genes present for {label}; skipping.")
        continue
    sc.tl.score_genes(
        mac_adata,
        gene_list=keep,
        ctrl_size=50,             
        score_name=f"{label}_score",
        use_raw=use_raw,
        random_state=0
    )
    score_cols.append(f"{label}_score")

if missing_report:
    print("Missing genes (ignored):")
    for k, v in missing_report.items():
        print(f"  {k}: {', '.join(v)}")

## Colors and UMAPs

In [None]:
import matplotlib as mpl
mpl.rcParams.update({
    "figure.dpi": 300,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.titlesize": 20,
    "axes.labelsize": 14,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "legend.fontsize": 11,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "axes.grid": False,  # ensure no grid globally
})

sc.pl.umap(
    mac_adata,
    color=["Cd80", "Cd40", "Il1b", "Folr2", "Cxcl9", 'Arg1'],
    show=False,
    legend_loc="right margin",
    legend_fontsize=10,
    legend_fontoutline=0,
    cmap="bwr",         # used for continuous only
    title=["Cd80", "Cd40"],
    vmax="p99.5",
    size=2,
    frameon=False,
    ncols=1,
)


In [None]:
plt.rcParams.update({
    "figure.dpi": 300,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.titlesize": 20,
    "axes.labelsize": 14,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "legend.fontsize": 11,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "axes.grid": False,  # ensure no grid globally
})
sc.pl.umap(
    mac_adata,
    color=["subtype_granular", "Treatment"],
    show=False,
    legend_loc="right margin",
    legend_fontsize=10,
    legend_fontoutline=0,
    cmap="coolwarm",         # used for continuous only
    title=["Macrophages", "Treatment"],
    vmax="p99.5",
    size=2,
    frameon=False,
    ncols=1,
)


In [None]:
# 1) Make pretty aliases (only if the originals exist)
rename_map = {
    "Immunosuppressive Macrophages_score": "Immunosuppressive (Mφ) score",
    "Pro-inflammatory Macrophages_score":  "Inflammatory (Mφ) score",
    "Il12_response_score":                 "IL-12 response"
}
for old, new in rename_map.items():
    if old in mac_adata.obs:
        mac_adata.obs[new] = mac_adata.obs[old].values
    else:
        print(f"Missing: {old}")

# 2) Plot using the new display names
sc.pl.stacked_violin(
    mac_adata,
    list(rename_map.values()),
    groupby="Treatment",
    swap_axes=True,
    dendrogram=True,
    cmap="coolwarm",   # or your warm custom cmap
    show=False,figsize=(3,2)
)


In [None]:
sq.pl.spatial_scatter(
    mac_adata,
    library_id="spatial",
    shape=None,
    color=['Folr2'],
    palette="Set2",
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    title=["Folr2\n\nControl           Treated"],
    cmap="bwr",
    vcenter=0,
    figsize=(10, 10)
)

In [None]:
grey_bwr = colors.LinearSegmentedColormap.from_list(
    "grey_bwr",
    [(0.0, plt.cm.bwr(0.0)),   
     (0.5, "#ffffff"),         
     (1.0, plt.cm.bwr(1.0))], 
    N=256
)

sq.pl.spatial_scatter(
    mac_adata,
    library_id="spatial",
    shape=None,
    color=['Immunosuppressive Macrophages_score'],
    palette="Set2",
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    title=["Immunosuppressive Macrophage Score\n\nControl           Treated"],
    cmap=grey_bwr, 
    vcenter=0,
    figsize=(10, 10)
)

norm = colors.TwoSlopeNorm(vmin=-0.25, vcenter=0, vmax=0.4)
sm = plt.cm.ScalarMappable(cmap=grey_bwr, norm=norm)
sm.set_array([])

fig_cb, ax_cb = plt.subplots(figsize=(0.5, 2.0), dpi=300)  
cbar = plt.colorbar(sm, cax=ax_cb, orientation='vertical')
cbar.ax.tick_params(labelsize=15, length=3, width=0.8)    
plt.show(fig_cb)
plt.close(fig_cb)



In [None]:
sq.pl.spatial_scatter(
    mac_adata,
    library_id="spatial",
    shape=None,
    color=['Il12_response_score'],
    palette="Set2",
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    title=["Response to Il12\n\nControl           Treated"],
    cmap=grey_bwr, 
    vcenter=0,
    figsize=(10, 10)
)

norm = colors.TwoSlopeNorm(vmin=-0.4, vcenter=0, vmax=0.5)
sm = plt.cm.ScalarMappable(cmap=grey_bwr, norm=norm)
sm.set_array([])

fig_cb, ax_cb = plt.subplots(figsize=(0.5, 2.0), dpi=300) 
cbar = plt.colorbar(sm, cax=ax_cb, orientation='vertical')
cbar.ax.tick_params(labelsize=15, length=3, width=0.8)     
plt.show(fig_cb)
plt.close(fig_cb)


In [None]:
sq.pl.spatial_scatter(
    mac_adata,
    library_id="spatial",
    shape=None,
    color=['Il12_response_score', 'Arg1', 'Cxcl9'],
    palette="Set2",
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    title=["Response to Il12\n\nControl           Treated", 
           "Arg1\n\nControl           Treated",
           "Cxcl9\n\nControl           Treated"],
    cmap="bwr",
    vcenter=0,
    figsize=(10, 10)
)

In [None]:
sq.pl.spatial_scatter(
    mac_adata,
    library_id="spatial",
    shape=None,
    color=['Arg1'],
    palette="Set2",
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    title=["Arg1\n\nControl           Treated"],
    cmap="Reds",
    figsize=(10, 10)
)


norm = colors.TwoSlopeNorm(vmin=0, vcenter=0.8, vmax=1.67)
sm = plt.cm.ScalarMappable(cmap="Reds", norm=norm)
sm.set_array([])

fig_cb, ax_cb = plt.subplots(figsize=(0.5, 2.0), dpi=300) 
cbar = plt.colorbar(sm, cax=ax_cb, orientation='vertical')
cbar.ax.tick_params(labelsize=15, length=3, width=0.8)    

plt.show(fig_cb)
plt.close(fig_cb)


In [None]:
sq.pl.spatial_scatter(
    mac_adata,
    library_id="spatial",
    shape=None,
    color=['Cxcl9'],
    palette="Set2",
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    title=["Cxcl9\n\nControl           Treated"],
    cmap="Blues",
    vcenter=0,
    figsize=(10, 10)
)

In [None]:
order = [
    "Folr2- Arg+ Macs",
    "Folr2 High Macs",
    "Folr2 Low Macs",
    "Il1b+ Macs",
    "Cxcl9+ Macs"]

mac_adata.obs["subtype_granular"] = mac_adata.obs["subtype_granular"].astype("category").cat.reorder_categories(order, ordered=True)
apply_palette(mac_adata, "subtype_granular", subtype_colors)  # re-apply to match new order

sq.pl.spatial_scatter(
    mac_adata,
    library_id="spatial",
    shape=None,
    color='subtype_granular',
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    title=["Macrophage Subets\n\nControl           Treated"],
    cmap="bwr",
    vcenter=0,
    figsize=(10, 10)
)





## Abundance plots

In [None]:
adata_sub = mac_adata
group_col = "subtype_granular"


# ---------- data prep ----------
obs = adata_sub.obs.copy()
sample_col    = "Region_ID"
condition_col = "Treatment"


# Category order + palette
if pd.api.types.is_categorical_dtype(obs[group_col]):
    cats = list(obs[group_col].cat.categories)
else:
    cats = sorted(obs[group_col].dropna().unique().tolist())
palette = get_palette(adata_sub, group_col, cats)

# Counts table (per-sample, optionally per-condition)
groupers = [sample_col, group_col] + ([condition_col] if condition_col else [])
counts = (obs.groupby(groupers, observed=True).size().reset_index(name="n"))

pivot_sample = counts.pivot_table(index=sample_col, columns=group_col, values="n", fill_value=0).reindex(columns=cats)

# Absolute per-sample
fig, ax = plt.subplots(figsize=(4, 3))
pivot_sample.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.85)
ax.set_ylabel("Number of cells", fontsize=11)
ax.set_xlabel("Control     Treated", fontsize=11)
ax.set_title("Absolute abundance of Macrophages", fontsize=12)
style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype")
fig.tight_layout()
plt.show()

# Relative per-sample
prop_sample = pivot_sample.div(pivot_sample.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)
fig, ax = plt.subplots(figsize=(4, 3))
prop_sample.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.85)
ax.set_ylabel("Relative abundance", fontsize=11)
ax.set_xlabel("")
ax.set_title("Relative abundance per sample (subtype_granular)", fontsize=12)
style_axes(ax, y_is_percent=True, legend=True, legend_title="Subtype")
fig.tight_layout()
plt.show()

if condition_col:
    pivot_cond = (
        counts.groupby([condition_col, group_col])["n"].sum().unstack(fill_value=0).reindex(columns=cats)
    )

    # Absolute by condition
    fig, ax = plt.subplots(figsize=(1, 4))
    pivot_cond.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.7)
    ax.set_ylabel("Number of cells", fontsize=11)
    ax.set_xlabel("")
    ax.set_title("Absolute abundance by condition (subtype_granular)", fontsize=12)
    style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype")
    fig.tight_layout()
    plt.show()

    # Relative by condition
    prop_cond = pivot_cond.div(pivot_cond.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)
    fig, ax = plt.subplots(figsize=(1, 4))
    prop_cond.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.7)
    ax.set_ylabel("Relative abundance", fontsize=11)
    ax.set_xlabel("")
    ax.set_title("Relative abundance by condition (subtype_granular)", fontsize=12)
    style_axes(ax, y_is_percent=True, legend=True, legend_title="Subtype")
    fig.tight_layout()
    plt.show()



# T and NK

## Setup

In [None]:
tnk_adata = adata[adata.obs['initial_broad'].isin(["T and NK Cells"])].copy()

In [None]:
sc.pp.neighbors(tnk_adata, use_rep='X_scVI')
sc.tl.umap(tnk_adata, min_dist=0.3, random_state=0)

## % Gzmb positive

In [None]:
import numpy as np
import pandas as pd
from scipy.stats import ttest_ind

def summarize_marker_positivity(
    adata,
    gene_candidates,
    group_key,
    scope_key=None,          
    pos_threshold=0.0,       
    label=None,              
    to_csv_path=None,        
    do_ttests=False,
    treatment_key="Treatment",
    replicate_key="Region_ID",
    return_rep_df=True        
):

    gene = pick_gene(adata, gene_candidates)
    expr = get_expr(adata, gene)
    pos = expr > float(pos_threshold)

    obs = adata.obs.copy()
    if group_key not in obs.columns:
        raise KeyError(f"Column '{group_key}' not found in adata.obs")

    if do_ttests:
        for col in (treatment_key, replicate_key):
            if col not in obs.columns:
                raise KeyError(f"Column '{col}' not found in adata.obs")

    base = pd.DataFrame({
        group_key:   obs[group_key].astype(str).values,
        "expr":      np.asarray(expr, dtype=float),
        "POS":       pos,
        "scope_kind": "All cells",
        "scope":      "ALL",
    })

    if do_ttests:
        base[treatment_key] = obs[treatment_key].astype(str).values
        base[replicate_key] = obs[replicate_key].astype(str).values

    keep_cols = [group_key, "scope_kind", "scope", "expr", "POS"]
    if do_ttests:
        keep_cols += [treatment_key, replicate_key]

    all_scope = base[keep_cols]

    if scope_key is not None:
        if scope_key not in obs.columns:
            raise KeyError(f"Column '{scope_key}' not found in adata.obs")

        sub = pd.DataFrame({
            group_key:   obs[group_key].astype(str).values,
            "expr":      np.asarray(expr, dtype=float),
            "POS":       pos,
            "scope_kind": "Subtype",
            "scope":     obs[scope_key].astype(str).values,
        })
        if do_ttests:
            sub[treatment_key] = obs[treatment_key].astype(str).values
            sub[replicate_key] = obs[replicate_key].astype(str).values

        sub = sub[keep_cols]
        df_scoped = pd.concat([all_scope, sub], ignore_index=True)
    else:
        df_scoped = all_scope

    counts_by_group = (
        df_scoped.assign(n=1)
                 .groupby(["scope_kind","scope",group_key,"POS"], as_index=False)
                 .agg(n_cells=("n","sum"))
    )

    overall_rows = (
        df_scoped.assign(n=1)
                 .groupby(["scope_kind","scope","POS"], as_index=False)
                 .agg(n_cells=("n","sum"))
                 .assign(**{group_key: "ALL"})
    )
    counts_all = pd.concat([counts_by_group, overall_rows], ignore_index=True)

    group_totals = (
        df_scoped.assign(n=1)
                 .groupby(["scope_kind","scope",group_key], as_index=False)
                 .agg(group_total_n=("n","sum"))
    )
    scope_totals = (
        df_scoped.assign(n=1)
                 .groupby(["scope_kind","scope"], as_index=False)
                 .agg(scope_total_n=("n","sum"))
    )

    means_by_group_pos = (
        df_scoped.groupby(["scope_kind","scope",group_key,"POS"], as_index=False)
                 .agg(mean_expr=("expr","mean"))
    )
    means_overall_pos = (
        df_scoped.groupby(["scope_kind","scope","POS"], as_index=False)
                 .agg(mean_expr=("expr","mean"))
                 .assign(**{group_key: "ALL"})
    )
    means_pos_all = pd.concat([means_by_group_pos, means_overall_pos], ignore_index=True)

    means_group_all = (
        df_scoped.groupby(["scope_kind","scope",group_key], as_index=False)
                 .agg(mean_expr_group_all=("expr","mean"))
    )
    means_scope_all = (
        df_scoped.groupby(["scope_kind","scope"], as_index=False)
                 .agg(mean_expr_scope_all=("expr","mean"))
    )

    out = (counts_all
           .merge(group_totals,    on=["scope_kind","scope",group_key], how="left")
           .merge(scope_totals,    on=["scope_kind","scope"],           how="left")
           .merge(means_pos_all,   on=["scope_kind","scope",group_key,"POS"], how="left")
           .merge(means_group_all, on=["scope_kind","scope",group_key],      how="left")
           .merge(means_scope_all, on=["scope_kind","scope"],                how="left")
    )

    base_label = (label or gene)
    pos_col    = f"{base_label}_label"
    ispos_col  = f"{base_label}_pos"

    out[ispos_col]  = out["POS"]
    out[pos_col]    = np.where(out["POS"], f"{base_label}+", f"{base_label}-")
    out["frac_in_group_%"] = (out["n_cells"] / out["group_total_n"] * 100).round(2)
    out["frac_in_scope_%"] = (out["n_cells"] / out["scope_total_n"] * 100).round(2)

    out[pos_col] = pd.Categorical(out[pos_col], [f"{base_label}+", f"{base_label}-"], ordered=True)
    out = out[[ "scope_kind", "scope", group_key, pos_col, "n_cells",
                "group_total_n", "scope_total_n",
                "frac_in_group_%", "frac_in_scope_%",
                "mean_expr", "mean_expr_group_all", "mean_expr_scope_all" ]] \
           .sort_values(by=["scope_kind","scope", group_key, pos_col]) \
           .reset_index(drop=True)

    if to_csv_path is None:
        safe = "".join(c if c.isalnum() or c in ("-", "_") else "_" for c in (label or gene))
        to_csv_path = f"{safe}_positivity_summary.csv"
    out.to_csv(to_csv_path, index=False)

    if not do_ttests:
        print(f"Positivity summary for `{gene}` (threshold > {pos_threshold}) — saved to: {to_csv_path}")
        return out

    dims = ["scope_kind", "scope", group_key, replicate_key, treatment_key]

    rep_totals = (
        df_scoped.assign(n=1)
                 .groupby(dims, as_index=False)
                 .agg(group_total_n=("n","sum"))
    )

    pos_df = df_scoped[df_scoped["POS"]].assign(n=1)
    rep_pos_counts = (
        pos_df.groupby(dims, as_index=False)
              .agg(n_pos=("n","sum"))
    )

    rep_pos = rep_totals.merge(rep_pos_counts, on=dims, how="left")
    rep_pos["n_pos"] = rep_pos["n_pos"].fillna(0).astype(float)
    rep_pos["frac_in_group"] = rep_pos["n_pos"] / rep_pos["group_total_n"]

    rows = []
    for (scope_kind_i, scope_i, group_i), sub in rep_pos.groupby(["scope_kind", "scope", group_key]):
        treatments = sorted(sub[treatment_key].unique())

        if set(["Control", "Treated"]).issubset(set(treatments)):
            g1, g2 = "Control", "Treated"
        elif len(treatments) == 2:
            g1, g2 = treatments[0], treatments[1]
        else:
            continue

        vals1 = sub.loc[sub[treatment_key] == g1, "frac_in_group"].values
        vals2 = sub.loc[sub[treatment_key] == g2, "frac_in_group"].values

        if len(vals1) == 0 or len(vals2) == 0:
            t_stat = np.nan
            p_val  = np.nan
        else:
            res = ttest_ind(vals1, vals2, equal_var=False)
            t_stat, p_val = res.statistic, res.pvalue

        rows.append({
            "scope_kind": scope_kind_i,
            "scope": scope_i,
            group_key: group_i,
            "treatment_A": g1,
            "treatment_B": g2,
            "n_A": len(vals1),
            "n_B": len(vals2),
            "mean_A_frac": np.mean(vals1) if len(vals1) else np.nan,
            "mean_B_frac": np.mean(vals2) if len(vals2) else np.nan,
            "median_A_frac": np.median(vals1) if len(vals1) else np.nan,
            "median_B_frac": np.median(vals2) if len(vals2) else np.nan,
            "t_stat": t_stat,
            "p_value": p_val,
        })

    ttest_df = pd.DataFrame(rows)

    print(f"Positivity summary for `{gene}` (threshold > {pos_threshold}) — saved to: {to_csv_path}")
    print("Returned (summary, ttest_df, rep_pos) with Control vs Treated t-tests per subgroup.")
    if return_rep_df:
        return out, ttest_df, rep_pos
    return out, ttest_df


In [None]:
granzyme_genes = ["Gzmb"]

granzyme_summary, granzyme_ttests, granzyme_rep = summarize_marker_positivity(
    adata=tnk_adata,
    gene_candidates=granzyme_genes,
    group_key="subtype_granular",   
    scope_key=None,
    pos_threshold=0,
    label="Granzyme",
    to_csv_path="granzyme_positivity_by_celltype.csv",
    do_ttests=True,
    treatment_key="Treatment",
    replicate_key="Region_ID",
    return_rep_df=True,
)

cell = "CD8 Activated T cells" 
dots = granzyme_rep[
    (granzyme_rep["scope_kind"] == "All cells") &
    (granzyme_rep["scope"] == "ALL") &
    (granzyme_rep["subtype_granular"] == cell)
][["Region_ID", "Treatment", "frac_in_group"]]

print()
print(dots)


## UMAPs

In [None]:
sc.pl.umap(
    tnk_adata,
    color=["subtype_granular", "Region_ID", "Treatment"],
    show=False,
    legend_loc="right margin",
    legend_fontsize=10,
    legend_fontoutline=0,
    title=["T and NK Cell Subtypes", "Samples", "Treatment"],
    size=4,
    frameon=False,
    ncols=1,
)

In [None]:
sc.set_figure_params(fontsize=25)  # bumps title & text sizes
sc.pl.umap(
    tnk_adata,  
    color=["Cd8a", "Cd4", "Foxp3", "Ccr7", "Gzmb", "Klrb1c"],
    show=False,
    legend_loc="right margin",
    legend_fontsize=10,
    legend_fontoutline=0,
    size=4,
    vmax='p99',
    cmap='Reds',
    frameon=False,
    ncols=3)

## Dotplot

In [None]:
tnk_dict = {
    "CD8 T Cells": ["Cd8a"],
    "Proliferating T": ["Mki67", "Top2a"],
    "Activated T Cells": ['Ifng', 'Gzmb', 'Cd38', 'Prf1', 'Ccl5' ],
    "NK Cells": ["Nkg7", "Klrk1", "Klrb1c"],
    "Memory-like T": ["Ccr7", "Tcf7", "Sell", "Il7r", "Cd44"],
    "CD4 T Cells": ["Cd4"],
    "Tregs": ["Foxp3", "Il2ra", "Ctla4", "Tnfrsf18"]
}


sc.tl.dendrogram(tnk_adata, use_rep="X_scVI", groupby="subtype")
# Then:
sc.pl.dotplot(
    tnk_adata,
    var_names=tnk_dict,
    groupby="subtype",
    var_group_rotation=45,
    standard_scale="var",
    dendrogram=True,
    swap_axes=False,
    figsize=(9, 2)
)

In [None]:
sc.tl.dendrogram(tnk_adata, use_rep="X_scVI", groupby="subtype")
# Then:
sc.pl.dotplot(
    tnk_adata,
    var_names=['Tcf7', 'Cd44', 'Il7r', 'Sell'],
    groupby="subtype",
    standard_scale="var",
    dendrogram=True,
    swap_axes=False,
    figsize=(4,2.5)
)

## Abundance plots

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

adata_sub = tnk_adata
group_col = "subtype_granular"

obs = adata_sub.obs.copy()
sample_col    = "Region_ID"
condition_col = "Treatment"

if group_col not in obs.columns:
    raise ValueError(f"'{group_col}' not in .obs")

if sample_col not in obs.columns:
    obs[sample_col] = "all"

if pd.api.types.is_categorical_dtype(obs[group_col]):
    cats = list(obs[group_col].cat.categories)
else:
    cats = sorted(obs[group_col].dropna().unique().tolist())
palette = get_palette(adata_sub, group_col, cats)

groupers = [sample_col, group_col] + ([condition_col] if condition_col else [])
counts = (obs.groupby(groupers, observed=True).size().reset_index(name="n"))

pivot_sample = counts.pivot_table(index=sample_col, columns=group_col, values="n", fill_value=0).reindex(columns=cats)

# Absolute per-sample
fig, ax = plt.subplots(figsize=(2, 4))
pivot_sample.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.85)
ax.set_ylabel("Number of cells", fontsize=12)
ax.set_xlabel("Control     Treated", fontsize=12)
ax.set_title("Absolute Abundance", fontsize=12)
style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype")
fig.tight_layout()
plt.show()

# Relative per-sample
prop_sample = pivot_sample.div(pivot_sample.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)
fig, ax = plt.subplots(figsize=(2, 4))
prop_sample.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.85)
ax.set_ylabel("Relative Abundance", fontsize=12)
ax.set_xlabel("Control     Treated", fontsize=12)
ax.set_title("Relative Abundance", fontsize=12)
ax.set_xlabel("Control     Treated", fontsize=12)
style_axes(ax, y_is_percent=True, legend=True, legend_title="Subtype")
fig.tight_layout()
plt.show()

if condition_col:
    pivot_cond = (
        counts.groupby([condition_col, group_col])["n"].sum().unstack(fill_value=0).reindex(columns=cats)
    )

    # Absolute by condition
    fig, ax = plt.subplots(figsize=(1, 4))
    pivot_cond.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.7)
    ax.set_ylabel("Number of cells", fontsize=11)
    ax.set_xlabel("")
    ax.set_title("Absolute abundance by condition (subtype_granular)", fontsize=12)
    style_axes(ax, y_is_percent=False, legend=True, legend_title="Subtype")
    fig.tight_layout()
    plt.show()

    # Relative by condition
    prop_cond = pivot_cond.div(pivot_cond.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)
    fig, ax = plt.subplots(figsize=(1, 4))
    prop_cond.plot(kind="bar", stacked=True, ax=ax, color=palette, edgecolor="white", linewidth=0.5, width=0.7)
    ax.set_ylabel("Relative abundance", fontsize=11)
    ax.set_xlabel("")
    ax.set_title("Relative abundance by condition (subtype_granular)", fontsize=12)
    style_axes(ax, y_is_percent=True, legend=True, legend_title="Subtype")
    fig.tight_layout()
    plt.show()



## Spatial scatters

In [None]:
sc.set_figure_params(fontsize=16)  # bumps title & text sizes
sq.pl.spatial_scatter(
    tnk_adata,
    library_id="spatial",
    shape=None,
    color='subtype_granular',
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    title=["T and NK Cell Subets\n\nControl           Treated"],
    cmap="bwr",
    vcenter=0,
    figsize=(10, 10)
)





In [None]:
grey_bwr = colors.LinearSegmentedColormap.from_list(
    "grey_bwr",
    [(0.0, plt.cm.bwr(0.0)),   
     (0.5, "#f6f3c3"),         # grey at the center (0)
     (1.0, plt.cm.bwr(1.0))],  
    N=256
)


sq.pl.spatial_scatter(
    tnk_adata,
    library_id="spatial",
    shape=None,
    color=['Gzmb'],
    palette="Set2",
    wspace=0.5,
    ncols=3,
    size=0.2,
    frameon=False,
    vmax=1.53,
    title=["Gzmb\n\nControl           Treated"],
    cmap=grey_bwr,
    vcenter=0,
    figsize=(10, 10)
)

norm = colors.TwoSlopeNorm(vmin=-1.53, vcenter=0, vmax=1.53)
sm = plt.cm.ScalarMappable(cmap=grey_bwr, norm=norm)
sm.set_array([])

fig_cb, ax_cb = plt.subplots(figsize=(0.5, 2.0), dpi=300)  
cbar = plt.colorbar(sm, cax=ax_cb, orientation='vertical')
cbar.ax.tick_params(labelsize=15, length=3, width=0.8)     
plt.show(fig_cb)
plt.close(fig_cb)


## Effector memory

In [None]:
import numpy as np, pandas as pd
import scipy.sparse as sp

# ---------- helpers ----------
def _to1d(x):
    return x.toarray().ravel() if sp.issparse(x) else np.asarray(x).ravel()

def pick_gene(adata, candidates):
    """Return the first present gene (case-insensitive), preferring .raw if available."""
    pools = []
    if adata.raw is not None:
        pools.append(pd.Index(adata.raw.var_names))
    pools.append(pd.Index(adata.var_names))
    lowers = [{g.lower(): g for g in pool} for pool in pools]
    for cand in candidates:
        for pool, low in zip(pools, lowers):
            if cand in pool:
                return cand
            if cand.lower() in low:
                return low[cand.lower()]
    raise KeyError(f"None of {candidates} found in var/raw.var names")

def get_expr(adata, gene):
    """Vector of expression for a single gene; prefers .raw if available."""
    if adata.raw is not None and gene in adata.raw.var_names:
        return _to1d(adata.raw[:, gene].X)
    return _to1d(adata[:, gene].X)



# Juxtaposition

## Setup

In [None]:
adata_full = sc.read_h5ad("/Users/bhavyasingh/Downloads/JMT_cleaned_full_dataset.h5ad")

In [None]:
adata_full.obs["subtype_granular_og"] = adata_full.obs["subtype_granular"].copy()

# Collapse selected cancer-like subtypes to "Cancer Cells" in the copy
col = "subtype_granular"
targets = [
    "Stem-Like Cancer Cells",
    "EMT Cancer Cells",
    "Transitional De-differentiated Cells",
    "Metabolically Reprogrammed Proliferating Cancer Cells",
    "Angiogenesis-Associated Cancer Cells",
    "IFN-Responsive Epithelial Cancer Cells",
    "Proliferating Epithelial Cancer Cells"
]
is_cat = pd.api.types.is_categorical_dtype(adata_full.obs[col])
if is_cat and "Cancer Cells" not in adata_full.obs[col].cat.categories:
    adata_full.obs[col] = adata_full.obs[col].cat.add_categories(["Cancer Cells"])
adata_full.obs.loc[adata_full.obs[col].isin(targets), col] = "Cancer Cells"
if is_cat:
    adata_full.obs[col] = adata_full.obs[col].cat.remove_unused_categories()

In [None]:
# Paerform the renames and removal
adata_full = rename_and_filter_subtypes_fast(adata_full, col="subtype_granular")



In [None]:
m = adata_full.obs["subtype"].astype(str).str.strip().eq("Dendritic Cells")
if pd.api.types.is_categorical_dtype(adata_full.obs["initial_broad"]) and "Myeloid Cells" not in adata_full.obs["initial_broad"].cat.categories:
    adata_full.obs["initial_broad"] = adata_full.obs["initial_broad"].cat.add_categories(["Myeloid Cells"])
adata_full.obs.loc[m, "initial_broad"] = "Myeloid Cells"

In [None]:
adata_full, color_schemes = setup_all_colors(adata_full, plot_preview=True)

## UMAPs

In [None]:
sc.pl.umap(
    adata,
    color=["initial_broad"],
    show=False,
    legend_loc="right margin",
    legend_fontsize=10,
    legend_fontoutline=0,
    title=["Confident Cells"],
    size=2,
    frameon=False,
    ncols=1,
)

In [None]:
sc.pl.umap(
    adata_full,
    color=["initial_broad", "qc_flag", "juxta_call", "subtype_granular"],
    show=False,
    legend_loc="right margin",
    legend_fontsize=10,
    legend_fontoutline=0,
    title=["Confident and \nJuxtaposed Cells", "Juxtaposition Flag", "Juxtaposition Call", "Granular Subtypes"],
    size=2,
    frameon=False,
    ncols=1,
)
