# Integrated scRNA-seq Clustering and Visualization Analysis

## Overview
This notebook performs clustering and comprehensive visualization on integrated single-cell RNA-seq data. The analysis uses the batch-corrected AnnData object from integration step.

The analysis pipeline includes:

- **Data Loading**: Load integrated AnnData object (with batch/library correction and cell cycle regression)
- **Clustering**: Leiden clustering at multiple resolutions using integrated latent representations (X_scVI)
- **Visualization**: UMAP projections colored by clusters, metadata, and feature expression
- **Quality Assessment**: Expression patterns across clusters and experimental groups
- **Marker Gene Expression**: Comprehensive marker gene analysis for cell type identification

The workflow is designed for flexibility and reusability. All parameters can be configured at the beginning of each section.

**Input**: `inputs/**.h5ad` (from integration step)

---

## Key Outputs

### Clustering Results
- Leiden clusters at multiple resolutions stored in `adata.obs` (e.g., `leiden_scVI_r0.4`, `leiden_scVI_r1.6`, etc.)
- UMAP coordinates for visualization
- Cell type identification based on marker gene expression

### Visualizations
- Multi-resolution clustering UMAPs
- Metadata-stratified UMAPs (by batch, library, mouse_ID)
- Gene expression UMAPs and dotplots (normalized counts)
- Marker gene expression for: Macrophages, Microglia, DCs, B cells, T cells, Neurons, NK cells

### Processed Data
- Clustered AnnData objects saved as `.h5ad` files
- All plots exported in both PNG (300 DPI) and PDF (fully vector) formats for publication and editing in Affinity Designer

---

## 1. Setup and Configuration

### Library Imports and Directory Structure
Import required libraries for single-cell analysis and configure input/output directories. The notebook assumes the following structure:
- `inputs/`: Contains processed single-cell data (`.h5ad` files)
- `outputs/`: Automatically created for results
  - `outputs/png/`: High-resolution PNG plots (300 DPI)
  - `outputs/pdf/`: Fully vector PDF plots for publication and editing

In [None]:
# Core libraries
import numpy as np
import pandas as pd
from pathlib import Path
from IPython.display import display, HTML
import sys, pkg_resources, datetime

# Single-cell analysis
import anndata as ad
import scanpy as sc
import scipy
import scipy.sparse

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# ============================================================
# DIRECTORY CONFIGURATION
# ============================================================
input_dir = Path('inputs')
output_dir = Path('outputs')
(output_dir / 'png').mkdir(parents=True, exist_ok=True)
(output_dir / 'pdf').mkdir(parents=True, exist_ok=True)  # PDF for vector editing

# ============================================================
# PLOTTING CONFIGURATION
# ============================================================
sc.set_figure_params(dpi=300, figsize=(6, 4))

def configure_plot_style():
    """Apply publication-quality plot styling with major/minor ticks."""
    # Font and text settings
    plt.rcParams['figure.dpi'] = 100
    plt.rcParams['font.family'] = 'Arial'
    plt.rcParams['font.size'] = 10
    
    # Line widths
    plt.rcParams['axes.linewidth'] = 0.8
    plt.rcParams['xtick.major.width'] = 0.8
    plt.rcParams['ytick.major.width'] = 0.8
    plt.rcParams['xtick.minor.width'] = 0.6
    plt.rcParams['ytick.minor.width'] = 0.6
    
    # Tick sizes
    plt.rcParams['xtick.major.size'] = 4
    plt.rcParams['ytick.major.size'] = 4
    plt.rcParams['xtick.minor.size'] = 2
    plt.rcParams['ytick.minor.size'] = 2
    
    # CRITICAL: Ensure fully vector output (NO rasterization)
    plt.rcParams['pdf.fonttype'] = 42  # TrueType fonts (not paths)
    plt.rcParams['svg.fonttype'] = 'none'  # Also for SVG if used
    plt.rcParams['path.simplify'] = False  # Don't simplify paths
    plt.rcParams['path.simplify_threshold'] = 0  # No simplification
    
    sns.set_style('white', {'axes.spines.left': True, 'axes.spines.bottom': True,
                             'axes.spines.top': False, 'axes.spines.right': False})

def show_inline_plot(fig=None):
    """Display plot centered in Jupyter notebook."""
    plt.tight_layout()
    display(HTML('<div style="display: flex; justify-content: center;">'))
    if fig is None:
        plt.show()
    else:
        display(fig)
    display(HTML('</div>'))

def save_plot(name, close=True, high_dpi=300):
    """
    Save current plot in both PNG (300 DPI) and PDF (vector axes, high-DPI raster points).
    
    Parameters
    ----------
    name : str
        Base filename without extension
    close : bool, default=True
        Whether to close the figure after saving
    high_dpi : int, default=300
        DPI for rasterized elements (300 recommended for publication quality, 
        600 for very high zoom)
    """
    plt.tight_layout()
    
    # Save PNG (raster, 300 DPI) - PNG is always raster
    png_path = output_dir / 'png' / f'{name}.png'
    plt.savefig(png_path, dpi=300, bbox_inches='tight')
    print(f'‚úì Saved: {name}.png (300 DPI)')
    
    # Save PDF with HIGH DPI for rasterized points
    # This makes scatter points crisp even at high zoom
    # Axes, text, and legends remain vector
    pdf_path = output_dir / 'pdf' / f'{name}.pdf'
    plt.savefig(pdf_path, bbox_inches='tight', format='pdf', 
                facecolor='white', dpi=high_dpi)
    print(f'‚úì Saved: {name}.pdf (vector axes/text, raster points at {high_dpi} DPI)')
    
    if close:
        plt.close()

# ============================================================
# ORIGINAL dittoSeq PALETTE (Colorblind-Friendly, Publication-Quality)
# ============================================================
# First 15 colors are the most distinct and colorblind-friendly
dittoseq_15_best = [
    "#E5D2DD", "#53A85F", "#F1BB72", "#F3B1A0", "#D6E7A3", "#57C3F3", "#476D87",
    "#E95C59", "#E59CC4", "#AB3282", "#23452F", "#BD956A", "#8C549C", "#585658", "#161616"
]

# Full dittoSeq palette (40 colors)
dittoseq_full = [
    "#E5D2DD", "#53A85F", "#F1BB72", "#F3B1A0", "#D6E7A3", "#57C3F3", "#476D87",
    "#E95C59", "#E59CC4", "#AB3282", "#23452F", "#BD956A", "#8C549C", "#585658",
    "#9FA3A8", "#E0D4CA", "#5F3D69", "#C5DEBA", "#58A4C3", "#E4C755", "#F7F398",
    "#AA9A59", "#E63863", "#E39A35", "#C1E6F3", "#6778CA", "#91D0BE", "#B53E2B",
    "#712820", "#DCC1DD", "#CCE0F5", "#CCC9E6", "#625D9F", "#68A180", "#3A6963",
    "#968175", "#161616", "#FF9999", "#344F31", "#00BB61", "#E8E8E8", "#BFBFBF"
]

