## scRNA-seq (10x) preprocessing with Scanpy (+ optional HashSolo)
Generic pipeline to prepare 10x GEX data for QC and filtering, with optional HTO demultiplexing.

### References
- [Scanpy: clustering basics](https://scanpy.readthedocs.io/en/stable/tutorials/basics/clustering.html)
- [scVI-tools docs](https://docs.scvi-tools.org/en/stable/)
- [Scanpy docs](https://scanpy.readthedocs.io/en/stable/)
- [HashSolo paper](https://www.cell.com/cell-systems/fulltext/S2405-4712(20)30195-2)

### Configuration Parameters
Define library-specific parameters before running the pipeline.


In [None]:
# ==================== CONFIGURATION ====================
# Modify these parameters for your specific library

# Library identification
LIBRARY_NAME = "PBS-CFA"  # Name of your library
BATCH_NAME = "Mistri"    # Batch identifier for future merging
SEX = "female"            # Sex of the library (e.g., "male", "female", "unknown")

# Output prefix (will be added to all plot and output file names)
OUTPUT_PREFIX = "PBS-CFA"

# HashSolo demultiplexing options
USE_HASHSOLO = True  # Set to True if you have HTO data and want to demultiplex
                      # Set to False for GEX-only analysis

# Data paths (relative to the 'inputs' directory)
INPUT_DATA_PATH = "filtered_feature_bc_matrix"  # Path to 10x data folder within inputs/

# ========================================================

print(f"Configuration loaded:")
print(f"  Library: {LIBRARY_NAME}")
print(f"  Batch: {BATCH_NAME}")
print(f"  Sex: {SEX}")
print(f"  Output prefix: {OUTPUT_PREFIX}")
print(f"  Use HashSolo: {USE_HASHSOLO}")
print(f"  Input data: inputs/{INPUT_DATA_PATH}")


### Pipeline Overview
- Configure library parameters (batch, sex, output prefix, hashsolo option)
- Import packages
- Load 10x data (GEX + optional HTO)
- Build `AnnData` with GEX; attach HTO to obsm (if available)
- Run HashSolo demultiplexing (if enabled)
- Add mouse column
- Perform QC filtering
- Save QC-filtered data
- Perform clustering and visualization (for contamination detection)
- Save final clustered data with raw counts preserved

In [None]:
# Imports
import scanpy as sc
import scvi
import pandas as pd
from pandas.api.types import CategoricalDtype
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path
import anndata as ad
import re
from mpl_toolkits.axes_grid1 import make_axes_locatable
import scanpy.external as sce
import sys, pkg_resources, datetime

print("Libraries imported successfully")

# Paths and plotting
data_dir = Path('inputs')
results_dir = Path('outputs')
results_dir.mkdir(parents=True, exist_ok=True)

png_dir = results_dir / "png"
pdf_dir = results_dir / "pdf"
png_dir.mkdir(parents=True, exist_ok=True)
pdf_dir.mkdir(parents=True, exist_ok=True)

# Plot defaults
sc.set_figure_params(dpi=300, figsize=(6, 4))

In [None]:
# Load 10x data
print(f"Loading data from: {INPUT_DATA_PATH}")

# Load 10x dataset (including HTOs if present)
adata_full = sc.read_10x_mtx(
    data_dir / INPUT_DATA_PATH,
    var_names='gene_symbols',
    cache=False,
    gex_only=False
)

# Add metadata
adata_full.obs['library'] = LIBRARY_NAME
adata_full.obs['batch'] = BATCH_NAME
adata_full.obs['sex'] = SEX
adata_full.obs_names_make_unique()

# Inspect feature types
print(f"\nFeature types found:")
print(adata_full.var['feature_types'].value_counts())

# Extract RNA/GEX only
adata = adata_full[:, adata_full.var['feature_types'] == 'Gene Expression'].copy()
print(f"\nGEX data shape: {adata.shape}")

# Check for HTO/Antibody Capture data
has_hto = 'Antibody Capture' in adata_full.var['feature_types'].values

if has_hto:
    hto_subset = adata_full[:, adata_full.var['feature_types'] == 'Antibody Capture']
    adata.obsm['HTO'] = hto_subset.X.toarray() if hasattr(hto_subset.X, "toarray") else hto_subset.X
    adata.uns['HTO_features'] = hto_subset.var_names.to_list()
    print(f"HTO data shape: {adata.obsm['HTO'].shape}")
    print(f"HTO features: {adata.uns['HTO_features']}")
else:
    print("No HTO/Antibody Capture data found")
    if USE_HASHSOLO:
        print("WARNING: USE_HASHSOLO is True but no HTO data available. HashSolo will be skipped.")

print(f"\n‚úÖ Data loaded successfully with {adata.n_obs} cells and {adata.n_vars} genes")

In [None]:
# Quick data inspection
print(f"\n=== Data Inspection ===")

# Inspect raw RNA counts
print(f"\nadata.X type: {type(adata.X)}")
print(f"adata.X shape: {adata.X.shape}")

# Inspect HTO data (if present)
if 'HTO' in adata.obsm:
    print(f"\nHTO shape: {adata.obsm['HTO'].shape}")
    print(f"HTO features: {adata.uns.get('HTO_features', 'Not available')}")
else:
    print("\nNo HTO data in adata.obsm")

# Inspect metadata
print(f"\nadata.obs shape: {adata.obs.shape}")
print("Metadata columns (first 5 rows):")
print(adata.obs.head())

### Demultiplex HTO with HashSolo (Optional)
Run HashSolo demultiplexing if USE_HASHSOLO is True and HTO data is available.

In [None]:
# Run HashSolo if enabled and HTO data is available
if USE_HASHSOLO and has_hto:
    print(f"\n=== Running HashSolo Demultiplexing ===")
    
    # Get HTO counts and features
    hto_counts = adata.obsm['HTO']
    hto_features = adata.uns['HTO_features']
    
    # Create DataFrame
    hto_df = pd.DataFrame(
        hto_counts.toarray() if hasattr(hto_counts, 'toarray') else hto_counts,
        index=adata.obs_names,
        columns=hto_features
    )
    
    # Add HTO columns to obs
    for col in hto_df.columns:
        adata.obs[f"HTO_{col}"] = hto_df.loc[adata.obs_names, col].values
    
    hto_columns = [f"HTO_{col}" for col in hto_features]
    
    # Run HashSolo
    print(f"Running HashSolo...")
    sce.pp.hashsolo(adata, cell_hashing_columns=hto_columns, inplace=True)
    
    # Summarize assignments
    print(f"\nClassification:")
    print(adata.obs['Classification'].value_counts())
    print(f"\nmost_likely_hypothesis:")
    print(adata.obs['most_likely_hypothesis'].value_counts())
    
    print("\n‚úÖ HashSolo demultiplexing completed")
else:
    if not USE_HASHSOLO:
        print("\n‚è≠Ô∏è  HashSolo skipped (USE_HASHSOLO is False)")
    elif not has_hto:
        print("\n‚è≠Ô∏è  HashSolo skipped (No HTO data available)")
    
    # Add placeholder columns for consistency
    adata.obs['Classification'] = 'Not_demultiplexed'
    adata.obs['most_likely_hypothesis'] = 'N/A'

### Add Mouse Column
Create a mouse identifier column by copying the library name. This allows tracking of individual mice, especially useful when merging multiple libraries.


In [None]:
# Add mouse column
adata.obs['mouse'] = adata.obs['Classification'].copy()

print(f"\n‚úÖ Added 'mouse' column")
print(f"Unique mouse IDs:")
print(adata.obs['mouse'].unique())
print(f"\nMouse ID counts:")
print(adata.obs['mouse'].value_counts())

### Cell Cycle Scoring (Mouse Gene Lists)
Compute cell cycle scores using mouse gene lists on normalized copies of the data.

In [None]:
# Cell Cycle Scoring
# Mouse gene lists for cell cycle scoring
s_genes = [
    'Mcm5', 'Pcna', 'Tyms', 'Fen1', 'Mcm2', 'Mcm4', 'Rrm1', 'Ung', 'Gins2', 'Mcm6',
    'Cdca7', 'Dtl', 'Prim1', 'Uhrf1', 'Hells', 'Rfc2', 'Rpa2', 'Nasp', 'Rad51ap1',
    'Gmnn', 'Wdr76', 'Slbp', 'Ccne2', 'Ubr7', 'Pold3', 'Msh2', 'Atad2', 'Rad51',
    'Rrm2', 'Cdc45', 'Cdc6', 'Exo1', 'Tipin', 'Dscc1', 'Blm', 'Casp8ap2', 'Usp1',
    'Clspn', 'Pola1', 'Chaf1b', 'Brip1', 'E2f8'
]

g2m_genes = [
    'Hmgb2', 'Cdk1', 'Nusap1', 'Ube2c', 'Birc5', 'Tpx2', 'Top2a', 'Ndc80', 'Cks2',
    'Nuf2', 'Cks1b', 'Mki67', 'Tmpo', 'Cenpf', 'Tacc3', 'Fam64a', 'Smc4', 'Ccnb2',
    'Ckap2l', 'Ckap2', 'Aurkb', 'Bub1', 'Kif11', 'Anp32e', 'Tubb4b', 'Gtse1',
    'Kif20b', 'Hjurp', 'Cdca3', 'Cdc20', 'Ttk', 'Cdc25c', 'Kif2c', 'Rangap1',
    'Ncapd2', 'Dlgap5', 'Cdca2', 'Cdca8', 'Ect2', 'Kif23', 'Hmmr', 'Aurka',
    'Psrc1', 'Anln', 'Lbr', 'Ckap5', 'Cenpe', 'Ctcf', 'Nek2', 'G2e3', 'Gas2l3',
    'Cbx5', 'Cenpa'
]

# Create normalized copy for cell cycle scoring
adata_temp = adata.copy()
sc.pp.normalize_total(adata_temp, target_sum=1e4)
sc.pp.log1p(adata_temp)

# Score cell cycle
sc.tl.score_genes_cell_cycle(adata_temp, s_genes=s_genes, g2m_genes=g2m_genes)

# Copy scores back to main object
adata.obs['S_score'] = adata_temp.obs['S_score']
adata.obs['G2M_score'] = adata_temp.obs['G2M_score']
adata.obs['phase'] = adata_temp.obs['phase']

print("\n‚úÖ Cell cycle scoring completed")

### Compute QC Metrics
Calculate QC metrics (MT%, ribo%, HB%) for the library.

In [None]:
# Compute QC metrics
# Mouse genes: mt- (not MT-), Rps/Rpl, Hb
adata.var["mt"] = adata.var_names.str.startswith("mt-")  # mitochondrial genes (mouse: mt-)
adata.var["ribo"] = adata.var_names.str.startswith(("Rps", "Rpl"))  # ribosomal genes
adata.var["hb"] = adata.var_names.str.contains("^Hb[abdefgh]")  # hemoglobin genes

sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt", "ribo", "hb"], inplace=True, log1p=False
)

print(f"\n‚úÖ QC metrics computed")
print(f"Data shape: {adata.shape}")
print(f"Library: {LIBRARY_NAME}")
print(f"Batch: {BATCH_NAME}")
print(f"Sex: {SEX}")

### Prefiltered QC Plots
Visualize QC metrics before any filtering to assess data quality.

In [None]:
# QC Violin plots
sc.pl.violin(
    adata,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt", "pct_counts_ribo", "pct_counts_hb"],
    jitter=0.4,
    multi_panel=True,
    show=False
)

# Visualize the violin plots here in the notebook
plt.show()

fig = plt.gcf()
png_path = png_dir / f"{OUTPUT_PREFIX}_prefiltered_qc_violin.png"
pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_prefiltered_qc_violin.pdf"

try:
    fig.savefig(png_path, dpi=600, bbox_inches='tight')
    fig.savefig(pdf_path, dpi=600, bbox_inches='tight')
    print(f"Violin plot saved as:\n- {png_path}\n- {pdf_path}")
except Exception as e:
    raise OSError(f"Failed to save violin plots: {e}")

plt.close(fig)

# QC Scatter plots (standard colorbars)
fig, axes = plt.subplots(1, 2, figsize=(18, 7), dpi=600)

def simple_scatter(adata, x, y, color, ax, title):
    sc.pl.scatter(
        adata,
        x=x,
        y=y,
        color=color,
        ax=ax,
        size=10,
        alpha=1.0,
        show=False,
        color_map='viridis',
        legend_loc='right margin',
        legend_fontoutline=2,
        legend_fontsize='small'
    )
    ax.set_title(title)

# Adjust layout to provide more space for colorbars
plt.tight_layout()
plt.subplots_adjust(right=0.80)

# First scatter
simple_scatter(
    adata,
    x='total_counts',
    y='n_genes_by_counts',
    color='pct_counts_mt',
    ax=axes[0],
    title="n_genes_by_counts vs total_counts"
)

# Second scatter
simple_scatter(
    adata,
    x='pct_counts_mt',
    y='pct_counts_ribo',
    color='pct_counts_hb',
    ax=axes[1],
    title="pct_counts_mt vs pct_counts_ribo"
)

plt.show()

# Save scatter plots
scatter_png_path = png_dir / f"{OUTPUT_PREFIX}_prefiltered_qc_scatter.png"
scatter_pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_prefiltered_qc_scatter.pdf"

try:
    fig.savefig(scatter_png_path, dpi=600, bbox_inches='tight')
    fig.savefig(scatter_pdf_path, dpi=600, bbox_inches='tight')
    print(f"Scatter plots saved as:\n- {scatter_png_path}\n- {scatter_pdf_path}")
except Exception as e:
    raise OSError(f"Failed to save scatter plots: {e}")

plt.close(fig)

### Intermediate Filtering and QC
Apply intermediate filtering thresholds and visualize QC metrics.

In [None]:
# Intermediate filters
cutoffs = {"min_genes": 1000, "max_counts": 40000, "max_pct_mt": 20, "min_pct_ribo": 0.5}
cutoff_text = (
    f"Cutoffs: n_genes_by_counts > {cutoffs['min_genes']} & total_counts < {cutoffs['max_counts']} & "
    f"pct_counts_mt < {cutoffs['max_pct_mt']} & pct_counts_ribo > {cutoffs['min_pct_ribo']}"
)

adata_filtered = adata[
    (adata.obs["n_genes_by_counts"] > cutoffs["min_genes"]) &
    (adata.obs["total_counts"] < cutoffs["max_counts"]) &
    (adata.obs["pct_counts_mt"] < cutoffs["max_pct_mt"]) &
    (adata.obs["pct_counts_ribo"] > cutoffs["min_pct_ribo"])
].copy()
print(f"Filtered down to {adata_filtered.n_obs} cells (from {adata.n_obs}).")

# Violin
sc.pl.violin(
    adata_filtered,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt", "pct_counts_ribo", "pct_counts_hb"],
    jitter=0.4,
    multi_panel=True,
    show=False,
)

# Visualize the violin plots here in the notebook
plt.show()

fig = plt.gcf()
fig.subplots_adjust(bottom=0.15)
fig.text(0.5, 0.01, cutoff_text, ha='center', fontsize=10, style='italic')

violin_png_path = png_dir / f"{OUTPUT_PREFIX}_intermediate_filter_qc_violin.png"
violin_pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_intermediate_filter_qc_violin.pdf"

try:
    fig.savefig(violin_png_path, dpi=600, bbox_inches='tight')
    fig.savefig(violin_pdf_path, dpi=600, bbox_inches='tight')
    print(f"Violin plots saved as:\n- {violin_png_path}\n- {violin_pdf_path}")
except Exception as e:
    raise OSError(f"Failed to save violin plots: {e}")

plt.close(fig)

# Scatter plots (standard colorbars)
fig, axes = plt.subplots(1, 2, figsize=(18, 7), dpi=600)

def simple_scatter(adata, x, y, color, ax, title):
    sc.pl.scatter(
        adata,
        x=x,
        y=y,
        color=color,
        ax=ax,
        size=10,
        alpha=1.0,
        show=False,
        color_map='viridis',
        legend_loc='right margin',
        legend_fontoutline=2,
        legend_fontsize='small'
    )
    ax.set_title(title)

# Adjust layout to provide more space for colorbars
plt.tight_layout()
plt.subplots_adjust(right=0.80, bottom=0.15)

# First scatter
simple_scatter(
    adata_filtered,
    x='total_counts',
    y='n_genes_by_counts',
    color='pct_counts_mt',
    ax=axes[0],
    title="n_genes_by_counts vs total_counts"
)

# Second scatter
simple_scatter(
    adata_filtered,
    x='pct_counts_mt',
    y='pct_counts_ribo',
    color='pct_counts_hb',
    ax=axes[1],
    title="pct_counts_mt vs pct_counts_ribo"
)

# Add cutoff text
fig.text(0.5, 0.01, cutoff_text, ha='center', fontsize=10, style='italic')

plt.show()

# Save scatter plots
scatter_png_path = png_dir / f"{OUTPUT_PREFIX}_intermediate_filter_qc_scatter.png"
scatter_pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_intermediate_filter_qc_scatter.pdf"

try:
    fig.savefig(scatter_png_path, dpi=600, bbox_inches='tight')
    fig.savefig(scatter_pdf_path, dpi=600, bbox_inches='tight')
    print(f"Scatter plots saved as:\n- {scatter_png_path}\n- {scatter_pdf_path}")
except Exception as e:
    raise OSError(f"Failed to save scatter plots: {e}")

plt.close(fig)

### Final Filtering and QC
Apply final filtering thresholds and visualize QC metrics.

In [None]:
# Final filters
cutoffs = {"min_genes": 2000, "max_counts": 30000, "max_pct_mt": 6, "min_pct_ribo": 1.0}
cutoff_text = (
    f"Cutoffs: n_genes_by_counts > {cutoffs['min_genes']} & total_counts < {cutoffs['max_counts']} & "
    f"pct_counts_mt < {cutoffs['max_pct_mt']} & pct_counts_ribo > {cutoffs['min_pct_ribo']}"
)

adata_filtered = adata[
    (adata.obs["n_genes_by_counts"] > cutoffs["min_genes"]) &
    (adata.obs["total_counts"] < cutoffs["max_counts"]) &
    (adata.obs["pct_counts_mt"] < cutoffs["max_pct_mt"]) &
    (adata.obs["pct_counts_ribo"] > cutoffs["min_pct_ribo"])
].copy()
print(f"Filtered down to {adata_filtered.n_obs} cells (from {adata.n_obs}).")

# Violin
sc.pl.violin(
    adata_filtered,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt", "pct_counts_ribo", "pct_counts_hb"],
    jitter=0.4,
    multi_panel=True,
    show=False,
)

# Visualize the violin plots here in the notebook
plt.show()

fig = plt.gcf()
fig.subplots_adjust(bottom=0.15)
fig.text(0.5, 0.01, cutoff_text, ha='center', fontsize=10, style='italic')

violin_png_path = png_dir / f"{OUTPUT_PREFIX}_final_filter_qc_violin.png"
violin_pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_final_filter_qc_violin.pdf"

try:
    fig.savefig(violin_png_path, dpi=600, bbox_inches='tight')
    fig.savefig(violin_pdf_path, dpi=600, bbox_inches='tight')
    print(f"Violin plots saved as:\n- {violin_png_path}\n- {violin_pdf_path}")
except Exception as e:
    raise OSError(f"Failed to save violin plots: {e}")

plt.close(fig)

# Scatter plots (standard colorbars)
fig, axes = plt.subplots(1, 2, figsize=(18, 7), dpi=600)

def simple_scatter(adata, x, y, color, ax, title):
    sc.pl.scatter(
        adata,
        x=x,
        y=y,
        color=color,
        ax=ax,
        size=10,
        alpha=1.0,
        show=False,
        color_map='viridis',
        legend_loc='right margin',
        legend_fontoutline=2,
        legend_fontsize='small'
    )
    ax.set_title(title)

# Adjust layout to provide more space for colorbars
plt.tight_layout()
plt.subplots_adjust(right=0.80, bottom=0.15)

# First scatter
simple_scatter(
    adata_filtered,
    x='total_counts',
    y='n_genes_by_counts',
    color='pct_counts_mt',
    ax=axes[0],
    title="n_genes_by_counts vs total_counts"
)

# Second scatter
simple_scatter(
    adata_filtered,
    x='pct_counts_mt',
    y='pct_counts_ribo',
    color='pct_counts_hb',
    ax=axes[1],
    title="pct_counts_mt vs pct_counts_ribo"
)

# Add cutoff text
fig.text(0.5, 0.01, cutoff_text, ha='center', fontsize=10, style='italic')

plt.show()

# Save scatter plots
scatter_png_path = png_dir / f"{OUTPUT_PREFIX}_final_filter_qc_scatter.png"
scatter_pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_final_filter_qc_scatter.pdf"

try:
    fig.savefig(scatter_png_path, dpi=600, bbox_inches='tight')
    fig.savefig(scatter_pdf_path, dpi=600, bbox_inches='tight')
    print(f"Scatter plots saved as:\n- {scatter_png_path}\n- {scatter_pdf_path}")
except Exception as e:
    raise OSError(f"Failed to save scatter plots: {e}")

plt.close(fig)

### HTO Classification Summary (if applicable)
Review HTO demultiplexing assignments after final filtering (only if HashSolo was run).

In [None]:
# HTO assignment summary (post-filter) - only if HashSolo was run
if USE_HASHSOLO and has_hto:
    print("\n--- HTO classification (filtered) ---")
    print(adata_filtered.obs['Classification'].value_counts())
    print(f"\nUnique classifications: {adata_filtered.obs['Classification'].unique()}")
    print(f"\nmost_likely_hypothesis:")
    print(adata_filtered.obs['most_likely_hypothesis'].value_counts())
    
    # QC by HTO class - Violin plots
    qc_df = adata_filtered.obs[['n_genes_by_counts', 'total_counts', 'Classification']].copy()
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=300)
    
    # First violin plot with dots
    sns.violinplot(data=qc_df, x='Classification', y='n_genes_by_counts', ax=axes[0], 
                   inner='quartile', hue='Classification', legend=False,
                   cut=0, scale='count', bw_adjust=0.5)
    sns.stripplot(data=qc_df, x='Classification', y='n_genes_by_counts', ax=axes[0], 
                  color='black', alpha=0.4, size=1.5, jitter=True)
    axes[0].set_title("n_genes_by_counts by Classification"); axes[0].tick_params(axis='x', rotation=90)
    
    # Second violin plot with dots
    sns.violinplot(data=qc_df, x='Classification', y='total_counts', ax=axes[1], 
                   inner='quartile', hue='Classification', legend=False,
                   cut=0, scale='count', bw_adjust=0.5)
    sns.stripplot(data=qc_df, x='Classification', y='total_counts', ax=axes[1], 
                  color='black', alpha=0.4, size=1.5, jitter=True)
    axes[1].set_title("total_counts by Classification"); axes[1].tick_params(axis='x', rotation=90)
    
    plt.tight_layout()
    fig.savefig(pdf_dir / f"{OUTPUT_PREFIX}_post_demux_qc_violin.pdf", bbox_inches='tight')
    fig.savefig(png_dir / f"{OUTPUT_PREFIX}_post_demux_qc_violin.png", bbox_inches='tight')
    plt.show()
    
    print("Saved post-demux QC violin plots")
    
    # CLR normalize HTO data for ridgeplot
    print("\n--- CLR normalizing HTO data for visualization ---")
    hto_features = adata_filtered.uns['HTO_features']
    hto_columns = [f"HTO_{hto}" for hto in hto_features]
    
    # Extract HTO counts matrix
    hto_matrix = adata_filtered.obs[hto_columns].values
    
    # CLR normalization: log(x / geometric_mean(x))
    # Add pseudocount to avoid log(0)
    hto_matrix_pseudo = hto_matrix + 1
    geometric_mean = np.exp(np.mean(np.log(hto_matrix_pseudo), axis=1, keepdims=True))
    hto_clr = np.log(hto_matrix_pseudo / geometric_mean)
    
    # Add CLR-normalized HTO columns
    for i, hto in enumerate(hto_features):
        adata_filtered.obs[f"HTO_{hto}_CLR"] = hto_clr[:, i]
    
    print("‚úÖ CLR normalization completed")
    
    # Ridgeplot of CLR-normalized HTO expression
    print("\n--- Creating HTO ridgeplot (CLR-normalized) ---")
    
    # Get all classifications sorted (excluding Not_demultiplexed and Negative)
    all_classifications = sorted([c for c in adata_filtered.obs['Classification'].unique() 
                                 if c not in ['Not_demultiplexed', 'Negative']])
    
    # Create ridgeplot layout: 2 columns, 3 rows (total 6 HTOs)
    n_htos = len(hto_features)
    fig, axes = plt.subplots(3, 2, figsize=(14, 10), dpi=300)
    axes = axes.flatten()  # Flatten to 1D for easier iteration
    
    for i, hto in enumerate(hto_features):
        ax = axes[i]
        hto_col_clr = f"HTO_{hto}_CLR"
        
        # Plot all classifications (singlets + doublets, excluding Negative)
        for classification in all_classifications:
            subset = adata_filtered.obs[adata_filtered.obs['Classification'] == classification]
            if len(subset) > 0:
                data = subset[hto_col_clr].values
                # Use different color for doublet
                color = 'red' if classification == 'Doublet' else None
                label = classification
                sns.kdeplot(data=data, ax=ax, label=label, fill=True, alpha=0.5, 
                           linewidth=1.5, color=color if color else None)
        
        ax.set_xlabel(f'{hto} CLR-normalized counts', fontsize=10)
        ax.set_ylabel('Density', fontsize=10)
        ax.set_title(f'Distribution of {hto} (CLR)', fontsize=11, fontweight='bold')
        ax.legend(loc='upper right', fontsize=7)
        ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
    
    plt.tight_layout()
    
    ridge_png_path = png_dir / f"{OUTPUT_PREFIX}_hto_ridgeplot_clr.png"
    ridge_pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_hto_ridgeplot_clr.pdf"
    
    fig.savefig(ridge_png_path, dpi=600, bbox_inches='tight')
    fig.savefig(ridge_pdf_path, dpi=600, bbox_inches='tight')
    print(f"HTO ridgeplot saved as:\n- {ridge_png_path}\n- {ridge_pdf_path}")
    
    plt.show()
    plt.close(fig)
    
else:
    print("\n‚è≠Ô∏è  Skipping HTO classification summary (HashSolo not run)")

In [None]:
# Barplot: Cell counts per HTO classification (singlets only) - only if HashSolo was run
if USE_HASHSOLO and has_hto:
    # Get value counts
    classification_counts = adata_filtered.obs['Classification'].value_counts()
    
    # Get all unique classifications and filter to those that exist (excluding Doublet and Negative)
    all_classifications = [c for c in classification_counts.index.tolist() if c not in ['Doublet', 'Negative', 'Not_demultiplexed']]
    
    # Sort classifications
    singlet_order = sorted(all_classifications)
    
    # Reorder counts (only singlets)
    ordered_counts = classification_counts.reindex(singlet_order)
    
    # Create barplot
    fig, ax = plt.subplots(figsize=(8, 5), dpi=300)
    
    bars = ax.bar(range(len(ordered_counts)), ordered_counts.values, color='#1f77b4')
    
    # Customize plot
    ax.set_xlabel('HTO Classification', fontsize=12, fontweight='bold')
    ax.set_ylabel('Cell Count', fontsize=12, fontweight='bold')
    ax.set_title('Cell Counts per HTO Classification (Singlets Only)', fontsize=14, fontweight='bold')
    ax.set_xticks(range(len(ordered_counts)))
    ax.set_xticklabels(ordered_counts.index, rotation=45, ha='right')
    
    # Add value labels on top of bars
    for i, (bar, count) in enumerate(zip(bars, ordered_counts.values)):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(count)}',
                ha='center', va='bottom', fontsize=9)
    
    # Add grid for better readability
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    
    plt.tight_layout()
    
    # Save the plot
    barplot_png_path = png_dir / f"{OUTPUT_PREFIX}_hto_classification_barplot.png"
    barplot_pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_hto_classification_barplot.pdf"
    
    try:
        fig.savefig(barplot_png_path, dpi=600, bbox_inches='tight')
        fig.savefig(barplot_pdf_path, dpi=600, bbox_inches='tight')
        print(f"Barplot saved as:\n- {barplot_png_path}\n- {barplot_pdf_path}")
    except Exception as e:
        raise OSError(f"Failed to save barplot: {e}")
    
    plt.show()
    plt.close(fig)
    
    # Print summary (singlets only)
    print("\n--- Cell counts per HTO classification (singlets only) ---")
    for label, count in ordered_counts.items():
        print(f"{label}: {count}")
