# Single-Cell RNA-seq Analysis with Scanpy & scvi-tools

This notebook provides a comprehensive workflow for analyzing single-cell RNA-seq data using Scanpy and scvi-tools. The pipeline loads multiple .h5ad files, concatenates them, and applies scVI for batch correction and covariate adjustment.

**Key Features**:
- Automatic loading and concatenation of all .h5ad files from inputs folder
- Batch correction using scVI (batch_key='batch')
- Categorical covariate: 'sample'
- Continuous covariates: Cell cycle scores ('S.Phase', 'G2M.Phase')
- Support for RNA and HTO data
- Raw data preservation for reproducibility
- Automated visualization generation
- Generic design for reuse across different datasets

**Outputs**:
- Integrated AnnData objects (HVG subset and full genes)
- Training loss visualizations
- Comprehensive data structure documentation

## Setup and Configuration

**Purpose**: Import required libraries and configure analysis parameters. Set up directory structure and plotting preferences.

In [None]:
# === CONFIGURATION PARAMETERS ===
# Modify these parameters for your specific dataset

# Directory configuration
INPUT_DIR = 'inputs'  # Directory containing input .h5ad files
OUTPUT_DIR = 'outputs'  # Directory for output files

# Analysis parameters
N_TOP_GENES = 3000  # Number of highly variable genes to select
MAX_EPOCHS = 400  # Maximum training epochs for totalVI
BATCH_SIZE = 256  # Batch size for training (adjust based on available memory)
EARLY_STOPPING_PATIENCE = 20  # Epochs to wait before early stopping

# HTO features (will be extracted from data)
HTO_FEATURES = None

# === JUPYTER NOTEBOOK CONFIGURATION ===
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# === LIBRARY IMPORTS ===
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import scvi
import mudata as md
import torch
import muon
import sys, pkg_resources, datetime
from pathlib import Path
from IPython.display import display, HTML

# === DIRECTORY SETUP ===
input_dir = Path(INPUT_DIR)
output_dir = Path(OUTPUT_DIR)
(output_dir / "png").mkdir(parents=True, exist_ok=True)
(output_dir / "pdf").mkdir(parents=True, exist_ok=True)

# === SCANPY CONFIGURATION ===
sc.set_figure_params(dpi=300, figsize=(6, 4))

# === PLOTTING UTILITIES ===
def configure_plot_style():
    """Configure matplotlib and seaborn styles for consistent plotting."""
    plt.rcParams['figure.dpi'] = 100
    plt.rcParams['font.family'] = 'Arial'
    sns.set_style('white')

def show_inline_plot(fig=None):
    """Display plot inline in Jupyter notebook."""
    plt.tight_layout()
    display(HTML('<div style="display: flex; justify-content: center;">'))
    if fig is None:
        plt.show()
    else:
        display(fig)
    display(HTML('</div>'))

def save_plot(name):
    """Save current plot as both PNG and PDF formats."""
    plt.tight_layout()
    for fmt in ['png', 'pdf']:
        path = output_dir / fmt / f"{name}.{fmt}"
        if fmt == 'png':
            plt.savefig(path, dpi=300, bbox_inches='tight')
        else:
            plt.savefig(path, bbox_inches='tight')
    print(f"Saved: {name}.png/.pdf")
    plt.close()

## Data Loading

**Purpose**: Load single-cell datasets from input directory and add batch metadata for downstream integration.

In [None]:
# === LOAD SPECIFIC .H5AD FILES ===
print("=== Loading Individual Datasets ===\n")

# Define file names and their corresponding groups
files_to_load = [
    ('PBS-CFA_QC_filtered_microglia_only.h5ad', 'PBS-CFA'),
    ('PEP15-CFA_QC_filtered_microglia_only.h5ad', 'PEP15-CFA')
]

# Load each file as a separate adata object

adata1 = sc.read_h5ad(input_dir / files_to_load[0][0])
print(f"adata1 (PBS-CFA): {adata1.shape[0]:,} cells, {adata1.shape[1]:,} genes")

adata2 = sc.read_h5ad(input_dir / files_to_load[1][0])
print(f"adata2 (PEP15-CFA): {adata2.shape[0]:,} cells, {adata2.shape[1]:,} genes")

# === CONCATENATE ALL DATASETS ===
print(f"\n=== Joining Datasets ===")
adata_list = [adata1, adata2]
adata = ad.concat(adata_list, join='outer', merge='same', index_unique='_')

print(f"Total cells: {adata.shape[0]:,}")
print(f"Total genes: {adata.shape[1]:,}")

# === CREATE GROUP COLUMN ===
print(f"\n=== Creating Group Column ===")

# Create a mapping from library name to group
# First, we need to identify which cells belong to which original dataset
# The concatenation adds suffixes to the cell names, so we can use the batch indices

# Get the library/mouse names from the original datasets
if 'library' in adata.obs.columns:
    # Use the library column to determine the group
    library_to_group = {}
    
    # Check unique library values to determine mapping
    for idx, (filename, group) in enumerate(files_to_load):
        # Get the library name from the original adata object
        temp_lib = adata_list[idx].obs['library'].unique()[0] if 'library' in adata_list[idx].obs.columns else f"unknown_{idx}"
        library_to_group[temp_lib] = group
    
    # Create the group column based on library
    adata.obs['group'] = adata.obs['library'].map(library_to_group)
    