def get_dittoseq_colors(n):
    """
    Get dittoSeq colors optimized for n categories.
    First 15 are the most distinct and colorblind-friendly.
    """
    if n <= 15:
        return dittoseq_15_best[:n]
    else:
        return dittoseq_full[:n] if n <= 40 else dittoseq_full * ((n // 40) + 1)[:n]

## 2. Load and Inspect Data

### Load Integrated AnnData Object
Load the integrated AnnData file from previous integration step. This section:
- Loads AnnData object containing integrated RNA expression data
- Verifies data layers (raw counts, normalized data)
- Checks for integration latent representations (X_scVI)
- **Removes all previous clustering results** (leiden columns, UMAP, color palettes)
- Processes categorical metadata variables (batch, library, mouse_ID)

This ensures a clean slate for fresh clustering with your chosen parameters.

**Configuration**:
- Use `integrated_**.h5ad` from integration step

In [None]:
# ============================================================
# USER CONFIGURATION
# ============================================================
# Use the integrated AnnData from integration step: integrated_batch_library_cc.h5ad
# This file contains batch/library corrected data with scVI latent representations
data_filename = 'integrated_library.h5ad'  # Update to your file
latent_key = 'X_scVI'  # Integration method latent key

# Metadata columns to verify and convert to categorical
categorical_vars = ['batch', 'library', 'mouse', 'group']

# ============================================================
# LOAD DATA
# ============================================================
data_file = input_dir / data_filename
if not data_file.exists():
    raise FileNotFoundError(f'Input file not found: {data_file}')

# Load AnnData object
adata = sc.read_h5ad(data_file)
print(f'‚úì Loaded AnnData from: {data_file}')
print(f'  Data shape: {adata.shape}')

# Make cell names unique if needed
if not adata.obs_names.is_unique:
    adata.obs_names_make_unique()
    print('‚úì Made cell identifiers unique')

# ============================================================
# VERIFY DATA STRUCTURE
# ============================================================
print('\n' + '='*60)
print('DATA STRUCTURE')
print('='*60)
print(f'Data shape: {adata.n_obs:,} cells √ó {adata.n_vars:,} genes')
print(f'\nLayers: {list(adata.layers.keys())}')
print(f'Observations (metadata): {list(adata.obs.columns[:10])}...')

# Check latent representation
print('\n' + '='*60)
print('LATENT REPRESENTATIONS')
print('='*60)
if latent_key in adata.obsm:
    print(f'‚úì {latent_key}: {adata.obsm[latent_key].shape}')
else:
    print(f'‚ö† {latent_key} not found in adata.obsm')
    print(f'  Available keys: {list(adata.obsm.keys())}')
    # If X_scVI not found, we may need to run scVI integration first
    if len(adata.obsm.keys()) == 0:
        print('  No embeddings found - you may need to run integration first')

# ============================================================
# CLEAN PREVIOUS CLUSTERING RESULTS
# ============================================================
print('\n' + '='*60)
print('REMOVING PREVIOUS CLUSTERING')
print('='*60)

# Remove old clustering columns
clustering_cols = [col for col in adata.obs.columns if col.startswith('leiden')]
if clustering_cols:
    print(f'Removing {len(clustering_cols)} previous clustering columns:')
    for col in clustering_cols:
        print(f'  ‚Ä¢ {col}')
        del adata.obs[col]
    print(f'‚úì Removed {len(clustering_cols)} clustering column(s)')
else:
    print('‚úì No previous clustering columns found')

# Remove old UMAP if exists (will compute fresh)
if 'X_umap' in adata.obsm:
    print('Removing previous UMAP coordinates')
    del adata.obsm['X_umap']
    print('‚úì Removed X_umap')
else:
    print('‚úì No previous UMAP found')

# Clear any clustering-related color info in uns
color_keys = [key for key in adata.uns.keys() if '_colors' in key and 'leiden' in key]
if color_keys:
    for key in color_keys:
        del adata.uns[key]
    print(f'‚úì Removed {len(color_keys)} color palette(s)')
else:
    print('‚úì No clustering color palettes found')

# ============================================================
# PROCESS CATEGORICAL VARIABLES
# ============================================================
print('\n' + '='*60)
print('CATEGORICAL METADATA')
print('='*60)

# Add 'sex' to the list if it exists in data
vars_to_display = ['group', 'sex', 'library', 'batch']

for col in categorical_vars:
    if col in adata.obs:
        if adata.obs[col].dtype.name != 'category':
            adata.obs[col] = adata.obs[col].astype(str).astype('category')
        n_categories = len(adata.obs[col].cat.categories)
        print(f'‚úì {col}: {n_categories} categories')
    else:
        print(f'‚ö† {col}: Not found in data')

# Also check 'sex' if not in categorical_vars list
if 'sex' not in categorical_vars and 'sex' in adata.obs:
    if adata.obs['sex'].dtype.name != 'category':
        adata.obs['sex'] = adata.obs['sex'].astype(str).astype('category')

# Display unique values for specific variables
print('\n' + '='*60)
print('UNIQUE VALUES IN KEY METADATA COLUMNS')
print('='*60)

for col in vars_to_display:
    if col in adata.obs:
        # Ensure it's categorical
        if adata.obs[col].dtype.name != 'category':
            adata.obs[col] = adata.obs[col].astype(str).astype('category')
        
        unique_vals = adata.obs[col].cat.categories.tolist()
        n_cells_per_val = adata.obs[col].value_counts().to_dict()
        
        print(f'\n{col.upper()}:')
        for val in unique_vals:
            n_cells = n_cells_per_val.get(val, 0)
            print(f'  ‚Ä¢ {val}: {n_cells:,} cells')
    else:
        print(f'\n{col.upper()}: Not found in data')

print('\n' + '='*60)

### Normalize Gene Expression Data

Normalize raw count data for visualization and marker gene analysis. The integrated data (X_scVI) is used for clustering, but normalized expression values are needed for gene expression plots.

**Steps**:
1. Check if data is already normalized
2. Normalize counts (library size normalization to 10,000 counts per cell)
3. Log-transform (log1p)
4. Store in `adata.layers['normalized']` for downstream visualization (UMAP, dotplot)
5. Also store in `adata.X` for compatibility with other analyses

In [None]:
# ============================================================
# DATA NORMALIZATION
# ============================================================
print('='*60)
print('GENE EXPRESSION NORMALIZATION')
print('='*60)

# Check if data appears to be normalized (max value < 20 suggests log-normalized)
max_val = adata.X.max() if hasattr(adata.X, 'max') else adata.X.data.max()
print(f'\nCurrent X matrix max value: {max_val:.2f}')

if max_val > 50:
    print('‚ö†Ô∏è  Data appears to be raw counts (max > 50)')
    print('   Performing normalization...')
    
    # Store raw counts if not already stored
    if 'counts' not in adata.layers:
        adata.layers['counts'] = adata.X.copy()
        print('‚úì Stored raw counts in adata.layers["counts"]')
    
    # Normalize to 10,000 counts per cell and log-transform
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    print('‚úì Normalized to 10,000 counts per cell')
    print('‚úì Log-transformed (log1p)')
    
    # Store normalized data in "normalized" layer
    adata.layers['normalized'] = adata.X.copy()
    print('‚úì Stored normalized data in adata.layers["normalized"]')
    
elif max_val > 20:
    print('‚ö†Ô∏è  Data may be normalized but not log-transformed')
    print('   Applying log transformation...')
    
    # Store non-log data if not already stored
    if 'normalized' not in adata.layers:
        adata.layers['normalized'] = adata.X.copy()
    
    sc.pp.log1p(adata)
    print('‚úì Log-transformed (log1p)')
    
    # Update normalized layer with log-transformed data
    adata.layers['normalized'] = adata.X.copy()
    print('‚úì Updated adata.layers["normalized"] with log-transformed data')
    
else:
    print('‚úì Data appears to be already normalized and log-transformed')
    print('  (max value < 20, typical for log-normalized data)')
    
    # Store in normalized layer if not already present
    if 'normalized' not in adata.layers:
        adata.layers['normalized'] = adata.X.copy()
        print('‚úì Stored normalized data in adata.layers["normalized"]')

# Verify final state
final_max = adata.X.max() if hasattr(adata.X, 'max') else adata.X.data.max()
normalized_max = adata.layers['normalized'].max() if hasattr(adata.layers['normalized'], 'max') else adata.layers['normalized'].data.max()
print(f'\n‚úì Final X matrix max value: {final_max:.2f}')
print(f'‚úì Normalized layer max value: {normalized_max:.2f}')
print(f'‚úì Normalized layer will be used for gene expression visualization')

print('\n' + '='*60)



### Diagnostic: Optimal n_neighbors Selection

Before running clustering, it's helpful to test different `n_neighbors` values to find one that produces stable, interpretable UMAP structure.

**Guidelines**:
- **n_neighbors**: Controls local vs. global structure in the neighbor graph
  - Lower values (10-15): Emphasize local/fine structure, more clusters
  - Higher values (30-50): Emphasize global structure, broader patterns
  - Typical range: 15-30 for most scRNA-seq datasets
  
- **n_pcs**: For scVI latent space (X_scVI), use `None` (all dimensions)
  - scVI already provides optimally reduced dimensions (typically 10-30)
  - No further dimensionality reduction needed

This cell generates UMAPs with different n_neighbors to help you choose the best value.


In [None]:
# Ensure inline plotting
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
latent_key = 'X_scVI'  # Latent representation to use
test_n_neighbors = [10, 15, 20, 30, 40, 50]  # Values to test
n_pcs = None  # Use all latent dimensions (recommended for scVI)

# Downsample for faster computation (set to None to use all cells)
downsample_n = 5000  # Use subset of cells for quick visualization

# ============================================================
# CHECK LATENT SPACE DIMENSIONS
# ============================================================
print('='*60)
print('LATENT SPACE INSPECTION')
print('='*60)

if latent_key not in adata.obsm:
    raise ValueError(f'{latent_key} not found in adata.obsm. Available: {list(adata.obsm.keys())}')

latent_dims = adata.obsm[latent_key].shape[1]
print(f'\n{latent_key} dimensions: {latent_dims}')
print(f'Recommendation: Use n_pcs=None to utilize all {latent_dims} dimensions')
print(f'  (scVI latent space is already optimally reduced)')

# ============================================================
# DOWNSAMPLE DATA FOR FASTER TESTING (OPTIONAL)
# ============================================================
if downsample_n is not None and downsample_n < adata.n_obs:
    print(f'\nDownsampling to {downsample_n:,} cells for quick visualization...')
    np.random.seed(42)
    sample_indices = np.random.choice(adata.n_obs, downsample_n, replace=False)
    adata_test = adata[sample_indices].copy()
    print(f'‚úì Using {adata_test.n_obs:,} cells for testing')
else:
    adata_test = adata.copy()
    print(f'\nUsing all {adata_test.n_obs:,} cells')

# ============================================================
# TEST DIFFERENT N_NEIGHBORS VALUES
# ============================================================
print('\n' + '='*60)
print('TESTING DIFFERENT N_NEIGHBORS VALUES')
print('='*60)

configure_plot_style()

# Create figure
n_cols = 3
n_rows = (len(test_n_neighbors) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
axes = axes.flatten() if len(test_n_neighbors) > 1 else [axes]

# Store connectivity metrics
metrics = []

for i, n_neigh in enumerate(test_n_neighbors):
    print(f'\nTesting n_neighbors={n_neigh}...')
    
    # Compute neighbors and UMAP
    sc.pp.neighbors(adata_test, use_rep=latent_key, n_neighbors=n_neigh, n_pcs=n_pcs)
    sc.tl.umap(adata_test)
    
    # Calculate connectivity metrics
    connectivities = adata_test.obsp['connectivities']
    n_components = scipy.sparse.csgraph.connected_components(connectivities, directed=False)[0]
    mean_connectivity = connectivities.mean()
    
    print(f'  Connected components: {n_components} (should be 1 for well-connected graph)')
    print(f'  Mean connectivity: {mean_connectivity:.4f}')
    
    metrics.append({
        'n_neighbors': n_neigh,
        'connected_components': n_components,
        'mean_connectivity': mean_connectivity
    })
    
    # Plot UMAP colored by group (or other metadata)
    color_by = 'group' if 'group' in adata_test.obs else 'library'
    
    sc.pl.umap(
        adata_test,
        color=color_by,
        ax=axes[i],
        show=False,
        title=f'n_neighbors={n_neigh}\n(components={n_components}, conn={mean_connectivity:.4f})',
        legend_loc='right margin' if i % n_cols == n_cols - 1 else None,
        frameon=False,
        size=20
    )

# Hide unused subplots
for j in range(len(test_n_neighbors), len(axes)):
    axes[j].set_visible(False)

plt.suptitle(f'UMAP with Different n_neighbors Values\n(colored by {color_by})', 
             y=1.00, fontsize=16, fontweight='bold')
plt.tight_layout()

save_plot('diagnostic_n_neighbors_comparison', close=False)
show_inline_plot()
plt.close()

# ============================================================
# SUMMARY AND RECOMMENDATIONS
# ============================================================
print('\n' + '='*60)
print('CONNECTIVITY METRICS SUMMARY')
print('='*60)

metrics_df = pd.DataFrame(metrics)
print(metrics_df.to_string(index=False))

print('\n' + '='*60)
print('RECOMMENDATIONS')
print('='*60)

# Find values with single connected component
good_values = metrics_df[metrics_df['connected_components'] == 1]['n_neighbors'].values

if len(good_values) > 0:
    print(f'\n‚úì n_neighbors values with fully connected graph: {good_values.tolist()}')
    
    # Recommend middle value
    if len(good_values) >= 3:
        recommended = good_values[len(good_values)//2]
    else:
        recommended = good_values[0]
    
    print(f'\nüéØ RECOMMENDED: n_neighbors = {recommended}')
    print(f'   ‚Ä¢ Provides good balance of local and global structure')
    print(f'   ‚Ä¢ Graph is fully connected (1 component)')
    print(f'   ‚Ä¢ Typical for scRNA-seq datasets with ~{adata.n_obs:,} cells')
else:
    print('\n‚ö†Ô∏è  Warning: No values produced fully connected graph')
    print('   Consider using higher n_neighbors values')

print(f'\nüí° For n_pcs: Use None (all {latent_dims} scVI latent dimensions)')
print(f'   scVI already provides optimally reduced representation')

print('\n' + '='*60)

# Clean up test object
del adata_test


## 3. Clustering and Dimensionality Reduction

### Compute Neighbors and UMAP
Perform graph-based clustering and dimensionality reduction using the scVI latent space:
1. **Neighbor Graph**: Compute k-nearest neighbors in scVI latent space
2. **UMAP**: Project to 2D for visualization
3. **Leiden Clustering**: Identify T cell clusters at multiple resolutions

Testing multiple resolutions helps identify the optimal granularity for T cell subset identification. This will overwrite any existing clustering results.

In [None]:
# ============================================================
# USER CONFIGURATION
# ============================================================
latent_key = 'X_scVI'     # Latent representation to use for clustering
n_neighbors = 30          # Number of neighbors for graph construction
n_pcs = None              # Number of PCs (None = use all dimensions in latent space)

# Clustering resolutions to test (lower = fewer clusters, higher = more clusters)
clustering_resolutions = [0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0]

# ============================================================
# COMPUTE NEIGHBOR GRAPH
# ============================================================
print('='*60)
print('COMPUTING NEIGHBOR GRAPH AND UMAP')
print('='*60)

# Verify latent representation exists
if latent_key not in adata.obsm:
    raise ValueError(f'{latent_key} not found in adata.obsm. Available keys: {list(adata.obsm.keys())}')

# Clear outdated color metadata
for res in clustering_resolutions:
    color_key = f'leiden_scVI_r{res}_colors'
    if color_key in adata.uns:
        del adata.uns[color_key]

# Compute neighbors in latent space
print(f'Computing neighbors (k={n_neighbors}) using {latent_key}')
sc.pp.neighbors(adata, use_rep=latent_key, n_neighbors=n_neighbors, n_pcs=n_pcs)

# Compute UMAP embedding
print('Computing UMAP embedding...')
sc.tl.umap(adata)
print(f'‚úì UMAP computed: {adata.obsm["X_umap"].shape}')

# ============================================================
# LEIDEN CLUSTERING AT MULTIPLE RESOLUTIONS
# ============================================================
print('\n' + '='*60)
print('LEIDEN CLUSTERING')
print('='*60)

for res in clustering_resolutions:
    key = f'leiden_scVI_r{res}'
    sc.tl.leiden(adata, resolution=res, key_added=key, flavor="igraph", 
                 n_iterations=2, directed=False)
    n_clusters = len(adata.obs[key].cat.categories)
    print(f'‚úì Resolution {res:4.1f} -> {n_clusters:3d} clusters (adata.obs["{key}"])')

print('\n' + '='*60)

## 4. Visualize Clustering Results

### Multi-Resolution Clustering Overview
Generate a grid of UMAP plots showing clustering results at all tested resolutions. This helps you select the appropriate resolution for downstream cell type analyses.

In [None]:
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
point_size = 10         # Size of points in UMAP
n_cols_grid = 3        # Number of columns in grid layout
plot_name = 'leiden_clustering_overview'
color_palette = 'husl' # Color palette for clusters (options: 'husl', 'tab20', 'Set3', etc.)

# Resolutions to visualize
clustering_resolutions = [0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]

# ============================================================
# GENERATE MULTI-RESOLUTION PLOT
# ============================================================

# Verify UMAP exists
if 'X_umap' not in adata.obsm:
    raise ValueError("X_umap not found. Run clustering section first.")

# Get clustering columns
clustering_keys = [f'leiden_scVI_r{res}' for res in clustering_resolutions]
missing_keys = [key for key in clustering_keys if key not in adata.obs]
if missing_keys:
    raise ValueError(f"Missing clustering results: {missing_keys}")

# Setup grid layout
n_rows = (len(clustering_resolutions) + n_cols_grid - 1) // n_cols_grid
fig, axes = plt.subplots(n_rows, n_cols_grid, 
                         figsize=(3 * n_cols_grid, 3.5 * n_rows), 
                         sharex=True, sharey=True)
axes = axes.flatten()

# Import for color palette conversion
import matplotlib.colors as mcolors

configure_plot_style()

# Plot each resolution
for i, res in enumerate(clustering_resolutions):
    key = f'leiden_scVI_r{res}'
    n_clusters = len(adata.obs[key].cat.categories)
    
    # Set up dittoSeq palette for publication quality
    custom_palette = get_dittoseq_colors(n_clusters)
    adata.uns[f'{key}_colors'] = custom_palette
    
    sc.pl.umap(
        adata,
        color=key,
        ax=axes[i],
        show=False,
        title=f'Resolution {res} ({n_clusters} clusters)',
        legend_loc='on data',
        frameon=False,
        size=point_size,
        palette=adata.uns[f'{key}_colors']
    )

# Hide unused subplots
for j in range(len(clustering_resolutions), len(axes)):
    axes[j].set_visible(False)

plt.suptitle("Leiden Clustering at Multiple Resolutions", y=1.02, fontsize=18)
plt.tight_layout()

save_plot(plot_name, close=False)
show_inline_plot()
plt.close()

### Selected Resolution Visualization
Generate a detailed UMAP plot for a single selected resolution. This provides a cleaner view for presentations and publications.

In [None]:
# ============================================================
# USER CONFIGURATION
# ============================================================
selected_resolution = 1.2  # Select clustering resolution to display
point_size = 15             # Size of points in UMAP
fig_width = 6              # Figure width
fig_height = 5              # Figure height

# ============================================================
# GENERATE SELECTED RESOLUTION PLOT
# ============================================================

# Verify clustering column and UMAP
cluster_key = f'leiden_scVI_r{selected_resolution}'
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering column '{cluster_key}' not found in adata.obs")
if 'X_umap' not in adata.obsm:
    raise ValueError("X_umap not found in adata.obsm")

configure_plot_style()

# Set up dittoSeq-like palette
import matplotlib.colors as mcolors
n_clusters = len(adata.obs[cluster_key].cat.categories)

# Use dittoSeq palette
custom_palette = get_dittoseq_colors(n_clusters)
adata.uns[f'{cluster_key}_colors'] = custom_palette

# Set up plot
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=300)

# Plot UMAP
sc.pl.umap(
    adata,
    color=cluster_key,
    ax=ax,
    show=False,
    title=f'Leiden Clustering (Resolution {selected_resolution})',
    legend_loc='right margin',
    legend_fontsize=10,
    legend_fontweight='normal',
    size=point_size,
    palette=adata.uns[f'{cluster_key}_colors']
)

# Add cluster labels on the plot (no background) with repel (like Seurat)
from adjustText import adjust_text

umap_coords = adata.obsm['X_umap']
clusters = adata.obs[cluster_key]

# Collect all labels for adjustText
texts = []

for cluster in clusters.cat.categories:
    # Get UMAP coordinates for cells in this cluster
    cluster_mask = clusters == cluster
    cluster_coords = umap_coords[cluster_mask]
    
    if len(cluster_coords) == 0:
        continue
    
    # Use centroid (mean) like Seurat does
    centroid_x = cluster_coords[:, 0].mean()
    centroid_y = cluster_coords[:, 1].mean()
    
    # Add text label to adjustText list
    text = ax.text(centroid_x, centroid_y, str(cluster),
            fontsize=14,
            fontweight='bold',
            ha='center',
            va='center',
            color='black',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7, edgecolor='none'))
    texts.append(text)

# Auto-adjust overlapping labels (like Seurat's repel=TRUE)
adjust_text(texts, 
            ax=ax,
            expand=(1.2, 1.2),
            arrowprops=dict(arrowstyle='->', color='black', lw=0.8))

# Add axis labels
ax.set_xlabel('UMAP-1', fontsize=12)
ax.set_ylabel('UMAP-2', fontsize=12)

# Configure ticks to show numerical labels (e.g., 2, 4, 6)
ax.tick_params(axis='both', which='major', 
               bottom=True, left=True,       # Show ticks
               labelbottom=True, labelleft=True,  # Show labels
               labelsize=10,
               length=4, width=0.8)

# Ensure minor ticks are also visible if desired
ax.tick_params(axis='both', which='minor',
               bottom=True, left=True,
               length=2, width=0.6)

plt.tight_layout()

save_plot(f'leiden_clustering_r{selected_resolution}', close=False)
show_inline_plot()
plt.close()

# Check cluster sizes to confirm the ordering
print("Cluster sizes:")
print(adata.obs[cluster_key].value_counts().sort_index())

## 7. Metadata-Stratified Visualizations

### Visualize Clusters by Metadata Groups
Generate UMAP plots stratified by experimental metadata (batch, library, mouse_ID) to assess:
- Batch-specific patterns
- Library effects and integration quality
- Distribution across mice
- Cluster composition differences between experimental groups

This helps identify technical artifacts and validate biological interpretations.

### QC Metrics by Cluster

Visualize quality control metrics across clusters to identify potential technical artifacts or low-quality clusters. This helps assess whether certain clusters are driven by technical factors rather than biological variation.


In [None]:
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
selected_res = 1.2  # Select resolution to analyze
plot_width = 30     # Width of QC plots
plot_height = 4     # Height of QC plots

# ============================================================
# QC METRICS VISUALIZATION BY CLUSTER
# ============================================================

cluster_key = f'leiden_scVI_r{selected_res}'

# Verify clustering exists
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering column '{cluster_key}' not found. Run clustering first.")

# Configure plot style
configure_plot_style()

# Define QC metrics to visualize (check which ones exist in your data)
qc_metrics = []
qc_labels = []

# Check for common QC metric column names
if 'n_genes_by_counts' in adata.obs:
    qc_metrics.append('n_genes_by_counts')
    qc_labels.append('Number of Genes')
elif 'n_genes' in adata.obs:
    qc_metrics.append('n_genes')
    qc_labels.append('Number of Genes')

if 'total_counts' in adata.obs:
    qc_metrics.append('total_counts')
    qc_labels.append('Total Counts (nCount_RNA)')
elif 'n_counts' in adata.obs:
    qc_metrics.append('n_counts')
    qc_labels.append('Total Counts (nCount_RNA)')

if 'pct_counts_mt' in adata.obs:
    qc_metrics.append('pct_counts_mt')
    qc_labels.append('% Mitochondrial')

if 'doublet_score' in adata.obs:
    qc_metrics.append('doublet_score')
    qc_labels.append('Doublet Score')

# Cell cycle scores if available
if 'S_score' in adata.obs:
    qc_metrics.append('S_score')
    qc_labels.append('S Phase Score')
    
if 'G2M_score' in adata.obs:
    qc_metrics.append('G2M_score')
    qc_labels.append('G2/M Phase Score')

if not qc_metrics:
    print("No QC metrics found in adata.obs. Available columns:")
    print(list(adata.obs.columns))
else:
    print(f"Found {len(qc_metrics)} QC metrics to visualize")
    
    # Create violin plots for each QC metric
    n_metrics = len(qc_metrics)
    
    fig, axes = plt.subplots(1, n_metrics, figsize=(plot_width, plot_height))
    if n_metrics == 1:
        axes = [axes]
    
    for i, (metric, label) in enumerate(zip(qc_metrics, qc_labels)):
        # Create violin plot
        sc.pl.violin(
            adata,
            keys=metric,
            groupby=cluster_key,
            ax=axes[i],
            show=False,
            rotation=90
        )
        axes[i].set_title(label, fontsize=12, fontweight='bold')
        axes[i].set_xlabel('Cluster', fontsize=10)
        axes[i].set_ylabel(label, fontsize=10)
        axes[i].tick_params(axis='x', labelsize=8)
    
    plt.suptitle(f'QC Metrics by Cluster (Resolution {selected_res})', 
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    # Save plot
    save_plot(f'qc_metrics_by_cluster_r{selected_res}', close=False)
    show_inline_plot()
    plt.close()
    
    # ============================================================
    # SUMMARY STATISTICS TABLE
    # ============================================================
    print('\n' + '='*80)
    print(f'QC METRICS SUMMARY BY CLUSTER (Resolution {selected_res})')
    print('='*80)
    
    # Calculate median values for each metric per cluster
    summary_data = []
    for cluster in adata.obs[cluster_key].cat.categories:
        cluster_mask = adata.obs[cluster_key] == cluster
        cluster_cells = cluster_mask.sum()
        
        row = {'Cluster': cluster, 'N_cells': cluster_cells}
        
        for metric, label in zip(qc_metrics, qc_labels):
            median_val = adata.obs.loc[cluster_mask, metric].median()
            row[label] = f'{median_val:.2f}'
        
        summary_data.append(row)
    
    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_string(index=False))
    
    # Save summary table
    summary_file = output_dir / f'qc_summary_by_cluster_r{selected_res}.csv'
    summary_df.to_csv(summary_file, index=False)
    print(f'\n‚úì Saved summary table: {summary_file.name}')
    
    # ============================================================
    # IDENTIFY POTENTIAL LOW-QUALITY CLUSTERS
    # ============================================================
    print('\n' + '='*80)
    print('POTENTIAL QUALITY ISSUES')
    print('='*80)
    
    # Check for high mitochondrial percentage
    if 'pct_counts_mt' in adata.obs:
        high_mt_threshold = 10  # Adjust as needed
        for cluster in adata.obs[cluster_key].cat.categories:
            cluster_mask = adata.obs[cluster_key] == cluster
            median_mt = adata.obs.loc[cluster_mask, 'pct_counts_mt'].median()
            
            if median_mt > high_mt_threshold:
                print(f'‚ö†Ô∏è  Cluster {cluster}: High mt% (median={median_mt:.2f}%)')
    
    # Check for low gene counts
    gene_metric = 'n_genes_by_counts' if 'n_genes_by_counts' in adata.obs else 'n_genes'
    if gene_metric in adata.obs:
        low_genes_threshold = 500  # Adjust as needed
        for cluster in adata.obs[cluster_key].cat.categories:
            cluster_mask = adata.obs[cluster_key] == cluster
            median_genes = adata.obs.loc[cluster_mask, gene_metric].median()
            
            if median_genes < low_genes_threshold:
                print(f'‚ö†Ô∏è  Cluster {cluster}: Low gene count (median={median_genes:.0f} genes)')
    
    print('\nIf no warnings appear, all clusters pass basic QC thresholds.')
    print('='*80)


## Remove Cluster 3

Remove cells belonging to cluster 3 from the AnnData object for downstream analysis.

In [None]:
# Remove cluster 3 from the AnnData object
cluster_key = f'leiden_scVI_r{selected_resolution}'

# Check if cluster 3 exists
if '3' in adata.obs[cluster_key].cat.categories:
    # Get count before removal
    n_cells_before = adata.n_obs
    n_cluster3_cells = (adata.obs[cluster_key] == '3').sum()
    
    print(f"Before removal: {n_cells_before} total cells")
    print(f"Cluster 3 contains: {n_cluster3_cells} cells")
    
    # Filter out cluster 3 cells
    adata = adata[adata.obs[cluster_key] != '3'].copy()
    
    # Update categorical column to remove unused category
    adata.obs[cluster_key] = adata.obs[cluster_key].cat.remove_unused_categories()
    
    # Report results
    n_cells_after = adata.n_obs
    print(f"After removal: {n_cells_after} total cells")
    print(f"Removed {n_cells_before - n_cells_after} cells ({n_cluster3_cells} from cluster 3)")
    print(f"\nRemaining clusters: {sorted(adata.obs[cluster_key].cat.categories.tolist())}")
else:
    print("Cluster 3 not found in the data. No cells removed.")

## Re-cluster and Visualize Filtered Data (Cluster 3 Removed)

After removing cluster 3, re-run clustering and visualization steps on the filtered AnnData object. All plots will be labeled as "filtered" to distinguish them from the original analysis.

Steps:
1. **Recompute neighbors, UMAP, and Leiden clustering** at multiple resolutions
2. **Diagnostic n_neighbors comparison** to verify optimal parameters
3. **Multi-resolution clustering visualization** overview
4. **Single resolution visualization** for selected resolution


In [None]:
%matplotlib inline

# ============================================================
# USER CONFIGURATION - FILTERED DATA CLUSTERING
# ============================================================
latent_key = 'X_scVI'     # Latent representation to use for clustering
n_neighbors = 30          # Number of neighbors for graph construction
n_pcs = None              # Number of PCs (None = use all dimensions in latent space)

# Clustering resolutions to test (lower = fewer clusters, higher = more clusters)
clustering_resolutions = [0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0]

# ============================================================
# COMPUTE NEIGHBOR GRAPH (FILTERED DATA)
# ============================================================
print('='*60)
print('RECOMPUTING NEIGHBOR GRAPH AND UMAP (FILTERED DATA - CLUSTER 3 REMOVED)')
print('='*60)

# Verify latent representation exists
if latent_key not in adata.obsm:
    raise ValueError(f'{latent_key} not found in adata.obsm. Available keys: {list(adata.obsm.keys())}')

print(f'Working with {adata.n_obs:,} cells (after cluster 3 removal)')

# Clear outdated color metadata
for res in clustering_resolutions:
    color_key = f'leiden_scVI_r{res}_colors'
    if color_key in adata.uns:
        del adata.uns[color_key]

# Compute neighbors in latent space
print(f'Computing neighbors (k={n_neighbors}) using {latent_key}')
sc.pp.neighbors(adata, use_rep=latent_key, n_neighbors=n_neighbors, n_pcs=n_pcs)

# Compute UMAP embedding
print('Computing UMAP embedding...')
sc.tl.umap(adata)
print(f'‚úì UMAP computed: {adata.obsm["X_umap"].shape}')

# ============================================================
# LEIDEN CLUSTERING AT MULTIPLE RESOLUTIONS (FILTERED DATA)
# ============================================================
print('\n' + '='*60)
print('LEIDEN CLUSTERING (FILTERED DATA)')
print('='*60)

for res in clustering_resolutions:
    key = f'leiden_scVI_r{res}'
    sc.tl.leiden(adata, resolution=res, key_added=key, flavor="igraph", 
                 n_iterations=2, directed=False)
    n_clusters = len(adata.obs[key].cat.categories)
    print(f'‚úì Resolution {res:4.1f} -> {n_clusters:3d} clusters (adata.obs["{key}"])')

print('\n' + '='*60)


### Multi-Resolution Clustering Overview (Filtered Data)

Generate a grid of UMAP plots showing clustering results at all tested resolutions for the filtered dataset.


In [None]:
%matplotlib inline

# ============================================================
# USER CONFIGURATION (FILTERED DATA)
# ============================================================
point_size = 10         # Size of points in UMAP
n_cols_grid = 3        # Number of columns in grid layout
plot_name = 'leiden_clustering_overview_filtered'
color_palette = 'husl' # Color palette for clusters (options: 'husl', 'tab20', 'Set3', etc.)

# Resolutions to visualize
clustering_resolutions = [0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]

# ============================================================
# GENERATE MULTI-RESOLUTION PLOT (FILTERED DATA)
# ============================================================

# Verify UMAP exists
if 'X_umap' not in adata.obsm:
    raise ValueError("X_umap not found. Run clustering section first.")

# Get clustering columns
clustering_keys = [f'leiden_scVI_r{res}' for res in clustering_resolutions]
missing_keys = [key for key in clustering_keys if key not in adata.obs]
if missing_keys:
    raise ValueError(f"Missing clustering results: {missing_keys}")

# Setup grid layout
n_rows = (len(clustering_resolutions) + n_cols_grid - 1) // n_cols_grid
fig, axes = plt.subplots(n_rows, n_cols_grid, 
                         figsize=(3 * n_cols_grid, 3.5 * n_rows), 
                         sharex=True, sharey=True)
axes = axes.flatten()

# Import for color palette conversion
import matplotlib.colors as mcolors

configure_plot_style()

# Plot each resolution
for i, res in enumerate(clustering_resolutions):
    key = f'leiden_scVI_r{res}'
    n_clusters = len(adata.obs[key].cat.categories)
    
    # Set up dittoSeq palette for publication quality
    custom_palette = get_dittoseq_colors(n_clusters)
    adata.uns[f'{key}_colors'] = custom_palette
    
    sc.pl.umap(
        adata,
        color=key,
        ax=axes[i],
        show=False,
        title=f'Resolution {res} ({n_clusters} clusters)',
        legend_loc='on data',
        frameon=False,
        size=point_size,
        palette=adata.uns[f'{key}_colors']
    )

# Hide unused subplots
for j in range(len(clustering_resolutions), len(axes)):
    axes[j].set_visible(False)

plt.suptitle("Leiden Clustering at Multiple Resolutions (FILTERED DATA - Cluster 3 Removed)", 
             y=1.02, fontsize=18, fontweight='bold')
plt.tight_layout()

save_plot(plot_name, close=False)
show_inline_plot()
plt.close()


### Selected Resolution Visualization (Filtered Data)

Generate a detailed UMAP plot for a single selected resolution on the filtered dataset. This provides a cleaner view for presentations and publications.


In [None]:
# ============================================================
# USER CONFIGURATION (FILTERED DATA)
# ============================================================
selected_resolution = 1.2  # Select clustering resolution to display
point_size = 15             # Size of points in UMAP
fig_width = 6              # Figure width
fig_height = 5              # Figure height

# ============================================================
# GENERATE SELECTED RESOLUTION PLOT (FILTERED DATA)
# ============================================================

# Verify clustering column and UMAP
cluster_key = f'leiden_scVI_r{selected_resolution}'
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering column '{cluster_key}' not found in adata.obs")
if 'X_umap' not in adata.obsm:
    raise ValueError("X_umap not found in adata.obsm")

configure_plot_style()

# Set up dittoSeq-like palette
import matplotlib.colors as mcolors
n_clusters = len(adata.obs[cluster_key].cat.categories)

# Use dittoSeq palette
custom_palette = get_dittoseq_colors(n_clusters)
adata.uns[f'{cluster_key}_colors'] = custom_palette

# Set up plot
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=300)

# Plot UMAP
sc.pl.umap(
    adata,
    color=cluster_key,
    ax=ax,
    show=False,
    title=f'Leiden Clustering (Resolution {selected_resolution}) - FILTERED DATA',
    legend_loc='right margin',
    legend_fontsize=10,
    legend_fontweight='normal',
    size=point_size,
    palette=adata.uns[f'{cluster_key}_colors']
)

# Add cluster labels on the plot (no background) with repel (like Seurat)
from adjustText import adjust_text

umap_coords = adata.obsm['X_umap']
clusters = adata.obs[cluster_key]

# Collect all labels for adjustText
texts = []

for cluster in clusters.cat.categories:
    # Get UMAP coordinates for cells in this cluster
    cluster_mask = clusters == cluster
    cluster_coords = umap_coords[cluster_mask]
    
    if len(cluster_coords) == 0:
        continue
    
    # Use centroid (mean) like Seurat does
    centroid_x = cluster_coords[:, 0].mean()
    centroid_y = cluster_coords[:, 1].mean()
    
    # Add text label to adjustText list
    text = ax.text(centroid_x, centroid_y, str(cluster),
            fontsize=14,
            fontweight='bold',
            ha='center',
            va='center',
            color='black',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7, edgecolor='none'))
    texts.append(text)