else:
    print("\n‚è≠Ô∏è  Skipping HTO barplot (HashSolo not run)")

### Save QC-Filtered Data
Save the QC-filtered AnnData object before clustering. If HashSolo was run, keep only singlets. This is the clean, quality-controlled dataset with raw counts.

In [None]:
# Determine filtereddataset to save
if USE_HASHSOLO and has_hto:
    # Keep singlets only (drop doublets and negatives)
    all_classifications = adata_filtered.obs['Classification'].unique()
    singlet_labels = [c for c in all_classifications if c not in ['Doublet', 'Negative']]
    
    print(f"Singlet labels: {singlet_labels}")
    
    adata_final = adata_filtered[adata_filtered.obs["Classification"].isin(singlet_labels)].copy()
    print(f"\nFiltered from {adata_filtered.n_obs} to {adata_final.n_obs} singlet cells")
    
    # Save singlets
    output_file = results_dir / f"{OUTPUT_PREFIX}_adata_qc_singlets.h5ad"
    adata_final.write(output_file)
    print(f"\n‚úÖ Saved singlets-only AnnData with {adata_final.n_obs} cells to:\n{output_file}")
else:
    # No HashSolo - just save the QC-filtered data
    adata_final = adata_filtered.copy()
    
    # Save filtered data
    output_file = results_dir / f"{OUTPUT_PREFIX}_adata_qc_filtered.h5ad"
    adata_final.write(output_file)
    print(f"\n‚úÖ Saved QC-filtered AnnData with {adata_final.n_obs} cells to:\n{output_file}")

