# Integrated scRNA-seq Data Visualization Analysis

## Overview
This notebook performs comprehensive visualization on clustered and annotated single-cell RNA-seq data. The analysis uses a pre-clustered and annotated AnnData object.

The analysis pipeline includes:

- **Data Loading**: Load clustered and annotated AnnData object with `somen_clusters` and `cell_type` columns
- **Visualization**: UMAP projections colored by clusters, metadata, and feature expression
- **Quality Assessment**: Expression patterns across clusters and experimental groups

The workflow is designed for flexibility and reusability. All parameters can be configured at the beginning of each section.

**Input**: `inputs/clustered_annotated_adata.h5ad` (pre-clustered and annotated data)

---

## Key Outputs

### Visualizations
- Cluster visualization UMAPs using `somen_clusters` and `cell_type`
- Metadata-stratified UMAPs (by batch, library, mouse)
- Gene expression UMAPs and dotplots (normalized counts)
- All plots exported in both PNG (300-600 DPI) and PDF (fully vector) formats for publication and editing

---

## Setup and Configuration

### Library Imports and Directory Structure
Import required libraries for single-cell analysis and configure input/output directories. The notebook assumes the following structure:
- `inputs/`: Contains processed single-cell data (`.h5ad` files)
- `outputs/`: Automatically created for results
  - `outputs/png/`: High-resolution PNG plots (300 DPI)
  - `outputs/pdf/`: Fully vector PDF plots for publication and editing

In [None]:
# Core libraries
import numpy as np
import pandas as pd
from pathlib import Path
from IPython.display import display, HTML
import sys, pkg_resources, datetime

# Single-cell analysis
import anndata as ad
import scanpy as sc
import scipy
import scipy.sparse

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# ============================================================
# DIRECTORY CONFIGURATION
# ============================================================
input_dir = Path('inputs')
output_dir = Path('outputs')
(output_dir / 'png').mkdir(parents=True, exist_ok=True)
(output_dir / 'pdf').mkdir(parents=True, exist_ok=True)  # PDF for vector editing

# ============================================================
# PLOTTING CONFIGURATION
# ============================================================
sc.set_figure_params(dpi=300, figsize=(6, 4))

def configure_plot_style():
    """Apply publication-quality plot styling with major/minor ticks."""
    # Font and text settings
    plt.rcParams['figure.dpi'] = 100
    plt.rcParams['font.family'] = 'Arial'
    plt.rcParams['font.size'] = 10
    
    # Line widths
    plt.rcParams['axes.linewidth'] = 0.8
    plt.rcParams['xtick.major.width'] = 0.8
    plt.rcParams['ytick.major.width'] = 0.8
    plt.rcParams['xtick.minor.width'] = 0.6
    plt.rcParams['ytick.minor.width'] = 0.6
    
    # Tick sizes
    plt.rcParams['xtick.major.size'] = 4
    plt.rcParams['ytick.major.size'] = 4
    plt.rcParams['xtick.minor.size'] = 2
    plt.rcParams['ytick.minor.size'] = 2
    
    # CRITICAL: Ensure fully vector output (NO rasterization)
    plt.rcParams['pdf.fonttype'] = 42  # TrueType fonts (not paths)
    plt.rcParams['svg.fonttype'] = 'none'  # Also for SVG if used
    plt.rcParams['path.simplify'] = False  # Don't simplify paths
    plt.rcParams['path.simplify_threshold'] = 0  # No simplification
    
    sns.set_style('white', {'axes.spines.left': True, 'axes.spines.bottom': True,
                             'axes.spines.top': False, 'axes.spines.right': False})

def show_inline_plot(fig=None):
    """Display plot centered in Jupyter notebook."""
    plt.tight_layout()
    display(HTML('<div style="display: flex; justify-content: center;">'))
    if fig is None:
        plt.show()
    else:
        display(fig)
    display(HTML('</div>'))

def save_plot(name, close=True, high_dpi=400, is_umap=False):
    """
    Save current plot in both PNG and PDF formats at 400 DPI.
    
    Parameters
    ----------
    name : str
        Base filename without extension
    close : bool, default=True
        Whether to close the figure after saving
    high_dpi : int, default=400
        DPI for rasterized elements (use 400 for crisp UMAP dots and other plots)
    is_umap : bool, default=False
        Unused; all plots are saved at 400 DPI for consistency
    """
    plt.tight_layout()
    
    # Always save at 400 DPI
    dpi_value = 400
    
    # Save PNG (raster)
    png_path = output_dir / 'png' / f'{name}.png'
    plt.savefig(png_path, dpi=dpi_value, bbox_inches='tight')
    print(f'✓ Saved: {name}.png ({dpi_value} DPI)')
    
    # Save PDF (vector axes/text; rasterized points at 400 DPI)
    pdf_path = output_dir / 'pdf' / f'{name}.pdf'
    plt.savefig(pdf_path, bbox_inches='tight', format='pdf', 
                facecolor='white', dpi=dpi_value)
    print(f'✓ Saved: {name}.pdf (vector axes/text, raster points at {dpi_value} DPI)')
    
    if close:
        plt.close()
# ============================================================
# dittoSeq PALETTE (Full set; always slice first n for categories)
# ============================================================
# Canonical dittoSeq full palette (ordered as in dittoSeq)
dittoseq_full = [
    "#E5D2DD", "#53A85F", "#F1BB72", "#F3B1A0", "#AB3282", "#57C3F3", "#476D87",
    "#E95C59", "#E59CC4","#D6E7A3", "#23452F", "#BD956A", "#8C549C", "#585658",
    "#9FA3A8", "#E0D4CA", "#5F3D69", "#C5DEBA", "#58A4C3", "#E4C755", "#F7F398",
    "#AA9A59", "#E63863", "#E39A35", "#C1E6F3", "#6778CA", "#91D0BE", "#B53E2B",
    "#712820", "#DCC1DD", "#CCE0F5", "#CCC9E6", "#625D9F", "#68A180", "#3A6963",
    "#968175", "#161616", "#FF9999", "#344F31", "#00BB61", "#E8E8E8", "#BFBFBF"
]

