# Enrichment Analysis for scRNA-seq Data

## Import required libraries

In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

## Configure Environment


In [2]:
# Configure Scanpy settings
sc.settings.verbosity = 3  # Show more output by default
sc.settings.set_figure_params(dpi=100, figsize=(8, 8))
np.random.seed(42)

# Project Configuration and paths
PROJ_NAME = ""
PROJ_DESCRIPTION = ""
FULL_PROJ_NAME = f"{PROJ_NAME}_{PROJ_DESCRIPTION}"

PROJECT_DIR = Path("/path/to/project")
OUTPUT_DIR = PROJECT_DIR / "output"

## Data Loading

In [None]:
# Load annotated data
print("Loading annotated data...")
adata = sc.read_h5ad(OUTPUT_DIR / f"{FULL_PROJ_NAME}_annotated.h5ad")
print(f"Data shape: {adata.shape[0]} cells and {adata.shape[1]} genes")

## DEG Analysis

In [None]:
# Find differentially expressed genes
print("\nFinding differentially expressed genes...")
sc.tl.rank_genes_groups(
    adata,
    groupby='celltype',
    method='wilcoxon',
    key_added='rank_genes_wilcox'
)

In [5]:
# Get DEGs as dataframe
def get_significant_degs(adata, min_logfc=0.8, max_pval=0.05):
    """Get significant DEGs for each cell type"""
    degs_dict = {}
    
    for group in adata.obs['celltype'].unique():
        # Get DEGs for this group
        degs = sc.get.rank_genes_groups_df(
            adata,
            group=group,
            key='rank_genes_wilcox'
        )
        
        # Filter for significance
        sig_degs = degs[
            (degs['logfoldchanges'] >= min_logfc) &
            (degs['pvals_adj'] < max_pval)
        ]
        
        degs_dict[group] = sig_degs['names'].tolist()
    
    return degs_dict



In [None]:
# Get significant DEGs
print("\nExtracting significant DEGs...")
deg_dict = get_significant_degs(adata)

for celltype, genes in deg_dict.items():
    print(f"{celltype}: {len(genes)} significant genes")

## GO Enrichment Analysis

In [None]:
print("\nPerforming GO enrichment analysis...")
go_results = {}

for celltype, genes in deg_dict.items():
    if len(genes) > 0:
        try:
            # Run GO enrichment
            enr = gp.enrichr(
                gene_list=genes,
                organism='Mouse',
                gene_sets=['GO_Biological_Process_2021'],
                cutoff=0.05
            )
            go_results[celltype] = enr.results
            
            # Plot top GO terms
            if not enr.results.empty:
                plt.figure(figsize=(10, 6))
                sns.barplot(enr.results.head(10))
                plt.title(f'Top GO terms for {celltype}')
                plt.tight_layout()
                plt.show()
                
        except Exception as e:
            print(f"Error in GO analysis for {celltype}: {str(e)}")


## KEGG Pathway Analysis

In [None]:
print("\nPerforming KEGG pathway analysis...")
kegg_results = {}

for celltype, genes in deg_dict.items():
    if len(genes) > 0:
        try:
            # Run KEGG enrichment
            enr = gp.enrichr(
                gene_list=genes,
                organism='Mouse',
                gene_sets=['KEGG_2019_Mouse'],
                cutoff=0.05
            )
            kegg_results[celltype] = enr.results
            
        except Exception as e:
            print(f"Error in KEGG analysis for {celltype}: {str(e)}")


In [None]:
# Plot combined KEGG results
print("\nPlotting KEGG pathway analysis results...")
# Combine significant pathways from all cell types
all_kegg = pd.concat([df.assign(celltype=ct) for ct, df in kegg_results.items()])
all_kegg = all_kegg[all_kegg['Adjusted P-value'] < 0.05]
# Add -log10(Adjusted P-value) column
all_kegg['-log10(Adjusted P-value)'] = -np.log10(all_kegg['Adjusted P-value'])