## Clustering and Visualization for Contamination Detection

### Purpose
Perform quick clustering and visualization to identify and filter out contaminating cells (non-microglia cells). After reviewing the clusters and marker gene expression, you can select only microglia cells for downstream scVI integration.

**Important**: Raw counts are preserved in `adata.layers['counts']` while normalized data is used for visualization only. This ensures compatibility with scVI integration in the next step.


### Store Raw Counts and Normalize for Visualization
Preserve raw counts in a layer, then normalize and log-transform the data for clustering and visualization.


In [None]:
# Store raw counts in a layer for future scVI integration
adata_final.layers['counts'] = adata_final.X.copy()
print(f"‚úì Stored raw counts in adata.layers['counts']")

# Normalize and log-transform for visualization
# This modifies adata.X but raw counts remain in adata.layers['counts']
sc.pp.normalize_total(adata_final, target_sum=1e4)
sc.pp.log1p(adata_final)
print(f"‚úì Normalized to 10,000 counts per cell")
print(f"‚úì Log-transformed (log1p)")

# Verify the data state
max_val = adata_final.X.max() if hasattr(adata_final.X, 'max') else adata_final.X.data.max()
print(f"\n‚úì Normalized X matrix max value: {max_val:.2f}")
print(f"‚úì Raw counts preserved in adata.layers['counts']")
print(f"‚úì X matrix will be used for clustering and visualization only")


