# Single-Cell Transcriptomic Analysis

In [None]:

import scanpy as sc
import os
import gseapy as gp
import numpy as np
import matplotlib.patheffects as PathEffects
from adjustText import adjust_text
import matplotlib.pyplot as plt
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
import decoupler as dc
import pandas as pd
import seaborn as sns
from matplotlib.patches import Patch


##### In the first step we access the read count matricies for each of the replicates and then combine them into a single annotation data object.

In [None]:
def generate_combined_anndata(folder):
    """
    Each of the count matrices belonging to the different samples are accessed and combined to produce the combined annotation data object. 
   
    Parameters:
        - folder (str): Folder path containing read count matrix files from the different samples.

    Returns:
        - combined_anndata (adata): A combined Annotation Data object generated from count matrices.
    """

    anndata_list=[]
    # Iterate through the count matrix files
    for filename in os.listdir(folder):
        # Print the filename being accessed
        print(filename)
        # Access the .csv count matrices 
        if filename.endswith(".csv"):
            #Load the AnnData object
            anndata = sc.read_csv(folder + "/" + filename).T        
            
            #Append the AnnData object to the list
            anndata_list.append(anndata)

            print(f" successfully loaded {folder}/{filename}")

    # Generate the combined AnnData object
    combined_anndata = sc.concat(anndata_list)
    
    # Return the combined AnnData object
    return combined_anndata

# Call the function to generate the combined AnnData object
combined_anndata = generate_combined_anndata("../Single-Cell-Transcriptomics/GSE171524_RAW")

##### We load the medatada file containing the cell information and use it for celltype annotation of the cells in the combined Anndata object.

In [None]:
# Load the metadata file
metadata=pd.read_csv("../Single-Cell-Transcriptomics/GSE171524_lung_metaData.txt",sep="\t", index_col=0)

# Annotate the annotation data object using the information from the metadata file
combined_anndata.obs = combined_anndata.obs.merge(metadata, left_index=True, right_index=True, how='left')

##### We now separate the AT2 cells for further investigation.

In [None]:
cell_types= ["AT2"]

#Keep only this cell type and filter our the rest
combined_anndata=combined_anndata[combined_anndata.obs["cell_type_fine"].isin(cell_types)]

##### We check the total number of AT2 Cells in the combined Annotation data object

In [None]:
combined_anndata.obs.groupby("cell_type_fine").count()

##### We check the cell counts for each biosample used in the study.

In [None]:
combined_anndata.obs.groupby(["cell_type_fine", "biosample_id"]).count()

##### We check the total number of AT2 Cells in the combined Annotation data object coming from the SARS-CoV-2 patients and Control samples


In [None]:
combined_anndata.obs.groupby("disease__ontology_label").count()

##### Next, we carry out differential expression analysis through pseudobulking and visualize the results


In [None]:
# Generate the pseudo-bulk profiles
pseudobulk_data = dc.get_pseudobulk(
    combined_anndata,
    sample_col='biosample_id',
    groups_col='cell_type_intermediate',
    mode='sum',
    min_cells=25,
    min_counts=2500
)

In [None]:
# Visualize the pseudobulk replicates
pseudobulk_data.obs

In [None]:
#Number of infected and control pseudobulk replicates
pseudobulk_data.obs.groupby("disease__ontology_label").count()

In [None]:
# Carry out DEA using Deseq2
dds = DeseqDataSet(
    adata=pseudobulk_data,
    design_factors=['disease__ontology_label'],
)

# Generate and get the DEA analysis results
dds.deseq2()
stat_ressults = DeseqStats(dds, contrast=('disease--ontology-label', 'COVID-19', 'normal'))
stat_ressults.summary()
dea_results  = stat_ressults.results_df

In [None]:
# The genes of interst we want to investigate
antiviral_zap=["ZC3HAV1"]
apobecs=['AICDA', 'APOBEC1', 'APOBEC2', 'APOBEC3A', 'APOBEC3B', 'APOBEC3C', 'APOBEC3D', 'APOBEC3F', 'APOBEC3G', 'APOBEC3H', 'APOBEC4']
adars=['ADAR', 'ADARB1', 'ADARB2']
gene_list=antiviral_zap+apobecs+adars

