In [1]:
# Import packages
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import csv
import yaml
import re
import seaborn as sns
import scipy.cluster.hierarchy as sch

In [2]:
# Load the processed data
adata = sc.read_h5ad("/Users/aumchampaneri/Databases/Triple/Hs_Nor-CKD-AKF_scRNA_processed.h5ad")

In [3]:
# Load the gene dictionary from the csv file
gene_dict = {}
with open('complement_gene_dictionary.csv', newline='') as file:
    reader = csv.reader(file)
    next(reader)  # Skip header
    for row in reader:
        gene_dict[row[0]] = row[1]

# Extract keys and values into separate lists
gene_dict_names = list(gene_dict.keys())
gene_dict_keys = list(gene_dict.values())

# Change the name of some entries in gene_dict_names to fix plotting errors
gene_dict_names = [re.sub(r'\bC2\b', 'C2_ENSG00000166278', name) for name in gene_dict_names]
gene_dict_names = [re.sub(r'\bC3\b', 'C3_ENSG00000125730', name) for name in gene_dict_names]
gene_dict_names = [re.sub(r'\bC6\b', 'C6_ENSG00000039537', name) for name in gene_dict_names]
gene_dict_names = [re.sub(r'\bC7\b', 'C7_ENSG00000112936', name) for name in gene_dict_names]
gene_dict_names = [re.sub(r'\bC9\b', 'C9_ENSG00000113600', name) for name in gene_dict_names]

# Load the tissue type dictionary from the yaml file
with open("Tissue Type Dictionary.yaml", "r") as file:
    cell_type_group = yaml.safe_load(file)

# Map cell types to groups
adata.obs['cell_type_group'] = 'Other'
for group, cell_types in cell_type_group.items():
    adata.obs.loc[adata.obs['cell_type'].isin(cell_types), 'cell_type_group'] = group

## Test differential expression of complement genes in different disease states

In [45]:
def differential_expression_heatmap(
        adata, gene_dict, groupby="cell_type", group1="Reference", group2="CKD",
        save_path=None, title_fontsize=16, tick_fontsize=12, annot_fontsize=10,
        colorbar_position=[0.85, 0.75, 0.03, 0.2], figsize=(15, 12), dendrogram_ratio=(0.05, 0.2)
):
    """
    Generates a clustered heatmap showing log fold change (logFC) in gene expression
    between two groups across cell types.

    Parameters
    ----------
    adata : AnnData
        Single-cell RNA-seq data.
    gene_dict : dict
        Dictionary mapping Gene Names to Ensembl IDs.
    groupby : str, optional
        Column in `adata.obs` defining cell types (default: "cell_type").
    group1 : str, optional
        Reference group (default: "Reference").
    group2 : str, optional
        Experimental/disease group (default: "CKD").
    save_path : str, optional
        Path to save the heatmap figure (default: None).
    title_fontsize : int, optional
        Font size of the heatmap title (default: 16).
    tick_fontsize : int, optional
        Font size of the tick labels (default: 12).
    annot_fontsize : int, optional
        Font size of heatmap annotations (default: 10).
    colorbar_position : list, optional
        Position of the colorbar as [x, y, width, height] (default: [0.85, 0.75, 0.03, 0.2]).
    figsize : tuple, optional
        Size of the heatmap figure (default: (15, 12)).
    dendrogram_ratio : tuple, optional
        Ratio of space allocated to the dendrograms (default: (0.05, 0.2)).

    Returns
    -------
    matplotlib.figure.Figure
        The heatmap figure.
    """
    # Convert Ensembl IDs to gene names
    ensembl_to_gene = {v: k for k, v in gene_dict.items()}
    valid_ensembl_ids = [gene_id for gene_id in ensembl_to_gene.keys() if gene_id in adata.var_names]

    if not valid_ensembl_ids:
        raise ValueError("None of the specified genes are found in the dataset.")

    gene_labels = [ensembl_to_gene[ensembl_id] for ensembl_id in valid_ensembl_ids]

    # Extract expression data for selected genes
    expression_data = adata[:, valid_ensembl_ids].to_df()
    expression_data = expression_data.join(adata.obs[[groupby, "diseasetype"]])

    # Check if both groups exist in the dataset
    if group1 not in expression_data["diseasetype"].values or group2 not in expression_data["diseasetype"].values:
        raise ValueError(f"One or both groups ({group1}, {group2}) not found in the dataset.")

    # Compute mean expression per cell type for both groups
    group1_mean = expression_data.loc[expression_data["diseasetype"] == group1].groupby(groupby)[
        valid_ensembl_ids].mean()
    group2_mean = expression_data.loc[expression_data["diseasetype"] == group2].groupby(groupby)[
        valid_ensembl_ids].mean()

    # Compute log fold change (logFC) with an adaptive pseudocount
    pseudocount = max(expression_data[valid_ensembl_ids].min().min(), 1e-6)
    logFC_data = np.log2((group2_mean + pseudocount) / (group1_mean + pseudocount))

    # Rename columns to gene names
    logFC_data.index.name = "Cell Type"
    logFC_data.columns = gene_labels

    # Cluster genes using hierarchical clustering
    row_linkage = sch.linkage(logFC_data.T, method="ward")

    # Plot heatmap with larger box size and adjusted dendrogram ratio
    g = sns.clustermap(
        logFC_data.T,
        cmap="bwr", center=0, linewidths=0.5, annot=True, fmt=".2f",
        row_cluster=True, col_cluster=False, row_linkage=row_linkage,
        figsize=figsize, annot_kws={"size": annot_fontsize},
        xticklabels=True, yticklabels=True,
        dendrogram_ratio=dendrogram_ratio,  # Reduce dendrogram size to expand heatmap area
        cbar_pos=colorbar_position  # Reposition colorbar
    )

    # Set title with adjustable font size
    g.ax_heatmap.set_title(f"Log Fold Change (logFC) in Gene Expression: {group2} vs {group1}", fontsize=title_fontsize)

    # Rotate x-axis labels and set font sizes
    plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), rotation=45, ha="right", fontsize=tick_fontsize)
    plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=tick_fontsize)

    # Save or show plot
    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=300)
    else:
        plt.show()

    return g.fig