def get_dittoseq_colors(n):
    """
    Get dittoSeq colors for n categories.
    Returns the first n colors from the dittoSeq full palette.
    """
    if n <= len(dittoseq_full):
        return dittoseq_full[:n]
    # If n exceeds full set, cycle colors (rare)
    base = dittoseq_full
    return (base * ((n // len(base)) + 1))[:n]


## Load and Inspect Data

### Load Clustered and Annotated AnnData Object
Load the pre-clustered and annotated AnnData file. This section:
- Loads AnnData object containing clustered and annotated data
- Verifies presence of `somen_clusters` and `cell_type` columns
- Verifies data layers (raw counts, normalized data)
- Checks for UMAP coordinates
- Processes categorical metadata variables (batch, library, mouse)
- Reorders mouse column: HTO_PEP15-1, HTO_PEP15-2, ..., HTO_PBS-1, HTO_PBS-2, ...

**Configuration**:
- Use `clustered_annotated_adata.h5ad` from annotation step

In [None]:
# ============================================================
# USER CONFIGURATION
# ============================================================
# Load the clustered and annotated AnnData file
data_filename = 'clustered_annotated_adata.h5ad'

# Metadata columns to verify and convert to categorical
categorical_vars = ['batch', 'library', 'mouse', 'group']

# ============================================================
# LOAD DATA
# ============================================================
data_file = input_dir / data_filename
if not data_file.exists():
    raise FileNotFoundError(f'Input file not found: {data_file}')

# Load AnnData object
adata = sc.read_h5ad(data_file)
print(f'✓ Loaded AnnData from: {data_file}')
print(f'  Data shape: {adata.shape}')

# Make cell names unique if needed
if not adata.obs_names.is_unique:
    adata.obs_names_make_unique()
    print('✓ Made cell identifiers unique')

# ============================================================
# VERIFY DATA STRUCTURE
# ============================================================
print('\n' + '='*60)
print('DATA STRUCTURE')
print('='*60)
print(f'Data shape: {adata.n_obs:,} cells × {adata.n_vars:,} genes')
print(f'\nLayers: {list(adata.layers.keys())}')
print(f'Observations (metadata): {list(adata.obs.columns[:15])}...')

# Check UMAP coordinates
print('\n' + '='*60)
print('UMAP COORDINATES')
print('='*60)
if 'X_umap' in adata.obsm:
    print(f'✓ X_umap found: {adata.obsm["X_umap"].shape}')
else:
    raise ValueError('X_umap not found in adata.obsm. Please ensure the data includes UMAP coordinates.')

# ============================================================
# VERIFY REQUIRED COLUMNS
# ============================================================
print('\n' + '='*60)
print('VERIFYING REQUIRED COLUMNS')
print('='*60)

# Check for somen_clusters
if 'somen_clusters' in adata.obs:
    print(f'✓ somen_clusters found: {len(adata.obs["somen_clusters"].cat.categories)} categories')
    print(f'  Categories: {list(adata.obs["somen_clusters"].cat.categories)}')
else:
    raise ValueError('somen_clusters column not found in adata.obs')

# Check for cell_type
if 'cell_type' in adata.obs:
    print(f'✓ cell_type found: {len(adata.obs["cell_type"].cat.categories)} categories')
    print(f'  Categories: {list(adata.obs["cell_type"].cat.categories)}')
else:
    raise ValueError('cell_type column not found in adata.obs')

# ============================================================
# PROCESS CATEGORICAL VARIABLES
# ============================================================
print('\n' + '='*60)
print('CATEGORICAL METADATA')
print('='*60)

# Add 'sex' to the list if it exists in data
vars_to_display = ['group', 'sex', 'library', 'batch', 'mouse']

for col in categorical_vars + ['somen_clusters', 'cell_type']:
    if col in adata.obs:
        if adata.obs[col].dtype.name != 'category':
            adata.obs[col] = adata.obs[col].astype(str).astype('category')
        n_categories = len(adata.obs[col].cat.categories)
        print(f'✓ {col}: {n_categories} categories')
    else:
        print(f'⚠ {col}: Not found in data')

# Also check 'sex' if not in categorical_vars list
if 'sex' not in categorical_vars and 'sex' in adata.obs:
    if adata.obs['sex'].dtype.name != 'category':
        adata.obs['sex'] = adata.obs['sex'].astype(str).astype('category')

# ============================================================
# REORDER MOUSE COLUMN
# ============================================================
print('\n' + '='*60)
print('REORDERING MOUSE COLUMN')
print('='*60)

if 'mouse' in adata.obs:
    # Get current mouse categories
    current_mice = adata.obs['mouse'].cat.categories.tolist()
    print(f'Current mouse order: {current_mice}')
    
    # Separate HTO_PEP15 and HTO_PBS mice
    pep15_mice = sorted([m for m in current_mice if 'PEP15' in m])
    pbs_mice = sorted([m for m in current_mice if 'PBS' in m])
    
    # Create new order: HTO_PEP15-1, HTO_PEP15-2, ..., HTO_PBS-1, HTO_PBS-2, ...
    new_order = pep15_mice + pbs_mice
    
    # Reorder categories
    adata.obs['mouse'] = adata.obs['mouse'].cat.reorder_categories(new_order)
    print(f'✓ Reordered mouse column: {new_order}')
else:
    print('⚠ mouse column not found in data')

# Display unique values for specific variables
print('\n' + '='*60)
print('UNIQUE VALUES IN KEY METADATA COLUMNS')
print('='*60)

for col in vars_to_display:
    if col in adata.obs:
        # Ensure it's categorical
        if adata.obs[col].dtype.name != 'category':
            adata.obs[col] = adata.obs[col].astype(str).astype('category')
        
        unique_vals = adata.obs[col].cat.categories.tolist()
        n_cells_per_val = adata.obs[col].value_counts().to_dict()
        
        print(f'\n{col.upper()}:')
        for val in unique_vals:
            n_cells = n_cells_per_val.get(val, 0)
            print(f'  • {val}: {n_cells:,} cells')
    else:
        print(f'\n{col.upper()}: Not found in data')

print('\n' + '='*60)

### Normalize Gene Expression Data

Normalize raw count data for visualization. Normalized expression values are needed for gene expression plots.

**Steps**:
1. Check if data is already normalized
2. Normalize counts (library size normalization to 10,000 counts per cell)
3. Log-transform (log1p)
4. Store in `adata.layers['normalized']` for downstream visualization (UMAP, dotplot)
5. Also store in `adata.X` for compatibility with other analyses

In [None]:
# ============================================================
# DATA NORMALIZATION
# ============================================================
print('='*60)
print('GENE EXPRESSION NORMALIZATION')
print('='*60)

# Check if data appears to be normalized (max value < 20 suggests log-normalized)
max_val = adata.X.max() if hasattr(adata.X, 'max') else adata.X.data.max()
print(f'\nCurrent X matrix max value: {max_val:.2f}')

if max_val > 50:
    print('⚠️  Data appears to be raw counts (max > 50)')
    print('   Performing normalization...')
    
    # Store raw counts if not already stored
    if 'counts' not in adata.layers:
        adata.layers['counts'] = adata.X.copy()
        print('✓ Stored raw counts in adata.layers["counts"]')
    
    # Normalize to 10,000 counts per cell and log-transform
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    print('✓ Normalized to 10,000 counts per cell')
    print('✓ Log-transformed (log1p)')
    
    # Store normalized data in "normalized" layer
    adata.layers['normalized'] = adata.X.copy()
    print('✓ Stored normalized data in adata.layers["normalized"]')
    
elif max_val > 20:
    print('⚠️  Data may be normalized but not log-transformed')
    print('   Applying log transformation...')
    
    # Store non-log data if not already stored
    if 'normalized' not in adata.layers:
        adata.layers['normalized'] = adata.X.copy()
    
    sc.pp.log1p(adata)
    print('✓ Log-transformed (log1p)')
    
    # Update normalized layer with log-transformed data
    adata.layers['normalized'] = adata.X.copy()
    print('✓ Updated adata.layers["normalized"] with log-transformed data')
    
else:
    print('✓ Data appears to be already normalized and log-transformed')
    print('  (max value < 20, typical for log-normalized data)')
    
    # Store in normalized layer if not already present
    if 'normalized' not in adata.layers:
        adata.layers['normalized'] = adata.X.copy()
        print('✓ Stored normalized data in adata.layers["normalized"]')

# Verify final state
final_max = adata.X.max() if hasattr(adata.X, 'max') else adata.X.data.max()
normalized_max = adata.layers['normalized'].max() if hasattr(adata.layers['normalized'], 'max') else adata.layers['normalized'].data.max()
print(f'\n✓ Final X matrix max value: {final_max:.2f}')
print(f'✓ Normalized layer max value: {normalized_max:.2f}')
print(f'✓ Normalized layer will be used for gene expression visualization')

print('\n' + '='*60)



### Cell Type Visualization
Generate a detailed UMAP plot colored by cell type. This provides a cleaner view for presentations and publications.

In [None]:
# ============================================================
# USER CONFIGURATION
# ============================================================
point_size = 15             # Size of points in UMAP
fig_width = 8              # Figure width
fig_height = 5             # Figure height

# ============================================================
# GENERATE CELL_TYPE PLOT
# ============================================================

# Verify cell_type column and UMAP
cluster_key = 'cell_type'
if cluster_key not in adata.obs:
    raise ValueError(f"Column '{cluster_key}' not found in adata.obs")
if 'X_umap' not in adata.obsm:
    raise ValueError("X_umap not found in adata.obsm")

configure_plot_style()

# Set up dittoSeq palette
import matplotlib.colors as mcolors
n_clusters = len(adata.obs[cluster_key].cat.categories)

# Use dittoSeq palette
custom_palette = get_dittoseq_colors(n_clusters)
adata.uns[f'{cluster_key}_colors'] = custom_palette

# Set up plot
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=400)

# Plot UMAP (no labels on data - only legend on right margin)
sc.pl.umap(
    adata,
    color=cluster_key,
    ax=ax,
    show=False,
    title='Cell Type',
    legend_loc='right margin',
    legend_fontsize=10,
    legend_fontweight='normal',
    size=point_size,
    palette=adata.uns[f'{cluster_key}_colors']
)

# Add axis labels
ax.set_xlabel('UMAP-1', fontsize=12)
ax.set_ylabel('UMAP-2', fontsize=12)

# Configure ticks to show numerical labels (e.g., 2, 4, 6)
ax.tick_params(axis='both', which='major', 
               bottom=True, left=True,       # Show ticks
               labelbottom=True, labelleft=True,  # Show labels
               labelsize=10,
               length=4, width=0.8)

# Ensure minor ticks are also visible if desired
ax.tick_params(axis='both', which='minor',
               bottom=True, left=True,
               length=2, width=0.6)

plt.tight_layout()

# Save (400 DPI via save_plot)
save_plot('cell_type_umap', close=False)
show_inline_plot()
plt.close()

# Check cell type counts to confirm the ordering
print("Cell type counts:")
print(adata.obs[cluster_key].value_counts().sort_index())

## Metadata-Stratified Visualizations

### Visualize Clusters by Metadata Groups
Generate UMAP plots stratified by experimental metadata (batch, library, mouse_ID) to assess:
- Batch-specific patterns
- Library effects and integration quality
- Distribution across mice
- Cluster composition differences between experimental groups

This helps identify technical artifacts and validate biological interpretations.

### QC Metrics by Cluster

Visualize quality control metrics across clusters to identify potential technical artifacts or low-quality clusters. This helps assess whether certain clusters are driven by technical factors rather than biological variation.


In [None]:
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
plot_width = 20     # Width of QC plots
plot_height = 6     # Height of QC plots

# ============================================================
# QC METRICS VISUALIZATION BY CELL TYPE
# ============================================================

cluster_key = 'cell_type'

# Verify column exists
if cluster_key not in adata.obs:
    raise ValueError(f"Column '{cluster_key}' not found. Ensure '{cluster_key}' exists in adata.obs")

# Configure plot style
configure_plot_style()

# Define QC metrics to visualize (check which ones exist in your data)
qc_metrics = []
qc_labels = []

# Check for common QC metric column names
if 'n_genes_by_counts' in adata.obs:
    qc_metrics.append('n_genes_by_counts')
    qc_labels.append('Number of Genes')
elif 'n_genes' in adata.obs:
    qc_metrics.append('n_genes')
    qc_labels.append('Number of Genes')

if 'total_counts' in adata.obs:
    qc_metrics.append('total_counts')
    qc_labels.append('Total Counts (nCount_RNA)')
elif 'n_counts' in adata.obs:
    qc_metrics.append('n_counts')
    qc_labels.append('Total Counts (nCount_RNA)')

if 'pct_counts_mt' in adata.obs:
    qc_metrics.append('pct_counts_mt')
    qc_labels.append('% Mitochondrial')

if 'doublet_score' in adata.obs:
    qc_metrics.append('doublet_score')
    qc_labels.append('Doublet Score')

# Cell cycle scores if available
if 'S_score' in adata.obs:
    qc_metrics.append('S_score')
    qc_labels.append('S Phase Score')
    
if 'G2M_score' in adata.obs:
    qc_metrics.append('G2M_score')
    qc_labels.append('G2/M Phase Score')

if not qc_metrics:
    print("No QC metrics found in adata.obs. Available columns:")
    print(list(adata.obs.columns))
else:
    print(f"Found {len(qc_metrics)} QC metrics to visualize")
    
    # Create violin plots for each QC metric
    n_metrics = len(qc_metrics)
    
    fig, axes = plt.subplots(1, n_metrics, figsize=(plot_width, plot_height), dpi=400)
    if n_metrics == 1:
        axes = [axes]
    
    for i, (metric, label) in enumerate(zip(qc_metrics, qc_labels)):
        # Create violin plot
        sc.pl.violin(
            adata,
            keys=metric,
            groupby=cluster_key,
            ax=axes[i],
            show=False,
            rotation=90  # rotate x-axis labels 90 degrees (vertical)
        )
        axes[i].set_title(label, fontsize=12, fontweight='bold')
        axes[i].set_xlabel('Cell Type', fontsize=10)
        axes[i].set_ylabel(label, fontsize=10)
        axes[i].tick_params(axis='x', labelsize=12, rotation=90)  # Larger text size (12) and 90 degree rotation
    
    plt.suptitle('QC Metrics by Cell Type', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    # Save plot (save_plot saves PNG/PDF at 400 DPI)
    save_plot('qc_metrics_by_cell_type', close=False)
    show_inline_plot()
    plt.close()
    
    # ============================================================
    # SUMMARY STATISTICS TABLE
    # ============================================================
    print('\n' + '='*80)
    print('QC METRICS SUMMARY BY CELL TYPE')
    print('='*80)
    
    # Calculate median values for each metric per cell type
    summary_data = []
    for ct in adata.obs[cluster_key].cat.categories:
        ct_mask = adata.obs[cluster_key] == ct
        ct_cells = ct_mask.sum()
        
        row = {'Cell_Type': ct, 'N_cells': ct_cells}
        
        for metric, label in zip(qc_metrics, qc_labels):
            median_val = adata.obs.loc[ct_mask, metric].median()
            row[label] = f'{median_val:.2f}'
        
        summary_data.append(row)
    
    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_string(index=False))
    
    # Save summary table
    summary_file = output_dir / 'qc_summary_by_cell_type.csv'
    summary_df.to_csv(summary_file, index=False)
    print(f'\n✓ Saved summary table: {summary_file.name}')
    
    # ============================================================
    # IDENTIFY POTENTIAL QUALITY ISSUES
    # ============================================================
    print('\n' + '='*80)
    print('POTENTIAL QUALITY ISSUES')
    print('='*80)
    
    # Check for high mitochondrial percentage
    if 'pct_counts_mt' in adata.obs:
        high_mt_threshold = 10  # Adjust as needed
        for ct in adata.obs[cluster_key].cat.categories:
            ct_mask = adata.obs[cluster_key] == ct
            median_mt = adata.obs.loc[ct_mask, 'pct_counts_mt'].median()
            if median_mt > high_mt_threshold:
                print(f'⚠️  {ct}: High mt% (median={median_mt:.2f}%)')
    
    # Check for low gene counts
    gene_metric = 'n_genes_by_counts' if 'n_genes_by_counts' in adata.obs else 'n_genes'
    if gene_metric in adata.obs:
        low_genes_threshold = 500  # Adjust as needed
        for ct in adata.obs[cluster_key].cat.categories:
            ct_mask = adata.obs[cluster_key] == ct
            median_genes = adata.obs.loc[ct_mask, gene_metric].median()
            if median_genes < low_genes_threshold:
                print(f'⚠️  {ct}: Low gene count (median={median_genes:.0f} genes)')
    
    print('\nIf no warnings appear, all cell types pass basic QC thresholds.')
    print('='*80)

### Metadata-Stratified UMAP Visualizations

Generate UMAP plots split by batch, library, and mouse_ID to assess integration quality and identify experimental group patterns. Cells are colored by cluster assignment to evaluate consistency across metadata categories.

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
# Customize parameters for MOUSE plots
mouse_dot_size = 20
mouse_fig_width = 5.2
mouse_fig_height = 3

# Customize parameters for GROUP plots
group_dot_size = 10
group_fig_width = 6.5
group_fig_height = 5

# ============================================================
# SETUP
# ============================================================
cluster_key = 'cell_type'
metadata_columns = ['mouse', 'group']  # Only mouse and group now

if cluster_key not in adata.obs:
    raise ValueError(f"Column {cluster_key} not found in adata.obs")

for col in metadata_columns:
    if col not in adata.obs:
        print(f"Warning: Metadata column {col} not found in adata.obs - skipping")

configure_plot_style()

# Import for palette
import matplotlib.colors as mcolors

# Set up dittoSeq palette for publication quality
n_clusters = len(adata.obs[cluster_key].cat.categories)
custom_palette = get_dittoseq_colors(n_clusters)
adata.uns[f'{cluster_key}_colors'] = custom_palette

# ============================================================
# ORIGINAL PLOTS (Full Dataset)
# ============================================================
print("Generating full dataset plots using cell_type...")

for meta_col in metadata_columns:
    if meta_col not in adata.obs:
        continue
    
    # Get custom parameters for this metadata type
    if meta_col == 'mouse':
        dot_size = mouse_dot_size
        fig_width = mouse_fig_width
        fig_height = mouse_fig_height
    elif meta_col == 'group':
        dot_size = group_dot_size
        fig_width = group_fig_width
        fig_height = group_fig_height
    
    meta_order = adata.obs[meta_col].cat.categories.tolist()
    n_items = len(meta_order)
    # Use 6 columns for mouse; otherwise keep up to 4 columns
    if meta_col == 'mouse':
        n_cols = 6
    else:
        n_cols = min(4, n_items)
    n_rows = (n_items + n_cols - 1) // n_cols  # Calculate rows needed
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width * n_cols, fig_height * n_rows), sharex=True, sharey=True)
    if n_items == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_items > 1 else [axes]
    
    for i, item in enumerate(meta_order):
        mask = adata.obs[meta_col] == item
        if mask.sum() > 0:
            # Show legend only on the last panel
            show_legend = (i == len(meta_order) - 1)
            sc.pl.umap(
                adata[mask],
                color=cluster_key,
                ax=axes[i],
                show=False,
                title=f'{item} ({mask.sum():,} cells)',
                legend_loc='right margin' if show_legend else None,
                frameon=False,
                size=dot_size,
                palette=adata.uns[f'{cluster_key}_colors']
            )
        else:
            axes[i].set_visible(False)
            print(f'Warning: {meta_col} {item} not found')
    
    # Hide unused subplots
    if n_items > 1:
        for j in range(n_items, len(axes)):
            axes[j].set_visible(False)
    
    plt.suptitle(f'Cell Types Split by {meta_col}', y=1.02, fontsize=18)
    plt.tight_layout()
    # Save at 400 DPI for crisp UMAP dots
    plt.savefig(output_dir / 'png' / f'{meta_col}_umaps.png', dpi=400, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / f'{meta_col}_umaps.pdf', dpi=400, bbox_inches='tight', format='pdf')
    print(f'Saved: {meta_col}_umaps.png/.pdf')
    plt.close()