# Make directory to store the single cell analysis images
os.makedirs('../Results/sc_transcriptomic_analysis/', exist_ok=True)

In [None]:
# See the DESEQ2 results for the genes of interest
dea_results[dea_results.index.isin(gene_list)].to_csv("../Results/sc_transcriptomic_analysis/DESEQ2_results_genes_of_interest.csv")

In [None]:
# Plot the DEA results volcano plots highlighting the genes of interests

def plot_volcano(data, gene_list, logfc_threshold=.75, padj_threshold=0.05, adjust=True):
    """
    Plot a volcano plot to visualize differential gene expression analysis results.

    Parameters:
    - data (pandas.DataFrame): The data containing gene expression analysis results.
    - gene_list (list): A list of genes of interest to be highlighted in the plot.
    - logfc_threshold (float, optional): The threshold for log2 fold change. Default is 0.75.
    - padj_threshold (float, optional): The threshold for adjusted p-value. Default is 0.1.
    - adjust (bool, optional): Whether to adjust the position of gene labels to avoid overlap. Default is True.
    """
    plt.figure(figsize=(10, 10))

    # Filter points based on thresholds
    nonsig_points = data[(data['padj'] > padj_threshold) | (abs(data['log2FoldChange']) < logfc_threshold)]
    sig_points = data[(data['padj'] <= padj_threshold) & (abs(data['log2FoldChange']) >= logfc_threshold)]
    sig_pval_only = data[data['padj'] <= padj_threshold]

    # Plot non-significant points
    plt.scatter(nonsig_points['log2FoldChange'], -np.log10(nonsig_points['padj']), c='#DCDCDC', edgecolor='#E0E0E0', label='Not Statistically Significant', marker='o')

    # Plot significant points
    plt.scatter(sig_points['log2FoldChange'], -np.log10(sig_points['padj']), c='#A9A9A9', edgecolor='#E0E0E0', label='Statistically Significant', marker='o')

    # Plot genes in the filtered gene list with a different marker
    genes_of_interest = data[data.index.isin(gene_list)]
    plt.scatter(genes_of_interest['log2FoldChange'], -np.log10(genes_of_interest['padj']), c='b', edgecolor='none', label='Genes of Interest', marker='^')

    # Draw horizontal line for padj threshold
    plt.axhline(y=-np.log10(padj_threshold), color='#909090', linestyle='--', linewidth=1)

    # Draw vertical lines for log2 fold change thresholds
    plt.axvline(x=logfc_threshold, color='#909090', linestyle='--', linewidth=1)
    plt.axvline(x=-logfc_threshold, color='#909090', linestyle='--', linewidth=1)

    # Label points in the filtered gene list
    texts = []
    for gene in gene_list:
        gene_data = data[data.index == gene]
        if not gene_data.empty:
            x = gene_data['log2FoldChange'].values[0]
            y = -np.log10(gene_data['padj'].values[0])
            gene_label = gene + (r'$^{★}$' if gene in sig_pval_only.index else '')
            txt = plt.text(x, y, gene_label, size=12, ha='center', va='center')  
            txt.set_path_effects([PathEffects.withStroke(linewidth=3, foreground='w')])
            texts.append(txt)
    
    # Add legend for the statistically significant genes of interest which are highlighted with a '★' marker
    plt.scatter([], [], c='black', marker="*",edgecolor='none', s=70, label=f'Statistically Significant Genes of \nInterest ($p_{{adj}}$ < {padj_threshold})')


    plt.legend(fontsize=12) 
    plt.xlabel("$log_{2}$ fold change", size=15)
    plt.ylabel("-$log_{10}$ $p_{adj}$-value", size=15)

    if adjust:
        adjust_text(texts, arrowprops=dict(arrowstyle="-", color='k', lw=0.5))
        
    # Save the plot
    plt.savefig('../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_volcano_plot.png', dpi = 1000)
    print("The Volcano plot has been saved at '../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_volcano_plot.png'")
    
    plt.show()