elif 'mouse' in adata.obs.columns:
    # Alternative: use mouse column
    mouse_to_group = {}
    for idx, (filename, group) in enumerate(files_to_load):
        temp_mouse = adata_list[idx].obs['mouse'].unique()[0] if 'mouse' in adata_list[idx].obs.columns else f"unknown_{idx}"
        mouse_to_group[temp_mouse] = group
    
    adata.obs['group'] = adata.obs['mouse'].map(mouse_to_group)
else:
    # If no library or mouse column, assign based on batch index
    # Create group column based on which dataset the cells came from
    group_list = []
    start_idx = 0
    for idx, temp_adata in enumerate(adata_list):
        n_cells = temp_adata.shape[0]
        group_list.extend([files_to_load[idx][1]] * n_cells)
        start_idx += n_cells
    
    adata.obs['group'] = group_list

print(f"Group assignments:")
print(adata.obs['group'].value_counts())

# Display metadata summary
print(f"\n=== Metadata Summary ===")
if 'batch' in adata.obs.columns:
    print(f"\nBatch distribution:")
    print(adata.obs['batch'].value_counts())

if 'library' in adata.obs.columns:
    print(f"\nLibrary distribution:")
    print(adata.obs['library'].value_counts())

if 'mouse' in adata.obs.columns:
    print(f"\nMouse distribution:")
    print(adata.obs['mouse'].value_counts())

if 'sex' in adata.obs.columns:
    print(f"\nSex distribution:")
    print(adata.obs['sex'].value_counts())

print("\n✅ Data loading and group assignment complete!")

## Remove Outlier Sample

**Purpose**: Remove cells from the outlier sample 'HTO_PBS-4' from the mouse column. This sample showed abnormal characteristics and should be excluded from downstream analysis.

In [None]:
# === REMOVE OUTLIER SAMPLE ===
print("\n=== Removing Outlier Sample ===")

# Check if 'mouse' column exists
if 'mouse' in adata.obs.columns:
    # Count cells before filtering
    n_cells_before = adata.shape[0]
    
    # Check how many cells have HTO_PBS-4
    n_pbs4_cells = (adata.obs['mouse'] == 'HTO_PBS-4').sum()
    print(f"\nCells with HTO_PBS-4: {n_pbs4_cells:,}")
    
    # Filter out HTO_PBS-4 cells
    adata = adata[adata.obs['mouse'] != 'HTO_PBS-4'].copy()
    
    # Count cells after filtering
    n_cells_after = adata.shape[0]
    print(f"\nCells before filtering: {n_cells_before:,}")
    print(f"Cells after filtering: {n_cells_after:,}")
    print(f"Cells removed: {n_cells_before - n_cells_after:,}")
    
    # Display updated metadata summary
    print(f"\n=== Updated Metadata Summary ===")
    
    if 'group' in adata.obs.columns:
        print(f"\nGroup distribution:")
        print(adata.obs['group'].value_counts())
    
    if 'batch' in adata.obs.columns:
        print(f"\nBatch distribution:")
        print(adata.obs['batch'].value_counts())
    
    if 'library' in adata.obs.columns:
        print(f"\nLibrary distribution:")
        print(adata.obs['library'].value_counts())
    
    print(f"\nMouse distribution:")
    print(adata.obs['mouse'].value_counts())
    
    print("\n✅ Outlier sample removed!")
else:
    print("⚠️ Warning: 'mouse' column not found in adata.obs")

## Set Categorical Order for Metadata Columns

**Purpose**: Define the display order for categorical columns (library, batch, mouse_ID) to ensure consistent visualization and analysis ordering throughout the workflow.

**Why this matters**: By setting ordered categories, plots and tables will display samples in a logical, consistent order rather than alphabetically or randomly.

In [None]:
# === SET CATEGORICAL ORDER FOR KEY COLUMNS ===
print("\n--- Setting Categorical Order ---")

# Set group order
group_order = ['PBS-CFA', 'PEP15-CFA']
if 'group' in adata.obs.columns:
    adata.obs['group'] = pd.Categorical(
        adata.obs['group'],
        categories=group_order,
        ordered=True
    )
    print(f"✅ Group order set: {group_order}")

# Set library order (explicit)
library_order = ['PBS-CFA', 'PEP15-CFA']
if 'library' in adata.obs.columns:
    adata.obs['library'] = pd.Categorical(
        adata.obs['library'],
        categories=library_order,
        ordered=True
    )
    print(f"✅ Library order set: {library_order}")

# Set batch order (explicit)
batch_order = ['Mistri']
if 'batch' in adata.obs.columns:
    adata.obs['batch'] = pd.Categorical(
        adata.obs['batch'],
        categories=batch_order,
        ordered=True
    )
    print(f"✅ Batch order set: {batch_order}")

# Set mouse order (explicit)
mouse_order = [
    'HTO_PBS-1', 'HTO_PBS-2', 'HTO_PBS-3', 'HTO_PBS-5', 'HTO_PBS-6',  # PBS mice
    'HTO_PEP15-1', 'HTO_PEP15-2', 'HTO_PEP15-3', 'HTO_PEP15-4', 'HTO_PEP15-5', 'HTO_PEP15-6'  # PEP15 mice
    
]
if 'mouse' in adata.obs.columns:
    adata.obs['mouse'] = pd.Categorical(
        adata.obs['mouse'],
        categories=mouse_order,
        ordered=True
    )
    print(f"✅ Mouse order set: {mouse_order}")

# Set sex order (explicit)
sex_order = ['female']
if 'sex' in adata.obs.columns:
    adata.obs['sex'] = pd.Categorical(
        adata.obs['sex'],
        categories=sex_order,
        ordered=True
    )
    print(f"✅ Sex order set: {sex_order}")