### Highly Variable Genes, PCA, and Clustering
Identify highly variable genes, compute PCA, construct neighbor graph, and perform Leiden clustering at multiple resolutions.


In [None]:
# Configuration
n_top_genes = 3000
n_pcs = 50
n_neighbors = 30

# Clustering resolutions to test
clustering_resolutions = [0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0]

print("=" * 60)
print("HIGHLY VARIABLE GENES AND PCA")
print("=" * 60)

# Identify highly variable genes
sc.pp.highly_variable_genes(adata_final, n_top_genes=n_top_genes, flavor='seurat_v3', layer='counts')
print(f"‚úì Identified {adata_final.var['highly_variable'].sum()} highly variable genes")

# Regress out cell cycle effects
print(f"\n‚úì Regressing out cell cycle effects (S_score, G2M_score)...")
sc.pp.regress_out(adata_final, ['S_score', 'G2M_score'])
print(f"‚úì Cell cycle effects regressed out")

# Scale data to unit variance and zero mean
print(f"‚úì Scaling data...")
sc.pp.scale(adata_final, max_value=10)
print(f"‚úì Data scaled (max_value=10)")

# Compute PCA on scaled data using HVGs
sc.tl.pca(adata_final, n_comps=n_pcs, use_highly_variable=True)
print(f"‚úì Computed PCA: {adata_final.obsm['X_pca'].shape}")

print("\n" + "=" * 60)
print("NEIGHBOR GRAPH AND UMAP")
print("=" * 60)

# Compute neighbors
sc.pp.neighbors(adata_final, n_neighbors=n_neighbors, n_pcs=n_pcs)
print(f"‚úì Computed neighbor graph (k={n_neighbors}, n_pcs={n_pcs})")

# Compute UMAP
sc.tl.umap(adata_final)
print(f"‚úì Computed UMAP: {adata_final.obsm['X_umap'].shape}")

print("\n" + "=" * 60)
print("LEIDEN CLUSTERING")
print("=" * 60)

# Perform Leiden clustering at multiple resolutions
for res in clustering_resolutions:
    key = f'leiden_r{res}'
    sc.tl.leiden(adata_final, resolution=res, key_added=key, flavor="igraph", 
                 n_iterations=2, directed=False)
    n_clusters = len(adata_final.obs[key].cat.categories)
    print(f'‚úì Resolution {res:4.1f} -> {n_clusters:3d} clusters (adata.obs["{key}"])')

print("\n" + "=" * 60)

### Visualization: Multi-Resolution Clustering Overview
Visualize clustering results across all resolutions to identify optimal granularity.


In [None]:
# ============================================================
# USER CONFIGURATION
# ============================================================
point_size = 20         # Size of points in UMAP
n_cols_grid = 7         # Number of columns in grid layout
plot_name = 'leiden_clustering_overview'
color_palette = 'husl'  # Color palette for clusters (options: 'husl', 'tab20', 'Set3', etc.)

# Resolutions to visualize (already defined earlier)
# clustering_resolutions = [0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0]

# ============================================================
# GENERATE MULTI-RESOLUTION PLOT
# ============================================================

# Verify UMAP exists
if 'X_umap' not in adata_final.obsm:
    raise ValueError("X_umap not found. Run clustering section first.")

# Get clustering columns
clustering_keys = [f'leiden_r{res}' for res in clustering_resolutions]
missing_keys = [key for key in clustering_keys if key not in adata_final.obs]
if missing_keys:
    raise ValueError(f"Missing clustering results: {missing_keys}")

# Setup grid layout
n_rows = (len(clustering_resolutions) + n_cols_grid - 1) // n_cols_grid
fig, axes = plt.subplots(n_rows, n_cols_grid, 
                         figsize=(3 * n_cols_grid, 3 * n_rows), 
                         sharex=True, sharey=True, dpi=300)
axes = axes.flatten()

# Import for color palette conversion
import matplotlib.colors as mcolors

# Plot each resolution
for i, res in enumerate(clustering_resolutions):
    key = f'leiden_r{res}'
    n_clusters = len(adata_final.obs[key].cat.categories)
    
    # Set up custom color palette for publication quality
    custom_palette = sns.color_palette(color_palette, n_clusters)
    adata_final.uns[f'{key}_colors'] = [mcolors.to_hex(c) for c in custom_palette]
    
    sc.pl.umap(
        adata_final,
        color=key,
        ax=axes[i],
        show=False,
        title=f'Resolution {res} ({n_clusters} clusters)',
        legend_loc='on data',
        frameon=False,
        size=point_size,
        palette=adata_final.uns[f'{key}_colors']
    )

# Hide unused subplots
for j in range(len(clustering_resolutions), len(axes)):
    axes[j].set_visible(False)

plt.suptitle("Leiden Clustering at Multiple Resolutions", y=1.02, fontsize=18)
plt.tight_layout()

# Save plots
fig.savefig(png_dir / f"{OUTPUT_PREFIX}_{plot_name}.png", dpi=300, bbox_inches='tight')
fig.savefig(pdf_dir / f"{OUTPUT_PREFIX}_{plot_name}.pdf", bbox_inches='tight')
print(f"‚úì Saved {plot_name}")

plt.show()
plt.close()

### Selected Resolution Clustering
Visualize clustering results at a specific resolution for detailed inspection. Choose your preferred resolution based on the multi-resolution overview above.

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION - SINGLE RESOLUTION
# ============================================================
selected_resolution = 1.0  # üëà Change this to your desired resolution
point_size = 40            # üëà Size of points in UMAP
plot_width = 5            # üëà Figure width (in inches)
plot_height = 5            # üëà Figure height (in inches)
color_palette = 'husl'     # Color palette for clusters (options: 'husl', 'tab20', 'Set3', etc.)
show_legend_on_data = True # üëà Show cluster numbers on UMAP (True) or in legend box (False)