# Call the function to plot the volcano plot highlighting the genes of interest
plot_volcano(dea_results, gene_list)


In [None]:
# Carry out Geneset Enrichment Analysis of the Oxidative damage reponse geneset

def run_gene_set_enrichment_analysis(dea_results, oxidative_stress_gene_set):
    """
    Runs gene set enrichment analysis using GSEApy for a specified gene set.

    Parameters:
        - dea_results(pandas dataframe): DESeq2 results with columns 'GeneSymbol', 'padj', and 'log2FoldChange'.
        - oxidative_stress_gene_set (str): Path to the Oxidative Damage response gene set file.
    """

    # Create a ranking of genes based on the DESeq2 results
    dea_results["-log_padj x lfc"] = -np.log(dea_results["padj"]) * dea_results["log2FoldChange"]
    # Sort the genes by the -log_padj x lfc
    ranked_genes = dea_results[['GeneSymbol', '-log_padj x lfc']].dropna().sort_values('-log_padj x lfc', ascending=False)
    # Reset the index to GeneSymbol column
    ranked_genes.reset_index(drop=True, inplace=True)

    # Run Gene Set Enrichment Analysis
    pre_res = gp.prerank(rnk=ranked_genes, gene_sets=oxidative_stress_gene_set, seed=6, min_size=0, max_size=200)

    # Plot the results
    pre_res.plot(terms="WP_OXIDATIVE_DAMAGE_RESPONSE", figsize=(10, 10))
    
    # Save the figure
    plt.savefig("../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_gsea_plot.png", dpi=1000)
    print("The GSEA plot has been saved at '../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_gsea_plot.png'")
    
# Explicitly Add the GeneSymbol column to the DataFrame as it would be needed for GSEA
dea_results["GeneSymbol"]=dea_results.index

# Call the function to run gene set enrichment analysis
run_gene_set_enrichment_analysis(dea_results, "../WP_OXIDATIVE_DAMAGE_RESPONSE.gmt")

In [None]:
# Plot the log fold changes heatmap for the genes of interest

def plot_lfd_heatmap(dea_results, gene_list):
    """
    Plot the Log2 Fold Changes heatmap for the Genes of Interest
    
    Parameters:
        - dea_results(pandas dataframe): DESeq2 results with columns 'GeneSymbol', 'padj', and 'log2FoldChange'.
        - gene_list (list): List of the Gene Symbols of the genes of interest.
    """
    # Extract the Log2 Fold Changes for the genes of interest
    log2_fold_changes = [
        dea_results.loc[dea_results["GeneSymbol"] == gene, "log2FoldChange"].values[0]
        if not dea_results.loc[dea_results["GeneSymbol"] == gene, "log2FoldChange"].empty
        else None
        for gene in gene_list
    ]

    # Create a DataFrame for the log2 fold changes, keeping missing values (None)
    log2_fold_changes_df = pd.DataFrame(
        data=log2_fold_changes, 
        index=gene_list, 
        columns=["log2FoldChange"]
    )

    # Plotting the heatmap, including missing values as NaN (None will appear as NaN)
    plt.figure(figsize=(0.75, 11))  # Dynamically adjust height
    heatmap = sns.heatmap(
        log2_fold_changes_df,
        cmap="coolwarm",
        annot=True,
        fmt=".2f",
        cbar=False,
        annot_kws={"size": 10},
        mask=log2_fold_changes_df.isnull()  # Mask the missing values for better visualization
    )

    # Set the title
    #plt.title(r'Log$_2$ Fold Changes', fontweight="bold", fontsize=10)

    # Turn off x-axis ticks
    plt.xticks([])

    # Set the font size for y-axis ticks
    heatmap.yaxis.set_tick_params(labelsize=10)

    plt.tight_layout()
    
    # Save the heatmap
    plt.savefig('../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_log2foldchange_heatmap.png', dpi=1000, bbox_inches='tight')
    print("The heatmap has been saved at '../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_log2foldchange_heatmap.png'")

    plt.show()