## Dataset Structure Inspection

**Purpose**: Examine dataset structure to verify data types, modalities, and metadata before integration.

In [None]:
# === DATASET STRUCTURE INSPECTION ===
# Examine the dataset to understand data structure
print(f"Inspecting dataset structure...")

# === RNA DATA INSPECTION ===
print("\n--- RNA Expression Data ---")
print(f"Type: {type(adata.X)}")
print(f"Shape: {adata.X.shape}")
if hasattr(adata.X, 'nnz'):
    sparsity = (1 - adata.X.nnz / (adata.X.shape[0] * adata.X.shape[1])) * 100
    print(f"Sparsity: {sparsity:.2f}%")

# === MULTIMODAL DATA INSPECTION ===
print("\n--- Additional Modalities ---")
for modality in ['HTO', 'ADT', 'protein']:
    if modality in adata.obsm:
        print(f"{modality}: {adata.obsm[modality].shape}")
        if f'{modality}_features' in adata.uns:
            print(f"  Features: {adata.uns[f'{modality}_features']}")
    else:
        print(f"{modality}: Not found")

# === METADATA INSPECTION ===
print("\n--- Metadata Structure ---")
print(f"Metadata shape: {adata.obs.shape}")
print(f"Metadata columns: {len(adata.obs.columns)}")

# Show sample metadata columns
print("\nSample metadata columns:")
sample_cols = adata.obs.columns[:10].tolist()
print(sample_cols)

# === CATEGORICAL VARIABLES INSPECTION ===
print("\n--- Categorical Variables ---")
categorical_cols = adata.obs.select_dtypes(include=['category', 'object']).columns
for col in categorical_cols[:5]:  # Show first 5 categorical columns
    unique_vals = adata.obs[col].unique()
    print(f"{col}: {len(unique_vals)} unique values")
    if len(unique_vals) <= 10:
        print(f"  Values: {unique_vals}")
    else:
        print(f"  Sample values: {unique_vals[:5]}...")

# === CHECK FOR REQUIRED QC COLUMNS ===
required_columns = ['pct_counts_mt', 'S_score', 'G2M_score']
print("\n--- Required QC Columns Check ---")
for col in required_columns:
    if col in adata.obs.columns:
        print(f"✅ {col}: Present")
    else:
        print(f"❌ {col}: Missing - will need to compute")

# === CHECK FOR KEY METADATA COLUMNS ===
print("\n--- Key Metadata Columns ---")
for col in ['batch', 'library', 'mouse']:
    if col in adata.obs.columns:
        unique_count = adata.obs[col].nunique()
        print(f"✅ {col}: {unique_count} unique values")
        if unique_count <= 20:
            print(f"  Values: {adata.obs[col].value_counts().to_dict()}")
    else:
        print(f"❌ {col}: Not found")

## Data Preparation and Metadata Verification

**Purpose**: Verify cell identifiers are unique and ensure all required metadata columns (batch, sample, cell cycle scores) are present for downstream analysis.

In [None]:
# === MAKE CELL IDENTIFIERS UNIQUE ===
adata.obs_names_make_unique()
assert adata.obs_names.is_unique, "Cell identifiers are not unique!"
print(f"✅ Cell identifiers are unique: {adata.shape[0]:,} cells")

# === VERIFY REQUIRED COLUMNS FOR SCVI ===
print("\n--- Verifying Required Columns ---")

# Check for group column
if 'group' in adata.obs.columns:
    print(f"✅ 'group' column: {adata.obs['group'].nunique()} unique groups")
    print(f"   Groups: {sorted(adata.obs['group'].unique().tolist())}")
    print(f"   Group distribution:")
    for group in sorted(adata.obs['group'].unique()):
        count = (adata.obs['group'] == group).sum()
        print(f"     - {group}: {count:,} cells")
else:
    print("❌ 'group' column not found")

# Check for batch column
if 'batch' in adata.obs.columns:
    print(f"\n✅ 'batch' column: {adata.obs['batch'].nunique()} unique batches")
    print(f"   Batches: {sorted(adata.obs['batch'].unique().tolist())}")
else:
    print("\n⚠️  'batch' column not found - will use 'library' as batch key for scVI")

# Check for library column
if 'library' in adata.obs.columns:
    print(f"\n✅ 'library' column: {adata.obs['library'].nunique()} unique libraries")
    print(f"   Libraries: {sorted(adata.obs['library'].unique().tolist())}")
else:
    print("\n❌ 'library' column not found")

# Check for mouse column
if 'mouse' in adata.obs.columns:
    print(f"\n✅ 'mouse' column: {adata.obs['mouse'].nunique()} unique mice")
else:
    print("\n⚠️  'mouse' column not found")

# Check for cell cycle scores
if 'S_score' in adata.obs.columns:
    print(f"\n✅ 'S_score' column found")
else:
    print("\n❌ 'S_score' column not found")

if 'G2M_score' in adata.obs.columns:
    print(f"✅ 'G2M_score' column found")
else:
    print("❌ 'G2M_score' column not found")

# === SUMMARY ===
print(f"\n=== Summary ===")
print(f"✅ Ready for scVI modeling with {adata.shape[0]:,} cells and {adata.shape[1]:,} genes")

# Determine batch key for scVI
batch_key = 'batch' if 'batch' in adata.obs.columns else 'library'
print(f"✅ Batch correction will use '{batch_key}' column")

# Determine categorical covariate
if 'library' in adata.obs.columns and batch_key != 'library':
    print(f"✅ Categorical covariate: 'library'")