# ============================================================
# VERIFY CLUSTERING EXISTS
# ============================================================

cluster_key = f'leiden_r{selected_resolution}'

if cluster_key not in adata_final.obs:
    available_leiden = [col for col in adata_final.obs.columns if col.startswith('leiden_r')]
    available_res = [float(col.replace('leiden_r', '')) for col in available_leiden]
    raise ValueError(
        f"Clustering column '{cluster_key}' not found.\n"
        f"Available resolutions: {sorted(available_res)}"
    )

if 'X_umap' not in adata_final.obsm:
    raise ValueError("X_umap not found. Run clustering section first.")

# Get cluster information
n_clusters = len(adata_final.obs[cluster_key].cat.categories)
print(f"‚úì Plotting clustering at resolution {selected_resolution}")
print(f"  Clusters: {n_clusters}")
print(f"  Total cells: {adata_final.n_obs:,}")

# ============================================================
# GENERATE SINGLE RESOLUTION PLOT
# ============================================================

import matplotlib.colors as mcolors

# Set up custom color palette
custom_palette = sns.color_palette(color_palette, n_clusters)
adata_final.uns[f'{cluster_key}_colors'] = [mcolors.to_hex(c) for c in custom_palette]

# Configure plot style
plt.rcParams['figure.dpi'] = 300
plt.rcParams['font.family'] = 'Arial'
sns.set_style('white')

# Create figure
fig, ax = plt.subplots(figsize=(plot_width, plot_height), dpi=300)

# Plot UMAP
sc.pl.umap(
    adata_final,
    color=cluster_key,
    ax=ax,
    show=False,
    title=f'Leiden Clustering (Resolution {selected_resolution}, n={n_clusters} clusters)',
    legend_loc='on data' if show_legend_on_data else 'right margin',
    frameon=False,
    size=point_size,
    palette=adata_final.uns[f'{cluster_key}_colors']
)

plt.tight_layout()

# Save plots with resolution in filename
plot_name = f'leiden_clustering_r{selected_resolution}'
png_path = png_dir / f'{OUTPUT_PREFIX}_{plot_name}.png'
pdf_path = pdf_dir / f'{OUTPUT_PREFIX}_{plot_name}.pdf'

fig.savefig(png_path, dpi=300, bbox_inches='tight')
fig.savefig(pdf_path, bbox_inches='tight')

print(f'\n‚úì Saved: {plot_name}.png/.pdf')
print(f'  PNG: {png_path}')
print(f'  PDF: {pdf_path}')

plt.show()
plt.close()

print(f'\n‚úì Visualization complete')
print('=' * 60)

### Visualization: QC Metrics per cluster

In [None]:
# ============================================================
# USER CONFIGURATION - QC METRICS
# ============================================================
point_size = 20         # Size of points in UMAP
n_cols_grid = 3         # Number of columns in grid layout
plot_name = 'qc_metrics_umap'

# QC variables to visualize
qc_vars = ['n_genes_by_counts', 'total_counts', 'pct_counts_mt', 
           'pct_counts_ribo', 'pct_counts_hb']

# ============================================================
# GENERATE QC METRICS PLOT
# ============================================================

# Verify UMAP exists
if 'X_umap' not in adata_final.obsm:
    raise ValueError("X_umap not found. Run clustering section first.")

# Setup grid layout
n_rows = (len(qc_vars) + n_cols_grid - 1) // n_cols_grid
fig, axes = plt.subplots(n_rows, n_cols_grid, 
                         figsize=(3 * n_cols_grid, 3 * n_rows), 
                         sharex=True, sharey=True, dpi=300)
axes = axes.flatten()

print("‚úì Visualizing QC metrics on UMAP...")

# Plot each QC variable
for i, var in enumerate(qc_vars):
    sc.pl.umap(adata_final, color=var, ax=axes[i], show=False, 
               frameon=False, size=point_size)

# Hide unused subplots
for j in range(len(qc_vars), len(axes)):
    axes[j].set_visible(False)

plt.suptitle("QC Metrics", y=1.02, fontsize=18)
plt.tight_layout()

# Save plots
fig.savefig(png_dir / f"{OUTPUT_PREFIX}_{plot_name}.png", dpi=300, bbox_inches='tight')
fig.savefig(pdf_dir / f"{OUTPUT_PREFIX}_{plot_name}.pdf", bbox_inches='tight')
print(f"‚úì Saved {plot_name}")

plt.show()
plt.close()

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION - MARKER GENES
# ============================================================
# Plot dimensions and style
point_size = 10         # Size of points in UMAP
n_cols_grid = 8         # Number of columns in grid layout
plot_width = 15         # Figure width (in inches)
plot_height = 8        # Figure height (in inches)
plot_name = 'marker_genes_umap'

# Colormap options:
# - 'grey_plasma': Custom grey-to-plasma (zero expression = grey, high = plasma colors)
# - 'viridis', 'plasma', 'magma', 'Blues', 'Reds' (standard matplotlib colormaps)
use_custom_colormap = True  # üëà Set to False to use standard colormap below
standard_colormap = 'viridis'  # Used if use_custom_colormap = False

# Marker genes for different cell types (mouse gene symbols)
marker_genes_dict = {
    # General immune markers
    'General': [
        'Ptprc',  # CD45 - all immune cells
        'Mki67',  # Proliferation marker
        'Itgam'   # CD11b - myeloid cells
    ],
    # Microglia markers
    'Microglia': [
        'Cx3cr1',  # Fractalkine receptor
        'P2ry12',  # Purinergic receptor
        'Tmem119', # Microglia-specific
        'Trem2'    # Homeostatic microglia
    ],
    # Macrophage markers
    'Macrophages/BAMs': [
        'Mrc1',    # CD206
        'Ccr2',    # CCR2
        'Ly6c2',
        'Ms4a7',
        'Pf4'    
    ],
    # T cell markers
    'T cells': [
        'Cd3e',    # CD3 epsilon
        'Cd3d',    # CD3 delta
        'Cd4',     # CD4+ T cells
        'Cd8b1'    # CD8 beta
    ],
    # B cell markers
    'B cells': [
        'Cd79a',   # B cell receptor component
        'Ms4a1',   # CD20
        'Cd19',    # B cell marker
        'Cd79b'    # B cell receptor component
    ],
    # NK cell markers
    'NK cells': [
        'Ncr1',    # NKp46
        'Nkg7',    # NK granule protein
        'Klrb1c',  # NK1.1
        'Klrd1'    # CD94
    ],
    # Neuron markers
    'Neurons': [
        'Rbfox3',  # NeuN
        'Tubb3',   # Beta-3 tubulin
        'Syt1',    # Synaptotagmin
        'Map2'     # Microtubule-associated protein 2
    ],
    # Astrocyte markers
    'Astrocytes': [
        'Gfap',    # Glial fibrillary acidic protein
        'Aldh1l1', # Aldehyde dehydrogenase
        'Aqp4',    # Aquaporin 4
        'S100b'    # Astrocyte marker
    ],
    # Oligodendrocyte markers
    'Oligodendrocytes': [
        'Mbp',     # Myelin basic protein
        'Mog',     # Myelin oligodendrocyte glycoprotein
        'Plp1',    # Proteolipid protein 1
        'Cnp'      # 2',3'-cyclic nucleotide 3'-phosphodiesterase
    ],
    # Endothelial markers
    'Endothelial': [
        'Pecam1',  # CD31
        'Vwf',     # von Willebrand factor
        'Cdh5',    # VE-cadherin
        'Flt1'     # VEGFR1
    ]
}

# ============================================================
# CREATE CUSTOM GREY-PLASMA COLORMAP
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

if use_custom_colormap:
    # Get plasma_r colormap
    plasma_r = plt.cm.plasma_r
    n_colors = 256
    plasma_colors = plasma_r(np.linspace(0, 1, n_colors))
    
    # Replace first 20% with grey gradient
    grey_to_plasma_transition = 0.2
    n_grey_colors = int(n_colors * grey_to_plasma_transition)
    
    # Create grey gradient from light grey to first plasma color
    light_grey = np.array([0.85, 0.85, 0.85, 1.0])
    transition_color = plasma_colors[n_grey_colors]
    grey_gradient = np.linspace(light_grey, transition_color, n_grey_colors)
    
    # Combine grey gradient with plasma_r
    custom_colors = np.vstack([grey_gradient, plasma_colors[n_grey_colors:]])
    colormap = LinearSegmentedColormap.from_list('grey_plasma', custom_colors)
    
    print('‚úì Created custom grey-plasma colormap for gene expression')
else:
    colormap = standard_colormap
    print(f'‚úì Using standard colormap: {colormap}')

# ============================================================
# VERIFY DATA AND LAYERS
# ============================================================

# Verify UMAP exists
if 'X_umap' not in adata_final.obsm:
    raise ValueError("X_umap not found. Run clustering section first.")

# Check data normalization status
print("\n" + "=" * 60)
print("DATA LAYER VERIFICATION")
print("=" * 60)
max_val = adata_final.X.max() if hasattr(adata_final.X, 'max') else adata_final.X.data.max()
print(f"adata_final.X max value: {max_val:.2f}")