In [4]:
'''
Generate genes of interest lists for plotting
- Use the BioMart query for ideas

gene_dict_names -> all the complement genes

'''

complement_receptors = ['CR1', 'CR2', 'CR3', 'CR4', 'C3AR1', 'C5AR1', 'C5AR2']
complement_regulatory_proteins = ['CD46', 'CD55', 'CD59', 'SERPING1', 'CLU', 'VTN', 'PLG', 'CD35', 'THBD', 'VWF']
complement_alternative_pathway = ['CFB', 'CFD', 'CFH', 'CFHR1', 'CFHR2', 'CFHR3', 'CFHR4', 'CFHR5', 'CFI']
core_complement = ['C1QA', 'C1QB', 'C1QC', 'C1R', 'C1S', 'C2_ENSG00000166278', 'C3_ENSG00000125730', 'C4A', 'C4B', 'C5',
                   'C6_ENSG00000039537', 'C7_ENSG00000112936', 'C8A', 'C8B', 'C8G', 'C9_ENSG00000113600']
lectin_pathway = ['MBL2', 'MASP1', 'MASP2', 'MASP3', 'FCN1', 'FCN2', 'FCN3']
complosome_core = ['C3_ENSG00000125730', 'C5', 'C3AR1', 'C5AR1']
intracellular_activation = ['C3AR1', 'C5AR1', 'C5AR2']
metabolism_autophagy = ['ATG5', 'ATG7', 'ATG12', 'ATG16L1', 'BECN1', 'LC3B', 'ULK1']
mitochondiral_response = ['NLRP3', 'CASP1', 'CASP4', 'CASP5', 'CASP8', 'CASP9', 'CASP12']