elif 'group' in adata.obs.columns:
    print(f"✅ Categorical covariate: 'group'")

# Check for continuous covariates
if 'S_score' in adata.obs.columns and 'G2M_score' in adata.obs.columns:
    print(f"✅ Continuous covariates: 'S_score' and 'G2M_score'")

## Cell Count Visualization

**Purpose**: Generate publication-quality barplot showing cell counts per batch/sample for quality assessment.


In [None]:
# === CELL COUNT VISUALIZATION ===
print("=== Cell Count Visualization ===\n")

# Configure plot style
configure_plot_style()

# === PLOT: GROUP COUNTS ===
if 'group' in adata.obs.columns:
    group_counts = adata.obs['group'].value_counts()
    
    print(f"--- Group cell counts ---")
    print(group_counts)
    
    plt.figure(figsize=(10, 6))
    
    # Professional color palette
    bar_color = '#2E86AB'  # Professional blue color
    
    # Create group barplot with the ordered categories
    ax1 = sns.countplot(data=adata.obs, x='group', color=bar_color, 
                       order=adata.obs['group'].cat.categories if hasattr(adata.obs['group'], 'cat') else None)
    
    # Customize plot appearance
    plt.title('Cell Counts per Group', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Group', fontsize=14, fontweight='bold')
    plt.ylabel('Number of Cells', fontsize=14, fontweight='bold')
    
    # Rotate x-axis labels for better readability
    plt.xticks(rotation=45, ha='right', fontsize=12)
    
    # Add grid for better readability
    plt.grid(axis='y', alpha=0.3, linestyle='--')
    
    # Annotate bars with exact cell counts
    for p in ax1.patches:
        count = int(p.get_height())
        ax1.annotate(f'{count:,}', 
                    (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='bottom', fontsize=10, fontweight='bold',
                    color='black', xytext=(0, 8), textcoords='offset points')
    
    # Adjust layout and display plot inline
    plt.tight_layout()
    plt.show()
    
    # Create a new figure for saving
    plt.figure(figsize=(10, 6))
    ax_save = sns.countplot(data=adata.obs, x='group', color=bar_color, 
                           order=adata.obs['group'].cat.categories if hasattr(adata.obs['group'], 'cat') else None)
    plt.title('Cell Counts per Group', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Group', fontsize=14, fontweight='bold')
    plt.ylabel('Number of Cells', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=12)
    plt.grid(axis='y', alpha=0.3, linestyle='--')
    for p in ax_save.patches:
        count = int(p.get_height())
        ax_save.annotate(f'{count:,}', 
                    (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='bottom', fontsize=10, fontweight='bold',
                    color='black', xytext=(0, 8), textcoords='offset points')
    plt.tight_layout()
    save_plot('cell_counts_group')
    
    print(f"\n✅ Group barplot saved as: cell_counts_group.png/.pdf")
    print(f"Total cells visualized: {adata.shape[0]:,}")
    print(f"Number of groups: {len(group_counts)}")

# === PLOT: LIBRARY COUNTS ===
if 'library' in adata.obs.columns:
    library_counts = adata.obs['library'].value_counts()
    
    print(f"\n--- Library cell counts ---")
    print(library_counts)
    
    plt.figure(figsize=(12, 6))
    
    # Create library barplot
    ax2 = sns.countplot(data=adata.obs, x='library', color='#FF6B35',
                       order=adata.obs['library'].cat.categories if hasattr(adata.obs['library'], 'cat') else None)
    
    plt.title('Cell Counts per Library', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Library', fontsize=14, fontweight='bold')
    plt.ylabel('Number of Cells', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=12)
    plt.grid(axis='y', alpha=0.3, linestyle='--')
    
    for p in ax2.patches:
        count = int(p.get_height())
        ax2.annotate(f'{count:,}', 
                    (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='bottom', fontsize=10, fontweight='bold',
                    color='black', xytext=(0, 8), textcoords='offset points')
    
    plt.tight_layout()
    plt.show()
    
    # Save
    plt.figure(figsize=(12, 6))
    ax_save2 = sns.countplot(data=adata.obs, x='library', color='#FF6B35',
                            order=adata.obs['library'].cat.categories if hasattr(adata.obs['library'], 'cat') else None)
    plt.title('Cell Counts per Library', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Library', fontsize=14, fontweight='bold')
    plt.ylabel('Number of Cells', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=12)
    plt.grid(axis='y', alpha=0.3, linestyle='--')
    for p in ax_save2.patches:
        count = int(p.get_height())
        ax_save2.annotate(f'{count:,}', 
                        (p.get_x() + p.get_width() / 2., p.get_height()),
                        ha='center', va='bottom', fontsize=10, fontweight='bold',
                        color='black', xytext=(0, 8), textcoords='offset points')
    plt.tight_layout()
    save_plot('cell_counts_library')
    
    print(f"\n✅ Library barplot saved as: cell_counts_library.png/.pdf")
    print(f"Number of libraries: {len(library_counts)}")

# === PLOT: MOUSE COUNTS ===
if 'mouse' in adata.obs.columns:
    mouse_counts = adata.obs['mouse'].value_counts()
    
    print(f"\n--- Mouse cell counts ---")
    print(mouse_counts)
    
    plt.figure(figsize=(14, 6))
    
    # Create mouse barplot
    ax3 = sns.countplot(data=adata.obs, x='mouse', color='#4ECDC4',
                       order=adata.obs['mouse'].cat.categories if hasattr(adata.obs['mouse'], 'cat') else None)
    
    plt.title('Cell Counts per Mouse', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Mouse', fontsize=14, fontweight='bold')
    plt.ylabel('Number of Cells', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=12)
    plt.grid(axis='y', alpha=0.3, linestyle='--')
    
    for p in ax3.patches:
        count = int(p.get_height())
        ax3.annotate(f'{count:,}', 
                    (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='bottom', fontsize=9, fontweight='bold',
                    color='black', xytext=(0, 8), textcoords='offset points')
    
    plt.tight_layout()
    plt.show()
    
    # Save
    plt.figure(figsize=(14, 6))
    ax_save3 = sns.countplot(data=adata.obs, x='mouse', color='#4ECDC4',
                            order=adata.obs['mouse'].cat.categories if hasattr(adata.obs['mouse'], 'cat') else None)
    plt.title('Cell Counts per Mouse', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Mouse', fontsize=14, fontweight='bold')
    plt.ylabel('Number of Cells', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=12)
    plt.grid(axis='y', alpha=0.3, linestyle='--')
    for p in ax_save3.patches:
        count = int(p.get_height())
        ax_save3.annotate(f'{count:,}', 
                        (p.get_x() + p.get_width() / 2., p.get_height()),
                        ha='center', va='bottom', fontsize=9, fontweight='bold',
                        color='black', xytext=(0, 8), textcoords='offset points')
    plt.tight_layout()
    save_plot('cell_counts_mouse')
    
    print(f"\n✅ Mouse barplot saved as: cell_counts_mouse.png/.pdf")
    print(f"Number of mice: {len(mouse_counts)}")

## Quality Control (Optional)

**Purpose**: This section can be used to visualize QC metrics by batch and filter out problematic samples if needed. You can skip to the next section if no additional filtering is required.

In [None]:
# === QUALITY CONTROL WITH VISUALIZATIONS ===
print("=== Quality Control Metrics ===\n")

# Configure plot style
configure_plot_style()

# Determine which QC columns are available
qc_columns = []
for col in ['n_genes_by_counts', 'total_counts', 'pct_counts_mt', 'pct_counts_ribo']:
    if col in adata.obs.columns:
        qc_columns.append(col)

if len(qc_columns) == 0:
    print("\n⚠️  No standard QC columns found. Skipping QC plots.")
else:
    print(f"--- Generating QC Plots for: {qc_columns} ---\n")
    
    # Generate QC plots for each grouping variable
    for groupby_var in ['group', 'batch', 'library', 'mouse']:
        if groupby_var not in adata.obs.columns:
            print(f"⚠️  '{groupby_var}' column not found. Skipping.")
            continue
        
        print(f"--- Cell Counts by {groupby_var} ---")
        print(adata.obs[groupby_var].value_counts().sort_index())
        print()
        
        # Create a figure with subplots for each QC metric
        n_metrics = len(qc_columns)
        figsize_width = 14 if groupby_var == 'mouse_ID' else 10
        fig, axes = plt.subplots(n_metrics, 1, figsize=(figsize_width, 5 * n_metrics))
        
        # If only one metric, axes won't be an array
        if n_metrics == 1:
            axes = [axes]
        
        for idx, metric in enumerate(qc_columns):
            # Create violin plot
            sc.pl.violin(
                adata, 
                keys=metric, 
                groupby=groupby_var,
                rotation=90,
                ax=axes[idx],
                show=False
            )
            axes[idx].set_title(f'{metric} by {groupby_var}', fontsize=12, fontweight='bold')
            axes[idx].set_xlabel(groupby_var, fontsize=10)
            axes[idx].set_ylabel(metric, fontsize=10)
            axes[idx].grid(axis='y', alpha=0.3, linestyle='--')
        
        plt.tight_layout()
        plt.show()
        
        # Save the plot
        fig, axes = plt.subplots(n_metrics, 1, figsize=(figsize_width, 5 * n_metrics))
        if n_metrics == 1:
            axes = [axes]
        
        for idx, metric in enumerate(qc_columns):
            sc.pl.violin(
                adata, 
                keys=metric, 
                groupby=groupby_var,
                rotation=90,
                ax=axes[idx],
                show=False
            )
            axes[idx].set_title(f'{metric} by {groupby_var}', fontsize=12, fontweight='bold')
            axes[idx].set_xlabel(groupby_var, fontsize=10)
            axes[idx].set_ylabel(metric, fontsize=10)
            axes[idx].grid(axis='y', alpha=0.3, linestyle='--')
        
        plt.tight_layout()
        save_plot(f"qc_metrics_by_{groupby_var}")
        
        print(f"✅ QC plots saved as: qc_metrics_by_{groupby_var}.png/.pdf\n")

# === CELL CYCLE DISTRIBUTION ===
if 'phase' in adata.obs.columns:
    print("--- Cell Cycle Distribution ---")
    
    for groupby_var in ['group', 'batch', 'library', 'mouse']:
        if groupby_var not in adata.obs.columns:
            continue
        
        print(f"\n{groupby_var.capitalize()} - Cell Cycle Phase Distribution:")
        
        # Create crosstab
        phase_counts = pd.crosstab(adata.obs[groupby_var], adata.obs['phase'], normalize='index') * 100
        print(phase_counts.round(2))
        
        # Plot stacked bar chart
        figsize_width = 14 if groupby_var == 'mouse_ID' else 10
        fig, ax = plt.subplots(figsize=(figsize_width, 6))
        phase_counts.plot(kind='bar', stacked=True, ax=ax, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
        ax.set_title(f'Cell Cycle Phase Distribution by {groupby_var}', fontsize=14, fontweight='bold')
        ax.set_xlabel(groupby_var, fontsize=12)
        ax.set_ylabel('Percentage of Cells', fontsize=12)
        ax.legend(title='Phase', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.xticks(rotation=90, ha='right')
        plt.grid(axis='y', alpha=0.3, linestyle='--')
        plt.tight_layout()
        plt.show()
        
        # Save
        fig, ax = plt.subplots(figsize=(figsize_width, 6))
        phase_counts.plot(kind='bar', stacked=True, ax=ax, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
        ax.set_title(f'Cell Cycle Phase Distribution by {groupby_var}', fontsize=14, fontweight='bold')
        ax.set_xlabel(groupby_var, fontsize=12)
        ax.set_ylabel('Percentage of Cells', fontsize=12)
        ax.legend(title='Phase', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.xticks(rotation=90, ha='right')
        plt.grid(axis='y', alpha=0.3, linestyle='--')
        plt.tight_layout()
        save_plot(f"cell_cycle_distribution_by_{groupby_var}")
        
        print(f"✅ Cell cycle plot saved as: cell_cycle_distribution_by_{groupby_var}.png/.pdf")

# === OPTIONAL FILTERING ===
print("\n--- Optional Filtering ---")
print("To remove problematic samples, uncomment and modify:")
print("# Remove by group: adata = adata[~adata.obs['group'].isin(['NI'])].copy()")
print("# Remove by batch: adata = adata[~adata.obs['batch'].isin(['batch1'])].copy()")
print("# Remove by library: adata = adata[~adata.obs['library'].isin(['APP-PS1-1'])].copy()")
print("# Remove by mouse: adata = adata[~adata.obs['mouse'].isin(['mouse1'])].copy()")

print(f"\n✅ Quality control complete")
print(f"Dataset ready for modeling: {adata.shape[0]:,} cells, {adata.shape[1]:,} genes")

## Data Preparation and scVI Setup

**Purpose**: Prepare data for scVI modeling by creating separate datasets for training (HVG subset) and analysis (full genes), selecting highly variable genes, and configuring scVI with batch and covariate correction.

**Key Steps**:
- Create `adata_hvg`: HVG subset for efficient scVI training
- Create `adata_full`: Full gene dataset for downstream analysis
- Select 4,000 highly variable genes using Seurat v3 method
- Store raw counts in `.layers['raw_counts']` for scVI modeling
- Setup scVI with batch and covariate correction

**scVI Configuration**:
- **Batch key**: `batch` - corrects for batch effects across samples
- **Categorical covariate**: `library` - accounts for library-specific effects
- **Continuous covariates**: `S_score`, `G2M_score` - corrects for cell cycle effects

This configuration allows scVI to model and correct for technical variation while preserving biological variation of interest.

In [None]:
# === BACKUP RAW COUNTS ===
# Store raw RNA counts (with all genes) before HVG selection
print("Preparing data for scVI modeling...")
raw_counts_full = adata.X.copy()

# === CREATE SEPARATE COPIES FOR HVG AND FULL GENE DATASETS ===
adata_hvg = adata.copy()  # This will be subsetted to HVGs for training
adata_full = adata.copy()  # This preserves all genes for downstream analysis
adata_full.layers['raw_counts'] = adata.X.copy()  # Store raw counts

print(f"✅ Created data copies:")
print(f"  - adata_hvg: Will be subset to {N_TOP_GENES} HVGs for training")
print(f"  - adata_full: Preserves all {adata_full.shape[1]:,} genes")

# === STORE ORIGINAL VAR NAMES ===
original_var_names = adata_hvg.var_names.copy()

# === HIGHLY VARIABLE GENE SELECTION ===
print(f"\nSelecting {N_TOP_GENES} highly variable genes...")
sc.pp.highly_variable_genes(
    adata_hvg,
    n_top_genes=N_TOP_GENES,
    flavor="seurat_v3",
    batch_key='library',  # Batch-specific HVG selection
    subset=True
)

print(f"✅ HVG selection complete: {adata_hvg.shape[1]} genes selected")

# === RESTORE RAW COUNTS FOR HVG-SELECTED GENES ===
hvg_indices = [original_var_names.get_loc(gene) for gene in adata_hvg.var_names]
adata_hvg.layers['raw_counts'] = raw_counts_full[:, hvg_indices]

print(f"✅ Raw counts layer created for HVG subset")
print(f"  - Shape: {adata_hvg.layers['raw_counts'].shape}")

# === SCVI CONFIGURATION ===
# **EDIT THESE VARIABLES TO CHANGE SCVI SETUP**
BATCH_KEY = 'library'  # Column to use for batch correction
CATEGORICAL_COVARIATES = []  # List of categorical covariates (e.g., ['sex'] or [])
CONTINUOUS_COVARIATES = []  # List of continuous covariates (e.g., ['S_score', 'G2M_score'] or [])

# === SETUP SCVI MODEL WITH BATCH AND COVARIATE CORRECTION ===
print(f"\n--- Setting up scVI model ---")
print(f"  - Batch correction: Enabled (batch_key='{BATCH_KEY}')")
print(f"  - Number of batches: {adata_hvg.obs[BATCH_KEY].nunique()}")
print(f"  - Batches: {sorted(adata_hvg.obs[BATCH_KEY].unique().tolist())}")

if CATEGORICAL_COVARIATES:
    print(f"  - Categorical covariates: {CATEGORICAL_COVARIATES}")
else:
    print(f"  - Categorical covariates: Not defined")

if CONTINUOUS_COVARIATES:
    print(f"  - Continuous covariates: {CONTINUOUS_COVARIATES}")
else:
    print(f"  - Continuous covariates: Not defined")

print(f"  - Using raw counts from layer: 'raw_counts'")

# Build setup arguments dynamically
setup_kwargs = {
    'adata': adata_hvg,
    'layer': 'raw_counts',
    'batch_key': BATCH_KEY
}

if CATEGORICAL_COVARIATES:
    setup_kwargs['categorical_covariate_keys'] = CATEGORICAL_COVARIATES

if CONTINUOUS_COVARIATES:
    setup_kwargs['continuous_covariate_keys'] = CONTINUOUS_COVARIATES

# Setup scVI
scvi.model.SCVI.setup_anndata(**setup_kwargs)

# === DATA STRUCTURE CONFIRMATION ===
print("\n=== Data Structure Confirmation ===")
print(f"✅ Data objects ready for scVI:")
print(f"  - adata_hvg (for training): {adata_hvg.shape}")
print(f"    • Genes: {adata_hvg.shape[1]:,} HVGs")
print(f"    • Cells: {adata_hvg.shape[0]:,}")
print(f"    • Raw counts layer: {adata_hvg.layers['raw_counts'].shape}")
print(f"    • Batches: {adata_hvg.obs[BATCH_KEY].nunique()}")
print(f"  - adata_full (for analysis): {adata_full.shape}")
print(f"    • Genes: {adata_full.shape[1]:,} (all genes)")
print(f"    • Cells: {adata_full.shape[0]:,}")
print(f"    • Raw counts layer: {adata_full.layers['raw_counts'].shape}")

print(f"\n✅ Metadata preservation:")
print(f"  - Cells: {adata_hvg.shape[0]:,}")
print(f"  - Metadata columns: {adata_hvg.obs.shape[1]}")
print(f"  - Key columns: batch, library, group, mouse")

print(f"\n✅ scVI setup complete - ready for training!")
print(f"Note: Batch correction applied for '{BATCH_KEY}' ({adata_hvg.obs[BATCH_KEY].nunique()} batches).")
if CATEGORICAL_COVARIATES:
    print(f"      Categorical covariates: {', '.join(CATEGORICAL_COVARIATES)}")
else:
    print(f"      Categorical covariates: Not defined")
if CONTINUOUS_COVARIATES:
    print(f"      Continuous covariates: {', '.join(CONTINUOUS_COVARIATES)}")
else:
    print(f"      Continuous covariates: Not defined")

## scVI Model Training

**Purpose**: Train scVI model for covariate correction and generate latent representations for downstream analysis.

In [None]:
# === DEVICE CONFIGURATION ===
# Automatically detect and configure training device
if torch.backends.mps.is_available():
    accelerator = "mps"
    devices = 1
    print("MPS (Apple Silicon GPU) available. Using MPS for training.")
elif torch.cuda.is_available():
    accelerator = "gpu"
    devices = 1
    print("CUDA GPU available. Using GPU for training.")
else:
    accelerator = "cpu"
    devices = "auto"
    print("No GPU available. Using CPU for training.")

# === MODEL INITIALIZATION ===
print(f"\nInitializing scVI model...")
model = scvi.model.SCVI(adata_hvg, n_latent=20, n_layers=2)

# === TRAINING CONFIGURATION ===
print(f"Training configuration:")
print(f"  - Max epochs: {MAX_EPOCHS}")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Accelerator: {accelerator}")
print(f"  - Early stopping patience: {EARLY_STOPPING_PATIENCE}")
print(f"  - Training on: {adata_hvg.shape[0]:,} cells, {adata_hvg.shape[1]:,} HVGs")

# === MODEL TRAINING ===
print("\nStarting scVI training...")
model.train(
    max_epochs=MAX_EPOCHS,
    batch_size=BATCH_SIZE,
    accelerator=accelerator,
    devices=devices,
    early_stopping=True,
    early_stopping_patience=EARLY_STOPPING_PATIENCE
)

# === TRAINING VISUALIZATION ===
configure_plot_style()
plt.figure(figsize=(10, 6))
plt.plot(model.history['elbo_train']['elbo_train'], label='Training ELBO', alpha=0.8)
plt.plot(model.history['elbo_validation']['elbo_validation'], label='Validation ELBO', alpha=0.8)
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("ELBO Loss", fontsize=12)
plt.title("scVI Training Progress", fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
show_inline_plot()
save_plot("scvi_training_elbo_loss")

# === EXTRACT LATENT REPRESENTATIONS ===
latent_key = "X_scVI"
print(f"\nExtracting latent representations...")

# Get latent representation
latent_representation = model.get_latent_representation()

# Store in both HVG and full gene objects
adata_hvg.obsm[latent_key] = latent_representation
adata_full.obsm[latent_key] = latent_representation

print(f"Latent representation stored in:")
print(f"  - adata_hvg.obsm['{latent_key}']: {latent_representation.shape}")
print(f"  - adata_full.obsm['{latent_key}']: {latent_representation.shape}")

# === TRAINING SUMMARY ===
print("\n=== Training Complete ===")

# Safely access training history
if len(model.history['elbo_train']['elbo_train']) > 0:
    print(f"Final training ELBO: {model.history['elbo_train']['elbo_train'].iloc[-1]:.2f}")
    print(f"Final validation ELBO: {model.history['elbo_validation']['elbo_validation'].iloc[-1]:.2f}")
else:
    print("Training history not available (possible early convergence)")
    
print(f"Latent dimensions: {latent_representation.shape[1]}")
print(f"Total cells: {latent_representation.shape[0]:,}")
print(f"Batch correction applied for: '{BATCH_KEY}' ({adata_hvg.obs[BATCH_KEY].nunique()} batches)")
if CATEGORICAL_COVARIATES:
    print(f"Categorical covariates: {', '.join(CATEGORICAL_COVARIATES)}")
if CONTINUOUS_COVARIATES:
    print(f"Continuous covariates: {', '.join(CONTINUOUS_COVARIATES)}")

## Post-Training Verification

**Purpose**: Verify that the scVI training was successful, latent representations are properly stored, and all metadata columns are preserved for downstream analysis.

In [None]:
# === VERIFY DATA TRANSFER ===
print("\n=== Verifying Latent Space Transfer ===")
print(f"✅ Latent representation successfully stored:")
print(f"  - adata_hvg.obsm['X_scVI']: {adata_hvg.obsm['X_scVI'].shape}")
print(f"  - adata_full.obsm['X_scVI']: {adata_full.obsm['X_scVI'].shape}")

# === VERIFY METADATA AVAILABILITY ===
print("\n=== Metadata Verification ===")
# Check for columns that exist in your data
important_columns = ['batch', 'library', 'group', 'mouse', 'phase']
if 'S_score' in adata_full.obs.columns:
    important_columns.extend(['S_score', 'G2M_score', 'pct_counts_mt'])

for col in important_columns:
    if col in adata_full.obs.columns:
        unique_count = adata_full.obs[col].nunique()
        print(f"✅ {col}: {unique_count} unique values")
    else:
        print(f"❌ {col}: Not found")

# === SUMMARY ===
print("\n=== Ready for Downstream Analysis ===")
print(f"✅ scVI training complete")
print(f"✅ Latent space available in both adata_hvg and adata_full")
print(f"✅ Use adata_full for downstream analysis (has all {adata_full.shape[1]:,} genes)")
print(f"✅ Batch correction applied for '{BATCH_KEY}' ({adata_full.obs[BATCH_KEY].nunique()} batches)")
if CATEGORICAL_COVARIATES:
    print(f"✅ Categorical covariates: {', '.join(CATEGORICAL_COVARIATES)}")
if CONTINUOUS_COVARIATES:
    print(f"✅ Continuous covariates: {', '.join(CONTINUOUS_COVARIATES)}")

### Data Object Summary

**Purpose**: Verify data structure and confirm both adata objects are ready for downstream analysis.

In [None]:
# === DATA OBJECT SUMMARY ===
print("=== Data Object Summary ===\n")

print("--- adata_hvg (for training) ---")
print(f"  Cells: {adata_hvg.shape[0]:,} | Genes: {adata_hvg.shape[1]:,} HVGs")
print(f"  Latent space: {'X_scVI' in adata_hvg.obsm}")

print("\n--- adata_full (for analysis) ---")
print(f"  Cells: {adata_full.shape[0]:,} | Genes: {adata_full.shape[1]:,} (all genes)")
print(f"  Latent space: {adata_full.obsm['X_scVI'].shape if 'X_scVI' in adata_full.obsm else 'Not found'}")
print(f"  Raw counts: {adata_full.layers['raw_counts'].shape}")

print("\n--- Recommendation ---")
print("✅ Use adata_full for all downstream analysis (clustering, UMAP, DE, visualization)")

print("\n✅ Data verification complete - ready for downstream analysis!")

## Save Integrated Data

**Purpose**: Save the integrated MuData objects (HVG subset and full genes) for downstream analysis.

In [None]:
# === SAVE INTEGRATED DATA ===
print("Saving integrated AnnData object...")

# Remove problematic column before saving
if 'most_likely_hypothesis' in adata_full.obs.columns:
    print("Removing 'most_likely_hypothesis' column (not compatible with h5ad format)")
    adata_full.obs.drop(columns=['most_likely_hypothesis'], inplace=True)

# Save full genes only
output_filename = "integrated_library"
adata_full.write_h5ad(output_dir / f"{output_filename}.h5ad", compression="gzip")

# File info
file_size = (output_dir / f"{output_filename}.h5ad").stat().st_size / (1024**2)
print(f"\n✅ Saved: {output_filename}.h5ad ({file_size:.1f} MB)")
print(f"   - {adata_full.shape[0]:,} cells, {adata_full.shape[1]:,} genes")
print(f"   - Latent dims: {adata_full.obsm['X_scVI'].shape[1]}")
print(f"   - Batch key: '{BATCH_KEY}' ({adata_full.obs[BATCH_KEY].nunique()} batches)")
if CATEGORICAL_COVARIATES:
    print(f"   - Categorical: {', '.join(CATEGORICAL_COVARIATES)}")
if CONTINUOUS_COVARIATES:
    print(f"   - Continuous: {', '.join(CONTINUOUS_COVARIATES)}")

## Analysis Summary

**Purpose**: Document analysis completion, environment details, and output files for reproducibility.

In [None]:
# === ANALYSIS COMPLETION SUMMARY ===
print("=== SINGLE-CELL Integration ANALYSIS COMPLETE ===\n")

# === ENVIRONMENT INFORMATION ===
print("--- Environment Details ---")
print(f"Python version: {sys.version.split()[0]}")
print(f"Analysis completed: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# === PACKAGE VERSIONS ===
print("\n--- Package Versions ---")
packages = ['anndata', 'scanpy', 'scvi-tools', 'mudata', 'muon', 'torch']
for pkg in packages:
    try:
        version = pkg_resources.get_distribution(pkg).version
        print(f"{pkg}: {version}")
    except:
        print(f"{pkg}: Not available")