plot_lfd_heatmap(dea_results, gene_list)

In [None]:
# Plot the Base Mean heatmap for the genes of interest

def plot_basemean_expression_heatmap(dea_results, gene_list):
    """
    Plot the Base Mean heatmap for the Genes of Interest
    
    Parameters:
        - dea_results (pandas.DataFrame): DESeq2 results with columns 'GeneSymbol', 'baseMean'.
        - gene_list (list): List of gene symbols of interest.
    """

    # Extract baseMean values for the genes of interest
    goi_basemeans = [
        dea_results.loc[dea_results["GeneSymbol"] == gene, "baseMean"].values[0]
        if not dea_results.loc[dea_results["GeneSymbol"] == gene, "baseMean"].empty
        else np.nan  # Use NaN instead of None for compatibility with heatmap
        for gene in gene_list
    ]

    # Read oxidative damage response gene list
    oxidative_damage_genes = gp.parser.read_gmt('../WP_OXIDATIVE_DAMAGE_RESPONSE.gmt')["WP_OXIDATIVE_DAMAGE_RESPONSE"]

    # Calculate the mean baseMean for oxidative damage response genes
    oxidative_basemeans = [
        dea_results.loc[dea_results["GeneSymbol"] == gene, "baseMean"].values[0]
        for gene in oxidative_damage_genes if gene in dea_results["GeneSymbol"].values
    ]
    oxidative_damage_mean = np.nanmean(oxidative_basemeans)  # Avoid errors if list is empty

    # Create a DataFrame including the mean baseMean of oxidative damage response genes
    goi_basemeans_df = pd.DataFrame(
        data=goi_basemeans + [oxidative_damage_mean],
        index=gene_list + ["WP_OXIDATIVE_\nDAMAGE_RESPONSE"],
        columns=["baseMean"]
    )

    # Plot heatmap
    plt.figure(figsize=(0.75, 12.5))  # Adjust dynamically for better visualization
    heatmap = sns.heatmap(
        goi_basemeans_df,
        cmap="viridis",
        annot=True,
        fmt=".2f",
        cbar=False,
        annot_kws={"size": 10},
        mask=goi_basemeans_df.isnull()  # Mask NaN values
    )

    # Set y-tick labels explicitly
    heatmap.set_yticklabels(heatmap.get_yticklabels(), rotation=0, fontsize=10)

    # Remove x-axis ticks
    plt.xticks([])

    plt.tight_layout()

    # Save the figure with tight bounding box
    save_path = '../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_basemean_heatmap.png'
    plt.savefig(save_path, dpi=1000, bbox_inches='tight', pad_inches=0.1)
    print(f"The heatmap has been saved at '{save_path}'")

    # Display the plot
    plt.show()

plot_basemean_expression_heatmap(dea_results, gene_list)

In [None]:

def plot_heatmap_with_clustermap(dds, gene_list, metadata):
    """
    Plot a heatmap using the genes of interest for the different replicates.

    Args:
        dds (object): The DESeq2 object containing normalized counts and log1p transformed data.
        gene_list (list): The list of genes to be used in constructing the heatmap.
        metadata (DataFrame): The DataFrame containing metadata information of the replicates.

    """
    # Extract the normalized counts and log1p transform them with a pseudocount of 1
    #dds.layers['log1p'] = np.log1p(dds.layers['normed_counts'])
    
    # Extract the normalized counts of the genes of interest
    dds_goi = dds[:, dds.var_names.isin(gene_list)]

    # Create a DataFrame for plotting the heatmap
    goi_norm_counts_df = pd.DataFrame(
        dds_goi.layers['normed_counts'].T, 
        index=dds_goi.var_names, 
        columns=dds_goi.obs_names
    )

    # Extract the oxidative damage response genes from the WP_OXIDATIVE_DAMAGE_RESPONSE gene set
    oxidative_damage_genes = gp.parser.read_gmt('../WP_OXIDATIVE_DAMAGE_RESPONSE.gmt')["WP_OXIDATIVE_DAMAGE_RESPONSE"]

    # Subset the AnnData object to include only the oxidative damage response genes
    oxidative_damage_response_genes_dds = dds[:, dds.var_names.isin(oxidative_damage_genes)]

    # Create a DataFrame of the normalized counts for the oxidative damage response genes
    oxidative_damage_genes_norm_counts_df = pd.DataFrame(
        oxidative_damage_response_genes_dds.layers['normed_counts'].T, 
        index=oxidative_damage_response_genes_dds.var_names, 
        columns=oxidative_damage_response_genes_dds.obs_names
    )

    # Compute the mean normalized counts for the oxidative damage response genes
    oxidative_damage_genes_mean_norm_counts = oxidative_damage_genes_norm_counts_df.mean().tolist()

    # Add the mean normalized counts to the goi_norm_counts_df
    oxidative_damage_genes_mean_norm_counts_df=pd.DataFrame(data=[oxidative_damage_genes_mean_norm_counts], index=["WP_OXIDATIVE_\nDAMAGE_RESPONSE"], columns=oxidative_damage_response_genes_dds.obs_names)

    goi_norm_counts_df=pd.concat([goi_norm_counts_df, oxidative_damage_genes_mean_norm_counts_df])

    # Log1p transform the data
    goi_norm_counts_df = np.log1p(goi_norm_counts_df)

    # Get the indexes of control and covid conditions
    indexes_condition_control = metadata[metadata['group'] == "Control"].index.tolist()
    indexes_condition_covid = metadata[metadata['group'] == "COVID-19"].index.tolist()
    indexes_order = indexes_condition_control + indexes_condition_covid


    # Clean the data and reorder it based on gene_list and indexes_order
    cleaned_data = goi_norm_counts_df[indexes_order]
    cleaned_data = cleaned_data.reindex(gene_list+["WP_OXIDATIVE_\nDAMAGE_RESPONSE"])
    cleaned_data.fillna(0, inplace=True)
    cleaned_data = cleaned_data[sorted(cleaned_data.columns)]

    
    # Create a DataFrame to use as col_colors
    col_colors_df = pd.DataFrame(index=cleaned_data.columns)

    # highlight the control and the covid samples with different colors
    col_colors_df.loc[indexes_condition_control, 'color'] = 'blue'  # Color for control columns
    col_colors_df.loc[indexes_condition_covid, 'color'] = 'red'  # Color for covid columns

    cleaned_data.columns = cleaned_data.columns.str[:-4]

    # Generate heatmap with marked columns
    heatmap = sns.clustermap(cleaned_data, method='average', z_score=-1, cmap='RdYlBu_r', col_cluster=False, row_cluster=False, col_colors=[col_colors_df['color']],  figsize=(18, 8), cbar_pos=(0.1, .6, 0.03, 0.18) )
    heatmap.ax_heatmap.set_xticklabels(heatmap.ax_heatmap.get_xticklabels(), fontsize=12)
    heatmap.ax_heatmap.set_yticklabels(heatmap.ax_heatmap.get_yticklabels(), fontsize=12)
    heatmap.ax_heatmap.set_ylabel("")


    # Create the legend 
    legend = [Patch(color=color, label=label) for label, color in {'Control': 'blue', 'COVID-19': 'red'}.items()]

    # Add legend to the plot
    plt.legend(handles=legend, loc='upper left', bbox_to_anchor=(-.1, -0.2), frameon=False)

    # Save the plot
    plt.savefig('../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_normalized_expression_heatmap.png', dpi=1000)
    print("Plot saved at '../Results/sc_transcriptomic_analysis/sc_transcriptomic_analysis_normalized_expression_heatmap.png'")
    # Show the plot
    plt.show()
    return cleaned_data

# Generate the metadata for the samples and their associated condition
metadata = pseudobulk_data.obs[[ "group"]]

# Call the function to plot the heatmap
plot_heatmap_with_clustermap(dds, gene_list, metadata)


In [None]:
dds