if not all_kegg.empty:
    plt.figure(figsize=(12, 8))
    sns.scatterplot(
        data=all_kegg,
        x='celltype',
        y='Term',
        size='-log10(Adjusted P-value)',
        hue='Combined Score',
        sizes=(100, 400),  # Increased minimum size and adjusted range
    )
    plt.xticks(rotation=45)
    plt.title('KEGG Pathways Across Cell Types')
    plt.tight_layout()
    plt.show()

## GSEA

In [None]:
def prepare_ranked_genes(adata, celltype):
    """Prepare ranked gene list for GSEA"""
    sc.tl.rank_genes_groups(
        adata,
        groupby='celltype',
        groups=[celltype],
        reference='rest',
        method='wilcoxon',
        key_added=f'gsea_{celltype}'
    )
    
    ranked_genes = sc.get.rank_genes_groups_df(
        adata,
        group=celltype,
        key=f'gsea_{celltype}'
    )
    
    # Remove NA values and duplicates
    ranked_genes = ranked_genes.dropna()
    ranked_genes = ranked_genes.drop_duplicates(subset='names')
    
    # Convert gene symbols to uppercase
    ranked_genes['names'] = ranked_genes['names'].str.upper()
    
    ranked_list = pd.Series(
        ranked_genes['logfoldchanges'].values,
        index=ranked_genes['names'].values
    ).sort_values(ascending=False)
    
    return ranked_list

In [None]:

print("\nPerforming GSEA analysis...")


# Initialize results dictionary
gsea_results = {}

# Run GSEA for each cell type
for celltype in adata.obs['celltype'].unique():
    try:
        # Prepare ranked gene list
        ranked_genes = prepare_ranked_genes(adata, celltype)
        
        if ranked_genes is not None and not ranked_genes.empty:
            # Run preranked GSEA with adjusted parameters
            gs_res = gp.prerank(
                rnk=ranked_genes,
                gene_sets='Mouse_Gene_Atlas',
                min_size=10,
                max_size=1000,
                permutation_num=100,
                no_plot=True,  # Prevent automatic plotting
                outdir=None,
                seed=42
            )
            
            # Store results
            if hasattr(gs_res, 'res2d'):
                gsea_results[celltype] = gs_res.res2d
                
                # Plot top pathways if results exist
                if not gs_res.res2d.empty:
                    plt.figure(figsize=(12, 6))
                    
                    # Get top 10 pathways by absolute NES value
                    top_pathways = gs_res.res2d.sort_values('NES', ascending=False).head(10)
                    plt.barh(range(len(top_pathways)), 
                            top_pathways['NES'],
                            align='center')
                    plt.yticks(range(len(top_pathways)), 
                             top_pathways.Term,
                             fontsize=8)
                    plt.xlabel('Normalized Enrichment Score (NES)')
                    plt.title(f'Top GSEA pathways for {celltype}')
                    for i, (nes, fdr) in enumerate(zip(top_pathways['NES'], 
                                                     top_pathways['FDR q-val'])):
                        plt.text(nes, i, f'FDR={fdr:.3f}', 
                               va='center', fontsize=8)
                    plt.tight_layout()
                    plt.show()
                    
    except Exception as e:
        print(f"Error in GSEA analysis for {celltype}: {str(e)}")



In [None]:
# Print summary of results
print("\nGSEA Analysis Summary:")
for celltype, results in gsea_results.items():
    if not results.empty:
        print(f"\n{celltype}: Found {len(results)} enriched gene sets")
        print("Top 5 enriched pathways:")
        top5 = results.sort_values('NES', ascending=False).head(5)
        print(top5[['Term', 'NES', 'FDR q-val']].to_string())

##  Save results

In [None]:
print("\nSaving enrichment results...")
enrichment_results = {
    'DEGs': deg_dict,
    'GO': go_results,
    'KEGG': kegg_results,
    'GSEA': gsea_results
}

output_file = OUTPUT_DIR / f"{FULL_PROJ_NAME}_enrichment.pkl"
pd.to_pickle(enrichment_results, output_file)

print("Enrichment analysis complete!")