In [None]:
def differential_expression_heatmap(
    adata, gene_dict, groupby="cell_type", group1="Reference", group2="CKD",
    save_path=None, title_fontsize=16, tick_fontsize=12, annot_fontsize=10,
    colorbar_position=[0.85, 0.75, 0.03, 0.2], figsize=(15, 12), dendrogram_ratio=(0.05, 0.2),
    cell_types=None, min_cells_per_group=10, gene_list=None
):
    """
    Generates a clustered heatmap showing log fold change (logFC) in gene expression
    between two groups across cell types.

    Parameters
    ----------
    [existing parameters...]
    cell_types : list, optional
        List of cell types to include (default: None, includes all).
    min_cells_per_group : int, optional
        Minimum number of cells required per group for a cell type (default: 10).
    gene_list : list, optional
        List of gene names to include (default: None, includes all genes in gene_dict).
    """
    # Filter gene_dict if gene_list is provided
    if gene_list:
        gene_dict = {k: v for k, v in gene_dict.items() if k in gene_list}

    if not gene_dict:
        raise ValueError("No genes selected for analysis.")

    # Filter by cell types if specified
    adata_subset = adata.copy()
    if cell_types:
        adata_subset = adata_subset[adata_subset.obs[groupby].isin(cell_types)]

    # Filter to include only the specified disease types
    adata_subset = adata_subset[adata_subset.obs["diseasetype"].isin([group1, group2])]

    # Convert Ensembl IDs to gene names
    ensembl_to_gene = {v: k for k, v in gene_dict.items()}
    valid_ensembl_ids = [gene_id for gene_id in gene_dict.values() if gene_id in adata_subset.var_names]

    if not valid_ensembl_ids:
        raise ValueError("None of the specified genes are found in the dataset.")

    gene_labels = [ensembl_to_gene[ensembl_id] for ensembl_id in valid_ensembl_ids]

    # Extract expression data for selected genes
    expression_data = adata_subset[:, valid_ensembl_ids].to_df()
    expression_data = expression_data.join(adata_subset.obs[[groupby, "diseasetype"]])

    # Filter cell types with too few cells
    if min_cells_per_group > 0:
        cell_counts = expression_data.groupby([groupby, "diseasetype"], observed=True).size().unstack()
        valid_cell_types = cell_counts[
            (cell_counts[group1] >= min_cells_per_group) &
            (cell_counts[group2] >= min_cells_per_group)
        ].index.tolist()

        expression_data = expression_data[expression_data[groupby].isin(valid_cell_types)]

    # Compute mean expression per cell type for both groups
    group1_mean = expression_data.loc[expression_data["diseasetype"] == group1].groupby(groupby, observed=True)[valid_ensembl_ids].mean()
    group2_mean = expression_data.loc[expression_data["diseasetype"] == group2].groupby(groupby, observed=True)[valid_ensembl_ids].mean()

    # Compute log fold change (logFC) with fixed pseudocount for stability
    pseudocount = 1.0  # Fixed pseudocount for more stable fold changes
    logFC_data = np.log2((group2_mean + pseudocount) / (group1_mean + pseudocount))

    # [rest of function remains the same]

    # Rename columns to gene names
    logFC_data.index.name = "Cell Type"
    logFC_data.columns = gene_labels

    # Cluster genes using hierarchical clustering
    row_linkage = sch.linkage(logFC_data.T, method="ward")

    # Plot heatmap with larger box size and adjusted dendrogram ratio
    g = sns.clustermap(
        logFC_data.T,
        cmap="bwr", center=0, linewidths=0.5, annot=True, fmt=".2f",
        row_cluster=True, col_cluster=False, row_linkage=row_linkage,
        figsize=figsize, annot_kws={"size": annot_fontsize},
        xticklabels=True, yticklabels=True,
        dendrogram_ratio=dendrogram_ratio,  # Reduce dendrogram size to expand heatmap area
        cbar_pos=colorbar_position  # Reposition colorbar
    )

    # Set title with adjustable font size
    g.ax_heatmap.set_title(f"Log Fold Change (logFC) in Gene Expression: {group2} vs {group1}", fontsize=title_fontsize)

    # Rotate x-axis labels and set font sizes
    plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), rotation=45, ha="right", fontsize=tick_fontsize)
    plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=tick_fontsize)

    # Save or show plot
    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=300)

        # Save the data as Excel file
        if save_path.endswith(('.pdf', '.png', '.jpg')):
            excel_path = save_path.rsplit('.', 1)[0] + '.xlsx'
            logFC_data.to_excel(excel_path)
            if debug:
                print(f"Data saved to {excel_path}")
    else:
        plt.show()

    return g.fig


# Example usage:
disease = 'CKD'  # Choose the disease you want to plot (AKI, CKD)
subset = 'cell_type'  # Choose the groupby variable (cell_type, cell_type_group)

# For the filename, use a more descriptive name based on what genes you're selecting
gene_list_description = 'all_complement_genes'