# Verify normalized data is being used
if max_val > 50:
    raise ValueError(
        f"adata_final.X appears to contain raw counts (max={max_val:.0f}). "
        "Expected normalized/log-transformed data (max < 20). "
        "Please run the normalization step first."
    )
else:
    print(f"‚úì adata_final.X contains normalized, log-transformed data (max={max_val:.2f})")

# Check available layers
print(f"\nAvailable layers: {list(adata_final.layers.keys())}")
if 'counts' in adata_final.layers:
    print("‚úì Raw counts preserved in adata.layers['counts']")

# Specify which data to use for plotting
use_layer = None  # None = use adata.X (normalized, log-transformed)
print(f"\n‚úì Will use adata_final.X for gene expression visualization")
print(f"  (normalized, log-transformed data)")
print("=" * 60 + "\n")

# ============================================================
# VERIFY GENES
# ============================================================

# Flatten all marker genes into a single list
all_marker_genes = []
for cell_type, genes in marker_genes_dict.items():
    all_marker_genes.extend(genes)

# Check which genes are available in the dataset
available_genes = [g for g in all_marker_genes if g in adata_final.var_names]
missing_genes = [g for g in all_marker_genes if g not in adata_final.var_names]

# Report available markers by cell type
print("=" * 60)
print("MARKER GENE AVAILABILITY")
print("=" * 60)
available_by_type = {}
for cell_type, genes in marker_genes_dict.items():
    available = [g for g in genes if g in adata_final.var_names]
    if available:
        available_by_type[cell_type] = available
        print(f"\n{cell_type}:")
        print(f"  Available: {', '.join(available)}")
        missing_type = [g for g in genes if g not in adata_final.var_names]
        if missing_type:
            print(f"  Missing: {', '.join(missing_type)}")

print(f"\n{'=' * 60}")
print(f"SUMMARY:")
print(f"  Total marker genes defined: {len(all_marker_genes)}")
print(f"  Available in dataset: {len(available_genes)}")
print(f"  Missing from dataset: {len(missing_genes)}")
if missing_genes:
    print(f"  Missing genes: {', '.join(missing_genes)}")
print(f"{'=' * 60}\n")

if not available_genes:
    raise ValueError('No valid marker genes found in dataset for plotting.')

# ============================================================
# CONFIGURE PLOT STYLE
# ============================================================

# Ensure consistent styling
plt.rcParams['figure.dpi'] = 300
plt.rcParams['font.family'] = 'Arial'
sns.set_style('white')

# ============================================================
# GENERATE MARKER GENES UMAP
# ============================================================

print("‚úì Generating marker gene UMAP plots...")

n_genes = len(available_genes)
n_rows = (n_genes + n_cols_grid - 1) // n_cols_grid

# Calculate figure size based on grid
fig, axes = plt.subplots(n_rows, n_cols_grid, 
                         figsize=(plot_width, plot_height),
                         sharex=True, sharey=True, dpi=300)
axes = axes.flatten()

# Plot each marker gene using normalized data
for i, gene in enumerate(available_genes):
    sc.pl.umap(adata_final, 
               color=gene, 
               ax=axes[i], 
               show=False, 
               title=gene, 
               frameon=False, 
               cmap=colormap,  # Use custom grey-plasma or standard colormap
               size=point_size,
               layer=use_layer,  # None = use adata.X (normalized)
               vmin=0)  # Start colormap at 0

# Hide unused subplots
for j in range(len(available_genes), len(axes)):
    axes[j].set_visible(False)

colormap_name = 'Grey-Plasma' if use_custom_colormap else standard_colormap.title()
plt.suptitle(f"Cell Type Marker Genes Expression (Normalized, {colormap_name})", 
             y=1.005, fontsize=22, fontweight='bold')
plt.tight_layout()

# Save plots
png_path = png_dir / f"{OUTPUT_PREFIX}_{plot_name}.png"
pdf_path = pdf_dir / f"{OUTPUT_PREFIX}_{plot_name}.pdf"
fig.savefig(png_path, dpi=300, bbox_inches='tight')
fig.savefig(pdf_path, bbox_inches='tight')
print(f"‚úì Saved: {plot_name}.png/.pdf")
print(f"  PNG: {png_path}")
print(f"  PDF: {pdf_path}")

plt.show()
plt.close()

print(f"\n‚úì Visualization complete - {len(available_genes)} marker genes plotted")
print(f"  Using normalized, log-transformed data from adata_final.X")
print(f"  Colormap: {colormap_name}")
print("=" * 60)

### Marker Gene Dotplot
Generate a comprehensive dotplot showing marker gene expression across all clusters at a selected resolution. The dotplot uses a dendrogram to organize clusters by similarity and displays both mean expression levels and the percentage of cells expressing each gene.

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION - DOTPLOT
# ============================================================
# Select resolution and plot dimensions
selected_res = 1.0  # üëà Change this to any desired resolution
plot_width = 5     # üëà Change figure width (in inches)
plot_height = 10    # üëà Change figure height (in inches)

# ============================================================
# VERIFY DATA AND LAYERS
# ============================================================

# Verify clustering column exists
cluster_key = f'leiden_r{selected_res}'
if cluster_key not in adata_final.obs:
    available_leiden = [col for col in adata_final.obs.columns if col.startswith('leiden_r')]
    raise ValueError(
        f"Clustering column '{cluster_key}' not found in adata.obs.\n"
        f"Available resolutions: {', '.join(available_leiden)}"
    )

print(f"‚úì Using clustering: {cluster_key}")
n_clusters = len(adata_final.obs[cluster_key].cat.categories)
print(f"  Number of clusters: {n_clusters}")

# Check data normalization status
max_val = adata_final.X.max() if hasattr(adata_final.X, 'max') else adata_final.X.data.max()
print(f"\nadata_final.X max value: {max_val:.2f}")

# Verify normalized data
if max_val > 50:
    raise ValueError(
        f"adata_final.X appears to contain raw counts (max={max_val:.0f}). "
        "Expected normalized/log-transformed data."
    )

# Use adata.X for gene expression (normalized, log-transformed)
use_layer = None
print("‚úì Using adata_final.X for gene expression (normalized, log-transformed)")

# ============================================================
# PREPARE GENE LIST
# ============================================================

# Use genes from marker_genes_dict (already defined above)
# Flatten all marker genes into a single list, preserving order
genes = []
for cell_type, gene_list in marker_genes_dict.items():
    genes.extend(gene_list)

# Check available genes
available_genes = [g for g in genes if g in adata_final.var_names]
missing_genes = [g for g in genes if g not in adata_final.var_names]

print(f"\n{'=' * 60}")
print(f"GENE AVAILABILITY FOR DOTPLOT")
print(f"{'=' * 60}")
print(f"  Total genes in list: {len(genes)}")
print(f"  Available in dataset: {len(available_genes)}")
print(f"  Missing from dataset: {len(missing_genes)}")

if missing_genes:
    print(f"\n  Missing genes: {', '.join(missing_genes)}")

if not available_genes:
    raise ValueError('No valid genes for dotplot.')

print(f"{'=' * 60}\n")

# ============================================================
# CREATE CUSTOM GREY-PLASMA COLORMAP
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Get plasma_r colormap
plasma_r = plt.cm.plasma_r
n_colors = 256
plasma_colors = plasma_r(np.linspace(0, 1, n_colors))

# Replace first 20% with grey gradient
grey_to_plasma_transition = 0.2
n_grey_colors = int(n_colors * grey_to_plasma_transition)

# Create grey gradient from light grey to first plasma color
light_grey = np.array([0.85, 0.85, 0.85, 1.0])
transition_color = plasma_colors[n_grey_colors]
grey_gradient = np.linspace(light_grey, transition_color, n_grey_colors)

# Combine grey gradient with plasma_r
custom_colors = np.vstack([grey_gradient, plasma_colors[n_grey_colors:]])
dotplot_colormap = LinearSegmentedColormap.from_list('grey_plasma', custom_colors)

print('‚úì Created custom grey-plasma colormap for dotplot')

# Configure plot style
plt.rcParams['figure.dpi'] = 300
plt.rcParams['font.family'] = 'Arial'
sns.set_style('white')

# ============================================================
# COMPUTE DENDROGRAM FOR CURRENT CLUSTERING
# ============================================================
print(f'\n‚úì Computing dendrogram for {cluster_key}...')
sc.tl.dendrogram(adata_final, groupby=cluster_key)
print('‚úì Dendrogram computed')

# ============================================================
# GENERATE DOTPLOT
# ============================================================

print(f'\n‚úì Generating dotplot...')

fig, ax = plt.subplots(figsize=(plot_width, plot_height), dpi=300)
dotplot = sc.pl.dotplot(
    adata_final,
    var_names=available_genes,
    groupby=cluster_key,
    layer=use_layer,
    dendrogram=True,
    ax=ax,
    return_fig=True,
    dot_max=0.8,
    dot_min=0.05,
    colorbar_title='Mean expression\nin group',
    size_title='% of cells\nexpressing gene',
    cmap=dotplot_colormap,  # Use custom grey-plasma colormap
    swap_axes=True,
    var_group_rotation=90
)

# Make dot borders thinner
main_ax = dotplot.get_axes()['mainplot_ax']
for collection in main_ax.collections:
    collection.set_linewidths(0)  # Remove dot borders for cleaner look

# Force straight cluster names with increased font size
main_ax.tick_params(axis='x', labelsize=12, rotation=0)
for label in main_ax.get_xticklabels():
    label.set_rotation(0)
    label.set_ha('center')