# Auto-adjust overlapping labels (like Seurat's repel=TRUE)
adjust_text(texts, 
            ax=ax,
            expand=(1.2, 1.2),
            arrowprops=dict(arrowstyle='->', color='black', lw=0.8))

# Add axis labels
ax.set_xlabel('UMAP-1', fontsize=12)
ax.set_ylabel('UMAP-2', fontsize=12)

# Configure ticks to show numerical labels (e.g., 2, 4, 6)
ax.tick_params(axis='both', which='major', 
               bottom=True, left=True,       # Show ticks
               labelbottom=True, labelleft=True,  # Show labels
               labelsize=10,
               length=4, width=0.8)

# Ensure minor ticks are also visible if desired
ax.tick_params(axis='both', which='minor',
               bottom=True, left=True,
               length=2, width=0.6)

plt.tight_layout()

save_plot(f'leiden_clustering_r{selected_resolution}_filtered', close=False)
show_inline_plot()
plt.close()

# Check cluster sizes to confirm the ordering (filtered data)
print("Cluster sizes (FILTERED DATA - Cluster 3 Removed):")
print(adata.obs[cluster_key].value_counts().sort_index())


### QC Metrics by Cluster (Filtered Data)

Visualize quality control metrics across clusters in the filtered dataset to identify potential technical artifacts or low-quality clusters.

In [None]:
%matplotlib inline

# ============================================================
# USER CONFIGURATION (FILTERED DATA)
# ============================================================
selected_res = 1.2  # Select resolution to analyze
plot_width = 30     # Width of QC plots
plot_height = 6     # Height of QC plots

# ============================================================
# QC METRICS VISUALIZATION BY CLUSTER (FILTERED DATA)
# ============================================================

cluster_key = f'leiden_scVI_r{selected_res}'

# Verify clustering exists
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering column '{cluster_key}' not found. Run clustering first.")

# Configure plot style
configure_plot_style()

# Define QC metrics to visualize (check which ones exist in your data)
qc_metrics = []
qc_labels = []

# Check for common QC metric column names
if 'n_genes_by_counts' in adata.obs:
    qc_metrics.append('n_genes_by_counts')
    qc_labels.append('Number of Genes')
elif 'n_genes' in adata.obs:
    qc_metrics.append('n_genes')
    qc_labels.append('Number of Genes')

if 'total_counts' in adata.obs:
    qc_metrics.append('total_counts')
    qc_labels.append('Total Counts (nCount_RNA)')
elif 'n_counts' in adata.obs:
    qc_metrics.append('n_counts')
    qc_labels.append('Total Counts (nCount_RNA)')

if 'pct_counts_mt' in adata.obs:
    qc_metrics.append('pct_counts_mt')
    qc_labels.append('% Mitochondrial')

if 'doublet_score' in adata.obs:
    qc_metrics.append('doublet_score')
    qc_labels.append('Doublet Score')

# Cell cycle scores if available
if 'S_score' in adata.obs:
    qc_metrics.append('S_score')
    qc_labels.append('S Phase Score')
    
if 'G2M_score' in adata.obs:
    qc_metrics.append('G2M_score')
    qc_labels.append('G2/M Phase Score')

if not qc_metrics:
    print("No QC metrics found in adata.obs. Available columns:")
    print(list(adata.obs.columns))
else:
    print(f"Found {len(qc_metrics)} QC metrics to visualize (FILTERED DATA)")
    
    # Create violin plots for each QC metric
    n_metrics = len(qc_metrics)
    
    fig, axes = plt.subplots(1, n_metrics, figsize=(plot_width, plot_height))
    if n_metrics == 1:
        axes = [axes]
    
    for i, (metric, label) in enumerate(zip(qc_metrics, qc_labels)):
        # Create violin plot
        sc.pl.violin(
            adata,
            keys=metric,
            groupby=cluster_key,
            ax=axes[i],
            show=False,
            rotation=90
        )
        axes[i].set_title(label, fontsize=12, fontweight='bold')
        axes[i].set_xlabel('Cluster', fontsize=10)
        axes[i].set_ylabel(label, fontsize=10)
        axes[i].tick_params(axis='x', labelsize=8)
    
    plt.suptitle(f'QC Metrics by Cluster (Resolution {selected_res}) - FILTERED DATA - Cluster 3 Removed', 
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    # Save plot with filtered suffix
    save_plot(f'qc_metrics_by_cluster_r{selected_res}_filtered', close=False)
    show_inline_plot()
    plt.close()
    
    # ============================================================
    # SUMMARY STATISTICS TABLE
    # ============================================================
    print('\n' + '='*80)
    print(f'QC METRICS SUMMARY BY CLUSTER (Resolution {selected_res}) - FILTERED DATA')
    print('='*80)
    
    # Calculate median values for each metric per cluster
    summary_data = []
    for cluster in adata.obs[cluster_key].cat.categories:
        cluster_mask = adata.obs[cluster_key] == cluster
        cluster_cells = cluster_mask.sum()
        
        row = {'Cluster': cluster, 'N_cells': cluster_cells}
        
        for metric, label in zip(qc_metrics, qc_labels):
            median_val = adata.obs.loc[cluster_mask, metric].median()
            row[label] = f'{median_val:.2f}'
        
        summary_data.append(row)
    
    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_string(index=False))
    
    # Save summary table with filtered suffix
    summary_file = output_dir / f'qc_summary_by_cluster_r{selected_res}_filtered.csv'
    summary_df.to_csv(summary_file, index=False)
    print(f'\n‚úì Saved summary table: {summary_file.name}')
    
    # ============================================================
    # IDENTIFY POTENTIAL LOW-QUALITY CLUSTERS
    # ============================================================
    print('\n' + '='*80)
    print('POTENTIAL QUALITY ISSUES (FILTERED DATA)')
    print('='*80)
    
    # Check for high mitochondrial percentage
    if 'pct_counts_mt' in adata.obs:
        high_mt_threshold = 10  # Adjust as needed
        for cluster in adata.obs[cluster_key].cat.categories:
            cluster_mask = adata.obs[cluster_key] == cluster
            median_mt = adata.obs.loc[cluster_mask, 'pct_counts_mt'].median()
            
            if median_mt > high_mt_threshold:
                print(f'‚ö†Ô∏è  Cluster {cluster}: High mt% (median={median_mt:.2f}%)')
    
    # Check for low gene counts
    gene_metric = 'n_genes_by_counts' if 'n_genes_by_counts' in adata.obs else 'n_genes'
    if gene_metric in adata.obs:
        low_genes_threshold = 500  # Adjust as needed
        for cluster in adata.obs[cluster_key].cat.categories:
            cluster_mask = adata.obs[cluster_key] == cluster
            median_genes = adata.obs.loc[cluster_mask, gene_metric].median()
            
            if median_genes < low_genes_threshold:
                print(f'‚ö†Ô∏è  Cluster {cluster}: Low gene count (median={median_genes:.0f} genes)')
    
    print('\nIf no warnings appear, all clusters pass basic QC thresholds.')
    print('='*80)

### Metadata-Stratified UMAP Visualizations

Generate UMAP plots split by batch, library, and mouse_ID to assess integration quality and identify experimental group patterns. Cells are colored by cluster assignment to evaluate consistency across metadata categories.

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# Select resolution
selected_res = 1.2  # üëà Change this to any desired resolution

# ============================================================
# USER CONFIGURATION - CONTROL PLOT PARAMETERS (FILTERED DATA)
# ============================================================
# Customize parameters for MOUSE plots
mouse_dot_size = 50
mouse_fig_width = 5
mouse_fig_height = 5.4

# Customize parameters for GROUP plots
group_dot_size = 10
group_fig_width = 5
group_fig_height = 5.4

# ============================================================
# SETUP (FILTERED DATA)
# ============================================================
cluster_key = f'leiden_scVI_r{selected_res}'
metadata_columns = ['mouse', 'group']  # Only mouse and group now

if cluster_key not in adata.obs:
    raise ValueError(f"Clustering column {cluster_key} not found in adata.obs")

for col in metadata_columns:
    if col not in adata.obs:
        print(f"Warning: Metadata column {col} not found in adata.obs - skipping")

configure_plot_style()

# Import for palette
import matplotlib.colors as mcolors

# Set up dittoSeq palette for publication quality
n_clusters = len(adata.obs[cluster_key].cat.categories)
custom_palette = get_dittoseq_colors(n_clusters)
adata.uns[f'{cluster_key}_colors'] = custom_palette

# ============================================================
# ORIGINAL PLOTS (Full Dataset - FILTERED DATA)
# ============================================================
print("Generating full dataset plots (FILTERED DATA - Cluster 3 Removed)...")
print(f"Using resolution: {selected_res}")

for meta_col in metadata_columns:
    if meta_col not in adata.obs:
        continue
    
    # Get custom parameters for this metadata type
    if meta_col == 'mouse':
        dot_size = mouse_dot_size
        fig_width = mouse_fig_width
        fig_height = mouse_fig_height
    elif meta_col == 'group':
        dot_size = group_dot_size
        fig_width = group_fig_width
        fig_height = group_fig_height
    
    meta_order = adata.obs[meta_col].cat.categories.tolist()
    n_items = len(meta_order)
    n_cols = min(4, n_items)  # Max 4 columns
    n_rows = (n_items + n_cols - 1) // n_cols  # Calculate rows needed
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width * n_cols, fig_height * n_rows), sharex=True, sharey=True)
    if n_items == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_items > 1 else [axes]
    
    for i, item in enumerate(meta_order):
        mask = adata.obs[meta_col] == item
        if mask.sum() > 0:
            sc.pl.umap(
                adata[mask],
                color=cluster_key,
                ax=axes[i],
                show=False,
                title=f'{item} ({mask.sum():,} cells)',
                legend_loc='on data',
                frameon=False,
                size=dot_size,
                palette=adata.uns[f'{cluster_key}_colors']
            )
        else:
            axes[i].set_visible(False)
            print(f'Warning: {meta_col} {item} not found')
    
    # Hide unused subplots
    if n_items > 1:
        for j in range(n_items, len(axes)):
            axes[j].set_visible(False)
    
    plt.suptitle(f'Clusters Split by {meta_col} - Resolution {selected_res} - FILTERED DATA - Cluster 3 Removed', y=1.02, fontsize=18)
    plt.tight_layout()
    plt.savefig(output_dir / 'png' / f'{meta_col}_umaps_filtered.png', dpi=300, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / f'{meta_col}_umaps_filtered.pdf', dpi=300, bbox_inches='tight', format='pdf')
    print(f'Saved: {meta_col}_umaps_filtered.png/.pdf')
    plt.close()

print('\n‚úì Completed full dataset metadata-stratified visualizations (FILTERED DATA)')

# ============================================================
# DOWNSAMPLED PLOTS (Equal Cells Per Group - FILTERED DATA)
# ============================================================
print("\n" + "="*60)
print("Generating downsampled plots with equal cells per group (FILTERED DATA - Cluster 3 Removed)...")
print(f"Using resolution: {selected_res}")
print("="*60)

for meta_col in metadata_columns:
    if meta_col not in adata.obs:
        continue
    
    # Get custom parameters for this metadata type
    if meta_col == 'mouse':
        dot_size = mouse_dot_size
        fig_width = mouse_fig_width
        fig_height = mouse_fig_height
    elif meta_col == 'group':
        dot_size = group_dot_size
        fig_width = group_fig_width
        fig_height = group_fig_height
    
    print(f"\nProcessing {meta_col}...")
    
    # Get category counts
    meta_order = adata.obs[meta_col].cat.categories.tolist()
    counts = adata.obs[meta_col].value_counts()
    min_cells = counts.min()
    
    print(f"  Cell counts per {meta_col}: {dict(counts)}")
    print(f"  Downsampling to: {min_cells:,} cells per group")
    
    # Downsample to equal number of cells per category
    np.random.seed(42)  # For reproducibility
    downsampled_indices = []
    for item in meta_order:
        item_indices = np.where(adata.obs[meta_col] == item)[0]
        sampled_indices = np.random.choice(item_indices, size=min_cells, replace=False)
        downsampled_indices.extend(sampled_indices)
    
    # Create downsampled AnnData
    adata_downsampled = adata[downsampled_indices].copy()
    
    # Generate UMAP plots with downsampled data
    n_items = len(meta_order)
    n_cols = min(6, n_items)
    n_rows = (n_items + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width * n_cols, fig_height * n_rows), sharex=True, sharey=True)
    if n_items == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_items > 1 else [axes]
    
    for i, item in enumerate(meta_order):
        mask = adata_downsampled.obs[meta_col] == item
        if mask.sum() > 0:
            sc.pl.umap(
                adata_downsampled[mask],
                color=cluster_key,
                ax=axes[i],
                show=False,
                title=f'{item} ({mask.sum():,} cells)',
                legend_loc='on data',
                frameon=False,
                size=dot_size * 2,  # Larger dots for downsampled
                palette=adata.uns[f'{cluster_key}_colors']
            )
        else:
            axes[i].set_visible(False)
            print(f'Warning: {meta_col} {item} not found in downsampled data')
    
    # Hide unused subplots
    if n_items > 1:
        for j in range(n_items, len(axes)):
            axes[j].set_visible(False)
    
    plt.suptitle(f'Clusters Split by {meta_col} - Downsampled (n={min_cells:,}/group) - Resolution {selected_res} - FILTERED DATA - Cluster 3 Removed', 
                 y=1.02, fontsize=18)
    plt.tight_layout()
    plt.savefig(output_dir / 'png' / f'{meta_col}_umaps_downsampled_filtered.png', dpi=300, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / f'{meta_col}_umaps_downsampled_filtered.pdf', dpi=300, bbox_inches='tight', format='pdf')
    print(f'  Saved: {meta_col}_umaps_downsampled_filtered.png/.pdf')
    plt.close()

print('\n‚úì Completed downsampled metadata-stratified visualizations (FILTERED DATA)')

## 8. Gene Expression Visualization

### Gene Expression Dotplot by Cluster

Visualize marker gene expression across clusters using dotplots. This visualization shows:
- **Dot color**: Mean expression level per cluster
- **Dot size**: Fraction of cells expressing the gene
- **Dendrogram**: Hierarchical relationship between clusters

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
# Select resolution and plot dimensions
selected_res = 1.2  # üëà Change this to any desired resolution
plot_width = 6     # üëà Change figure width (in inches)
plot_height = 15    # üëà Change figure height (in inches)

# ============================================================
# VERIFY DATA AND LAYERS
# ============================================================

# Verify clustering column and layers
cluster_key = f'leiden_scVI_r{selected_res}'
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering column {cluster_key} not found in adata.obs")

# Check if normalized layer exists, if not use X (which should be normalized)
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer is None:
    print("Note: Using adata.X for gene expression (no 'normalized' layer found)")

# Genes to visualize (mouse gene symbols)
# Classic marker genes for major cell types
genes = [
    # Homeostatic Microglia (resting, surveillance state)
    'Tmem119', 'P2ry12', 'Cx3cr1', 'Sall1', 'Fcrls', 'Siglech', 'Hexb',

    # Activated/Pro-inflammatory Microglia (M1-like, immune-activated)
    'Il1b', 'Tnf', 'Cd68', 'Nos2', 'Cd86',

    # Anti-inflammatory/Repair Microglia (M2-like, tissue repair)
    'Arg1', 'Mrc1', 'Il10', 'Tgfb1', 'Ym1',

    # Disease-Associated Microglia (DAM, neurodegenerative contexts)
    'Trem2', 'Apoe', 'Cst7', 'Lpl', 'Tyrobp', 'Clec7a',

    # Aged Microglia (altered in aging brain)
    'Ccl2', 'C1qa', 'B2m',

    # Proliferative Microglia (dividing, injury/disease)
    'Mki67', 'Top2a', 'Cdk1', 'Ccna2', 'Birc5',

    # Interferon-Responsive Microglia (viral/interferon response)
    'Ifit1', 'Irf7', 'Stat1', 'Cxcl10', 'Isg15',

    # Phagocytic Microglia (enhanced phagocytosis)
    'Trem2', 'Cd68', 'Mertk', 'Axl', 'C1qa',

    # General Immune Cell Marker
    'Ptprc','H2-Aa', 'H2-Ab1',
    'Mrc1', 'Ccr2', 'Ly6g', 'Ly6c2', 'Ms4a7',
    'Cdk8', 'Cmss1', 'Lars2',
    'H2-Q7', 'H2-Aa', 'H2-Ab1', 'H2-Q4', 'Cd74',
    'Ifi27I2a', ' Arhgap15', 'Klf2', 'Cd52', 'Cd34'
]