print('\n✓ Completed full dataset metadata-stratified visualizations')

# ============================================================
# DOWNSAMPLED PLOTS (Equal Cells Per Group)
# ============================================================
print("\n" + "="*60)
print("Generating downsampled plots with equal cells per group...")
print("="*60)

for meta_col in metadata_columns:
    if meta_col not in adata.obs:
        continue
    
    # Get custom parameters for this metadata type
    if meta_col == 'mouse':
        dot_size = mouse_dot_size
        fig_width = mouse_fig_width
        fig_height = mouse_fig_height
    elif meta_col == 'group':
        dot_size = group_dot_size
        fig_width = group_fig_width
        fig_height = group_fig_height
    
    print(f"\nProcessing {meta_col}...")
    
    # Get category counts
    meta_order = adata.obs[meta_col].cat.categories.tolist()
    counts = adata.obs[meta_col].value_counts()
    min_cells = counts.min()
    
    print(f"  Cell counts per {meta_col}: {dict(counts)}")
    print(f"  Downsampling to: {min_cells:,} cells per group")
    
    # Downsample to equal number of cells per category
    np.random.seed(42)  # For reproducibility
    downsampled_indices = []
    for item in meta_order:
        item_indices = np.where(adata.obs[meta_col] == item)[0]
        sampled_indices = np.random.choice(item_indices, size=min_cells, replace=False)
        downsampled_indices.extend(sampled_indices)
    
    # Create downsampled AnnData
    adata_downsampled = adata[downsampled_indices].copy()
    
    # Generate UMAP plots with downsampled data
    n_items = len(meta_order)
    # Use 6 columns for mouse; otherwise keep up to 6 columns
    if meta_col == 'mouse':
        n_cols = 6
    else:
        n_cols = min(6, n_items)
    n_rows = (n_items + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width * n_cols, fig_height * n_rows), sharex=True, sharey=True)
    if n_items == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_items > 1 else [axes]
    
    for i, item in enumerate(meta_order):
        mask = adata_downsampled.obs[meta_col] == item
        if mask.sum() > 0:
            # Show legend only on the last panel
            show_legend = (i == len(meta_order) - 1)
            sc.pl.umap(
                adata_downsampled[mask],
                color=cluster_key,
                ax=axes[i],
                show=False,
                title=f'{item} ({mask.sum():,} cells)',
                legend_loc='right margin' if show_legend else None,
                frameon=False,
                size=dot_size * 2,  # Larger dots for downsampled
                palette=adata.uns[f'{cluster_key}_colors']
            )
        else:
            axes[i].set_visible(False)
            print(f'Warning: {meta_col} {item} not found in downsampled data')
    
    # Hide unused subplots
    if n_items > 1:
        for j in range(n_items, len(axes)):
            axes[j].set_visible(False)
    
    plt.suptitle(f'Cell Types Split by {meta_col} - Downsampled (n={min_cells:,}/group)', 
                 y=1.02, fontsize=18)
    plt.tight_layout()
    # Save at 400 DPI and show inline (downsampled only)
    plt.savefig(output_dir / 'png' / f'{meta_col}_umaps_downsampled.png', dpi=400, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / f'{meta_col}_umaps_downsampled.pdf', dpi=400, bbox_inches='tight', format='pdf')
    show_inline_plot()
    print(f'  Saved: {meta_col}_umaps_downsampled.png/.pdf')
    plt.close()

print('\n✓ Completed downsampled metadata-stratified visualizations')

## Gene Expression Visualization

### Gene Expression Dotplot by Cluster

Visualize marker gene expression across clusters using dotplots. This visualization shows:
- **Dot color**: Mean expression level per cluster
- **Dot size**: Fraction of cells expressing the gene
- **Dendrogram**: Hierarchical relationship between clusters

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
# Select plot dimensions
plot_width = 5      # figure width (in inches)
plot_height = 18    # figure height (in inches)

# ============================================================
# VERIFY DATA AND LAYERS
# ============================================================
cluster_key = 'cell_type'
if cluster_key not in adata.obs:
    raise ValueError(f"Column {cluster_key} not found in adata.obs")

# Check if normalized layer exists, if not use X (which should be normalized)
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer is None:
    print("Note: Using adata.X for gene expression (no 'normalized' layer found)")

# Genes to visualize (mouse gene symbols)
# Classic marker genes for major cell types
genes = [
    # Homeostatic Microglia (resting, surveillance state)
    'Tmem119', 'P2ry12', 'Cx3cr1', 'Sall1', 'Fcrls', 'Siglech', 'Hexb',

    # Activated/Pro-inflammatory Microglia (M1-like, immune-activated)
    'Il1b', 'Tnf', 'Cd68', 'Nos2', 'Cd86',

    # Anti-inflammatory/Repair Microglia (M2-like, tissue repair)
    'Arg1', 'Mrc1', 'Il10', 'Tgfb1', 'Ym1',

    # Disease-Associated Microglia (DAM, neurodegenerative contexts)
    'Trem2', 'Apoe', 'Cst7', 'Lpl', 'Tyrobp', 'Clec7a',

    # Aged Microglia (altered in aging brain)
    'Ccl2', 'C1qa', 'B2m',

    # Proliferative Microglia (dividing, injury/disease)
    'Mki67', 'Top2a', 'Cdk1', 'Ccna2', 'Birc5',

    # Interferon-Responsive Microglia (viral/interferon response)
    'Ifit1', 'Irf7', 'Stat1', 'Cxcl10', 'Isg15',

    # Phagocytic Microglia (enhanced phagocytosis)
    'Trem2', 'Cd68', 'Mertk', 'Axl', 'C1qa',

    # General Immune Cell Marker
    'Ptprc','H2-Aa', 'H2-Ab1',
    'Mrc1', 'Ccr2', 'Ly6g', 'Ly6c2', 'Ms4a7',
    'Cdk8', 'Cmss1', 'Lars2',
    'H2-Q7', 'H2-Aa', 'H2-Ab1', 'H2-Q4', 'Cd74',
    'Ifi27I2a', ' Arhgap15', 'Klf2', 'Cd52', 'Cd34'
]

# Check available genes
available_genes = [g for g in genes if g in adata.var_names]
missing_genes = [g for g in genes if g not in adata.var_names]

if missing_genes:
    print(f'Warning: Genes not found in adata.var_names: {", ".join(missing_genes)}')
if not available_genes:
    raise ValueError('No valid genes for plotting.')

# ============================================================
# CREATE CUSTOM GREY-PLASMA COLORMAP
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Get plasma_r colormap
plasma_r = plt.cm.plasma_r
n_colors = 256
plasma_colors = plasma_r(np.linspace(0, 1, n_colors))

# Replace first 20% with grey gradient
grey_to_plasma_transition = 0.2
n_grey_colors = int(n_colors * grey_to_plasma_transition)

# Create grey gradient from light grey to first plasma color
light_grey = np.array([0.85, 0.85, 0.85, 1.0])
transition_color = plasma_colors[n_grey_colors]
grey_gradient = np.linspace(light_grey, transition_color, n_grey_colors)

# Combine grey gradient with plasma_r
custom_colors = np.vstack([grey_gradient, plasma_colors[n_grey_colors:]])
dotplot_colormap = LinearSegmentedColormap.from_list('grey_plasma', custom_colors)

print('✓ Created custom grey-plasma colormap for dotplot')

# Configure plot style
configure_plot_style()

# ============================================================
# COMPUTE DENDROGRAM FOR CURRENT GROUPING
# ============================================================
print(f'Computing dendrogram for {cluster_key}...')
sc.tl.dendrogram(adata, groupby=cluster_key)
print('✓ Dendrogram computed')