# Ensure ticks are visible on both axes
main_ax.tick_params(axis='both', bottom=True, left=True, labelbottom=True, 
                    labelleft=True, length=4, direction='out')
main_ax.spines['bottom'].set_visible(True)
main_ax.spines['left'].set_visible(True)

plt.suptitle(f'Gene Expression Dotplot (Normalized, r={selected_res})', 
             y=1.02, fontsize=18, fontweight='bold')
plt.tight_layout()

# Save plots
png_path = png_dir / f'{OUTPUT_PREFIX}_marker_dotplot_r{selected_res}.png'
pdf_path = pdf_dir / f'{OUTPUT_PREFIX}_marker_dotplot_r{selected_res}.pdf'
fig.savefig(png_path, dpi=300, bbox_inches='tight')
fig.savefig(pdf_path, bbox_inches='tight')
print(f'\n‚úì Saved: marker_dotplot_r{selected_res}.png/.pdf')
print(f'  PNG: {png_path}')
print(f'  PDF: {pdf_path}')

plt.show()
plt.close()

print(f'\n‚úì Dotplot complete')
print(f'  Resolution: {selected_res}')
print(f'  Clusters: {n_clusters}')
print(f'  Genes plotted: {len(available_genes)}')
print('=' * 60)

### Filter Out Contaminating Cells
Remove non-microglia cell clusters based on marker gene expression analysis. After reviewing the clustering and marker gene plots above, specify which clusters represent contaminating cell 

In [None]:
# Ensure inline plotting in Jupyter
%matplotlib inline

# ============================================================
# USER CONFIGURATION - CELL FILTERING
# ============================================================
# Select resolution used for filtering
filter_resolution = 1.0  # üëà Resolution to use for identifying clusters to remove

# Specify clusters to REMOVE (contaminating cells)
# Example: clusters_to_remove = ['0', '5', '12']  # Adjust based on your marker gene analysis
clusters_to_remove = ['11']  # üëà Add cluster numbers to remove (as strings)

# NOTE: Leave empty [] to skip filtering (keep all cells)

# ============================================================
# VERIFY CLUSTERING AND EXECUTE FILTERING
# ============================================================

if clusters_to_remove:
    # Verify clustering column exists
    cluster_key = f'leiden_r{filter_resolution}'
    if cluster_key not in adata_final.obs:
        available_leiden = [col for col in adata_final.obs.columns if col.startswith('leiden_r')]
        raise ValueError(
            f"Clustering column '{cluster_key}' not found.\n"
            f"Available resolutions: {', '.join(available_leiden)}"
        )
    
    print("=" * 60)
    print("CELL FILTERING")
    print("=" * 60)
    print(f"Using clustering: {cluster_key}")
    print(f"Clusters to REMOVE: {', '.join(clusters_to_remove)}")
    
    # Get cell counts before filtering
    n_cells_before = adata_final.n_obs
    cluster_counts_before = adata_final.obs[cluster_key].value_counts().sort_index()
    
    print(f"\nBefore filtering:")
    print(f"  Total cells: {n_cells_before:,}")
    print(f"  Total clusters: {len(cluster_counts_before)}")
    print(f"\n  Cluster distribution:")
    for cluster, count in cluster_counts_before.items():
        marker = " ‚ùå REMOVE" if cluster in clusters_to_remove else ""
        print(f"    Cluster {cluster}: {count:,} cells{marker}")
    
    # Perform filtering - keep cells NOT in clusters_to_remove
    mask = ~adata_final.obs[cluster_key].isin(clusters_to_remove)
    adata_filtered_microglia = adata_final[mask].copy()
    
    # Get cell counts after filtering
    n_cells_after = adata_filtered_microglia.n_obs
    n_cells_removed = n_cells_before - n_cells_after
    cluster_counts_after = adata_filtered_microglia.obs[cluster_key].value_counts().sort_index()
    
    print(f"\nAfter filtering:")
    print(f"  Total cells: {n_cells_after:,}")
    print(f"  Cells removed: {n_cells_removed:,} ({100*n_cells_removed/n_cells_before:.1f}%)")
    print(f"  Remaining clusters: {len(cluster_counts_after)}")
    print(f"\n  Remaining cluster distribution:")
    for cluster, count in cluster_counts_after.items():
        print(f"    Cluster {cluster}: {count:,} cells")
    
    print("\n" + "=" * 60)
    print("‚úì Filtering complete")
    print("=" * 60 + "\n")
    
else:
    print("=" * 60)
    print("CELL FILTERING")
    print("=" * 60)
    print("‚è≠Ô∏è  No clusters specified for removal (clusters_to_remove is empty)")
    print("   Keeping all cells in adata_final")
    print("=" * 60 + "\n")
    
    # No filtering - just copy
    adata_filtered_microglia = adata_final.copy()

# ============================================================
# VISUALIZE FILTERED DATA
# ============================================================

if clusters_to_remove:
    print("‚úì Generating confirmation UMAP...")
    
    # Configure plot
    point_size = 40
    
    # Create side-by-side comparison
    fig, axes = plt.subplots(1, 2, figsize=(16, 6), dpi=300)
    
    # Before filtering (all cells, highlight removed clusters)
    sc.pl.umap(adata_final, 
               color=cluster_key, 
               ax=axes[0], 
               show=False,
               title=f'Before Filtering (n={adata_final.n_obs:,})',
               frameon=False,
               size=point_size,
               legend_loc='on data')
    
    # After filtering (only microglia)
    sc.pl.umap(adata_filtered_microglia, 
               color=cluster_key, 
               ax=axes[1], 
               show=False,
               title=f'After Filtering (n={adata_filtered_microglia.n_obs:,})',
               frameon=False,
               size=point_size,
               legend_loc='on data')
    
    removed_text = ', '.join(clusters_to_remove)
    plt.suptitle(f'Cell Filtering Confirmation (Removed clusters: {removed_text})', 
                 y=1.02, fontsize=18, fontweight='bold')
    plt.tight_layout()
    
    # Save comparison plot
    png_path = png_dir / f'{OUTPUT_PREFIX}_filtering_confirmation.png'
    pdf_path = pdf_dir / f'{OUTPUT_PREFIX}_filtering_confirmation.pdf'
    fig.savefig(png_path, dpi=300, bbox_inches='tight')
    fig.savefig(pdf_path, bbox_inches='tight')
    print(f'\n‚úì Saved filtering confirmation plot')
    print(f'  PNG: {png_path}')
    print(f'  PDF: {pdf_path}')
    
    plt.show()
    plt.close()
    
    print(f'\n‚úì Visualization complete')
    print('=' * 60)
else:
    print("‚è≠Ô∏è  Skipping visualization (no filtering performed)")

# ============================================================
# VERIFY FILTERED DATA
# ============================================================

print("\n" + "=" * 60)
print("FILTERED DATA SUMMARY")
print("=" * 60)
print(f"Final cell count: {adata_filtered_microglia.n_obs:,}")
print(f"Gene count: {adata_filtered_microglia.n_vars:,}")
print(f"\nData layers:")
print(f"  - adata.X: Normalized, log-transformed")
if 'counts' in adata_filtered_microglia.layers:
    print(f"  - adata.layers['counts']: Raw counts (for scVI)")
print(f"\nMetadata columns:")
print(f"  - library: {adata_filtered_microglia.obs['library'].unique()[0]}")
print(f"  - batch: {adata_filtered_microglia.obs['batch'].unique()[0]}")
print(f"  - mouse: {adata_filtered_microglia.obs['mouse'].unique()[0]}")
print(f"  - sex: {adata_filtered_microglia.obs['sex'].unique()[0]}")
print("=" * 60)

### Save Filtered Microglia Data for Integration
Prepare the filtered microglia dataset for downstream scVI integration by:
1. Restoring raw counts to `adata.X` (from `adata.layers['counts']`)
2. Removing preprocessing artifacts (UMAP, PCA, clustering results)
3. Preserving essential metadata (library, batch, sex, mouse)
4. Saving a clean dataset ready for integration with other samples

In [None]:
# ============================================================
# PREPARE FILTERED DATA FOR INTEGRATION
# ============================================================

print("=" * 60)
print("PREPARING FILTERED DATA FOR INTEGRATION")
print("=" * 60)

# Create a copy for final save
adata_microglia_clean = adata_filtered_microglia.copy()

# ============================================================
# RESTORE RAW COUNTS TO adata.X
# ============================================================

print("\n1. Restoring raw counts to adata.X...")

if 'counts' in adata_microglia_clean.layers:
    # Restore raw counts from layer
    adata_microglia_clean.X = adata_microglia_clean.layers['counts'].copy()
    print(f"   ‚úì Restored raw counts from adata.layers['counts']")
    
    # Verify raw counts
    max_val = adata_microglia_clean.X.max() if hasattr(adata_microglia_clean.X, 'max') else adata_microglia_clean.X.data.max()
    print(f"   ‚úì adata.X max value: {max_val:.0f} (raw counts)")
    
    if max_val < 50:
        print(f"   ‚ö†Ô∏è  Warning: Max value seems low for raw counts. Please verify.")
else:
    raise ValueError("adata.layers['counts'] not found. Cannot restore raw counts!")

# ============================================================
# REMOVE ALL LAYERS (raw counts now in adata.X)
# ============================================================

print("\n2. Removing all layers...")