# Select which gene list to use - directly use the list instead of trying to reference by string name
# Use one of the pre-defined lists:
selected_genes = {gene: gene_dict[gene] for gene in gene_dict_names if gene in gene_dict}

# Call the function with the selected gene dictionary
differential_expression_heatmap(
    adata,
    gene_dict,
    groupby=subset,
    group1="Reference",
    group2=disease,
    save_path=f"{disease}_{subset}_logFC-heatmap_{gene_list_f}.pdf",
    gene_list=gene_list,  # Pass the gene list explicitly
    min_cells_per_group=10,
    title_fontsize=14,
    tick_fontsize=8,
    annot_fontsize=6,
    colorbar_position=[0.98, 0.30, 0.01, 0.50],
    figsize=(12, 17),
    dendrogram_ratio=(0.05, 0.2)
)

In [12]:
import logging
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.cluster.hierarchy as sch

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def differential_expression_heatmap(
    adata, gene_dict, groupby="cell_type", group1="Reference", group2="CKD",
    save_path=None, title_fontsize=16, tick_fontsize=12, annot_fontsize=10,
    colorbar_position=[0.85, 0.75, 0.03, 0.2], figsize=(15, 12), dendrogram_ratio=(0.05, 0.2),
    cell_types=None, min_cells_per_group=10, gene_list=None, cmap="bwr", clustering_method="ward"
):
    """
    Generates a clustered heatmap showing log fold change (logFC) in gene expression
    between two groups across cell types with statistical significance testing.
    """
    # Filter gene_dict if gene_list is provided
    if gene_list:
        gene_dict = {k: v for k, v in gene_dict.items() if k in gene_list}

    if not gene_dict:
        raise ValueError("No genes selected for analysis.")

    # Filter by cell types if specified
    adata_subset = adata.copy()
    if cell_types:
        adata_subset = adata_subset[adata_subset.obs[groupby].isin(cell_types)]

    # Filter to include only the specified disease types
    adata_subset = adata_subset[adata_subset.obs["diseasetype"].isin([group1, group2])]

    # Convert Ensembl IDs to gene names
    ensembl_to_gene = {v: k for k, v in gene_dict.items()}
    valid_ensembl_ids = [gene_id for gene_id in gene_dict.values() if gene_id in adata_subset.var_names]

    if not valid_ensembl_ids:
        raise ValueError("None of the specified genes are found in the dataset.")

    gene_labels = [ensembl_to_gene[ensembl_id] for ensembl_id in valid_ensembl_ids]

    # Extract expression data
    expression_data = adata_subset[:, valid_ensembl_ids].to_df()
    expression_data = expression_data.join(adata_subset.obs[[groupby, "diseasetype"]])

    # Filter cell types with too few cells
    if min_cells_per_group > 0:
        cell_counts = expression_data.groupby([groupby, "diseasetype"]).size().unstack(fill_value=0)
        valid_cell_types = cell_counts[
            (cell_counts[group1] >= min_cells_per_group) &
            (cell_counts[group2] >= min_cells_per_group)
        ].index.tolist()

        if not valid_cell_types:
            raise ValueError(f"No cell types with at least {min_cells_per_group} cells per group found.")

        expression_data = expression_data[expression_data[groupby].isin(valid_cell_types)]

    # CHANGE 3: Normalize data (z-score)
    normed_expression = pd.DataFrame()
    for cell_type in expression_data[groupby].unique():
        cell_data = expression_data[expression_data[groupby] == cell_type][valid_ensembl_ids]

        # Skip normalization for very low variance data
        stds = cell_data.std()
        for col in cell_data.columns:
            if stds[col] > 1e-10:
                cell_data[col] = (cell_data[col] - cell_data[col].mean()) / (stds[col] + 1e-10)
            else:
                cell_data[col] = cell_data[col] - cell_data[col].mean()

        cell_data[groupby] = cell_type
        cell_data["diseasetype"] = expression_data.loc[cell_data.index, "diseasetype"]
        normed_expression = pd.concat([normed_expression, cell_data])

    expression_data = normed_expression

    # CHANGE 4: Use median for robust calculation
    group1_median = expression_data.loc[expression_data["diseasetype"] == group1].groupby(groupby)[valid_ensembl_ids].median()
    group2_median = expression_data.loc[expression_data["diseasetype"] == group2].groupby(groupby)[valid_ensembl_ids].median()

    # CHANGE 2: Use adaptive pseudocount
    min_nonzero = max(expression_data[valid_ensembl_ids].replace(0, np.nan).min().min(), 1e-6)
    pseudocount = min_nonzero

    # Compute log fold change with careful handling of edge cases
    logFC_data = np.log2((group2_median + pseudocount) / (group1_median + pseudocount))

    # Handle problematic values BEFORE clustering
    logFC_data = logFC_data.fillna(0)
    logFC_data = logFC_data.replace([np.inf, -np.inf], 0)

    # CHANGE 1: Statistical testing
    pvalues = pd.DataFrame(index=logFC_data.index, columns=logFC_data.columns)
    for cell_type in expression_data[groupby].unique():
        for gene_id in valid_ensembl_ids:
            gene_name = ensembl_to_gene[gene_id]
            group1_expr = expression_data.loc[(expression_data["diseasetype"] == group1) &
                                             (expression_data[groupby] == cell_type), gene_id]
            group2_expr = expression_data.loc[(expression_data["diseasetype"] == group2) &
                                             (expression_data[groupby] == cell_type), gene_id]

            # Use Mann-Whitney U test (non-parametric)
            if len(group1_expr) >= 3 and len(group2_expr) >= 3:  # Minimum sample size for validity
                try:
                    stat, pval = stats.mannwhitneyu(group1_expr, group2_expr, alternative='two-sided')
                    pvalues.loc[cell_type, gene_name] = pval
                except ValueError:
                    pvalues.loc[cell_type, gene_name] = 1.0
            else:
                pvalues.loc[cell_type, gene_name] = 1.0

    # Fill any missing p-values
    pvalues = pvalues.fillna(1.0)

    # Highlight statistically significant changes
    sig_mask = pvalues < 0.05

    # Rename columns to gene names
    logFC_data.index.name = "Cell Type"
    logFC_data.columns = gene_labels

    # Ensure data is clean for clustering
    data_for_clustering = logFC_data.T.fillna(0).replace([np.inf, -np.inf], 0)

    # Check if data is valid for clustering
    if np.all(np.isfinite(data_for_clustering)):
        row_linkage = sch.linkage(data_for_clustering, method=clustering_method)

        # Plot with clustering
        g = sns.clustermap(
            logFC_data.T,
            cmap=cmap, center=0, linewidths=0.5,
            annot=True, fmt=".2f",
            row_cluster=True, col_cluster=False, row_linkage=row_linkage,
            figsize=figsize, annot_kws={"size": annot_fontsize},
            xticklabels=True, yticklabels=True,
            dendrogram_ratio=dendrogram_ratio,
            cbar_pos=colorbar_position,
            mask=None if sig_mask.empty else (~sig_mask.T)
        )
    else:
        # Fallback: plot without clustering if data still has problems
        logging.warning("Data contains non-finite values. Creating heatmap without clustering.")
        fig, ax = plt.subplots(figsize=figsize)
        g = sns.heatmap(
            logFC_data.T,
            cmap=cmap, center=0, linewidths=0.5,
            annot=True, fmt=".2f",
            ax=ax, cbar=True,
            xticklabels=True, yticklabels=True
        )
        g = fig

    # Set title with adjustable font size
    plt.title(f"Log Fold Change (logFC) in Gene Expression: {group2} vs {group1}", fontsize=title_fontsize)

    # Save or show plot
    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=300)

        # Save the data as Excel file with p-values
        if save_path.endswith(('.pdf', '.png', '.jpg')):
            excel_path = save_path.rsplit('.', 1)[0] + '.xlsx'
            with pd.ExcelWriter(excel_path) as writer:
                logFC_data.to_excel(writer, sheet_name='Log Fold Changes')
                pvalues.to_excel(writer, sheet_name='P-values')
            logging.info(f"Data and p-values saved to {excel_path}")
    else:
        plt.show()

    return g.fig if hasattr(g, 'fig') else g

In [None]:
differential_expression_heatmap(adata, gene_dict, groupby="cell_type", group1="Reference", group2="CKD", title_fontsize=14, tick_fontsize=8, annot_fontsize=6, colorbar_position=[0.98, 0.30, 0.01, 0.50], figsize=(12, 17), dendrogram_ratio=(0.05, 0.2))