# Check available genes
available_genes = [g for g in genes if g in adata.var_names]
missing_genes = [g for g in genes if g not in adata.var_names]

if missing_genes:
    print(f'Warning: Genes not found in adata.var_names: {", ".join(missing_genes)}')
if not available_genes:
    raise ValueError('No valid genes for plotting.')

# ============================================================
# CREATE CUSTOM GREY-PLASMA COLORMAP
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Get plasma_r colormap
plasma_r = plt.cm.plasma_r
n_colors = 256
plasma_colors = plasma_r(np.linspace(0, 1, n_colors))

# Replace first 20% with grey gradient
grey_to_plasma_transition = 0.2
n_grey_colors = int(n_colors * grey_to_plasma_transition)

# Create grey gradient from light grey to first plasma color
light_grey = np.array([0.85, 0.85, 0.85, 1.0])
transition_color = plasma_colors[n_grey_colors]
grey_gradient = np.linspace(light_grey, transition_color, n_grey_colors)

# Combine grey gradient with plasma_r
custom_colors = np.vstack([grey_gradient, plasma_colors[n_grey_colors:]])
dotplot_colormap = LinearSegmentedColormap.from_list('grey_plasma', custom_colors)

print('‚úì Created custom grey-plasma colormap for dotplot')

# Configure plot style
configure_plot_style()

# ============================================================
# COMPUTE DENDROGRAM FOR CURRENT CLUSTERING
# ============================================================
print(f'Computing dendrogram for {cluster_key}...')
sc.tl.dendrogram(adata, groupby=cluster_key)
print('‚úì Dendrogram computed')

# ============================================================
# GENERATE DOTPLOT
# ============================================================

fig, ax = plt.subplots(figsize=(plot_width, plot_height), dpi=300)
dotplot = sc.pl.dotplot(
    adata,
    var_names=available_genes,
    groupby=cluster_key,
    layer=use_layer,
    dendrogram=True,
    ax=ax,
    return_fig=True,
    dot_max=0.8,
    dot_min=0.05,
    colorbar_title='Mean expression\nin group',
    size_title='% of cells\nexpressing gene',
    cmap=dotplot_colormap,  # Use custom grey-plasma colormap
    swap_axes=True,
    var_group_rotation=90
)

# Make dot borders thinner
main_ax = dotplot.get_axes()['mainplot_ax']
for collection in main_ax.collections:
    collection.set_linewidths(0)  # üëà Adjust this value (default is ~1.0, try 0.3, 0.5, etc.)

# Force straight cluster names with increased font size
main_ax.tick_params(axis='x', labelsize=12, rotation=0)
for label in main_ax.get_xticklabels():
    label.set_rotation(0)
    label.set_ha('center')

# Ensure ticks are visible on both axes
main_ax.tick_params(axis='both', bottom=True, left=True, labelbottom=True, labelleft=True, length=4, direction='out')
main_ax.spines['bottom'].set_visible(True)
main_ax.spines['left'].set_visible(True)

plt.suptitle(f'Gene Expression Dotplot (Normalized, Resolution {selected_res})', y=1.05, fontsize=18)
plt.subplots_adjust(bottom=0.2)
plt.tight_layout()

# Save with resolution in filename
plt.savefig(output_dir / 'png' / f'rna_dotplot_normalized_r{selected_res}.png', dpi=300, bbox_inches='tight')
plt.savefig(output_dir / 'pdf' / f'rna_dotplot_normalized_r{selected_res}.pdf', dpi=300, bbox_inches='tight', format='pdf')
print(f'Saved: rna_dotplot_normalized_r{selected_res}.png/.pdf')
show_inline_plot()
plt.close()

### Gene Expression UMAPs

Visualize expression patterns of marker genes on UMAP projections. This helps:
- Validate cluster identities
- Identify cell types based on canonical markers
- Assess quality of normalization and integration

Gene expression is visualized using normalized (log1p) data.

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# Genes to visualize (mouse gene symbols)
# Classic marker genes for major cell types
genes = [
    # Homeostatic Microglia (resting, surveillance state)
    'Tmem119', 'P2ry12', 'Cx3cr1', 'Sall1', 'Fcrls', 'Siglech', 'Hexb',

    # Activated/Pro-inflammatory Microglia (M1-like, immune-activated)
    'Il1b', 'Tnf', 'Cd68', 'Nos2', 'Cd86',

    # Anti-inflammatory/Repair Microglia (M2-like, tissue repair)
    'Arg1', 'Mrc1', 'Il10', 'Tgfb1', 'Ym1',

    # Disease-Associated Microglia (DAM, neurodegenerative contexts)
    'Trem2', 'Apoe', 'Cst7', 'Lpl', 'Tyrobp', 'Clec7a',

    # Aged Microglia (altered in aging brain)
    'Ccl2', 'C1qa', 'B2m',

    # Proliferative Microglia (dividing, injury/disease)
    'Mki67', 'Top2a', 'Cdk1', 'Ccna2', 'Birc5',

    # Interferon-Responsive Microglia (viral/interferon response)
    'Ifit1', 'Irf7', 'Stat1', 'Cxcl10', 'Isg15',

    # Phagocytic Microglia (enhanced phagocytosis)
    'Trem2', 'Cd68', 'Mertk', 'Axl', 'C1qa',

    # General Immune Cell Marker
    'Ptprc','H2-Aa', 'H2-Ab1',
    'Mrc1', 'Ccr2', 'Ly6g', 'Ly6c2', 'Ms4a7',
    'Cdk8', 'Cmss1', 'Lars2',
    'H2-Q7', 'H2-Aa', 'H2-Ab1', 'H2-Q4', 'Cd74'   # CD45 - all immune cells
]

# Verify layers and UMAP
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer is None:
    print("Note: Using adata.X for gene expression (no 'normalized' layer found)")
if 'X_umap' not in adata.obsm:
    raise ValueError("X_umap not found in adata.obsm")

# ============================================================
# CREATE CUSTOM GREY-PLASMA COLORMAP
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Get plasma_r colormap
plasma_r = plt.cm.plasma_r
n_colors = 256
plasma_colors = plasma_r(np.linspace(0, 1, n_colors))

# Replace first 20% with grey gradient
grey_to_plasma_transition = 0.2
n_grey_colors = int(n_colors * grey_to_plasma_transition)

# Create grey gradient from light grey to first plasma color
light_grey = np.array([0.85, 0.85, 0.85, 1.0])
transition_color = plasma_colors[n_grey_colors]
grey_gradient = np.linspace(light_grey, transition_color, n_grey_colors)

# Combine grey gradient with plasma_r
custom_colors = np.vstack([grey_gradient, plasma_colors[n_grey_colors:]])
gene_colormap = LinearSegmentedColormap.from_list('grey_plasma', custom_colors)

print('‚úì Created custom grey-plasma colormap for gene expression')

# Configure plot style
configure_plot_style()

# ============================================================
# GENERATE GENE EXPRESSION UMAPs
# ============================================================

# Check available genes
available_genes_norm = [g for g in genes if g in adata.var_names]
missing_genes_norm = [g for g in genes if g not in adata.var_names]

if missing_genes_norm:
    print(f'Warning: Genes not found in adata.var_names: {", ".join(missing_genes_norm)}')
if not available_genes_norm:
    raise ValueError('No valid genes for plotting.')

# Set up plot dimensions
n_cols = 7  # Adjust accordingly
n_rows = (len(available_genes_norm) + n_cols - 1) // n_cols  # Calculate rows needed
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), sharex=True, sharey=True)
axes = axes.flatten()

for i, gene in enumerate(available_genes_norm):
    sc.pl.umap(
        adata,
        color=gene,
        layer=use_layer,
        ax=axes[i],
        show=False,
        legend_loc='none',
        cmap=gene_colormap,  # Use custom grey-plasma colormap
        frameon=False,
        size=10,
        vmin='p5',
        vmax='p99'
    )
    axes[i].set_title(gene, fontsize=20)  # Increased font size for legibility

for j in range(len(available_genes_norm), len(axes)):
    axes[j].set_visible(False)

plt.suptitle('Gene Expression UMAPs', y=1.05, fontsize=18)
plt.tight_layout()
plt.savefig(output_dir / 'png' / 'gene_expression_umaps.png', dpi=300, bbox_inches='tight')
plt.savefig(output_dir / 'pdf' / 'gene_expression_umaps.pdf', dpi=300, bbox_inches='tight', format='pdf')
print(f'Saved: gene_expression_umaps.png/.pdf')
show_inline_plot()
plt.close()

## Export Cell Counts: Metadata


In [None]:
# ============================================================
# EXPORT CELL COUNTS: CLUSTERS √ó METADATA
# ============================================================

# User configuration
selected_res = 1.2  # Match your clustering resolution
cluster_key = f'leiden_scVI_r{selected_res}'
groupby_col = 'group'  # Change to 'library', 'mouse_ID', or other metadata column

print('='*80)
print(f'EXPORTING CELL COUNTS PER CLUSTER AND {groupby_col.upper()} (Resolution {selected_res})')
print('='*80)

# Check if column exists
if groupby_col not in adata.obs:
    print(f'‚ö† Column "{groupby_col}" not found in adata.obs')
    print(f'Available columns: {list(adata.obs.columns)}')
else:
    # Create crosstab: clusters (rows) √ó groupby column (columns)
    count_table = pd.crosstab(
        adata.obs[cluster_key], 
        adata.obs[groupby_col],
        margins=True,  # Add row and column totals
        margins_name='Total'
    )
    
    # Sort clusters numerically
    cluster_order = sorted([c for c in count_table.index if c != 'Total'], 
                           key=lambda x: int(x) if str(x).isdigit() else float('inf'))
    cluster_order.append('Total')
    count_table = count_table.reindex(cluster_order)
    
    # Display preview
    print('\nPreview of count table:')
    print(count_table.head(10))
    print(f'\nTable shape: {count_table.shape[0]} clusters √ó {count_table.shape[1]} groups')
    
    # Export to Excel
    excel_file = output_dir / f'cell_counts_cluster_by_{groupby_col}_r{selected_res}.xlsx'
    
    with pd.ExcelWriter(excel_file, engine='openpyxl') as writer:
        # Sheet 1: Raw counts
        count_table.to_excel(writer, sheet_name='Cell_Counts')
        
        # Sheet 2: Percentages (% of each group in each cluster)
        pct_by_group = pd.crosstab(
            adata.obs[cluster_key], 
            adata.obs[groupby_col],
            normalize='columns'  # Normalize by column (each group sums to 100%)
        ) * 100
        pct_by_group = pct_by_group.reindex([c for c in cluster_order if c != 'Total'])
        pct_by_group.to_excel(writer, sheet_name=f'Percent_by_{groupby_col}')
        
        # Sheet 3: Percentages (% of each cluster from each group)
        pct_by_cluster = pd.crosstab(
            adata.obs[cluster_key], 
            adata.obs[groupby_col],
            normalize='index'  # Normalize by row (each cluster sums to 100%)
        ) * 100
        pct_by_cluster = pct_by_cluster.reindex([c for c in cluster_order if c != 'Total'])
        pct_by_cluster.to_excel(writer, sheet_name='Percent_by_Cluster')
    
    print(f'\n‚úì Saved Excel file: {excel_file.name}')
    print(f'  Contains 3 sheets:')
    print(f'    1. Cell_Counts - Raw cell counts')
    print(f'    2. Percent_by_{groupby_col} - % of each group in each cluster')
    print(f'    3. Percent_by_Cluster - % of each cluster from each group')

print('='*80)

# ============================================================
# BARPLOT: GROUP COUNTS (Total cells per group)
# ============================================================
print('\n' + '='*80)
print('GENERATING BARPLOT: GROUP COUNTS')
print('='*80)