# ============================================================
# GENERATE DOTPLOT
# ============================================================
fig, ax = plt.subplots(figsize=(plot_width, plot_height), dpi=400)
dotplot = sc.pl.dotplot(
    adata,
    var_names=available_genes,
    groupby=cluster_key,
    layer=use_layer,
    dendrogram=True,
    ax=ax,
    return_fig=True,
    dot_max=0.8,
    dot_min=0.05,
    colorbar_title='Mean expression\nin group',
    size_title='% of cells\nexpressing gene',
    cmap=dotplot_colormap,  # Use custom grey-plasma colormap
    swap_axes=True,
    var_group_rotation=90
)

# Make dot borders thinner
main_ax = dotplot.get_axes()['mainplot_ax']
for collection in main_ax.collections:
    collection.set_linewidths(0)

# X-axis labels 90 degrees and larger
main_ax.tick_params(axis='x', labelsize=12, rotation=90)
for label in main_ax.get_xticklabels():
    label.set_rotation(90)
    label.set_ha('right')

# Ensure ticks are visible on both axes
main_ax.tick_params(axis='both', bottom=True, left=True, labelbottom=True, labelleft=True, length=4, direction='out')
main_ax.spines['bottom'].set_visible(True)
main_ax.spines['left'].set_visible(True)

plt.suptitle('Gene Expression Dotplot by Cell Type (Normalized)', y=1.05, fontsize=18)
plt.subplots_adjust(bottom=0.2)
plt.tight_layout()

# Save at 400 DPI and show inline
plt.savefig(output_dir / 'png' / 'rna_dotplot_cell_type_normalized.png', dpi=400, bbox_inches='tight')
plt.savefig(output_dir / 'pdf' / 'rna_dotplot_cell_type_normalized.pdf', dpi=400, bbox_inches='tight', format='pdf')
print('Saved: rna_dotplot_cell_type_normalized.png/.pdf')
show_inline_plot()
plt.close()

### Dotplot: Selected Marker Genes (by Cell Type)
This plot shows a curated subset of marker genes across `cell_type`. The x-axis lists genes (rotated for readability); y-axis lists cell types (dendogram NOT ordered)

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION (Selected genes only)
# ============================================================
plot_width = 4.5
plot_height = 10
cluster_key = 'cell_type'

# Curated list of genes (edit as needed)
selected_genes = [
    # Homeostatic
    'Tmem119', 'P2ry12', 'Cx3cr1',
    'Sall1', 'Fcrls', 'Hexb',
    # Activated/Phagocytic/Intermediate
    'Cd68', 'Clec7a', 'Lpl',
    'Cst7','Trem2', 'Apoe',
    # Universal Activated Myeloid state
    'Cdk8', 'Cmss1', 'Lars2',
    # Interferon
    'Ifit1', 'Irf7', 'Stat1',
    # Proliferative
    'Mki67', 'Top2a', 'Ccna2',
    # General immune
    'Ptprc', 'Itgam', 'Itgax'
]

# ============================================================
# VERIFY DATA AND LAYERS
# ============================================================
if cluster_key not in adata.obs:
    raise ValueError(f"Column {cluster_key} not found in adata.obs")

# Ensure we keep the original order of cell_type categories (no dendrogram reordering)
if adata.obs[cluster_key].dtype.name != 'category':
    adata.obs[cluster_key] = adata.obs[cluster_key].astype(str).astype('category')
# Mark as ordered without changing the existing category order
adata.obs[cluster_key] = adata.obs[cluster_key].cat.set_categories(
    list(adata.obs[cluster_key].cat.categories), ordered=True
)

use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer is None:
    print("Note: Using adata.X for gene expression (no 'normalized' layer found)")

# Keep only genes present in data; warn on missing
available_genes = [g for g in selected_genes if g in adata.var_names]
missing_genes = [g for g in selected_genes if g not in adata.var_names]
if missing_genes:
    print(f'Warning: Genes not found in adata.var_names: {", ".join(missing_genes)}')
if not available_genes:
    raise ValueError('No valid genes in selected list for plotting.')

# ============================================================
# COLORMAP AND STYLE
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

configure_plot_style()

# Grey→plasma colormap
plasma_r = plt.cm.plasma_r
n_colors = 256
plasma_colors = plasma_r(np.linspace(0, 1, n_colors))
grey_to_plasma_transition = 0.2
n_grey_colors = int(n_colors * grey_to_plasma_transition)
light_grey = np.array([0.85, 0.85, 0.85, 1.0])
transition_color = plasma_colors[n_grey_colors]
grey_gradient = np.linspace(light_grey, transition_color, n_grey_colors)
custom_colors = np.vstack([grey_gradient, plasma_colors[n_grey_colors:]])
dotplot_colormap = LinearSegmentedColormap.from_list('grey_plasma', custom_colors)

# ============================================================
# DOTPLOT (no dendrogram; original order)
# ============================================================
fig, ax = plt.subplots(figsize=(plot_width, plot_height), dpi=400)
dotplot = sc.pl.dotplot(
    adata,
    var_names=available_genes,
    groupby=cluster_key,
    layer=use_layer,
    dendrogram=False,        # <-- no dendrogram; keep original order
    ax=ax,
    return_fig=True,
    dot_max=0.8,
    dot_min=0.05,
    colorbar_title='Mean expression\nin group',
    size_title='% of cells\nexpressing gene',
    cmap=dotplot_colormap,
    swap_axes=True,          # genes on x-axis
    var_group_rotation=90
)

# Thin dot borders
main_ax = dotplot.get_axes()['mainplot_ax']
for collection in main_ax.collections:
    collection.set_linewidths(0)

# X-axis labels 90°
main_ax.tick_params(axis='x', labelsize=12, rotation=90)
for label in main_ax.get_xticklabels():
    label.set_rotation(90)
    label.set_ha('center')

# Ticks/spines
main_ax.tick_params(axis='both', bottom=True, left=True, labelbottom=True, labelleft=True, length=4, direction='out')
main_ax.spines['bottom'].set_visible(True)
main_ax.spines['left'].set_visible(True)

plt.suptitle('Selected Gene Dotplot by Cell Type (Normalized)', y=1.05, fontsize=16)
plt.subplots_adjust(bottom=0.22)
plt.tight_layout()

# Save and show (400 DPI)
plt.savefig(output_dir / 'png' / 'rna_dotplot_cell_type_selected_genes.png', dpi=400, bbox_inches='tight')
plt.savefig(output_dir / 'pdf' / 'rna_dotplot_cell_type_selected_genes.pdf', dpi=400, bbox_inches='tight', format='pdf')
print('Saved: rna_dotplot_cell_type_selected_genes.png/.pdf')
show_inline_plot()
plt.close()

### Gene Expression UMAPs

Visualize expression patterns of marker genes on UMAP projections. This helps:
- Validate cluster identities
- Identify cell types based on canonical markers
- Assess quality of normalization and integration

Gene expression is visualized using normalized (log1p) data.

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# Genes to visualize (mouse gene symbols)
# Classic marker genes for major cell types
genes = [
    # Homeostatic Microglia (resting, surveillance state)
    'Tmem119', 'P2ry12', 'Cx3cr1', 'Sall1', 'Fcrls', 'Siglech', 'Hexb',

    # Activated/Pro-inflammatory Microglia (M1-like, immune-activated)
    'Il1b', 'Tnf', 'Cd68', 'Nos2', 'Cd86',

    # Anti-inflammatory/Repair Microglia (M2-like, tissue repair)
    'Arg1', 'Mrc1', 'Il10', 'Tgfb1', 'Ym1',

    # Disease-Associated Microglia (DAM, neurodegenerative contexts)
    'Trem2', 'Apoe', 'Cst7', 'Lpl', 'Tyrobp', 'Clec7a',

    # Aged Microglia (altered in aging brain)
    'Ccl2', 'C1qa', 'B2m',

    # Proliferative Microglia (dividing, injury/disease)
    'Mki67', 'Top2a', 'Cdk1', 'Ccna2', 'Birc5',

    # Interferon-Responsive Microglia (viral/interferon response)
    'Ifit1', 'Irf7', 'Stat1', 'Cxcl10', 'Isg15',

    # Phagocytic Microglia (enhanced phagocytosis)
    'Trem2', 'Cd68', 'Mertk', 'Axl', 'C1qa',

    # General Immune Cell Marker
    'Ptprc','H2-Aa', 'H2-Ab1',
    'Mrc1', 'Ccr2', 'Ly6g', 'Ly6c2', 'Ms4a7',
    'Cdk8', 'Cmss1', 'Lars2',
    'H2-Q7', 'H2-Aa', 'H2-Ab1', 'H2-Q4', 'Cd74'   # CD45 - all immune cells
]

# Verify layers and UMAP
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer is None:
    print("Note: Using adata.X for gene expression (no 'normalized' layer found)")
if 'X_umap' not in adata.obsm:
    raise ValueError("X_umap not found in adata.obsm")

# ============================================================
# CREATE CUSTOM GREY-PLASMA COLORMAP
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Get plasma_r colormap
plasma_r = plt.cm.plasma_r
n_colors = 256
plasma_colors = plasma_r(np.linspace(0, 1, n_colors))

# Replace first 20% with grey gradient
grey_to_plasma_transition = 0.2
n_grey_colors = int(n_colors * grey_to_plasma_transition)

# Create grey gradient from light grey to first plasma color
light_grey = np.array([0.85, 0.85, 0.85, 1.0])
transition_color = plasma_colors[n_grey_colors]
grey_gradient = np.linspace(light_grey, transition_color, n_grey_colors)

# Combine grey gradient with plasma_r
custom_colors = np.vstack([grey_gradient, plasma_colors[n_grey_colors:]])
gene_colormap = LinearSegmentedColormap.from_list('grey_plasma', custom_colors)

print('✓ Created custom grey-plasma colormap for gene expression')

# Configure plot style
configure_plot_style()

# ============================================================
# GENERATE GENE EXPRESSION UMAPs
# ============================================================

# Check available genes
available_genes_norm = [g for g in genes if g in adata.var_names]
missing_genes_norm = [g for g in genes if g not in adata.var_names]

if missing_genes_norm:
    print(f'Warning: Genes not found in adata.var_names: {", ".join(missing_genes_norm)}')
if not available_genes_norm:
    raise ValueError('No valid genes for plotting.')

# Set up plot dimensions
n_cols = 7  # Adjust accordingly
n_rows = (len(available_genes_norm) + n_cols - 1) // n_cols  # Calculate rows needed
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), sharex=True, sharey=True)
axes = axes.flatten()

for i, gene in enumerate(available_genes_norm):
    sc.pl.umap(
        adata,
        color=gene,
        layer=use_layer,
        ax=axes[i],
        show=False,
        legend_loc='none',
        cmap=gene_colormap,  # Use custom grey-plasma colormap
        frameon=False,
        size=20,
        vmin='p5',
        vmax='p99'
    )
    axes[i].set_title(gene, fontsize=20)  # Increased font size for legibility

for j in range(len(available_genes_norm), len(axes)):
    axes[j].set_visible(False)

plt.suptitle('Gene Expression UMAPs', y=1.05, fontsize=18)
plt.tight_layout()
plt.savefig(output_dir / 'png' / 'rna_umap_normalized.png', dpi=300, bbox_inches='tight')
plt.savefig(output_dir / 'pdf' / 'rna_umap_normalized.pdf', dpi=300, bbox_inches='tight', format='pdf')
print(f'Saved: rna_umap_normalized.png/.pdf')
show_inline_plot()
plt.close()