layer_names = list(adata_microglia_clean.layers.keys())
for layer_name in layer_names:
    del adata_microglia_clean.layers[layer_name]
    
if layer_names:
    print(f"   ‚úì Removed layers: {', '.join(layer_names)}")
else:
    print(f"   ‚úì No layers found")

# ============================================================
# REMOVE ALL OBSM (dimensionality reduction, etc.)
# ============================================================

print("\n3. Removing all obsm entries...")

obsm_keys = list(adata_microglia_clean.obsm.keys())
for key in obsm_keys:
    del adata_microglia_clean.obsm[key]
    
if obsm_keys:
    print(f"   ‚úì Removed obsm entries: {', '.join(obsm_keys)}")
else:
    print(f"   ‚úì No obsm entries found")

# ============================================================
# REMOVE ALL VARM (PCA loadings, etc.)
# ============================================================

print("\n4. Removing all varm entries...")

varm_keys = list(adata_microglia_clean.varm.keys())
for key in varm_keys:
    del adata_microglia_clean.varm[key]
    
if varm_keys:
    print(f"   ‚úì Removed varm entries: {', '.join(varm_keys)}")
else:
    print(f"   ‚úì No varm entries found")

# ============================================================
# REMOVE CLUSTERING AND HVG FROM OBS/VAR
# ============================================================

print("\n5. Removing clustering and HVG annotations...")

artifacts_removed = []

# Remove clustering columns from obs
leiden_cols = [col for col in adata_microglia_clean.obs.columns if col.startswith('leiden_r')]
for col in leiden_cols:
    del adata_microglia_clean.obs[col]
if leiden_cols:
    artifacts_removed.append(f'{len(leiden_cols)} leiden columns from obs')

# Remove HVG-related columns from var
hvg_cols = ['highly_variable', 'means', 'dispersions', 'dispersions_norm', 
            'highly_variable_rank', 'highly_variable_nbatches']
for col in hvg_cols:
    if col in adata_microglia_clean.var.columns:
        del adata_microglia_clean.var[col]
        artifacts_removed.append(f'{col} from var')

if artifacts_removed:
    print(f"   ‚úì Removed: {', '.join(artifacts_removed)}")
else:
    print(f"   ‚úì No clustering/HVG columns found")

# ============================================================
# CLEAN UP UNS (keep only essential metadata)
# ============================================================

print("\n6. Cleaning uns dictionary...")

# Keep only essential uns keys (if any)
essential_uns = []  # Add any essential keys here if needed

uns_keys_to_remove = [key for key in adata_microglia_clean.uns.keys() 
                       if key not in essential_uns]

for key in uns_keys_to_remove:
    del adata_microglia_clean.uns[key]

if uns_keys_to_remove:
    print(f"   ‚úì Removed {len(uns_keys_to_remove)} uns entries: {', '.join(uns_keys_to_remove)}")
else:
    print(f"   ‚úì No uns entries to remove")

# ============================================================
# REMOVE OBSP AND VARP (neighbor graphs, etc.)
# ============================================================

print("\n7. Removing obsp and varp entries...")

# Remove obsp (neighbor connectivities, distances)
obsp_keys = list(adata_microglia_clean.obsp.keys())
for key in obsp_keys:
    del adata_microglia_clean.obsp[key]
if obsp_keys:
    print(f"   ‚úì Removed obsp entries: {', '.join(obsp_keys)}")
else:
    print(f"   ‚úì No obsp entries found")

# Remove varp
varp_keys = list(adata_microglia_clean.varp.keys())
for key in varp_keys:
    del adata_microglia_clean.varp[key]
if varp_keys:
    print(f"   ‚úì Removed varp entries: {', '.join(varp_keys)}")
else:
    print(f"   ‚úì No varp entries found")

# ============================================================
# VERIFY ESSENTIAL METADATA
# ============================================================

print("\n8. Verifying essential metadata columns...")

required_metadata = ['library', 'batch', 'sex', 'mouse']
missing_metadata = [col for col in required_metadata if col not in adata_microglia_clean.obs.columns]

if missing_metadata:
    raise ValueError(f"Missing required metadata columns: {', '.join(missing_metadata)}")

print(f"   ‚úì All required metadata columns present:")
for col in required_metadata:
    unique_vals = adata_microglia_clean.obs[col].unique()
    if len(unique_vals) == 1:
        print(f"     - {col}: {unique_vals[0]}")
    else:
        print(f"     - {col}: {len(unique_vals)} unique values")

# ============================================================
# VERIFY QC METRICS ARE PRESERVED
# ============================================================

print("\n9. Verifying QC metrics...")

qc_columns = ['n_genes_by_counts', 'total_counts', 'pct_counts_mt', 
              'pct_counts_ribo', 'pct_counts_hb']
present_qc = [col for col in qc_columns if col in adata_microglia_clean.obs.columns]

if present_qc:
    print(f"   ‚úì QC metrics preserved: {', '.join(present_qc)}")
else:
    print(f"   ‚ö†Ô∏è  Warning: No QC metrics found in obs")

# ============================================================
# VERIFY CELL CYCLE SCORES ARE PRESERVED
# ============================================================

print("\n10. Verifying cell cycle scores...")

cc_columns = ['S_score', 'G2M_score', 'phase']
present_cc = [col for col in cc_columns if col in adata_microglia_clean.obs.columns]

if present_cc:
    print(f"   ‚úì Cell cycle scores preserved: {', '.join(present_cc)}")
else:
    print(f"   ‚ö†Ô∏è  Warning: No cell cycle scores found in obs")

# ============================================================
# FINAL DATA SUMMARY
# ============================================================

print("\n" + "=" * 60)
print("CLEAN DATA SUMMARY")
print("=" * 60)
print(f"Cells: {adata_microglia_clean.n_obs:,}")
print(f"Genes: {adata_microglia_clean.n_vars:,}")

print(f"\nadata.X:")
print(f"  Type: {type(adata_microglia_clean.X).__name__}")
print(f"  Shape: {adata_microglia_clean.X.shape}")
print(f"  Data: Raw counts ONLY (for scVI integration)")

print(f"\nLayers: {list(adata_microglia_clean.layers.keys())} (should be empty)")
print(f"Obsm keys: {list(adata_microglia_clean.obsm.keys())} (should be empty)")
print(f"Varm keys: {list(adata_microglia_clean.varm.keys())} (should be empty)")
print(f"Obsp keys: {list(adata_microglia_clean.obsp.keys())} (should be empty)")
print(f"Varp keys: {list(adata_microglia_clean.varp.keys())} (should be empty)")
print(f"Uns keys: {list(adata_microglia_clean.uns.keys())} (should be minimal/empty)")

print(f"\nMetadata (obs) columns: {len(adata_microglia_clean.obs.columns)}")
print(f"  Key columns: {', '.join(required_metadata)}")
print(f"  All obs columns: {list(adata_microglia_clean.obs.columns)}")

print(f"\nGene metadata (var) columns: {len(adata_microglia_clean.var.columns)}")
print(f"  All var columns: {list(adata_microglia_clean.var.columns)}")

print("=" * 60)

# ============================================================
# SAVE CLEAN MICROGLIA DATA
# ============================================================

print("\n" + "=" * 60)
print("SAVING CLEAN MICROGLIA DATA")
print("=" * 60)

# Generate filename
output_filename = f"{OUTPUT_PREFIX}_QC_filtered_microglia_only.h5ad"
output_path = results_dir / output_filename

# Save the data
adata_microglia_clean.write(output_path)

print(f"\n‚úì Saved filtered microglia data:")
print(f"  File: {output_path}")
print(f"  Cells: {adata_microglia_clean.n_obs:,}")
print(f"  Genes: {adata_microglia_clean.n_vars:,}")
print(f"  Data: Raw counts ONLY in adata.X")
print(f"  Ready for: scVI integration with other samples")

print("\n" + "=" * 60)
print("‚úì DATA PREPARATION COMPLETE")
print("=" * 60)

# ============================================================
# QUICK VERIFICATION CHECK
# ============================================================

print("\nüìã Quick Verification Checklist:")
print(f"  ‚úì Raw counts in adata.X")
print(f"  ‚úì All layers removed")
print(f"  ‚úì No UMAP/PCA (obsm empty)")
print(f"  ‚úì No neighbor graphs (obsp empty)")
print(f"  ‚úì No HVG annotations")
print(f"  ‚úì No clustering results")
print(f"  ‚úì Metadata columns preserved: {', '.join(required_metadata)}")
print(f"  ‚úì QC metrics preserved")
print(f"  ‚úì Cell cycle scores preserved")
print(f"  ‚úì Ready for integration")

print("\nüéØ Next Steps:")
print("  1. Process other samples using this same pipeline")
print("  2. Load all filtered samples")
print("  3. Concatenate datasets")
print("  4. Run scVI integration")
print("  5. Perform downstream analysis")

print("\nüíæ File size should be minimal (only raw counts + metadata)")
print("=" * 60)

### Session Info and Outputs
Record environment, key package versions, saved outputs, and timestamp.


In [None]:
# Session info
venv_name = "scRNAseq-scVI"  # adjust if needed
print(f"Venv: {venv_name}")
print(f"Python: {sys.version.split()[0]}")
print("Packages:")
for pkg in ['anndata', 'scanpy', 'scvi-tools', 'mudata', 'muon']:
    print(f"{pkg}: {pkg_resources.get_distribution(pkg).version}")
print(f"Completed: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