if 'group' in adata.obs:
    configure_plot_style()
    
    # Get total cell counts per group
    group_counts = adata.obs['group'].value_counts().sort_index()
    
    # Define publication-quality colors for groups
    group_colors = ['#1f77b4', '#ff7f0e']  # Blue and orange (publication quality)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Create bar plot with two colors
    bars = ax.bar(range(len(group_counts)), group_counts.values, 
                  color=group_colors[:len(group_counts)])
    
    # Set x-axis labels
    ax.set_xticks(range(len(group_counts)))
    ax.set_xticklabels(group_counts.index, rotation=0, ha='center')
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height):,}',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    ax.set_title(f'Cell Counts by Group (Resolution {selected_res})', 
                 fontsize=14, fontweight='bold', pad=15)
    ax.set_xlabel('Group', fontsize=12)
    ax.set_ylabel('Cell Count', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(output_dir / 'png' / f'cell_counts_barplot_group_r{selected_res}.png', 
                dpi=300, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / f'cell_counts_barplot_group_r{selected_res}.pdf', 
                dpi=300, bbox_inches='tight', format='pdf')
    print(f'‚úì Saved barplot: cell_counts_barplot_group_r{selected_res}.png/.pdf')
    plt.close()

# ============================================================
# BARPLOT: MOUSE COUNTS (Total cells per mouse, colored by group)
# ============================================================
print('\n' + '='*80)
print('GENERATING BARPLOT: MOUSE COUNTS')
print('='*80)

if 'mouse' in adata.obs and 'group' in adata.obs:
    configure_plot_style()
    
    # Get total cell counts per mouse
    mouse_counts = adata.obs['mouse'].value_counts().sort_index()
    
    # Map each mouse to its group for coloring
    mouse_to_group = adata.obs.groupby('mouse')['group'].first()
    
    # Define publication-quality colors matching group colors
    group_colors = ['#1f77b4', '#ff7f0e']  # Blue and orange
    unique_groups = adata.obs['group'].cat.categories.tolist()
    group_color_map = {group: group_colors[i] for i, group in enumerate(unique_groups)}
    
    # Get colors for each mouse based on their group
    mouse_colors = [group_color_map.get(mouse_to_group.get(mouse, unique_groups[0])) for mouse in mouse_counts.index]
    
    # Create figure
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create bar plot with group-based colors
    bars = ax.bar(range(len(mouse_counts)), mouse_counts.values, color=mouse_colors)
    
    # Set x-axis labels
    ax.set_xticks(range(len(mouse_counts)))
    ax.set_xticklabels(mouse_counts.index, rotation=45, ha='right')
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height):,}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    ax.set_title(f'Cell Counts by Mouse (Resolution {selected_res})', 
                 fontsize=14, fontweight='bold', pad=15)
    ax.set_xlabel('Mouse', fontsize=12)
    ax.set_ylabel('Cell Count', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(output_dir / 'png' / f'cell_counts_barplot_mouse_r{selected_res}.png', 
                dpi=300, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / f'cell_counts_barplot_mouse_r{selected_res}.pdf', 
                dpi=300, bbox_inches='tight', format='pdf')
    print(f'‚úì Saved barplot: cell_counts_barplot_mouse_r{selected_res}.png/.pdf')
    show_inline_plot()
    plt.close()

print('='*80)

## 9. Cluster Frequency Comparison Between Groups

### Overview
This section calculates and visualizes cluster frequencies between two selected groups within a chosen metadata category. 

**Features**:
- Interactive selection of metadata category (e.g., batch, library, group)
- Selection of two groups to compare
- Statistical testing (Fisher's exact test) for each cluster
- Publication-quality plot with significance markers (*, **, ***)
- Excel export with frequencies and exact p-values

**Statistical Approach**:
- **Fisher's Exact Test**: Tests whether the proportion of cells in each cluster differs between groups
- **Significance Levels**: `* p<0.05`, `** p<0.01`, `*** p<0.001`
- **Multiple Testing**: Consider applying Bonferroni or FDR correction if comparing many clusters

In [None]:
# Ensure inline plotting
%matplotlib inline

# Import scipy.stats for statistical tests
import scipy.stats as stats

# ============================================================
# CONFIGURATION
# ============================================================
selected_res = 1.2                    # Clustering resolution
filter_batch_value = 'Mistri'         # Batch to filter to
comparison_category = 'group'          # Category to compare (e.g., 'group')
group_1 = 'PBS-CFA'                    # First group
group_2 = 'PEP15-CFA'                  # Second group
replicate_column = 'mouse'             # Column containing biological replicates

# Plot colors (optional customization)
bar_colors = {group_1: '#3C5488', group_2: '#E64B35'}

# ============================================================
# FILTER DATA
# ============================================================
cluster_key = f'leiden_scVI_r{selected_res}'

# Filter to specific batch
batch_mask = adata.obs['batch'] == filter_batch_value
adata_filtered = adata[batch_mask].copy()

print("="*80)
print(f"CLUSTER FREQUENCY COMPARISON WITH BIOLOGICAL REPLICATES")
print("="*80)
print(f"\nBatch: {filter_batch_value}")
print(f"Comparing: {group_1} vs {group_2}")
print(f"Biological replicates: {replicate_column}")
print(f"Resolution: {selected_res}")

# Get masks for each group
mask_g1 = adata_filtered.obs[comparison_category] == group_1
mask_g2 = adata_filtered.obs[comparison_category] == group_2

# Get unique mice in each group
mice_g1 = adata_filtered.obs[mask_g1][replicate_column].unique()
mice_g2 = adata_filtered.obs[mask_g2][replicate_column].unique()

print(f"\n{group_1} mice: {list(mice_g1)} (n={len(mice_g1)})")
print(f"{group_2} mice: {list(mice_g2)} (n={len(mice_g2)})")
print("="*80)

# ============================================================
# STATISTICAL TEST SELECTION
# ============================================================
# For small sample sizes (n < 30 per group), Mann-Whitney U test (non-parametric) is preferred
# over t-test because it doesn't assume normal distribution and is more robust
# Decision: Use Mann-Whitney U test for all comparisons
statistical_test = "Mann-Whitney U test (non-parametric)"
print(f"\nStatistical test: {statistical_test}")
print(f"Rationale: Suitable for small sample sizes (n={len(mice_g1)} vs n={len(mice_g2)}),")
print(f"           doesn't assume normality, more robust to outliers")

# Get all clusters
all_clusters = adata_filtered.obs[cluster_key].cat.categories

# ============================================================
# CALCULATE PER-MOUSE FREQUENCIES
# ============================================================

# Store per-mouse data
per_mouse_data = []

# Group 1 mice
for mouse in mice_g1:
    mouse_mask = mask_g1 & (adata_filtered.obs[replicate_column] == mouse)
    n_cells_mouse = mouse_mask.sum()
    
    for cluster in all_clusters:
        cluster_count = ((adata_filtered.obs[cluster_key] == cluster) & mouse_mask).sum()
        cluster_freq = (cluster_count / n_cells_mouse * 100) if n_cells_mouse > 0 else 0
        
        per_mouse_data.append({
            'Group': group_1,
            'Mouse': mouse,
            'Cluster': cluster,
            'Cell_Count': cluster_count,
            'Total_Cells': n_cells_mouse,
            'Frequency_%': cluster_freq
        })

# Group 2 mice
for mouse in mice_g2:
    mouse_mask = mask_g2 & (adata_filtered.obs[replicate_column] == mouse)
    n_cells_mouse = mouse_mask.sum()
    
    for cluster in all_clusters:
        cluster_count = ((adata_filtered.obs[cluster_key] == cluster) & mouse_mask).sum()
        cluster_freq = (cluster_count / n_cells_mouse * 100) if n_cells_mouse > 0 else 0
        
        per_mouse_data.append({
            'Group': group_2,
            'Mouse': mouse,
            'Cluster': cluster,
            'Cell_Count': cluster_count,
            'Total_Cells': n_cells_mouse,
            'Frequency_%': cluster_freq
        })

per_mouse_df = pd.DataFrame(per_mouse_data)

print("\n‚úì Calculated per-mouse cluster frequencies")
print(f"  Data points: {len(per_mouse_df)} (clusters √ó mice)")
print(f"  Mean cells per mouse (Group 1): {per_mouse_df[per_mouse_df['Group']==group_1]['Total_Cells'].mean():.0f}")
print(f"  Mean cells per mouse (Group 2): {per_mouse_df[per_mouse_df['Group']==group_2]['Total_Cells'].mean():.0f}")

# ============================================================
# STATISTICAL TESTING (MANN-WHITNEY U TEST)
# ============================================================
print(f"\nPerforming {statistical_test}...")

stat_results = []

for cluster in all_clusters:
    # Get frequencies for this cluster from all mice in each group
    freq_g1 = per_mouse_df[(per_mouse_df['Group'] == group_1) & 
                            (per_mouse_df['Cluster'] == cluster)]['Frequency_%'].values
    freq_g2 = per_mouse_df[(per_mouse_df['Group'] == group_2) & 
                            (per_mouse_df['Cluster'] == cluster)]['Frequency_%'].values
    
    # Calculate mean and SEM for each group
    mean_g1 = freq_g1.mean()
    sem_g1 = freq_g1.std() / np.sqrt(len(freq_g1))
    mean_g2 = freq_g2.mean()
    sem_g2 = freq_g2.std() / np.sqrt(len(freq_g2))
    
    # Perform Mann-Whitney U test (non-parametric)
    if len(freq_g1) > 1 and len(freq_g2) > 1:
        statistic, p_val = stats.mannwhitneyu(freq_g1, freq_g2, alternative='two-sided')
    else:
        p_val = np.nan
        statistic = np.nan
    
    # Determine significance
    if np.isnan(p_val):
        sig = 'NA'
    elif p_val < 0.001:
        sig = '***'
    elif p_val < 0.01:
        sig = '**'
    elif p_val < 0.05:
        sig = '*'
    else:
        sig = 'ns'
    
    stat_results.append({
        'Cluster': cluster,
        f'{group_1}_mean_%': mean_g1,
        f'{group_1}_SEM_%': sem_g1,
        f'{group_1}_n_mice': len(freq_g1),
        f'{group_2}_mean_%': mean_g2,
        f'{group_2}_SEM_%': sem_g2,
        f'{group_2}_n_mice': len(freq_g2),
        'p_value': p_val,
        'U_statistic': statistic,
        'significance': sig
    })
    
    print(f"Cluster {cluster}: {group_1}={mean_g1:.2f}¬±{sem_g1:.2f}%, "
          f"{group_2}={mean_g2:.2f}¬±{sem_g2:.2f}%, p={p_val:.4e}, {sig}")

stat_df = pd.DataFrame(stat_results)

# ============================================================
# EXPORT TO EXCEL
# ============================================================
timestamp = pd.Timestamp.now().strftime('%Y%m%d')
excel_file = f'cluster_freq_by_mouse_{filter_batch_value}_{group_1}_vs_{group_2}_r{selected_res}_{timestamp}.xlsx'
excel_path = output_dir / excel_file

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Per-mouse data (long format)
    per_mouse_df.to_excel(writer, sheet_name='Per_Mouse_Data', index=False)
    
    # Sheet 2: Statistical comparison
    stat_df.to_excel(writer, sheet_name='Statistical_Comparison', index=False)
    
    # Sheet 3: Per-mouse data pivoted (wide format)
    pivot_counts = per_mouse_df.pivot_table(
        index='Cluster', 
        columns=['Group', 'Mouse'], 
        values='Cell_Count', 
        fill_value=0
    )
    pivot_counts.to_excel(writer, sheet_name='Cell_Counts_by_Mouse')
    
    pivot_freq = per_mouse_df.pivot_table(
        index='Cluster', 
        columns=['Group', 'Mouse'], 
        values='Frequency_%', 
        fill_value=0
    )
    pivot_freq.to_excel(writer, sheet_name='Frequencies_by_Mouse')
    
    # Sheet 5: Summary with statistical test info
    pd.DataFrame({
        'Parameter': ['Batch', 'Group 1', 'Group 2', 'Replicate Column', 
                      f'{group_1}_n_mice', f'{group_2}_n_mice',
                      'Resolution', 'Statistical_Test', 'Significant_clusters', 'Date'],
        'Value': [filter_batch_value, group_1, group_2, replicate_column,
                  len(mice_g1), len(mice_g2), selected_res, 
                  statistical_test, sum(stat_df['p_value'] < 0.05),
                  pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')]
    }).to_excel(writer, sheet_name='Summary', index=False)

print(f'\n‚úì Saved: {excel_file}')

# ============================================================
# PLOT
# ============================================================
configure_plot_style()

n_clusters = len(all_clusters)
fig, ax = plt.subplots(figsize=(max(12, n_clusters*0.6), 7), dpi=300)

x = np.arange(n_clusters)
width = 0.35

# Use mean frequencies
mean_g1 = stat_df[f'{group_1}_mean_%'].values
mean_g2 = stat_df[f'{group_2}_mean_%'].values
sem_g1 = stat_df[f'{group_1}_SEM_%'].values
sem_g2 = stat_df[f'{group_2}_SEM_%'].values

# Create bars with error bars (SEM)
ax.bar(x - width/2, mean_g1, width, label=group_1, color=bar_colors[group_1], 
       alpha=0.85, edgecolor='black', linewidth=0.8, yerr=sem_g1, capsize=3)
ax.bar(x + width/2, mean_g2, width, label=group_2, color=bar_colors[group_2], 
       alpha=0.85, edgecolor='black', linewidth=0.8, yerr=sem_g2, capsize=3)

# Add significance markers
for i, row in stat_df.iterrows():
    if row['significance'] not in ['ns', 'NA']:
        y_pos = max(mean_g1[i] + sem_g1[i], mean_g2[i] + sem_g2[i]) + 1.0
        ax.plot([i - width/2, i + width/2], [y_pos, y_pos], lw=1.2, c='black')
        ax.text(i, y_pos + 0.3, row['significance'], ha='center', va='bottom', 
                fontsize=11, fontweight='bold')

ax.set_xlabel('Cluster', fontsize=14, fontweight='bold')
ax.set_ylabel('Frequency (% ¬± SEM)', fontsize=14, fontweight='bold')
ax.set_title(f'{group_1} vs {group_2} ({filter_batch_value} batch, r={selected_res})\n'
             f'Biological replicates: {len(mice_g1)} vs {len(mice_g2)} mice', 
             fontsize=16, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels(all_clusters)
ax.legend(title='Group', frameon=True, edgecolor='black')
ax.yaxis.grid(True, linestyle='--', alpha=0.3)
ax.set_axisbelow(True)
sns.despine()

if n_clusters > 15:
    ax.tick_params(axis='x', rotation=45, labelsize=10)

# Add caption with statistical test
fig.text(0.5, -0.02, 
         f"Statistical analysis: {statistical_test}. Cluster frequencies were calculated per mouse "
         f"and compared between {group_1} (n={len(mice_g1)}) and {group_2} (n={len(mice_g2)}) using "
         f"the Mann-Whitney U test. * p<0.05, ** p<0.01, *** p<0.001. "
         f"Error bars show SEM (standard error of the mean).",
         ha='center', fontsize=10, style='italic')

plt.tight_layout()
plot_file = f'cluster_freq_by_mouse_{filter_batch_value}_{group_1}_vs_{group_2}_r{selected_res}'
save_plot(plot_file, close=False)
show_inline_plot()
plt.close()

# Summary
n_sig = sum(stat_df['p_value'] < 0.05)
print(f'\n‚úì {n_sig}/{n_clusters} clusters significantly different (p<0.05)')
if n_sig > 0:
    print('\nSignificant clusters:')
    for _, row in stat_df[stat_df['p_value'] < 0.05].iterrows():
        print(f"  Cluster {row['Cluster']}: "
              f"{group_1}={row[f'{group_1}_mean_%']:.2f}¬±{row[f'{group_1}_SEM_%']:.2f}%, "
              f"{group_2}={row[f'{group_2}_mean_%']:.2f}¬±{row[f'{group_2}_SEM_%']:.2f}%, "
              f"p={row['p_value']:.4e} {row['significance']}")

print("\n" + "="*80)

## Marker Gene Identification

### Find Differentially Expressed Genes for Each Cluster

This section identifies marker genes (differentially expressed genes) for each cluster using scanpy's `rank_genes_groups` function, which is equivalent to Seurat's `FindAllMarkers`.

**Methods available:**
- **Wilcoxon rank-sum test** (non-parametric, robust for single-cell data) - Recommended
- **t-test** (parametric, assumes normal distribution)
- **Logreg** (logistic regression, multivariate)

**Output:**
- Excel file with multiple sheets:
  - Top marker genes per cluster
  - Full ranked gene list
  - Summary statistics (log fold-change, p-values, adjusted p-values)

In [None]:
# ============================================================
# MARKER GENE IDENTIFICATION - RESOLUTION 1.2 (with pct.1/pct.2)
# ============================================================

import pandas as pd
import numpy as np

# ============================================================
# CONFIGURATION
# ============================================================
selected_res = 1.2
cluster_key = f'leiden_scVI_r{selected_res}'  # Use resolution-based clustering
n_top_genes = 100                            # Number of top genes to export per cluster
method = 'wilcoxon'                         # Method: 'wilcoxon', 't-test', or 'logreg'
filter_by_fc = True                         # Filter by fold-change
min_fold_change = 0.2                       # Minimum log fold-change (log scale)
max_pval = 0.05                             # Maximum adjusted p-value

# Additional Seurat-style filters (optional)
min_pct = 0.1                               # Minimum pct.1 (expression in cluster)
min_diff_pct = 0.1                          # Minimum difference between pct.1 and pct.2

# Expression threshold for calculating percentages
expr_threshold = 0.0                        # Count cells with expression > this value

print("="*80)
print(f"MARKER GENE IDENTIFICATION (Seurat-style)")
print("="*80)
print(f"Clustering: {cluster_key}")
print(f"Method: {method}")
print(f"Expression threshold for pct calculations: > {expr_threshold}")
print(f"\nFilters:")
print(f"  - avg_log2FC > {min_fold_change}")
print(f"  - p_val_adj < {max_pval}")
print(f"  - pct.1 > {min_pct}")
print(f"  - (pct.1 - pct.2) > {min_diff_pct}")
print("="*80)

# Verify clustering exists
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering '{cluster_key}' not found in adata.obs")

# ============================================================
# RUN DIFFERENTIAL EXPRESSION ANALYSIS
# ============================================================
print(f"\nRunning {method} test for all clusters...")

# Run rank_genes_groups
sc.tl.rank_genes_groups(
    adata,
    groupby=cluster_key,
    method=method,
    use_raw=False,
    n_genes=None,
    key_added=f'rank_genes_{cluster_key}',
    tie_correct=True
)

print(f"‚úì Differential expression analysis complete")

# ============================================================
# MANUALLY CALCULATE PCT.1 AND PCT.2
# ============================================================
print("\nCalculating pct.1 and pct.2 for all genes...")

# Get expression matrix
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer:
    expr_matrix = adata.layers[use_layer]
else:
    expr_matrix = adata.X

# Convert to dense if sparse
if hasattr(expr_matrix, 'toarray'):
    expr_matrix_dense = expr_matrix.toarray()
else:
    expr_matrix_dense = expr_matrix

# Get clusters
clusters = adata.obs[cluster_key].cat.categories
cluster_labels = adata.obs[cluster_key].values

# Calculate pct.1 and pct.2 for each cluster and gene
pct_dict = {}

for cluster in clusters:
    print(f"  Calculating percentages for {cluster}...")
    
    # Mask for cells in this cluster vs. others
    in_cluster = cluster_labels == cluster
    out_cluster = ~in_cluster
    
    # For each gene, calculate % cells expressing
    pct_in = (expr_matrix_dense[in_cluster, :] > expr_threshold).mean(axis=0)
    pct_out = (expr_matrix_dense[out_cluster, :] > expr_threshold).mean(axis=0)
    
    # Store as dictionary: {gene_name: (pct.1, pct.2)}
    pct_dict[cluster] = {
        'pct.1': dict(zip(adata.var_names, pct_in)),
        'pct.2': dict(zip(adata.var_names, pct_out))
    }

print("‚úì Percentage calculations complete")

# ============================================================
# EXTRACT RESULTS WITH PCT.1 AND PCT.2
# ============================================================
print("\nExtracting marker genes with Seurat-style columns...")

result_key = f'rank_genes_{cluster_key}'
all_markers = []

for cluster in clusters:
    # Get results for this cluster
    cluster_markers = sc.get.rank_genes_groups_df(
        adata, 
        group=cluster, 
        key=result_key
    )
    
    # Add cluster column
    cluster_markers['cluster'] = cluster
    
    # Rename columns to Seurat style
    column_mapping = {
        'names': 'gene',
        'scores': 'score',
        'logfoldchanges': 'avg_log2FC',
        'pvals': 'p_val',
        'pvals_adj': 'p_val_adj'
    }
    
    cluster_markers = cluster_markers.rename(columns=column_mapping)
    
    # Add pct.1 and pct.2 manually
    cluster_markers['pct.1'] = cluster_markers['gene'].map(pct_dict[cluster]['pct.1'])
    cluster_markers['pct.2'] = cluster_markers['gene'].map(pct_dict[cluster]['pct.2'])
    
    all_markers.append(cluster_markers)

# Combine all clusters
all_markers_df = pd.concat(all_markers, ignore_index=True)

# Calculate pct difference
all_markers_df['pct_diff'] = all_markers_df['pct.1'] - all_markers_df['pct.2']

# Reorder columns for better readability
column_order = ['cluster', 'gene', 'avg_log2FC', 'pct.1', 'pct.2', 'pct_diff', 
                'p_val', 'p_val_adj', 'score']
# Only keep columns that exist
column_order = [col for col in column_order if col in all_markers_df.columns]
all_markers_df = all_markers_df[column_order]

print(f"‚úì Extracted {len(all_markers_df):,} gene-cluster comparisons")
print(f"  Columns: {list(all_markers_df.columns)}")

# Show example to verify pct.1 and pct.2
print("\nExample rows (first cluster, top 3 genes):")
print(all_markers_df.head(3).to_string(index=False))

# ============================================================
# FILTER MARKERS (SEURAT-STYLE)
# ============================================================
print("\n" + "="*80)
print("APPLYING FILTERS")
print("="*80)

filtered_markers = all_markers_df.copy()
n_before = len(filtered_markers)

# Filter by log fold-change
if filter_by_fc and 'avg_log2FC' in filtered_markers.columns:
    filtered_markers = filtered_markers[filtered_markers['avg_log2FC'] > min_fold_change]
    print(f"  After avg_log2FC > {min_fold_change}: {len(filtered_markers):,} genes")

# Filter by adjusted p-value
if 'p_val_adj' in filtered_markers.columns:
    filtered_markers = filtered_markers[filtered_markers['p_val_adj'] < max_pval]
    print(f"  After p_val_adj < {max_pval}: {len(filtered_markers):,} genes")

# Filter by minimum pct.1
filtered_markers = filtered_markers[filtered_markers['pct.1'] > min_pct]
print(f"  After pct.1 > {min_pct}: {len(filtered_markers):,} genes")

# Filter by minimum pct difference
filtered_markers = filtered_markers[filtered_markers['pct_diff'] > min_diff_pct]
print(f"  After pct_diff > {min_diff_pct}: {len(filtered_markers):,} genes")

# Sort by cluster and p_val_adj
filtered_markers = filtered_markers.sort_values(['cluster', 'p_val_adj'])

print(f"\n‚úì Total filtered: {n_before:,} ‚Üí {len(filtered_markers):,} significant markers")

# Get top N genes per cluster
top_markers = filtered_markers.groupby('cluster').head(n_top_genes).reset_index(drop=True)
print(f"‚úì Top {n_top_genes} genes per cluster: {len(top_markers):,} total")

# ============================================================
# SUMMARY STATISTICS
# ============================================================
print("\n" + "="*80)
print("MARKER GENES PER CLUSTER")
print("="*80)
for cluster in clusters:
    n_markers = (filtered_markers['cluster'] == cluster).sum()
    print(f"  {cluster}: {n_markers:,} significant markers")

# ============================================================
# EXPORT TO EXCEL
# ============================================================
timestamp = pd.Timestamp.now().strftime('%Y%m%d')
excel_file = f'marker_genes_{cluster_key}_{timestamp}.xlsx'
excel_path = output_dir / excel_file

print(f"\nExporting to Excel: {excel_file}")

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Top N markers per cluster
    top_markers.to_excel(writer, sheet_name=f'Top_{n_top_genes}_per_cluster', index=False)
    
    # Sheet 2: All significant markers
    filtered_markers.to_excel(writer, sheet_name='All_Significant_Markers', index=False)
    
    # Sheet 3: All genes (unfiltered)
    all_markers_df.to_excel(writer, sheet_name='All_Genes_Unfiltered', index=False)
    
    # Sheet 4: Summary per cluster
    summary_list = []
    for cluster in clusters:
        cluster_data = filtered_markers[filtered_markers['cluster'] == cluster]
        summary_list.append({
            'cluster': cluster,
            'n_markers': len(cluster_data),
            'mean_log2FC': cluster_data['avg_log2FC'].mean(),
            'max_log2FC': cluster_data['avg_log2FC'].max(),
            'mean_pct.1': cluster_data['pct.1'].mean(),
            'mean_pct.2': cluster_data['pct.2'].mean(),
            'mean_pct_diff': cluster_data['pct_diff'].mean(),
            'min_p_val_adj': cluster_data['p_val_adj'].min()
        })
    
    summary_df = pd.DataFrame(summary_list)
    summary_df.to_excel(writer, sheet_name='Summary', index=False)
    
    # Sheet 5: Analysis parameters
    params_df = pd.DataFrame({
        'Parameter': ['Clustering', 'Resolution', 'Method', 'Min_avg_log2FC', 
                      'Max_p_val_adj', 'Min_pct.1', 'Min_pct_diff',
                      'Expr_threshold', 'Top_N_Genes', 'Total_Clusters', 
                      'Total_Significant_Markers', 'Date'],
        'Value': [cluster_key, selected_res, method, min_fold_change, 
                  max_pval, min_pct, min_diff_pct, expr_threshold,
                  n_top_genes, len(clusters), 
                  len(filtered_markers), pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')]
    })
    params_df.to_excel(writer, sheet_name='Parameters', index=False)

print(f"‚úì Saved: {excel_file}")

# ============================================================
# DISPLAY TOP MARKERS FOR EACH CLUSTER
# ============================================================
print("\n" + "="*80)
print("TOP 10 MARKER GENES PER CLUSTER (Seurat-style)")
print("="*80)

for cluster in clusters:
    cluster_top = top_markers[top_markers['cluster'] == cluster].head(10)
    if len(cluster_top) > 0:
        print(f"\n{cluster}:")
        for idx, row in cluster_top.iterrows():
            print(f"  {row['gene']:15s} | log2FC: {row['avg_log2FC']:5.2f} | "
                  f"pct.1: {row['pct.1']:.1%} | pct.2: {row['pct.2']:.1%} | "
                  f"p_adj: {row['p_val_adj']:.2e}")

print("\n" + "="*80)
print("‚úì Marker gene identification complete!")
print("="*80)

### Create Custom Clustering somen_clusters

This section creates a custom cluster annotation by:
1. Starting from a selected resolution (e.g., leiden_scVI_r1.0)
2. Merging specific clusters (e.g., combining clusters 3, 12, and 1)
3. Renaming clusters as C0, C1, C2... ordered by cell count (highest to lowest)
4. Visualizing the custom clusters on UMAP with centered labels

**Purpose**: Simplify cluster annotations by combining biologically similar clusters and creating a clean, ordered naming scheme for publication.

In [None]:
# Ensure inline plotting
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
source_resolution = 1.2
new_column_name = 'somen_clusters'

# Combine exactly these groups; all others (including 0) stay separate
cluster_groups = [
    [9, 10, 13, 5],  # Group A
    [1, 4],          # Group B
]

# Visualization settings
point_size = 10
fig_width = 6
fig_height = 4.5

# High-quality export settings
export_dpi = 600            # High DPI for crisp PNG points
export_bbox = 'tight'
export_pad = 0.1

# ============================================================
# CREATE SOMEN_CLUSTERS COLUMN
# ============================================================
print('='*60)
print('CREATING SOMEN_CLUSTERS ANNOTATION')
print('='*60)

source_key = f'leiden_scVI_r{source_resolution}'

if source_key not in adata.obs:
    raise ValueError(f"Source clustering '{source_key}' not found in adata.obs")

# clusters as integers
clusters = adata.obs[source_key].astype(int)
all_clusters = sorted(clusters.unique())
print(f'\nAvailable clusters in {source_key}: {all_clusters}')

# Build mapping: original cluster -> temp group id
clusters_in_groups = set(c for group in cluster_groups for c in group)
separate_clusters = [c for c in all_clusters if c not in clusters_in_groups]
print(f'\nClusters merged: {sorted(list(clusters_in_groups))}')
print(f'Clusters staying separate: {separate_clusters}')

cluster_mapping = {}
for group_idx, group in enumerate(cluster_groups):
    for cid in group:
        cluster_mapping[cid] = f'GROUP_M{group_idx}'  # merged groups labeled M0, M1,...

for cid in separate_clusters:
    cluster_mapping[cid] = f'GROUP_{cid}'  # each separate cluster gets its own id

# Apply mapping
adata.obs['_temp_group'] = clusters.map(cluster_mapping)

# Count, sort desc, and assign C0, C1, C2... by cell count
temp_counts = adata.obs['_temp_group'].value_counts()
sorted_groups = temp_counts.sort_values(ascending=False).index.tolist()
final_mapping = {gid: f'C{i}' for i, gid in enumerate(sorted_groups)}

print('\nFinal cluster assignments (by descending cell count):')
for gid in sorted_groups:
    print(f'  {gid} -> {final_mapping[gid]}: {temp_counts[gid]:,} cells')

adata.obs[new_column_name] = adata.obs['_temp_group'].map(final_mapping)
del adata.obs['_temp_group']

# Order categories as C0, C1, C2, ...
ordered_labels = [final_mapping[gid] for gid in sorted_groups]
adata.obs[new_column_name] = pd.Categorical(
    adata.obs[new_column_name],
    categories=ordered_labels,
    ordered=True
)

# Summary
final_counts = adata.obs[new_column_name].value_counts().reindex(ordered_labels)
print('\n' + '='*60)
print('FINAL SOMEN_CLUSTERS SUMMARY')
print('='*60)
print(f'Total clusters: {len(ordered_labels)}')
for lbl in ordered_labels:
    print(f'  {lbl}: {int(final_counts[lbl]):,} cells')
print(f'\n‚úì Created new column: adata.obs["{new_column_name}"]')
print(f'  Categories (ordered): {ordered_labels}')

# ============================================================
# VISUALIZE SOMEN_CLUSTERS (high-quality export; no on-data labels)
# ============================================================
print('\n' + '='*60)
print('GENERATING UMAP VISUALIZATION (somen_clusters)')
print('='*60)

configure_plot_style()

# Palette for somen_clusters
n_clusters = len(adata.obs[new_column_name].cat.categories)
custom_palette = get_dittoseq_colors(n_clusters)
adata.uns[f'{new_column_name}_colors'] = custom_palette

# Figure
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=export_dpi)

# Plot with legend on the side (no labels on data)
sc.pl.umap(
    adata,
    color=new_column_name,
    ax=ax,
    show=False,
    title=f'Somen Clusters (Resolution {source_resolution})',
    legend_loc='right margin',
    legend_fontsize=9,
    legend_fontweight='normal',
    frameon=False,
    size=point_size,
    palette=adata.uns[f'{new_column_name}_colors']
)

# Axes styling
ax.set_xlabel('UMAP-1', fontsize=12)
ax.set_ylabel('UMAP-2', fontsize=12)
ax.tick_params(axis='both', which='major',
               bottom=True, left=True,
               labelbottom=True, labelleft=True,
               labelsize=10, length=4, width=0.8)
ax.tick_params(axis='both', which='minor',
               bottom=True, left=True, length=2, width=0.6)

plt.tight_layout()

# Save 600 DPI PNG and fully vector PDF
png_path = output_dir / 'png' / f'somen_clusters_umap_r{source_resolution}.png'
pdf_path = output_dir / 'pdf' / f'somen_clusters_umap_r{source_resolution}.pdf'

plt.savefig(png_path, dpi=export_dpi, bbox_inches=export_bbox, pad_inches=export_pad)
plt.savefig(pdf_path, bbox_inches=export_bbox, pad_inches=export_pad, format='pdf')

print(f'  Saved: {png_path.name} (DPI={export_dpi})')
print(f'  Saved: {pdf_path.name} (vector axes/text)')

show_inline_plot()
plt.close()

# UMAP Visualization: somen_clusters Split by Metadata

Generates UMAP plots of `somen_clusters` across metadata categories (batch, library, mouse, group).

**Features:**
- Full dataset + downsampled plots
- High-quality export (300 DPI, PDF+PNG)
- Custom dimensions for mouse plots
- Publication-ready for Illustrator/Affinity Designer

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
cluster_key = 'somen_clusters'  # Use microglia_type instead of resolution-based clustering

# ============================================================
# USER CONFIGURATION - CONTROL PLOT PARAMETERS
# ============================================================
# Customize parameters for MOUSE plots
mouse_dot_size = 50
mouse_fig_width = 10
mouse_fig_height = 5.4

# Customize parameters for GROUP plots
group_dot_size = 10
group_fig_width = 8
group_fig_height = 5.4

# ============================================================
# SETUP
# ============================================================
metadata_columns = ['mouse', 'group']  # Only mouse and group now

if cluster_key not in adata.obs:
    raise ValueError(f"Clustering column {cluster_key} not found in adata.obs")

for col in metadata_columns:
    if col not in adata.obs:
        print(f"Warning: Metadata column {col} not found in adata.obs - skipping")

configure_plot_style()

# Import for palette
import matplotlib.colors as mcolors

# Set up dittoSeq palette for publication quality
n_clusters = len(adata.obs[cluster_key].cat.categories)
custom_palette = get_dittoseq_colors(n_clusters)
adata.uns[f'{cluster_key}_colors'] = custom_palette

# ============================================================
# ORIGINAL PLOTS (Full Dataset)
# ============================================================
print("Generating full dataset plots...")

for meta_col in metadata_columns:
    if meta_col not in adata.obs:
        continue
    
    # Get custom parameters for this metadata type
    if meta_col == 'mouse':
        dot_size = mouse_dot_size
        fig_width = mouse_fig_width
        fig_height = mouse_fig_height
    elif meta_col == 'group':
        dot_size = group_dot_size
        fig_width = group_fig_width
        fig_height = group_fig_height
    
    meta_order = adata.obs[meta_col].cat.categories.tolist()
    n_items = len(meta_order)
    n_cols = min(4, n_items)  # Max 4 columns
    n_rows = (n_items + n_cols - 1) // n_cols  # Calculate rows needed
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width * n_cols, fig_height * n_rows), sharex=True, sharey=True)
    if n_items == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_items > 1 else [axes]
    
    for i, item in enumerate(meta_order):
        mask = adata.obs[meta_col] == item
        if mask.sum() > 0:
            sc.pl.umap(
                adata[mask],
                color=cluster_key,
                ax=axes[i],
                show=False,
                title=f'{item} ({mask.sum():,} cells)',
                legend_loc='right margin',
                frameon=False,
                size=dot_size,
                palette=adata.uns[f'{cluster_key}_colors']
            )
        else:
            axes[i].set_visible(False)
            print(f'Warning: {meta_col} {item} not found')
    
    # Hide unused subplots
    if n_items > 1:
        for j in range(n_items, len(axes)):
            axes[j].set_visible(False)
    
    plt.suptitle(f'Microglial Types by {meta_col}', y=1.02, fontsize=18)
    plt.tight_layout()
    plt.savefig(output_dir / 'png' / f'{cluster_key}_{meta_col}_umaps.png', dpi=300, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / f'{cluster_key}_{meta_col}_umaps.pdf', dpi=300, bbox_inches='tight', format='pdf')
    print(f'Saved: {cluster_key}_{meta_col}_umaps.png/.pdf')
    plt.close()

print('\n‚úì Completed full dataset metadata-stratified visualizations')

# ============================================================
# DOWNSAMPLED PLOTS (Equal Cells Per Group)
# ============================================================
print("\n" + "="*60)
print("Generating downsampled plots with equal cells per group...")
print("="*60)

for meta_col in metadata_columns:
    if meta_col not in adata.obs:
        continue
    
    # Get custom parameters for this metadata type
    if meta_col == 'mouse':
        dot_size = mouse_dot_size
        fig_width = mouse_fig_width
        fig_height = mouse_fig_height
    elif meta_col == 'group':
        dot_size = group_dot_size
        fig_width = group_fig_width
        fig_height = group_fig_height
    
    print(f"\nProcessing {meta_col}...")
    
    # Get category counts
    meta_order = adata.obs[meta_col].cat.categories.tolist()
    counts = adata.obs[meta_col].value_counts()
    min_cells = counts.min()
    
    print(f"  Cell counts per {meta_col}: {dict(counts)}")
    print(f"  Downsampling to: {min_cells:,} cells per group")
    
    # Downsample to equal number of cells per category
    np.random.seed(42)  # For reproducibility
    downsampled_indices = []
    for item in meta_order:
        item_indices = np.where(adata.obs[meta_col] == item)[0]
        sampled_indices = np.random.choice(item_indices, size=min_cells, replace=False)
        downsampled_indices.extend(sampled_indices)
    
    # Create downsampled AnnData
    adata_downsampled = adata[downsampled_indices].copy()
    
    # Generate UMAP plots with downsampled data
    n_items = len(meta_order)
    n_cols = min(6, n_items)
    n_rows = (n_items + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width * n_cols, fig_height * n_rows), sharex=True, sharey=True)
    if n_items == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_items > 1 else [axes]
    
    for i, item in enumerate(meta_order):
        mask = adata_downsampled.obs[meta_col] == item
        if mask.sum() > 0:
            sc.pl.umap(
                adata_downsampled[mask],
                color=cluster_key,
                ax=axes[i],
                show=False,
                title=f'{item} ({mask.sum():,} cells)',
                legend_loc='right margin',
                frameon=False,
                size=dot_size * 2,  # Larger dots for downsampled
                palette=adata.uns[f'{cluster_key}_colors']
            )
        else:
            axes[i].set_visible(False)
            print(f'Warning: {meta_col} {item} not found in downsampled data')
    
    # Hide unused subplots
    if n_items > 1:
        for j in range(n_items, len(axes)):
            axes[j].set_visible(False)
    
    plt.suptitle(f'Microglial Types by {meta_col} - Downsampled (n={min_cells:,}/group)', 
                 y=1.02, fontsize=18)
    plt.tight_layout()
    plt.savefig(output_dir / 'png' / f'{cluster_key}_{meta_col}_umaps_downsampled.png', dpi=300, bbox_inches='tight')
    plt.savefig(output_dir / 'pdf' / f'{cluster_key}_{meta_col}_umaps_downsampled.pdf', dpi=300, bbox_inches='tight', format='pdf')
    print(f'  Saved: {cluster_key}_{meta_col}_umaps_downsampled.png/.pdf')
    plt.close()

print('\n‚úì Completed downsampled metadata-stratified visualizations')

### Cluster Frequency Comparison

Compare the frequency of custom clusters (`somen_clusters`) between experimental groups using biological replicates. This analysis:
- Calculates per-mouse cluster frequencies
- Performs Mann-Whitney U test for each cluster
- Visualizes differences with statistical significance markers
- Exports detailed results to Excel

In [None]:
# Ensure inline plotting
%matplotlib inline

# Import scipy.stats for statistical tests
import scipy.stats as stats

# ============================================================
# CONFIGURATION
# ============================================================
cluster_key = 'somen_clusters'        # Use custom cluster annotation
filter_batch_value = 'Mistri'         # Batch to filter to
comparison_category = 'group'          # Category to compare (e.g., 'group')
group_1 = 'PBS-CFA'                    # First group
group_2 = 'PEP15-CFA'                  # Second group
replicate_column = 'mouse'             # Column containing biological replicates

# Plot colors (optional customization)
bar_colors = {group_1: '#3C5488', group_2: '#E64B35'}

# ============================================================
# VERIFY AND FILTER DATA
# ============================================================

# Verify microglia_type exists
if cluster_key not in adata.obs:
    raise ValueError(f"Column '{cluster_key}' not found in adata.obs.")

# Filter to specific batch
batch_mask = adata.obs['batch'] == filter_batch_value
adata_filtered = adata[batch_mask].copy()

print("="*80)
print(f"CLUSTER FREQUENCY COMPARISON WITH BIOLOGICAL REPLICATES")
print(f"Using: {cluster_key}")
print("="*80)
print(f"\nBatch: {filter_batch_value}")
print(f"Comparing: {group_1} vs {group_2}")
print(f"Biological replicates: {replicate_column}")

# Get masks for each group
mask_g1 = adata_filtered.obs[comparison_category] == group_1
mask_g2 = adata_filtered.obs[comparison_category] == group_2

# Get unique mice in each group
mice_g1 = adata_filtered.obs[mask_g1][replicate_column].unique()
mice_g2 = adata_filtered.obs[mask_g2][replicate_column].unique()

print(f"\n{group_1} mice: {list(mice_g1)} (n={len(mice_g1)})")
print(f"{group_2} mice: {list(mice_g2)} (n={len(mice_g2)})")
print("="*80)

# Get all clusters
all_clusters = adata_filtered.obs[cluster_key].cat.categories

# ============================================================
# CALCULATE PER-MOUSE FREQUENCIES
# ============================================================

# Store per-mouse data
per_mouse_data = []

# Group 1 mice
for mouse in mice_g1:
    mouse_mask = mask_g1 & (adata_filtered.obs[replicate_column] == mouse)
    n_cells_mouse = mouse_mask.sum()
    
    for cluster in all_clusters:
        cluster_count = ((adata_filtered.obs[cluster_key] == cluster) & mouse_mask).sum()
        cluster_freq = (cluster_count / n_cells_mouse * 100) if n_cells_mouse > 0 else 0
        
        per_mouse_data.append({
            'Group': group_1,
            'Mouse': mouse,
            'Cluster': cluster,
            'Cell_Count': cluster_count,
            'Total_Cells': n_cells_mouse,
            'Frequency_%': cluster_freq
        })

# Group 2 mice
for mouse in mice_g2:
    mouse_mask = mask_g2 & (adata_filtered.obs[replicate_column] == mouse)
    n_cells_mouse = mouse_mask.sum()
    
    for cluster in all_clusters:
        cluster_count = ((adata_filtered.obs[cluster_key] == cluster) & mouse_mask).sum()
        cluster_freq = (cluster_count / n_cells_mouse * 100) if n_cells_mouse > 0 else 0
        
        per_mouse_data.append({
            'Group': group_2,
            'Mouse': mouse,
            'Cluster': cluster,
            'Cell_Count': cluster_count,
            'Total_Cells': n_cells_mouse,
            'Frequency_%': cluster_freq
        })

per_mouse_df = pd.DataFrame(per_mouse_data)

# ============================================================
# STATISTICAL TESTING (MANN-WHITNEY U TEST)
# ============================================================
# Compare cluster frequencies between groups using biological replicates

print("\nPerforming statistical tests (Mann-Whitney U test)...")

stat_results = []

for cluster in all_clusters:
    # Get frequencies for this cluster from all mice in each group
    freq_g1 = per_mouse_df[(per_mouse_df['Group'] == group_1) & 
                            (per_mouse_df['Cluster'] == cluster)]['Frequency_%'].values
    freq_g2 = per_mouse_df[(per_mouse_df['Group'] == group_2) & 
                            (per_mouse_df['Cluster'] == cluster)]['Frequency_%'].values
    
    # Calculate mean and SEM for each group
    mean_g1 = freq_g1.mean()
    sem_g1 = freq_g1.std() / np.sqrt(len(freq_g1))
    mean_g2 = freq_g2.mean()
    sem_g2 = freq_g2.std() / np.sqrt(len(freq_g2))
    
    # Perform Mann-Whitney U test (non-parametric, good for small sample sizes)
    if len(freq_g1) > 1 and len(freq_g2) > 1:
        statistic, p_val = stats.mannwhitneyu(freq_g1, freq_g2, alternative='two-sided')
    else:
        p_val = np.nan
        statistic = np.nan
    
    # Determine significance
    if np.isnan(p_val):
        sig = 'NA'
    elif p_val < 0.001:
        sig = '***'
    elif p_val < 0.01:
        sig = '**'
    elif p_val < 0.05:
        sig = '*'
    else:
        sig = 'ns'
    
    stat_results.append({
        'Cluster': cluster,
        f'{group_1}_mean_%': mean_g1,
        f'{group_1}_SEM_%': sem_g1,
        f'{group_1}_n_mice': len(freq_g1),
        f'{group_2}_mean_%': mean_g2,
        f'{group_2}_SEM_%': sem_g2,
        f'{group_2}_n_mice': len(freq_g2),
        'p_value': p_val,
        'U_statistic': statistic,
        'significance': sig
    })
    
    print(f"Cluster {cluster}: {group_1}={mean_g1:.2f}¬±{sem_g1:.2f}%, "
          f"{group_2}={mean_g2:.2f}¬±{sem_g2:.2f}%, p={p_val:.4e}, {sig}")

stat_df = pd.DataFrame(stat_results)

# ============================================================
# EXPORT TO EXCEL
# ============================================================
timestamp = pd.Timestamp.now().strftime('%Y%m%d')
excel_file = f'{cluster_key}_freq_by_mouse_{filter_batch_value}_{group_1}_vs_{group_2}_{timestamp}.xlsx'
excel_path = output_dir / excel_file

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Per-mouse data (long format)
    per_mouse_df.to_excel(writer, sheet_name='Per_Mouse_Data', index=False)
    
    # Sheet 2: Statistical comparison
    stat_df.to_excel(writer, sheet_name='Statistical_Comparison', index=False)
    
    # Sheet 3: Per-mouse data pivoted (wide format - easier to read)
    pivot_counts = per_mouse_df.pivot_table(
        index='Cluster', 
        columns=['Group', 'Mouse'], 
        values='Cell_Count', 
        fill_value=0
    )
    pivot_counts.to_excel(writer, sheet_name='Cell_Counts_by_Mouse')
    
    pivot_freq = per_mouse_df.pivot_table(
        index='Cluster', 
        columns=['Group', 'Mouse'], 
        values='Frequency_%', 
        fill_value=0
    )
    pivot_freq.to_excel(writer, sheet_name='Frequencies_by_Mouse')
    
    # Sheet 5: Summary
    pd.DataFrame({
        'Parameter': ['Cluster_Type', 'Batch', 'Group 1', 'Group 2', 'Replicate Column', 
                      f'{group_1}_n_mice', f'{group_2}_n_mice',
                      'Statistical Test', 'Significant_clusters', 'Date'],
        'Value': [cluster_key, filter_batch_value, group_1, group_2, replicate_column,
                  len(mice_g1), len(mice_g2),
                  'Mann-Whitney U test', sum(stat_df['p_value'] < 0.05),
                  pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')]
    }).to_excel(writer, sheet_name='Summary', index=False)

print(f'\n‚úì Saved: {excel_file}')

# ============================================================
# PLOT
# ============================================================
configure_plot_style()

n_clusters = len(all_clusters)
fig, ax = plt.subplots(figsize=(max(12, n_clusters*0.6), 7), dpi=300)

x = np.arange(n_clusters)
width = 0.35

# Use mean frequencies
mean_g1 = stat_df[f'{group_1}_mean_%'].values
mean_g2 = stat_df[f'{group_2}_mean_%'].values
sem_g1 = stat_df[f'{group_1}_SEM_%'].values
sem_g2 = stat_df[f'{group_2}_SEM_%'].values

# Create bars with error bars (SEM)
ax.bar(x - width/2, mean_g1, width, label=group_1, color=bar_colors[group_1], 
       alpha=0.85, edgecolor='black', linewidth=0.8, yerr=sem_g1, capsize=3)
ax.bar(x + width/2, mean_g2, width, label=group_2, color=bar_colors[group_2], 
       alpha=0.85, edgecolor='black', linewidth=0.8, yerr=sem_g2, capsize=3)

# Add significance markers
for i, row in stat_df.iterrows():
    if row['significance'] not in ['ns', 'NA']:
        y_pos = max(mean_g1[i] + sem_g1[i], mean_g2[i] + sem_g2[i]) + 1.0
        ax.plot([i - width/2, i + width/2], [y_pos, y_pos], lw=1.2, c='black')
        ax.text(i, y_pos + 0.3, row['significance'], ha='center', va='bottom', 
                fontsize=11, fontweight='bold')

ax.set_xlabel('Microglial Type', fontsize=14, fontweight='bold')
ax.set_ylabel('Frequency (% ¬± SEM)', fontsize=14, fontweight='bold')
ax.set_title(f'{group_1} vs {group_2} ({filter_batch_value} batch)\n'
             f'Microglial Types | Biological replicates: {len(mice_g1)} vs {len(mice_g2)} mice', 
             fontsize=16, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels(all_clusters, rotation=45, ha='right')
ax.legend(title='Group', frameon=True, edgecolor='black')
ax.yaxis.grid(True, linestyle='--', alpha=0.3)
ax.set_axisbelow(True)
sns.despine()

fig.text(0.5, -0.02, 
         f"Mann-Whitney U test comparing per-mouse cluster frequencies. * p<0.05, ** p<0.01, *** p<0.001.\n"
         f"n={len(mice_g1)} mice ({group_1}), n={len(mice_g2)} mice ({group_2}). Error bars show SEM.",
         ha='center', fontsize=10, style='italic')

plt.tight_layout()
plot_file = f'{cluster_key}_freq_by_mouse_{filter_batch_value}_{group_1}_vs_{group_2}'
save_plot(plot_file, close=False)
show_inline_plot()
plt.close()

# Summary
n_sig = sum(stat_df['p_value'] < 0.05)
print(f'\n‚úì {n_sig}/{n_clusters} clusters significantly different (p<0.05)')
if n_sig > 0:
    print('\nSignificant clusters:')
    for _, row in stat_df[stat_df['p_value'] < 0.05].iterrows():
        print(f"  Cluster {row['Cluster']}: "
              f"{group_1}={row[f'{group_1}_mean_%']:.2f}¬±{row[f'{group_1}_SEM_%']:.2f}%, "
              f"{group_2}={row[f'{group_2}_mean_%']:.2f}¬±{row[f'{group_2}_SEM_%']:.2f}%, "
              f"p={row['p_value']:.4e} {row['significance']}")

print("\n" + "="*80)

### Gene Expression Dotplot - Somen Clusters

Generate a gene expression dotplot using the `somen_clusters` annotation. This visualization shows:
- **Mean expression** of marker genes across somen_clusters (color intensity)
- **Percent of cells** expressing each gene in each cluster (dot size)
- **Dendrogram** showing hierarchical relationships between clusters based on gene expression

The dotplot uses the same marker gene panel as previous analyses and is computed using normalized expression values. Clusters are automatically ordered by the dendrogram based on their expression similarity.

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION
# ============================================================
# Use somen_clusters (not resolution-based clustering)
cluster_key = 'somen_clusters'
plot_width = 6     # üëà Change figure width (in inches)
plot_height = 15    # üëà Change figure height (in inches)

# ============================================================
# VERIFY DATA AND LAYERS
# ============================================================

# Verify clustering column exists
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering column '{cluster_key}' not found in adata.obs")

# Check if normalized layer exists, if not use X (which should be normalized)
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer is None:
    print("Note: Using adata.X for gene expression (no 'normalized' layer found)")

# Genes to visualize (mouse gene symbols)
# Classic marker genes for major cell types
genes = [
    # Homeostatic Microglia (resting, surveillance state)
    'Tmem119', 'P2ry12', 'Cx3cr1', 'Sall1', 'Fcrls', 'Siglech', 'Hexb',

    # Activated/Pro-inflammatory Microglia (M1-like, immune-activated)
    'Il1b', 'Tnf', 'Cd68', 'Nos2', 'Cd86',

    # Anti-inflammatory/Repair Microglia (M2-like, tissue repair)
    'Arg1', 'Mrc1', 'Il10', 'Tgfb1', 'Ym1',

    # Disease-Associated Microglia (DAM, neurodegenerative contexts)
    'Trem2', 'Apoe', 'Cst7', 'Lpl', 'Tyrobp', 'Clec7a',

    # Aged Microglia (altered in aging brain)
    'Ccl2', 'C1qa', 'B2m',

    # Proliferative Microglia (dividing, injury/disease)
    'Mki67', 'Top2a', 'Cdk1', 'Ccna2', 'Birc5',

    # Interferon-Responsive Microglia (viral/interferon response)
    'Ifit1', 'Irf7', 'Stat1', 'Cxcl10', 'Isg15',

    # Phagocytic Microglia (enhanced phagocytosis)
    'Trem2', 'Cd68', 'Mertk', 'Axl', 'C1qa',

    # General Immune Cell Marker
    'Ptprc','H2-Aa', 'H2-Ab1',
    'Mrc1', 'Ccr2', 'Ly6g', 'Ly6c2', 'Ms4a7',
    'Cdk8', 'Cmss1', 'Lars2',
    'H2-Q7', 'H2-Aa', 'H2-Ab1', 'H2-Q4', 'Cd74',
    'Ifi27I2a', ' Arhgap15', 'Klf2', 'Cd52', 'Cd34'
]

# Check available genes
available_genes = [g for g in genes if g in adata.var_names]
missing_genes = [g for g in genes if g not in adata.var_names]

if missing_genes:
    print(f'Warning: Genes not found in adata.var_names: {", ".join(missing_genes)}')
if not available_genes:
    raise ValueError('No valid genes for plotting.')

# ============================================================
# CREATE CUSTOM GREY-PLASMA COLORMAP
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Get plasma_r colormap
plasma_r = plt.cm.plasma_r
n_colors = 256
plasma_colors = plasma_r(np.linspace(0, 1, n_colors))

# Replace first 20% with grey gradient
grey_to_plasma_transition = 0.2
n_grey_colors = int(n_colors * grey_to_plasma_transition)

# Create grey gradient from light grey to first plasma color
light_grey = np.array([0.85, 0.85, 0.85, 1.0])
transition_color = plasma_colors[n_grey_colors]
grey_gradient = np.linspace(light_grey, transition_color, n_grey_colors)

# Combine grey gradient with plasma_r
custom_colors = np.vstack([grey_gradient, plasma_colors[n_grey_colors:]])
dotplot_colormap = LinearSegmentedColormap.from_list('grey_plasma', custom_colors)

print('‚úì Created custom grey-plasma colormap for dotplot')

# Configure plot style
configure_plot_style()

# ============================================================
# COMPUTE DENDROGRAM FOR SOMEN_CLUSTERS
# ============================================================
print(f'Computing dendrogram for {cluster_key}...')
sc.tl.dendrogram(adata, groupby=cluster_key)
print('‚úì Dendrogram computed')

# ============================================================
# GENERATE DOTPLOT
# ============================================================

fig, ax = plt.subplots(figsize=(plot_width, plot_height), dpi=300)
dotplot = sc.pl.dotplot(
    adata,
    var_names=available_genes,
    groupby=cluster_key,
    layer=use_layer,
    dendrogram=True,
    ax=ax,
    return_fig=True,
    dot_max=0.8,
    dot_min=0.05,
    colorbar_title='Mean expression\nin group',
    size_title='% of cells\nexpressing gene',
    cmap=dotplot_colormap,  # Use custom grey-plasma colormap
    swap_axes=True,
    var_group_rotation=90
)

# Make dot borders thinner
main_ax = dotplot.get_axes()['mainplot_ax']
for collection in main_ax.collections:
    collection.set_linewidths(0)  # üëà Adjust this value (default is ~1.0, try 0.3, 0.5, etc.)

# Force straight cluster names with increased font size
main_ax.tick_params(axis='x', labelsize=12, rotation=0)
for label in main_ax.get_xticklabels():
    label.set_rotation(0)
    label.set_ha('center')

# Ensure ticks are visible on both axes
main_ax.tick_params(axis='both', bottom=True, left=True, labelbottom=True, labelleft=True, length=4, direction='out')
main_ax.spines['bottom'].set_visible(True)
main_ax.spines['left'].set_visible(True)

plt.suptitle(f'Gene Expression Dotplot - Somen Clusters (Normalized)', y=1.05, fontsize=18)
plt.subplots_adjust(bottom=0.2)
plt.tight_layout()

# Save with somen_clusters in filename
plt.savefig(output_dir / 'png' / f'rna_dotplot_somen_clusters_normalized.png', dpi=300, bbox_inches='tight')
plt.savefig(output_dir / 'pdf' / f'rna_dotplot_somen_clusters_normalized.pdf', dpi=300, bbox_inches='tight', format='pdf')
print(f'Saved: rna_dotplot_somen_clusters_normalized.png/.pdf')
show_inline_plot()
plt.close()

### Cell Type Annotation from Somen Clusters

Convert `somen_clusters` (C0, C1, C2, etc.) to descriptive cell type annotations based on biological interpretation. This mapping assigns each cluster to a specific microglial cell type or functional state:

- **Homeostatic clusters**: C0, C1, C2 ‚Üí Homeostatic states 1-3 (resting surveillance microglia)
- **Activated states**: C3 ‚Üí Universal Activated Myeloid State
- **Functional states**: C4 ‚Üí Phagocytic, C5 ‚Üí Interferon Responsive, C6 ‚Üí Intermediate State
- **Proliferating**: C8 ‚Üí Proliferating
- **Unidentified**: C7, C9 ‚Üí Unidentified states requiring further characterization

This annotation facilitates downstream analysis and visualization by providing biologically meaningful labels for each cluster.

In [None]:
# ============================================================
# CREATE CELL_TYPE COLUMN FROM SOMEN_CLUSTERS
# ============================================================
print('='*60)
print('CREATING CELL_TYPE ANNOTATION')
print('='*60)

# Verify somen_clusters exists
if 'somen_clusters' not in adata.obs:
    raise ValueError("Column 'somen_clusters' not found in adata.obs. Run somen_clusters creation first.")

# Mapping: somen_clusters -> cell_type
cluster_to_celltype = {
    'C0': 'Homeostatic-1',
    'C1': 'Homeostatic-2',
    'C2': 'Homeostatic-3',
    'C3': 'Universal Activated Myeloid State',
    'C4': 'Phagocytic',
    'C5': 'Interferon Responsive',
    'C6': 'Intermediate State',
    'C7': 'Unidentified-1',
    'C8': 'Proliferating',
    'C9': 'Unidentified-2'
}

# Get all somen_clusters
all_somen_clusters = adata.obs['somen_clusters'].cat.categories
print(f'\nAvailable somen_clusters: {list(all_somen_clusters)}')

# Check if all clusters have mappings
missing_mappings = [c for c in all_somen_clusters if c not in cluster_to_celltype]
if missing_mappings:
    raise ValueError(f"Missing cell_type mappings for clusters: {missing_mappings}. "
                    f"Please add mappings for all clusters before proceeding.")

# Create cell_type column by mapping
print('\nMapping somen_clusters to cell_type:')
adata.obs['cell_type'] = adata.obs['somen_clusters'].map(cluster_to_celltype)

# Verify no missing values
if adata.obs['cell_type'].isna().any():
    unmapped_clusters = adata.obs[adata.obs['cell_type'].isna()]['somen_clusters'].unique()
    raise ValueError(f"Some cells remain unmapped. Missing mappings for clusters: {list(unmapped_clusters)}")

for cluster in all_somen_clusters:
    if cluster in cluster_to_celltype:
        mask = adata.obs['somen_clusters'] == cluster
        n_cells = mask.sum()
        if n_cells > 0:
            print(f'  {cluster} ‚Üí {cluster_to_celltype[cluster]}: {n_cells:,} cells')
        else:
            print(f'  {cluster} ‚Üí {cluster_to_celltype[cluster]}: 0 cells (not found)')

# Convert to categorical with meaningful order
cell_type_order = [
    'Homeostatic-1', 'Homeostatic-2', 'Homeostatic-3',
    'Intermediate State',
    'Universal Activated Myeloid State',
    'Phagocytic',
    'Interferon Responsive',
    'Proliferating',
    'Unidentified-1', 'Unidentified-2'
]

# Only include cell types that actually exist in the data
existing_cell_types = [ct for ct in cell_type_order if ct in adata.obs['cell_type'].unique()]
adata.obs['cell_type'] = pd.Categorical(
    adata.obs['cell_type'],
    categories=existing_cell_types,
    ordered=True
)

# Summary
cell_type_counts = adata.obs['cell_type'].value_counts().sort_index()
print('\n' + '='*60)
print('CELL_TYPE SUMMARY')
print('='*60)
print(f'\nTotal cell types: {len(cell_type_counts)}')
print('\nCell counts per cell_type:')
for cell_type, count in cell_type_counts.items():
    print(f'  {cell_type}: {count:,} cells')

# Verify all cells are mapped
unmapped_count = adata.obs['cell_type'].isna().sum()
if unmapped_count > 0:
    raise ValueError(f"ERROR: {unmapped_count} cells remain unmapped!")
else:
    print(f'\n‚úì All {adata.n_obs:,} cells successfully mapped to cell_type')

print(f'\n‚úì Created new column: adata.obs["cell_type"]')
print(f'  Categories (ordered): {list(adata.obs["cell_type"].cat.categories)}')
print('='*60)

# ============================================================
# VISUALIZE CELL_TYPE ON UMAP
# ============================================================
print('\n' + '='*60)
print('GENERATING UMAP VISUALIZATION (cell_type)')
print('='*60)

# Visualization settings
point_size = 10
fig_width = 6
fig_height = 4.5
export_dpi = 600
export_bbox = 'tight'
export_pad = 0.1

configure_plot_style()

# Set up dittoSeq palette for cell types
n_cell_types = len(adata.obs['cell_type'].cat.categories)
custom_palette = get_dittoseq_colors(n_cell_types)
adata.uns['cell_type_colors'] = custom_palette

# Create figure with high DPI
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=export_dpi)

# Plot UMAP with legend on the side
sc.pl.umap(
    adata,
    color='cell_type',
    ax=ax,
    show=False,
    title='Cell Type Annotation',
    legend_loc='right margin',
    legend_fontsize=9,
    legend_fontweight='normal',
    frameon=False,
    size=point_size,
    palette=adata.uns['cell_type_colors']
)

# Add axis labels
ax.set_xlabel('UMAP-1', fontsize=12)
ax.set_ylabel('UMAP-2', fontsize=12)

# Configure ticks
ax.tick_params(axis='both', which='major',
               bottom=True, left=True,
               labelbottom=True, labelleft=True,
               labelsize=10, length=4, width=0.8)
ax.tick_params(axis='both', which='minor',
               bottom=True, left=True, length=2, width=0.6)

plt.tight_layout()

# Save high-quality plots
png_path = output_dir / 'png' / 'cell_type_umap.png'
pdf_path = output_dir / 'pdf' / 'cell_type_umap.pdf'

plt.savefig(png_path, dpi=export_dpi, bbox_inches=export_bbox, pad_inches=export_pad)
plt.savefig(pdf_path, bbox_inches=export_bbox, pad_inches=export_pad, format='pdf')

print(f'  Saved: {png_path.name} (DPI={export_dpi})')
print(f'  Saved: {pdf_path.name} (vector axes/text)')

show_inline_plot()
plt.close()

print(f'\n‚úì Cell type visualization complete')
print('='*60)

### Marker Gene Identification

Identify marker genes that are differentially expressed in each cell type compared to all other cell types. This analysis:

- Uses statistical tests (t-test or Wilcoxon) to compare expression between each cell type and all others
- Identifies genes with significantly higher expression in specific cell types
- Ranks genes by their statistical significance and fold change
- Helps characterize the biological identity and function of each cell type

The results can be used to:
- Validate cell type annotations
- Identify novel markers for specific cell states
- Understand transcriptional programs underlying each cell type

In [None]:
# ============================================================
# MARKER GENE IDENTIFICATION - CELL_TYPE (with pct.1/pct.2)
# ============================================================

import pandas as pd
import numpy as np

# ============================================================
# CONFIGURATION
# ============================================================
cluster_key = 'cell_type'  # Use cell_type
n_top_genes = 100                            # Number of top genes to export per cluster
method = 'wilcoxon'                         # Method: 'wilcoxon', 't-test', or 'logreg'
filter_by_fc = True                         # Filter by fold-change
min_fold_change = 0.2                       # Minimum log fold-change (log scale)
max_pval = 0.05                             # Maximum adjusted p-value

# Additional Seurat-style filters (optional)
min_pct = 0.1                               # Minimum pct.1 (expression in cluster)
min_diff_pct = 0.1                          # Minimum difference between pct.1 and pct.2

# Expression threshold for calculating percentages
expr_threshold = 0.0                        # Count cells with expression > this value

print("="*80)
print(f"MARKER GENE IDENTIFICATION BY CELL_TYPE (Seurat-style)")
print("="*80)
print(f"Clustering: {cluster_key}")
print(f"Method: {method}")
print(f"Expression threshold for pct calculations: > {expr_threshold}")
print(f"\nFilters:")
print(f"  - avg_log2FC > {min_fold_change}")
print(f"  - p_val_adj < {max_pval}")
print(f"  - pct.1 > {min_pct}")
print(f"  - (pct.1 - pct.2) > {min_diff_pct}")
print("="*80)

# Verify clustering exists
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering '{cluster_key}' not found in adata.obs")

# ============================================================
# RUN DIFFERENTIAL EXPRESSION ANALYSIS
# ============================================================
print(f"\nRunning {method} test for all cell types...")

# Run rank_genes_groups
sc.tl.rank_genes_groups(
    adata,
    groupby=cluster_key,
    method=method,
    use_raw=False,
    n_genes=None,
    key_added=f'rank_genes_{cluster_key}',
    tie_correct=True
)

print(f"‚úì Differential expression analysis complete")

# ============================================================
# MANUALLY CALCULATE PCT.1 AND PCT.2
# ============================================================
print("\nCalculating pct.1 and pct.2 for all genes...")

# Get expression matrix
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer:
    expr_matrix = adata.layers[use_layer]
else:
    expr_matrix = adata.X

# Convert to dense if sparse
if hasattr(expr_matrix, 'toarray'):
    expr_matrix_dense = expr_matrix.toarray()
else:
    expr_matrix_dense = expr_matrix

# Get clusters
clusters = adata.obs[cluster_key].cat.categories
cluster_labels = adata.obs[cluster_key].values

# Calculate pct.1 and pct.2 for each cluster and gene
pct_dict = {}

for cluster in clusters:
    print(f"  Calculating percentages for {cluster}...")
    
    # Mask for cells in this cluster vs. others
    in_cluster = cluster_labels == cluster
    out_cluster = ~in_cluster
    
    # For each gene, calculate % cells expressing
    pct_in = (expr_matrix_dense[in_cluster, :] > expr_threshold).mean(axis=0)
    pct_out = (expr_matrix_dense[out_cluster, :] > expr_threshold).mean(axis=0)
    
    # Store as dictionary: {gene_name: (pct.1, pct.2)}
    pct_dict[cluster] = {
        'pct.1': dict(zip(adata.var_names, pct_in)),
        'pct.2': dict(zip(adata.var_names, pct_out))
    }

print("‚úì Percentage calculations complete")

# ============================================================
# EXTRACT RESULTS WITH PCT.1 AND PCT.2
# ============================================================
print("\nExtracting marker genes with Seurat-style columns...")

result_key = f'rank_genes_{cluster_key}'
all_markers = []

for cluster in clusters:
    # Get results for this cluster
    cluster_markers = sc.get.rank_genes_groups_df(
        adata, 
        group=cluster, 
        key=result_key
    )
    
    # Add cluster column
    cluster_markers['cluster'] = cluster
    
    # Rename columns to Seurat style
    column_mapping = {
        'names': 'gene',
        'scores': 'score',
        'logfoldchanges': 'avg_log2FC',
        'pvals': 'p_val',
        'pvals_adj': 'p_val_adj'
    }
    
    cluster_markers = cluster_markers.rename(columns=column_mapping)
    
    # Add pct.1 and pct.2 manually
    cluster_markers['pct.1'] = cluster_markers['gene'].map(pct_dict[cluster]['pct.1'])
    cluster_markers['pct.2'] = cluster_markers['gene'].map(pct_dict[cluster]['pct.2'])
    
    all_markers.append(cluster_markers)

# Combine all clusters
all_markers_df = pd.concat(all_markers, ignore_index=True)

# Calculate pct difference
all_markers_df['pct_diff'] = all_markers_df['pct.1'] - all_markers_df['pct.2']

# Reorder columns for better readability
column_order = ['cluster', 'gene', 'avg_log2FC', 'pct.1', 'pct.2', 'pct_diff', 
                'p_val', 'p_val_adj', 'score']
# Only keep columns that exist
column_order = [col for col in column_order if col in all_markers_df.columns]
all_markers_df = all_markers_df[column_order]

print(f"‚úì Extracted {len(all_markers_df):,} gene-cluster comparisons")
print(f"  Columns: {list(all_markers_df.columns)}")

# Show example to verify pct.1 and pct.2
print("\nExample rows (first cluster, top 3 genes):")
print(all_markers_df.head(3).to_string(index=False))

# ============================================================
# FILTER MARKERS (SEURAT-STYLE)
# ============================================================
print("\n" + "="*80)
print("APPLYING FILTERS")
print("="*80)

filtered_markers = all_markers_df.copy()
n_before = len(filtered_markers)

# Filter by log fold-change
if filter_by_fc and 'avg_log2FC' in filtered_markers.columns:
    filtered_markers = filtered_markers[filtered_markers['avg_log2FC'] > min_fold_change]
    print(f"  After avg_log2FC > {min_fold_change}: {len(filtered_markers):,} genes")

# Filter by adjusted p-value
if 'p_val_adj' in filtered_markers.columns:
    filtered_markers = filtered_markers[filtered_markers['p_val_adj'] < max_pval]
    print(f"  After p_val_adj < {max_pval}: {len(filtered_markers):,} genes")

# Filter by minimum pct.1
filtered_markers = filtered_markers[filtered_markers['pct.1'] > min_pct]
print(f"  After pct.1 > {min_pct}: {len(filtered_markers):,} genes")

# Filter by minimum pct difference
filtered_markers = filtered_markers[filtered_markers['pct_diff'] > min_diff_pct]
print(f"  After pct_diff > {min_diff_pct}: {len(filtered_markers):,} genes")

# Sort by cluster and p_val_adj
filtered_markers = filtered_markers.sort_values(['cluster', 'p_val_adj'])

print(f"\n‚úì Total filtered: {n_before:,} ‚Üí {len(filtered_markers):,} significant markers")

# Get top N genes per cluster
top_markers = filtered_markers.groupby('cluster').head(n_top_genes).reset_index(drop=True)
print(f"‚úì Top {n_top_genes} genes per cluster: {len(top_markers):,} total")

# ============================================================
# SUMMARY STATISTICS
# ============================================================
print("\n" + "="*80)
print("MARKER GENES PER CELL_TYPE")
print("="*80)
for cluster in clusters:
    n_markers = (filtered_markers['cluster'] == cluster).sum()
    print(f"  {cluster}: {n_markers:,} significant markers")

# ============================================================
# EXPORT TO EXCEL
# ============================================================
timestamp = pd.Timestamp.now().strftime('%Y%m%d')
excel_file = f'marker_genes_{cluster_key}_{timestamp}.xlsx'
excel_path = output_dir / excel_file

print(f"\nExporting to Excel: {excel_file}")

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Top N markers per cluster
    top_markers.to_excel(writer, sheet_name=f'Top_{n_top_genes}_per_cluster', index=False)
    
    # Sheet 2: All significant markers
    filtered_markers.to_excel(writer, sheet_name='All_Significant_Markers', index=False)
    
    # Sheet 3: All genes (unfiltered)
    all_markers_df.to_excel(writer, sheet_name='All_Genes_Unfiltered', index=False)
    
    # Sheet 4: Summary per cluster
    summary_list = []
    for cluster in clusters:
        cluster_data = filtered_markers[filtered_markers['cluster'] == cluster]
        summary_list.append({
            'cluster': cluster,
            'n_markers': len(cluster_data),
            'mean_log2FC': cluster_data['avg_log2FC'].mean(),
            'max_log2FC': cluster_data['avg_log2FC'].max(),
            'mean_pct.1': cluster_data['pct.1'].mean(),
            'mean_pct.2': cluster_data['pct.2'].mean(),
            'mean_pct_diff': cluster_data['pct_diff'].mean(),
            'min_p_val_adj': cluster_data['p_val_adj'].min()
        })
    
    summary_df = pd.DataFrame(summary_list)
    summary_df.to_excel(writer, sheet_name='Summary', index=False)
    
    # Sheet 5: Analysis parameters
    params_df = pd.DataFrame({
        'Parameter': ['Clustering', 'Method', 'Min_avg_log2FC', 
                      'Max_p_val_adj', 'Min_pct.1', 'Min_pct_diff',
                      'Expr_threshold', 'Top_N_Genes', 'Total_Clusters', 
                      'Total_Significant_Markers', 'Date'],
        'Value': [cluster_key, method, min_fold_change, 
                  max_pval, min_pct, min_diff_pct, expr_threshold,
                  n_top_genes, len(clusters), 
                  len(filtered_markers), pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')]
    })
    params_df.to_excel(writer, sheet_name='Parameters', index=False)

print(f"‚úì Saved: {excel_file}")

# ============================================================
# DISPLAY TOP MARKERS FOR EACH CLUSTER
# ============================================================
print("\n" + "="*80)
print("TOP 10 MARKER GENES PER CELL_TYPE (Seurat-style)")
print("="*80)

for cluster in clusters:
    cluster_top = top_markers[top_markers['cluster'] == cluster].head(10)
    if len(cluster_top) > 0:
        print(f"\n{cluster}:")
        for idx, row in cluster_top.iterrows():
            print(f"  {row['gene']:15s} | log2FC: {row['avg_log2FC']:5.2f} | "
                  f"pct.1: {row['pct.1']:.1%} | pct.2: {row['pct.2']:.1%} | "
                  f"p_adj: {row['p_val_adj']:.2e}")

print("\n" + "="*80)
print("‚úì Marker gene identification complete!")
print("="*80)

In [None]:
# ============================================================
# MARKER GENE IDENTIFICATION - SOMEN_CLUSTERS (with pct.1/pct.2)
# ============================================================

import pandas as pd
import numpy as np

# ============================================================
# CONFIGURATION
# ============================================================
cluster_key = 'somen_clusters'  # Use somen_clusters
n_top_genes = 100                            # Number of top genes to export per cluster
method = 'wilcoxon'                         # Method: 'wilcoxon', 't-test', or 'logreg'
filter_by_fc = True                         # Filter by fold-change
min_fold_change = 0.2                       # Minimum log fold-change (log scale)
max_pval = 0.05                             # Maximum adjusted p-value

# Additional Seurat-style filters (optional)
min_pct = 0.1                               # Minimum pct.1 (expression in cluster)
min_diff_pct = 0.1                          # Minimum difference between pct.1 and pct.2

# Expression threshold for calculating percentages
expr_threshold = 0.0                        # Count cells with expression > this value

print("="*80)
print(f"MARKER GENE IDENTIFICATION BY SOMEN_CLUSTERS (Seurat-style)")
print("="*80)
print(f"Clustering: {cluster_key}")
print(f"Method: {method}")
print(f"Expression threshold for pct calculations: > {expr_threshold}")
print(f"\nFilters:")
print(f"  - avg_log2FC > {min_fold_change}")
print(f"  - p_val_adj < {max_pval}")
print(f"  - pct.1 > {min_pct}")
print(f"  - (pct.1 - pct.2) > {min_diff_pct}")
print("="*80)

# Verify clustering exists
if cluster_key not in adata.obs:
    raise ValueError(f"Clustering '{cluster_key}' not found in adata.obs")

# ============================================================
# RUN DIFFERENTIAL EXPRESSION ANALYSIS
# ============================================================
print(f"\nRunning {method} test for all somen_clusters...")

# Run rank_genes_groups
sc.tl.rank_genes_groups(
    adata,
    groupby=cluster_key,
    method=method,
    use_raw=False,
    n_genes=None,
    key_added=f'rank_genes_{cluster_key}',
    tie_correct=True
)

print(f"‚úì Differential expression analysis complete")

# ============================================================
# MANUALLY CALCULATE PCT.1 AND PCT.2
# ============================================================
print("\nCalculating pct.1 and pct.2 for all genes...")

# Get expression matrix
use_layer = 'normalized' if 'normalized' in adata.layers else None
if use_layer:
    expr_matrix = adata.layers[use_layer]
else:
    expr_matrix = adata.X

# Convert to dense if sparse
if hasattr(expr_matrix, 'toarray'):
    expr_matrix_dense = expr_matrix.toarray()
else:
    expr_matrix_dense = expr_matrix

# Get clusters
clusters = adata.obs[cluster_key].cat.categories
cluster_labels = adata.obs[cluster_key].values

# Calculate pct.1 and pct.2 for each cluster and gene
pct_dict = {}

for cluster in clusters:
    print(f"  Calculating percentages for {cluster}...")
    
    # Mask for cells in this cluster vs. others
    in_cluster = cluster_labels == cluster
    out_cluster = ~in_cluster
    
    # For each gene, calculate % cells expressing
    pct_in = (expr_matrix_dense[in_cluster, :] > expr_threshold).mean(axis=0)
    pct_out = (expr_matrix_dense[out_cluster, :] > expr_threshold).mean(axis=0)
    
    # Store as dictionary: {gene_name: (pct.1, pct.2)}
    pct_dict[cluster] = {
        'pct.1': dict(zip(adata.var_names, pct_in)),
        'pct.2': dict(zip(adata.var_names, pct_out))
    }

print("‚úì Percentage calculations complete")

# ============================================================
# EXTRACT RESULTS WITH PCT.1 AND PCT.2
# ============================================================
print("\nExtracting marker genes with Seurat-style columns...")

result_key = f'rank_genes_{cluster_key}'
all_markers = []

for cluster in clusters:
    # Get results for this cluster
    cluster_markers = sc.get.rank_genes_groups_df(
        adata, 
        group=cluster, 
        key=result_key
    )
    
    # Add cluster column
    cluster_markers['cluster'] = cluster
    
    # Rename columns to Seurat style
    column_mapping = {
        'names': 'gene',
        'scores': 'score',
        'logfoldchanges': 'avg_log2FC',
        'pvals': 'p_val',
        'pvals_adj': 'p_val_adj'
    }
    
    cluster_markers = cluster_markers.rename(columns=column_mapping)
    
    # Add pct.1 and pct.2 manually
    cluster_markers['pct.1'] = cluster_markers['gene'].map(pct_dict[cluster]['pct.1'])
    cluster_markers['pct.2'] = cluster_markers['gene'].map(pct_dict[cluster]['pct.2'])
    
    all_markers.append(cluster_markers)

# Combine all clusters
all_markers_df = pd.concat(all_markers, ignore_index=True)

# Calculate pct difference
all_markers_df['pct_diff'] = all_markers_df['pct.1'] - all_markers_df['pct.2']

# Reorder columns for better readability
column_order = ['cluster', 'gene', 'avg_log2FC', 'pct.1', 'pct.2', 'pct_diff', 
                'p_val', 'p_val_adj', 'score']
# Only keep columns that exist
column_order = [col for col in column_order if col in all_markers_df.columns]
all_markers_df = all_markers_df[column_order]

print(f"‚úì Extracted {len(all_markers_df):,} gene-cluster comparisons")
print(f"  Columns: {list(all_markers_df.columns)}")

# Show example to verify pct.1 and pct.2
print("\nExample rows (first cluster, top 3 genes):")
print(all_markers_df.head(3).to_string(index=False))

# ============================================================
# FILTER MARKERS (SEURAT-STYLE)
# ============================================================
print("\n" + "="*80)
print("APPLYING FILTERS")
print("="*80)

filtered_markers = all_markers_df.copy()
n_before = len(filtered_markers)

# Filter by log fold-change
if filter_by_fc and 'avg_log2FC' in filtered_markers.columns:
    filtered_markers = filtered_markers[filtered_markers['avg_log2FC'] > min_fold_change]
    print(f"  After avg_log2FC > {min_fold_change}: {len(filtered_markers):,} genes")

# Filter by adjusted p-value
if 'p_val_adj' in filtered_markers.columns:
    filtered_markers = filtered_markers[filtered_markers['p_val_adj'] < max_pval]
    print(f"  After p_val_adj < {max_pval}: {len(filtered_markers):,} genes")

# Filter by minimum pct.1
filtered_markers = filtered_markers[filtered_markers['pct.1'] > min_pct]
print(f"  After pct.1 > {min_pct}: {len(filtered_markers):,} genes")

# Filter by minimum pct difference
filtered_markers = filtered_markers[filtered_markers['pct_diff'] > min_diff_pct]
print(f"  After pct_diff > {min_diff_pct}: {len(filtered_markers):,} genes")

# Sort by cluster and p_val_adj
filtered_markers = filtered_markers.sort_values(['cluster', 'p_val_adj'])

print(f"\n‚úì Total filtered: {n_before:,} ‚Üí {len(filtered_markers):,} significant markers")

# Get top N genes per cluster
top_markers = filtered_markers.groupby('cluster').head(n_top_genes).reset_index(drop=True)
print(f"‚úì Top {n_top_genes} genes per cluster: {len(top_markers):,} total")

# ============================================================
# SUMMARY STATISTICS
# ============================================================
print("\n" + "="*80)
print("MARKER GENES PER SOMEN_CLUSTER")
print("="*80)
for cluster in clusters:
    n_markers = (filtered_markers['cluster'] == cluster).sum()
    print(f"  {cluster}: {n_markers:,} significant markers")

# ============================================================
# EXPORT TO EXCEL
# ============================================================
timestamp = pd.Timestamp.now().strftime('%Y%m%d')
excel_file = f'marker_genes_{cluster_key}_{timestamp}.xlsx'
excel_path = output_dir / excel_file

print(f"\nExporting to Excel: {excel_file}")

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Top N markers per cluster
    top_markers.to_excel(writer, sheet_name=f'Top_{n_top_genes}_per_cluster', index=False)
    
    # Sheet 2: All significant markers
    filtered_markers.to_excel(writer, sheet_name='All_Significant_Markers', index=False)
    
    # Sheet 3: All genes (unfiltered)
    all_markers_df.to_excel(writer, sheet_name='All_Genes_Unfiltered', index=False)
    
    # Sheet 4: Summary per cluster
    summary_list = []
    for cluster in clusters:
        cluster_data = filtered_markers[filtered_markers['cluster'] == cluster]
        summary_list.append({
            'cluster': cluster,
            'n_markers': len(cluster_data),
            'mean_log2FC': cluster_data['avg_log2FC'].mean(),
            'max_log2FC': cluster_data['avg_log2FC'].max(),
            'mean_pct.1': cluster_data['pct.1'].mean(),
            'mean_pct.2': cluster_data['pct.2'].mean(),
            'mean_pct_diff': cluster_data['pct_diff'].mean(),
            'min_p_val_adj': cluster_data['p_val_adj'].min()
        })
    
    summary_df = pd.DataFrame(summary_list)
    summary_df.to_excel(writer, sheet_name='Summary', index=False)
    
    # Sheet 5: Analysis parameters
    params_df = pd.DataFrame({
        'Parameter': ['Clustering', 'Method', 'Min_avg_log2FC', 
                      'Max_p_val_adj', 'Min_pct.1', 'Min_pct_diff',
                      'Expr_threshold', 'Top_N_Genes', 'Total_Clusters', 
                      'Total_Significant_Markers', 'Date'],
        'Value': [cluster_key, method, min_fold_change, 
                  max_pval, min_pct, min_diff_pct, expr_threshold,
                  n_top_genes, len(clusters), 
                  len(filtered_markers), pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')]
    })
    params_df.to_excel(writer, sheet_name='Parameters', index=False)

print(f"‚úì Saved: {excel_file}")

# ============================================================
# DISPLAY TOP MARKERS FOR EACH CLUSTER
# ============================================================
print("\n" + "="*80)
print("TOP 10 MARKER GENES PER SOMEN_CLUSTER (Seurat-style)")
print("="*80)

for cluster in clusters:
    cluster_top = top_markers[top_markers['cluster'] == cluster].head(10)
    if len(cluster_top) > 0:
        print(f"\n{cluster}:")
        for idx, row in cluster_top.iterrows():
            print(f"  {row['gene']:15s} | log2FC: {row['avg_log2FC']:5.2f} | "
                  f"pct.1: {row['pct.1']:.1%} | pct.2: {row['pct.2']:.1%} | "
                  f"p_adj: {row['p_val_adj']:.2e}")

print("\n" + "="*80)
print("‚úì Marker gene identification complete!")
print("="*80)

### Save clustered and annotated file

In [None]:
# Ensure inline plotting
%matplotlib inline

# ============================================================
# SAVE ANNOTATED ANNDATA OBJECT
# ============================================================

print("="*80)
print("SAVING ANNOTATED ANNDATA OBJECT")
print("="*80)

# Print available metadata columns
print(f"\nMetadata columns ({len(adata.obs.columns)}):")
for col in adata.obs.columns:
    print(f"  - {col}")

# Define output filename
output_filename = 'clustered_annotated_adata.h5ad'
output_path = output_dir / output_filename

print(f"\nTotal cells: {adata.n_obs:,}")
print(f"Total genes: {adata.n_vars:,}")
print(f"Saving to: {output_path}")

# Save the AnnData object
adata.write_h5ad(output_path, compression='gzip')

print(f"\n‚úì Successfully saved: {output_filename}")
print(f"  File size: {output_path.stat().st_size / 1024**2:.2f} MB")
print("="*80)

## Session Information

Document the computational environment and package versions used in this analysis for reproducibility.

**Key Information:**
- System and Python version
- Package versions for single-cell analysis
- Current dataset dimensions
- Analysis timestamp

In [None]:

# ============================================================
# SESSION INFORMATION
# ============================================================

import sys
import platform
import datetime

print('='*60)
print('SESSION INFORMATION')
print('='*60)

# System
print(f'\nSystem: {platform.system()} {platform.release()}')
print(f'Python: {sys.version.split()[0]}')

# Key packages
print('\nPackages:')
packages = ['anndata', 'scanpy', 'numpy', 'pandas', 'scipy', 
            'matplotlib', 'seaborn', 'scvi', 'igraph', 'leidenalg']

for pkg in packages:
    try:
        if pkg == 'scvi':
            import scvi
            p = scvi
        else:
            p = __import__(pkg)
        ver = p.__version__ if hasattr(p, '__version__') else 'unknown'
        print(f'  {pkg:12s} {ver}')
    except ImportError:
        pass

# Dataset
if 'adata' in globals():
    print(f'\nDataset: {adata.n_obs:,} cells √ó {adata.n_vars:,} genes')

# Date
print(f'\nDate: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M")}')
print('='*60)