### Gene Expression UMAPs (Selected Genes, 3 columns)
UMAPs of a curated set of marker genes (from `selected_genes`) using the normalized layer (if available). Layout uses 3 columns per row for readability.

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION (Selected genes; 3 columns layout)
# ============================================================
selected_genes = [
    # Homeostatic
    'Tmem119', 'P2ry12', 'Cx3cr1',
    'Sall1', 'Fcrls', 'Hexb',
    # Activated/Phagocytic/Intermediate
    'Cd68', 'Clec7a', 'Lpl',
    'Cst7','Trem2', 'Apoe',
    # Universal Activated Myeloid state
    'Cdk8', 'Cmss1', 'Lars2',
    # Interferon
    'Ifit1', 'Irf7', 'Stat1',
    # Proliferative
    'Mki67', 'Top2a', 'Ccna2',
    # General immune
    'Ptprc', 'Itgam', 'Itgax'
]
n_cols = 3  # columns per row
dot_size = 20
title_fontsize = 16

# Verify layer and UMAP
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer is None:
    print("Note: Using adata.X for gene expression (no 'normalized' layer found)")
if 'X_umap' not in adata.obsm:
    raise ValueError("X_umap not found in adata.obsm")

# Filter genes to those present
available_genes_sel = [g for g in selected_genes if g in adata.var_names]
missing_genes_sel = [g for g in selected_genes if g not in adata.var_names]
if missing_genes_sel:
    print(f'Warning: Genes not found in adata.var_names: {", ".join(missing_genes_sel)}')
if not available_genes_sel:
    raise ValueError('No valid genes in selected list for plotting.')

# Reuse grey→plasma colormap and style
configure_plot_style()

fig_h_unit = 4
fig_w_unit = 5
n_rows = (len(available_genes_sel) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w_unit * n_cols, fig_h_unit * n_rows), sharex=True, sharey=True)
axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]

for i, gene in enumerate(available_genes_sel):
    sc.pl.umap(
        adata,
        color=gene,
        layer=use_layer,
        ax=axes[i],
        show=False,
        legend_loc='none',
        cmap=gene_colormap,  # reuse custom grey-plasma colormap
        frameon=False,
        size=dot_size,
        vmin='p5',
        vmax='p99'
    )
    axes[i].set_title(gene, fontsize=title_fontsize)

# Hide any unused axes
for j in range(len(available_genes_sel), len(axes)):
    axes[j].set_visible(False)

plt.suptitle('Gene Expression UMAPs (Selected Genes)', y=1.05, fontsize=18)
plt.tight_layout()
plt.savefig(output_dir / 'png' / 'rna_umap_selected_genes.png', dpi=400, bbox_inches='tight')
plt.savefig(output_dir / 'pdf' / 'rna_umap_selected_genes.pdf', dpi=400, bbox_inches='tight', format='pdf')
print('Saved: rna_umap_selected_genes.png/.pdf')
show_inline_plot()
plt.close()

### Violin plots: selected genes by cell_type and group (3 columns)

This section visualizes expression distributions of a curated gene set across `cell_type`, split by experimental groups (`PBS-CFA` vs `PEP15-CFA`). Each subplot shows:
- Split violins per group with consistent colors (PBS-CFA: #3C5488, PEP15-CFA: #E64B35)
- Median line (white) for each split
- Jittered points (subsampled when large) for per-cell values
- 3-column layout for readability

Data use the normalized layer when available; otherwise `adata.X`. Cell type order follows the original category order (no dendrogram reordering). Figures are saved at 400 DPI for publication-quality rasterized dots.

In [None]:
# Ensure inline plotting
%matplotlib inline

# ============================================================
# USER CONFIGURATION (Violin plots for selected genes; 3 columns)
# ============================================================
cluster_key = 'cell_type'          # Use cell_type (not somen_clusters or res)
group_column = 'group'             # Column with experimental groups
group_1 = 'PBS-CFA'                # First group
group_2 = 'PEP15-CFA'              # Second group

# Colors (matching previous bar plot)
group_colors = {group_1: '#3C5488', group_2: '#E64B35'}

# Selected genes (curate as needed)
selected_genes = [
    # Homeostatic
    'Tmem119', 'P2ry12', 'Cx3cr1',
    'Sall1', 'Fcrls', 'Hexb',
    # Activated/Phagocytic/Intermediate
    'Cd68', 'Clec7a', 'Lpl',
    'Cst7', 'Trem2', 'Apoe',
    # Universal Activated Myeloid state
    'Cdk8', 'Cmss1', 'Lars2',
    # Interferon
    'Ifit1', 'Irf7', 'Stat1',
    # Proliferative
    'Mki67', 'Top2a', 'Ccna2',
    # General immune
    'Ptprc', 'Itgam', 'Itgax'
]

# Plot settings
n_cols = 3                         # Number of columns in grid

# Control subplot size (inches) and x-axis label rotation here
subplot_width = 4.5                  # Width of each subplot (inches)
subplot_height = 3.5                 # Height of each subplot (inches)
x_tick_rotation = 45               # Choose 0, 45, or 90 degrees
x_tick_ha = 'center' if x_tick_rotation == 0 else 'right'

point_size = 0.2                   # Size of individual data points
violin_alpha = 0.6                 # Transparency of violins

# ============================================================
# VERIFY DATA
# ============================================================

# Verify columns exist
if cluster_key not in adata.obs:
    raise ValueError(f"Column '{cluster_key}' not found in adata.obs")
if group_column not in adata.obs:
    raise ValueError(f"Column '{group_column}' not found in adata.obs")

# Ensure categorical for stable ordering
if adata.obs[cluster_key].dtype.name != 'category':
    adata.obs[cluster_key] = adata.obs[cluster_key].astype(str).astype('category')

# Check if normalized layer exists
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer is None:
    print("Note: Using adata.X for gene expression (no 'normalized' layer found)")
else:
    print(f"Using layer: {use_layer}")

# Check available genes
available_genes = [g for g in selected_genes if g in adata.var_names]
missing_genes = [g for g in selected_genes if g not in adata.var_names]

if missing_genes:
    print(f'\nWarning: Genes not found: {", ".join(missing_genes)}')
if not available_genes:
    raise ValueError('No valid genes for plotting.')

print(f'\nPlotting {len(available_genes)} genes: {", ".join(available_genes)}')

# ============================================================
# PREPARE DATA
# ============================================================
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import numpy as np

# Get clusters and groups
clusters = adata.obs[cluster_key].cat.categories
groups = [group_1, group_2]

print(f'\nCell types: {list(clusters)}')
print(f'Groups: {groups}')

# Configure plot style
configure_plot_style()

# ============================================================
# CREATE COMBINED VIOLIN PLOT GRID
# ============================================================

# Calculate grid dimensions
n_genes = len(available_genes)
n_rows = (n_genes + n_cols - 1) // n_cols  # Ceiling division

# Create figure with subplots
fig, axes = plt.subplots(
    n_rows, n_cols,
    figsize=(subplot_width * n_cols, subplot_height * n_rows),
    dpi=400
)

# Flatten axes array for easier iteration
if n_genes == 1:
    axes = [axes]
else:
    axes = axes.flatten()

print(f'\nCreating {n_rows} x {n_cols} grid for {n_genes} genes...')

# Plot each gene
for gene_idx, gene in enumerate(available_genes):
    print(f'  Plotting {gene}...')
    ax = axes[gene_idx]
    
    # Extract gene expression
    if use_layer:
        expr = adata[:, gene].layers[use_layer]
    else:
        expr = adata[:, gene].X
    # Dense array
    gene_expr = expr.toarray().flatten() if hasattr(expr, 'toarray') else np.array(expr).flatten()
    
    # Create dataframe for plotting
    plot_df = pd.DataFrame({
        'expression': gene_expr,
        'cell_type': adata.obs[cluster_key].values,
        'group': adata.obs[group_column].values
    })
    
    # Filter to only include the two groups of interest
    plot_df = plot_df[plot_df['group'].isin(groups)]
    
    # Positions for cell types
    ct_order = list(clusters)
    n_ct = len(ct_order)
    positions = np.arange(n_ct)
    width = 0.4  # Width of each violin
    
    # Plot violins for each cell type and group
    for i, ct in enumerate(ct_order):
        for grp in groups:
            subset = plot_df[(plot_df['cell_type'] == ct) & (plot_df['group'] == grp)]
            if len(subset) == 0:
                continue
            
            x_offset = -width/2 if grp == group_1 else width/2
            
            parts = ax.violinplot(
                [subset['expression'].values],
                positions=[i + x_offset],
                widths=width,
                showmeans=False,
                showextrema=False,
                showmedians=False
            )
            # Color violins
            for pc in parts['bodies']:
                pc.set_facecolor(group_colors[grp])
                pc.set_alpha(violin_alpha)
                pc.set_edgecolor('black')
                pc.set_linewidth(0.8)
            
            # Median line
            median_val = np.median(subset['expression'].values)
            ax.plot([i + x_offset - width/2.5, i + x_offset + width/2.5],
                    [median_val, median_val],
                    color='white', linewidth=2, zorder=10)
            
            # Add individual points (jittered)
            if len(subset) > 500:
                subset_sample = subset.sample(n=500, random_state=42)
            else:
                subset_sample = subset
            
            np.random.seed(42)
            x_jitter = np.random.normal(i + x_offset, width/8, size=len(subset_sample))
            ax.scatter(
                x_jitter,
                subset_sample['expression'].values,
                s=point_size,
                color='black',
                alpha=0.3,
                zorder=5,
                rasterized=True
            )
    
    # Customize subplot
    ax.set_xticks(positions)
    ax.set_xticklabels(ct_order, fontsize=9, rotation=x_tick_rotation, ha=x_tick_ha)  # rotation control
    # Ensure x-axis ticks are visible
    ax.tick_params(axis='x', which='both', bottom=True, top=False, length=4, width=0.8)
    ax.xaxis.set_tick_params(direction='out')
    
    ax.set_xlabel('Cell Type', fontsize=10, fontweight='bold')
    ax.set_ylabel('Normalized Expression', fontsize=10, fontweight='bold')
    ax.set_title(gene, fontsize=12, fontweight='bold', pad=10)
    
    # Grid and spines
    ax.yaxis.grid(True, linestyle='--', alpha=0.3, zorder=0)
    ax.set_axisbelow(True)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Figure-level legend (bottom), avoid blocking subplots
legend_elements = [
    Patch(facecolor=group_colors[group_1], alpha=violin_alpha, edgecolor='black', label=group_1),
    Patch(facecolor=group_colors[group_2], alpha=violin_alpha, edgecolor='black', label=group_2)
]
fig.legend(handles=legend_elements, loc='lower center', ncol=2, frameon=True,
           edgecolor='black', fontsize=10, bbox_to_anchor=(0.5, -0.02))

# Adjust layout to make room for bottom legend
plt.tight_layout(rect=[0, 0.05, 1, 0.98])

# Save combined plot
plot_name = f'{cluster_key}_violin_selected_genes_by_group'
save_plot(plot_name, close=False)  # saves PNG/PDF at 400 DPI
show_inline_plot()
plt.close()

print(f'\n✓ Saved: {plot_name}.png/.pdf')
print('='*60)

## Export Cell Counts: Metadata


In [None]:
# ============================================================
# EXPORT CELL COUNTS: CELL_TYPE × METADATA
# ============================================================

# User configuration
cluster_key = 'cell_type'
groupby_col = 'group'  # Change to 'library', 'mouse', or other metadata column

print('='*80)
print(f'EXPORTING CELL COUNTS PER CELL_TYPE AND {groupby_col.upper()}')
print('='*80)

# Check if column exists
if groupby_col not in adata.obs:
    print(f'⚠ Column "{groupby_col}" not found in adata.obs')
    print(f'Available columns: {list(adata.obs.columns)}')
elif cluster_key not in adata.obs:
    print(f'⚠ Column "{cluster_key}" not found in adata.obs')
else:
    # Create crosstab: cell_type (rows) × groupby column (columns)
    count_table = pd.crosstab(
        adata.obs[cluster_key],
        adata.obs[groupby_col],
        margins=True,  # Add row and column totals
        margins_name='Total'
    )
    
    # Keep original category order if categorical; otherwise sort index
    if adata.obs[cluster_key].dtype.name == 'category':
        celltype_order = list(adata.obs[cluster_key].cat.categories)
        # Append 'Total' at the end
        celltype_order = [c for c in celltype_order if c in count_table.index] + ['Total']
        count_table = count_table.reindex(celltype_order)
    else:
        # Fallback: keep 'Total' last after sorted labels
        ct_order = sorted([c for c in count_table.index if c != 'Total'])
        ct_order.append('Total')
        count_table = count_table.reindex(ct_order)
    
    # Display preview
    print('\nPreview of count table:')
    print(count_table.head(10))
    print(f'\nTable shape: {count_table.shape[0]} cell types × {count_table.shape[1]} groups')
    
    # Export to Excel
    excel_file = output_dir / f'cell_counts_{cluster_key}_by_{groupby_col}.xlsx'
    
    with pd.ExcelWriter(excel_file, engine='openpyxl') as writer:
        # Sheet 1: Raw counts
        count_table.to_excel(writer, sheet_name='Cell_Counts')
        
        # Sheet 2: Percentages (% of each group in each cell_type)
        pct_by_group = pd.crosstab(
            adata.obs[cluster_key],
            adata.obs[groupby_col],
            normalize='columns'
        ) * 100
        if adata.obs[cluster_key].dtype.name == 'category':
            pct_by_group = pct_by_group.reindex(adata.obs[cluster_key].cat.categories)
        pct_by_group.to_excel(writer, sheet_name=f'Percent_by_{groupby_col}')
        
        # Sheet 3: Percentages (% of each cell_type from each group)
        pct_by_celltype = pd.crosstab(
            adata.obs[cluster_key],
            adata.obs[groupby_col],
            normalize='index'
        ) * 100
        if adata.obs[cluster_key].dtype.name == 'category':
            pct_by_celltype = pct_by_celltype.reindex(adata.obs[cluster_key].cat.categories)
        pct_by_celltype.to_excel(writer, sheet_name='Percent_by_CellType')
    
    print(f'\n✓ Saved Excel file: {excel_file.name}')
    print(f'  Contains 3 sheets:')
    print(f'    1. Cell_Counts - Raw cell counts')
    print(f'    2. Percent_by_{groupby_col} - % of each group in each cell_type')
    print(f'    3. Percent_by_CellType - % of each cell_type from each group')

print('='*80)

# ============================================================
# BARPLOT: GROUP COUNTS (Total cells per group)
# ============================================================
print('\n' + '='*80)
print('GENERATING BARPLOT: GROUP COUNTS')
print('='*80)

if 'group' in adata.obs:
    configure_plot_style()
    
    # Get total cell counts per group
    group_counts = adata.obs['group'].value_counts().sort_index()
    
    # Define publication-quality colors for groups
    group_colors = ['#1f77b4', '#ff7f0e']  # Blue and orange
    
    # Create figure
    fig, ax = plt.subplots(figsize=(8, 6), dpi=400)
    
    # Create bar plot with two colors
    bars = ax.bar(range(len(group_counts)), group_counts.values, 
                  color=group_colors[:len(group_counts)])
    
    # Set x-axis labels
    ax.set_xticks(range(len(group_counts)))
    ax.set_xticklabels(group_counts.index, rotation=0, ha='center')
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height):,}',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    ax.set_title('Cell Counts by Group', fontsize=14, fontweight='bold', pad=15)
    ax.set_xlabel('Group', fontsize=12)
    ax.set_ylabel('Cell Count', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(output_dir / 'png' / 'cell_counts_barplot_group.png', dpi=400, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / 'cell_counts_barplot_group.pdf', dpi=400, bbox_inches='tight', format='pdf')
    print('✓ Saved barplot: cell_counts_barplot_group.png/.pdf')
    plt.close()

# ============================================================
# BARPLOT: MOUSE COUNTS (Total cells per mouse, colored by group)
# ============================================================
print('\n' + '='*80)
print('GENERATING BARPLOT: MOUSE COUNTS')
print('='*80)

if 'mouse' in adata.obs and 'group' in adata.obs:
    configure_plot_style()
    
    # Reorder mouse categories: PBS first (ascending), then PEP15 (ascending)
    if adata.obs['mouse'].dtype.name != 'category':
        adata.obs['mouse'] = adata.obs['mouse'].astype(str).astype('category')

    mice = list(adata.obs['mouse'].cat.categories)

    # Helper to parse like "HTO_PBS-2" -> ("PBS", 2)
    def parse_mouse(m):
        try:
            prefix, num = m.split('-')
            cond = 'PBS' if 'PBS' in prefix.upper() else ('PEP15' if 'PEP15' in prefix.upper() else prefix)
            return cond, int(num)
        except Exception:
            return m, float('inf')

    pbs = [m for m in mice if 'PBS' in m.upper()]
    pep15 = [m for m in mice if 'PEP15' in m.upper()]

    pbs_sorted = sorted(pbs, key=parse_mouse)      # HTO_PBS-1, HTO_PBS-2, ...
    pep15_sorted = sorted(pep15, key=parse_mouse)  # HTO_PEP15-1, ..., HTO_PEP15-6

    new_order = pbs_sorted + pep15_sorted + [m for m in mice if m not in pbs and m not in pep15]
    adata.obs['mouse'] = adata.obs['mouse'].cat.reorder_categories(new_order, ordered=True)
    
    # Get total cell counts per mouse following the new order
    mouse_counts = adata.obs['mouse'].value_counts().reindex(adata.obs['mouse'].cat.categories).dropna()
    
    # Map each mouse to its group for coloring
    mouse_to_group = adata.obs.groupby('mouse')['group'].first()
    
    # Define publication-quality colors matching group colors
    group_colors = ['#1f77b4', '#ff7f0e']  # Blue and orange
    unique_groups = adata.obs['group'].cat.categories.tolist()
    group_color_map = {group: group_colors[i] for i, group in enumerate(unique_groups)}
    
    # Get colors for each mouse based on their group
    mouse_colors = [group_color_map.get(mouse_to_group.get(mouse, unique_groups[0])) for mouse in mouse_counts.index]
    
    # Create figure
    fig, ax = plt.subplots(figsize=(10, 6), dpi=400)
    
    # Create bar plot with group-based colors
    bars = ax.bar(range(len(mouse_counts)), mouse_counts.values, color=mouse_colors)
    
    # Set x-axis labels
    ax.set_xticks(range(len(mouse_counts)))
    ax.set_xticklabels(mouse_counts.index, rotation=45, ha='right')
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height):,}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    ax.set_title('Cell Counts by Mouse', fontsize=14, fontweight='bold', pad=15)
    ax.set_xlabel('Mouse', fontsize=12)
    ax.set_ylabel('Cell Count', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(output_dir / 'png' / 'cell_counts_barplot_mouse.png', dpi=400, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / 'cell_counts_barplot_mouse.pdf', dpi=400, bbox_inches='tight', format='pdf')
    print('✓ Saved barplot: cell_counts_barplot_mouse.png/.pdf')
    show_inline_plot()
    plt.close()

print('='*80)

## Cluster Frequency Comparison Between Groups

### Overview
This section calculates and visualizes cluster frequencies between two selected groups within a chosen metadata category. 

**Features**:
- Interactive selection of metadata category (e.g., batch, library, group)
- Selection of two groups to compare
- Statistical testing (Fisher's exact test) for each cluster
- Publication-quality plot with significance markers (*, **, ***)
- Excel export with frequencies and exact p-values

**Statistical Approach**:
- **Fisher's Exact Test**: Tests whether the proportion of cells in each cluster differs between groups
- **Significance Levels**: `* p<0.05`, `** p<0.01`, `*** p<0.001`
- **Multiple Testing**: Consider applying Bonferroni or FDR correction if comparing many clusters

In [None]:
# Ensure inline plotting
%matplotlib inline

# Import scipy.stats for statistical tests
import scipy.stats as stats

# ============================================================
# CONFIGURATION
# ============================================================
filter_batch_value = 'Mistri'         # Batch to filter to
comparison_category = 'group'         # Category to compare (e.g., 'group')
group_1 = 'PBS-CFA'                   # First group
group_2 = 'PEP15-CFA'                 # Second group
replicate_column = 'mouse'            # Column containing biological replicates

# X-axis label rotation for plots: choose 0, 45, or 90
x_tick_rotation = 45
x_align = 'center' if x_tick_rotation == 0 else 'right'

# Figure size (in inches) for the summary bar plot
fig_width = 10
fig_height = 10

# Plot colors (optional customization)
bar_colors = {group_1: '#3C5488', group_2: '#E64B35'}

# ============================================================
# FILTER DATA
# ============================================================
cluster_key = 'cell_type'  # <-- use cell_type instead of resolution-based clusters

# Filter to specific batch
batch_mask = adata.obs['batch'] == filter_batch_value
adata_filtered = adata[batch_mask].copy()

print("="*80)
print(f"CELL TYPE FREQUENCY COMPARISON WITH BIOLOGICAL REPLICATES")
print("="*80)
print(f"\nBatch: {filter_batch_value}")
print(f"Comparing: {group_1} vs {group_2}")
print(f"Biological replicates: {replicate_column}")
print(f"Grouping: {cluster_key}")

# Get masks for each group
mask_g1 = adata_filtered.obs[comparison_category] == group_1
mask_g2 = adata_filtered.obs[comparison_category] == group_2

# Get unique mice in each group
mice_g1 = adata_filtered.obs[mask_g1][replicate_column].unique()
mice_g2 = adata_filtered.obs[mask_g2][replicate_column].unique()

print(f"\n{group_1} mice: {list(mice_g1)} (n={len(mice_g1)})")
print(f"{group_2} mice: {list(mice_g2)} (n={len(mice_g2)})")
print("="*80)

# ============================================================
# STATISTICAL TEST SELECTION
# ============================================================
# For small sample sizes (n < 30 per group), Mann-Whitney U test (non-parametric) is preferred
statistical_test = "Mann-Whitney U test (non-parametric)"
print(f"\nStatistical test: {statistical_test}")
print(f"Rationale: Suitable for small sample sizes (n={len(mice_g1)} vs n={len(mice_g2)}),")
print(f"           doesn't assume normality, more robust to outliers")

# Get all cell types
if adata_filtered.obs[cluster_key].dtype.name != 'category':
    adata_filtered.obs[cluster_key] = adata_filtered.obs[cluster_key].astype(str).astype('category')
all_clusters = adata_filtered.obs[cluster_key].cat.categories

# ============================================================
# CALCULATE PER-MOUSE FREQUENCIES
# ============================================================

# Store per-mouse data
per_mouse_data = []

# Group 1 mice
for mouse in mice_g1:
    mouse_mask = mask_g1 & (adata_filtered.obs[replicate_column] == mouse)
    n_cells_mouse = mouse_mask.sum()
    
    for cluster in all_clusters:
        cluster_count = ((adata_filtered.obs[cluster_key] == cluster) & mouse_mask).sum()
        cluster_freq = (cluster_count / n_cells_mouse * 100) if n_cells_mouse > 0 else 0
        
        per_mouse_data.append({
            'Group': group_1,
            'Mouse': mouse,
            'Cluster': cluster,
            'Cell_Count': cluster_count,
            'Total_Cells': n_cells_mouse,
            'Frequency_%': cluster_freq
        })

# Group 2 mice
for mouse in mice_g2:
    mouse_mask = mask_g2 & (adata_filtered.obs[replicate_column] == mouse)
    n_cells_mouse = mouse_mask.sum()
    
    for cluster in all_clusters:
        cluster_count = ((adata_filtered.obs[cluster_key] == cluster) & mouse_mask).sum()
        cluster_freq = (cluster_count / n_cells_mouse * 100) if n_cells_mouse > 0 else 0
        
        per_mouse_data.append({
            'Group': group_2,
            'Mouse': mouse,
            'Cluster': cluster,
            'Cell_Count': cluster_count,
            'Total_Cells': n_cells_mouse,
            'Frequency_%': cluster_freq
        })

per_mouse_df = pd.DataFrame(per_mouse_data)

print("\n✓ Calculated per-mouse cell type frequencies")
print(f"  Data points: {len(per_mouse_df)} (cell types × mice)")
print(f"  Mean cells per mouse (Group 1): {per_mouse_df[per_mouse_df['Group']==group_1]['Total_Cells'].mean():.0f}")
print(f"  Mean cells per mouse (Group 2): {per_mouse_df[per_mouse_df['Group']==group_2]['Total_Cells'].mean():.0f}")

# ============================================================
# STATISTICAL TESTING (MANN-WHITNEY U TEST)
# ============================================================
print(f"\nPerforming {statistical_test}...")

stat_results = []

for cluster in all_clusters:
    # Get frequencies for this cell type from all mice in each group
    freq_g1 = per_mouse_df[(per_mouse_df['Group'] == group_1) & 
                            (per_mouse_df['Cluster'] == cluster)]['Frequency_%'].values
    freq_g2 = per_mouse_df[(per_mouse_df['Group'] == group_2) & 
                            (per_mouse_df['Cluster'] == cluster)]['Frequency_%'].values
    
    # Calculate mean and SEM for each group
    mean_g1 = freq_g1.mean() if len(freq_g1) > 0 else np.nan
    sem_g1 = (freq_g1.std() / np.sqrt(len(freq_g1))) if len(freq_g1) > 1 else np.nan
    mean_g2 = freq_g2.mean() if len(freq_g2) > 0 else np.nan
    sem_g2 = (freq_g2.std() / np.sqrt(len(freq_g2))) if len(freq_g2) > 1 else np.nan
    
    # Perform Mann-Whitney U test (non-parametric)
    if len(freq_g1) > 1 and len(freq_g2) > 1:
        statistic, p_val = stats.mannwhitneyu(freq_g1, freq_g2, alternative='two-sided')
    else:
        p_val = np.nan
        statistic = np.nan
    
    # Determine significance
    if np.isnan(p_val):
        sig = 'NA'
    elif p_val < 0.001:
        sig = '***'
    elif p_val < 0.01:
        sig = '**'
    elif p_val < 0.05:
        sig = '*'
    else:
        sig = 'ns'
    
    stat_results.append({
        'Cluster': cluster,
        f'{group_1}_mean_%': mean_g1,
        f'{group_1}_SEM_%': sem_g1,
        f'{group_1}_n_mice': len(freq_g1),
        f'{group_2}_mean_%': mean_g2,
        f'{group_2}_SEM_%': sem_g2,
        f'{group_2}_n_mice': len(freq_g2),
        'p_value': p_val,
        'U_statistic': statistic,
        'significance': sig
    })
    
    print(f"{cluster}: {group_1}={mean_g1:.2f}±{(0 if np.isnan(sem_g1) else sem_g1):.2f}%, "
          f"{group_2}={mean_g2:.2f}±{(0 if np.isnan(sem_g2) else sem_g2):.2f}%, p={p_val:.4e}, {sig}")

stat_df = pd.DataFrame(stat_results)

# ============================================================
# EXPORT TO EXCEL
# ============================================================
timestamp = pd.Timestamp.now().strftime('%Y%m%d')
excel_file = f'celltype_freq_by_mouse_{filter_batch_value}_{group_1}_vs_{group_2}_{timestamp}.xlsx'
excel_path = output_dir / excel_file

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Per-mouse data (long format)
    per_mouse_df.to_excel(writer, sheet_name='Per_Mouse_Data', index=False)
    
    # Sheet 2: Statistical comparison
    stat_df.to_excel(writer, sheet_name='Statistical_Comparison', index=False)
    
    # Sheet 3: Per-mouse data pivoted (wide format)
    pivot_counts = per_mouse_df.pivot_table(
        index='Cluster', 
        columns=['Group', 'Mouse'], 
        values='Cell_Count', 
        fill_value=0
    )
    pivot_counts.to_excel(writer, sheet_name='Cell_Counts_by_Mouse')
    
    pivot_freq = per_mouse_df.pivot_table(
        index='Cluster', 
        columns=['Group', 'Mouse'], 
        values='Frequency_%', 
        fill_value=0
    )
    pivot_freq.to_excel(writer, sheet_name='Frequencies_by_Mouse')
    
    # Sheet 5: Summary with statistical test info
    pd.DataFrame({
        'Parameter': ['Batch', 'Group 1', 'Group 2', 'Replicate Column', 
                      f'{group_1}_n_mice', f'{group_2}_n_mice',
                      'Grouping', 'Statistical_Test', 'Significant_clusters', 'Date'],
        'Value': [filter_batch_value, group_1, group_2, replicate_column,
                  len(mice_g1), len(mice_g2), cluster_key, 
                  statistical_test, sum(stat_df['p_value'] < 0.05),
                  pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')]
    }).to_excel(writer, sheet_name='Summary', index=False)

print(f'\n✓ Saved: {excel_file}')

# ============================================================
# PLOT
# ============================================================
configure_plot_style()

n_clusters = len(all_clusters)
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=400)

x = np.arange(n_clusters)
width = 0.35

# Use mean frequencies
mean_g1 = stat_df[f'{group_1}_mean_%'].values
mean_g2 = stat_df[f'{group_2}_mean_%'].values
sem_g1 = stat_df[f'{group_1}_SEM_%'].values
sem_g2 = stat_df[f'{group_2}_SEM_%'].values

# Create bars with error bars (SEM)
ax.bar(x - width/2, mean_g1, width, label=group_1, color=bar_colors[group_1], 
       alpha=0.85, edgecolor='black', linewidth=0.8, yerr=sem_g1, capsize=3)
ax.bar(x + width/2, mean_g2, width, label=group_2, color=bar_colors[group_2], 
       alpha=0.85, edgecolor='black', linewidth=0.8, yerr=sem_g2, capsize=3)

# Add significance markers
for i, row in stat_df.iterrows():
    if row['significance'] not in ['ns', 'NA']:
        y_pos = max((0 if np.isnan(mean_g1[i]) else mean_g1[i]) + (0 if np.isnan(sem_g1[i]) else sem_g1[i]),
                    (0 if np.isnan(mean_g2[i]) else mean_g2[i]) + (0 if np.isnan(sem_g2[i]) else sem_g2[i])) + 1.0
        ax.plot([i - width/2, i + width/2], [y_pos, y_pos], lw=1.2, c='black')
        ax.text(i, y_pos + 0.3, row['significance'], ha='center', va='bottom', 
                fontsize=11, fontweight='bold')

ax.set_xlabel('Cell Type', fontsize=14, fontweight='bold')
ax.set_ylabel('Frequency (% ± SEM)', fontsize=14, fontweight='bold')
ax.set_title(f'{group_1} vs {group_2} ({filter_batch_value} batch)\n'
             f'Biological replicates: {len(mice_g1)} vs {len(mice_g2)} mice', 
             fontsize=16, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels(all_clusters, rotation=x_tick_rotation, ha=x_align)
ax.legend(title='Group', frameon=True, edgecolor='black')
ax.yaxis.grid(True, linestyle='--', alpha=0.3)
ax.set_axisbelow(True)
sns.despine()

plt.tight_layout()
plot_file = f'celltype_freq_by_mouse_{filter_batch_value}_{group_1}_vs_{group_2}'
save_plot(plot_file, close=False)
show_inline_plot()
plt.close()

# Summary
n_sig = sum(stat_df['p_value'] < 0.05)
print(f'\n✓ {n_sig}/{n_clusters} cell types significantly different (p<0.05)')
if n_sig > 0:
    print('\nSignificant cell types:')
    for _, row in stat_df[stat_df['p_value'] < 0.05].iterrows():
        print(f"  {row['Cluster']}: "
              f"{group_1}={row[f'{group_1}_mean_%']:.2f}±{(0 if np.isnan(row[f'{group_1}_SEM_%']) else row[f'{group_1}_SEM_%']):.2f}%, "
              f"{group_2}={row[f'{group_2}_mean_%']:.2f}±{(0 if np.isnan(row[f'{group_2}_SEM_%']) else row[f'{group_2}_SEM_%']):.2f}%, "
              f"p={row['p_value']:.4e} {row['significance']}")

print("\n" + "="*80)

## Marker Gene Identification

### Find Differentially Expressed Genes for Each Cluster

This section identifies marker genes (differentially expressed genes) for each cluster using scanpy's `rank_genes_groups` function, which is equivalent to Seurat's `FindAllMarkers`.

**Methods available:**
- **Wilcoxon rank-sum test** (non-parametric, robust for single-cell data) - Recommended
- **t-test** (parametric, assumes normal distribution)
- **Logreg** (logistic regression, multivariate)

**Output:**
- Excel file with multiple sheets:
  - Top marker genes per cluster
  - Full ranked gene list
  - Summary statistics (log fold-change, p-values, adjusted p-values)

In [None]:
# ============================================================
# MARKER GENE IDENTIFICATION - CELL_TYPE (with pct.1/pct.2)
# ============================================================

import pandas as pd
import numpy as np

# ============================================================
# CONFIGURATION
# ============================================================
cluster_key = 'cell_type'               # Use cell_type grouping
n_top_genes = 100                       # Number of top genes to export per cell_type
method = 'wilcoxon'                     # Method: 'wilcoxon', 't-test', or 'logreg'
filter_by_fc = True                     # Filter by fold-change
min_fold_change = 0.2                   # Minimum log fold-change (log scale)
max_pval = 0.05                         # Maximum adjusted p-value

# Additional Seurat-style filters (optional)
min_pct = 0.1                           # Minimum pct.1 (expression in cell_type)
min_diff_pct = 0.1                      # Minimum difference between pct.1 and pct.2

# Expression threshold for calculating percentages
expr_threshold = 0.0                    # Count cells with expression > this value

print("="*80)
print(f"MARKER GENE IDENTIFICATION (Seurat-style)")
print("="*80)
print(f"Grouping: {cluster_key}")
print(f"Method: {method}")
print(f"Expression threshold for pct calculations: > {expr_threshold}")
print(f"\nFilters:")
print(f"  - avg_log2FC > {min_fold_change}")
print(f"  - p_val_adj < {max_pval}")
print(f"  - pct.1 > {min_pct}")
print(f"  - (pct.1 - pct.2) > {min_diff_pct}")
print("="*80)

# Verify grouping exists
if cluster_key not in adata.obs:
    raise ValueError(f"Grouping '{cluster_key}' not found in adata.obs")

# Ensure categorical
if adata.obs[cluster_key].dtype.name != 'category':
    adata.obs[cluster_key] = adata.obs[cluster_key].astype(str).astype('category')

# ============================================================
# RUN DIFFERENTIAL EXPRESSION ANALYSIS
# ============================================================
print(f"\nRunning {method} test for all cell types...")

# Run rank_genes_groups
sc.tl.rank_genes_groups(
    adata,
    groupby=cluster_key,
    method=method,
    use_raw=False,
    n_genes=None,
    key_added=f'rank_genes_{cluster_key}',
    tie_correct=True
)

print(f"✓ Differential expression analysis complete")

# ============================================================
# MANUALLY CALCULATE PCT.1 AND PCT.2
# ============================================================
print("\nCalculating pct.1 and pct.2 for all genes...")

# Get expression matrix
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer:
    expr_matrix = adata.layers[use_layer]
else:
    expr_matrix = adata.X

# Convert to dense if sparse
if hasattr(expr_matrix, 'toarray'):
    expr_matrix_dense = expr_matrix.toarray()
else:
    expr_matrix_dense = expr_matrix

# Get cell types
clusters = adata.obs[cluster_key].cat.categories
cluster_labels = adata.obs[cluster_key].values

# Calculate pct.1 and pct.2 for each cell type and gene
pct_dict = {}

for cluster in clusters:
    print(f"  Calculating percentages for {cluster}...")
    
    # Mask for cells in this cell type vs. others
    in_cluster = cluster_labels == cluster
    out_cluster = ~in_cluster
    
    # For each gene, calculate % cells expressing
    pct_in = (expr_matrix_dense[in_cluster, :] > expr_threshold).mean(axis=0)
    pct_out = (expr_matrix_dense[out_cluster, :] > expr_threshold).mean(axis=0)
    
    # Store as dictionary: {gene_name: (pct.1, pct.2)}
    pct_dict[cluster] = {
        'pct.1': dict(zip(adata.var_names, pct_in)),
        'pct.2': dict(zip(adata.var_names, pct_out))
    }

print("✓ Percentage calculations complete")

# ============================================================
# EXTRACT RESULTS WITH PCT.1 AND PCT.2
# ============================================================
print("\nExtracting marker genes with Seurat-style columns...")

result_key = f'rank_genes_{cluster_key}'
all_markers = []

for cluster in clusters:
    # Get results for this cell type
    cluster_markers = sc.get.rank_genes_groups_df(
        adata, 
        group=cluster, 
        key=result_key
    )
    
    # Add cluster (cell_type) column
    cluster_markers['cluster'] = cluster
    
    # Rename columns to Seurat style
    column_mapping = {
        'names': 'gene',
        'scores': 'score',
        'logfoldchanges': 'avg_log2FC',
        'pvals': 'p_val',
        'pvals_adj': 'p_val_adj'
    }
    
    cluster_markers = cluster_markers.rename(columns=column_mapping)
    
    # Add pct.1 and pct.2 manually
    cluster_markers['pct.1'] = cluster_markers['gene'].map(pct_dict[cluster]['pct.1'])
    cluster_markers['pct.2'] = cluster_markers['gene'].map(pct_dict[cluster]['pct.2'])
    
    all_markers.append(cluster_markers)

# Combine all cell types
all_markers_df = pd.concat(all_markers, ignore_index=True)

# Calculate pct difference
all_markers_df['pct_diff'] = all_markers_df['pct.1'] - all_markers_df['pct.2']

# Reorder columns for better readability
column_order = ['cluster', 'gene', 'avg_log2FC', 'pct.1', 'pct.2', 'pct_diff', 
                'p_val', 'p_val_adj', 'score']
# Only keep columns that exist
column_order = [col for col in column_order if col in all_markers_df.columns]
all_markers_df = all_markers_df[column_order]

print(f"✓ Extracted {len(all_markers_df):,} gene-cluster comparisons")
print(f"  Columns: {list(all_markers_df.columns)}")

# Show example to verify pct.1 and pct.2
print("\nExample rows (first cell type, top 3 genes):")
print(all_markers_df.head(3).to_string(index=False))

# ============================================================
# FILTER MARKERS (SEURAT-STYLE)
# ============================================================
print("\n" + "="*80)
print("APPLYING FILTERS")
print("="*80)

filtered_markers = all_markers_df.copy()
n_before = len(filtered_markers)

# Filter by log fold-change
if filter_by_fc and 'avg_log2FC' in filtered_markers.columns:
    filtered_markers = filtered_markers[filtered_markers['avg_log2FC'] > min_fold_change]
    print(f"  After avg_log2FC > {min_fold_change}: {len(filtered_markers):,} genes")

# Filter by adjusted p-value
if 'p_val_adj' in filtered_markers.columns:
    filtered_markers = filtered_markers[filtered_markers['p_val_adj'] < max_pval]
    print(f"  After p_val_adj < {max_pval}: {len(filtered_markers):,} genes")

# Filter by minimum pct.1
filtered_markers = filtered_markers[filtered_markers['pct.1'] > min_pct]
print(f"  After pct.1 > {min_pct}: {len(filtered_markers):,} genes")

# Filter by minimum pct difference
filtered_markers = filtered_markers[filtered_markers['pct_diff'] > min_diff_pct]
print(f"  After pct_diff > {min_diff_pct}: {len(filtered_markers):,} genes")

# Sort by cluster and p_val_adj
filtered_markers = filtered_markers.sort_values(['cluster', 'p_val_adj'])

print(f"\n✓ Total filtered: {n_before:,} → {len(filtered_markers):,} significant markers")

# Get top N genes per cell type
top_markers = filtered_markers.groupby('cluster').head(n_top_genes).reset_index(drop=True)
print(f"✓ Top {n_top_genes} genes per cell type: {len(top_markers):,} total")

# ============================================================
# SUMMARY STATISTICS
# ============================================================
print("\n" + "="*80)
print("MARKER GENES PER CELL TYPE")
print("="*80)
for cluster in clusters:
    n_markers = (filtered_markers['cluster'] == cluster).sum()
    print(f"  {cluster}: {n_markers:,} significant markers")

# ============================================================
# EXPORT TO EXCEL
# ============================================================
timestamp = pd.Timestamp.now().strftime('%Y%m%d')
excel_file = f'marker_genes_{cluster_key}_{timestamp}.xlsx'
excel_path = output_dir / excel_file

print(f"\nExporting to Excel: {excel_file}")

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Top N markers per cell type
    top_markers.to_excel(writer, sheet_name=f'Top_{n_top_genes}_per_cell_type', index=False)
    
    # Sheet 2: All significant markers
    filtered_markers.to_excel(writer, sheet_name='All_Significant_Markers', index=False)
    
    # Sheet 3: All genes (unfiltered)
    all_markers_df.to_excel(writer, sheet_name='All_Genes_Unfiltered', index=False)
    
    # Sheet 4: Summary per cell type
    summary_list = []
    for cluster in clusters:
        cluster_data = filtered_markers[filtered_markers['cluster'] == cluster]
        summary_list.append({
            'cell_type': cluster,
            'n_markers': len(cluster_data),
            'mean_log2FC': cluster_data['avg_log2FC'].mean(),
            'max_log2FC': cluster_data['avg_log2FC'].max(),
            'mean_pct.1': cluster_data['pct.1'].mean(),
            'mean_pct.2': cluster_data['pct.2'].mean(),
            'mean_pct_diff': cluster_data['pct_diff'].mean(),
            'min_p_val_adj': cluster_data['p_val_adj'].min()
        })
    
    summary_df = pd.DataFrame(summary_list)
    summary_df.to_excel(writer, sheet_name='Summary', index=False)
    
    # Sheet 5: Analysis parameters
    params_df = pd.DataFrame({
        'Parameter': ['Grouping', 'Method', 'Min_avg_log2FC', 
                      'Max_p_val_adj', 'Min_pct.1', 'Min_pct_diff',
                      'Expr_threshold', 'Top_N_Genes', 'Total_CellTypes', 
                      'Total_Significant_Markers', 'Date'],
        'Value': [cluster_key, method, min_fold_change, 
                  max_pval, min_pct, min_diff_pct, expr_threshold,
                  n_top_genes, len(clusters), 
                  len(filtered_markers), pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')]
    })
    params_df.to_excel(writer, sheet_name='Parameters', index=False)

print(f"✓ Saved: {excel_file}")

# ============================================================
# DISPLAY TOP MARKERS FOR EACH CELL TYPE
# ============================================================
print("\n" + "="*80)
print("TOP 10 MARKER GENES PER CELL TYPE (Seurat-style)")
print("="*80)

for cluster in clusters:
    cluster_top = top_markers[top_markers['cluster'] == cluster].head(10)
    if len(cluster_top) > 0:
        print(f"\n{cluster}:")
        for idx, row in cluster_top.iterrows():
            print(f"  {row['gene']:15s} | log2FC: {row['avg_log2FC']:5.2f} | "
                  f"pct.1: {row['pct.1']:.1%} | pct.2: {row['pct.2']:.1%} | "
                  f"p_adj: {row['p_val_adj']:.2e}")

print("\n" + "="*80)
print("✓ Marker gene identification complete!")
print("="*80)

## Session Information

Document the computational environment and package versions used in this analysis for reproducibility.

**Key Information:**
- System and Python version
- Package versions for single-cell analysis
- Current dataset dimensions
- Analysis timestamp

In [None]:

# ============================================================
# SESSION INFORMATION
# ============================================================

import sys
import platform
import datetime

print('='*60)
print('SESSION INFORMATION')
print('='*60)

# System
print(f'\nSystem: {platform.system()} {platform.release()}')
print(f'Python: {sys.version.split()[0]}')

# Key packages
print('\nPackages:')
packages = ['anndata', 'scanpy', 'numpy', 'pandas', 'scipy', 
            'matplotlib', 'seaborn', 'scvi', 'igraph', 'leidenalg']

for pkg in packages:
    try:
        if pkg == 'scvi':
            import scvi
            p = scvi
        else:
            p = __import__(pkg)
        ver = p.__version__ if hasattr(p, '__version__') else 'unknown'
        print(f'  {pkg:12s} {ver}')
    except ImportError:
        pass

# Dataset
if 'adata' in globals():
    print(f'\nDataset: {adata.n_obs:,} cells × {adata.n_vars:,} genes')

# Date
print(f'\nDate: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M")}')
print('='*60)