In [None]:
from scipy.stats import median_abs_deviation
import pandas as pd
import scvi
import scanpy as sc
import os
import gseapy as gp
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patheffects as PathEffects
from adjustText import adjust_text
import matplotlib.pyplot as plt
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
import seaborn as sns
import celltypist


In the first step we access the raw read count matricies for each of the samples, carry out preprocessing to filter out low quality cells and them combine them into a single annotation data object

In [None]:
def outlier_removal(anndata):
    """
    Perform quality control on the input AnnData object by filtering out outliers with median absolute deviation greater or less than 5.

    Args:
        anndata: The input annotation data matrix 
    Returns:
        annData: The filtered Annotation data matrix after quality control.

    """
    # annotate the mitochondrial genes
    anndata.var["mt"] = anndata.var_names.str.startswith("MT-")
    # annotate the ribosomal genes
    anndata.var["ribo"] = anndata.var_names.str.startswith(("RPS", "RPL"))
    #computer quality control metrics
    sc.pp.calculate_qc_metrics(
        anndata, qc_vars=["mt", "ribo"], inplace=True, percent_top=[20], log1p=True
    )
    #filter cells with median absolute deviation greater or less than 5
    def is_outlier(anndata, metric: str, nmads: int):
        M = anndata.obs[metric]
        outlier = (M < np.median(M) - nmads * median_abs_deviation(M)) | (
            np.median(M) + nmads * median_abs_deviation(M) < M
        )
        return outlier
    
    #annotate the outliers
    anndata.obs["outlier"] = (
        is_outlier(anndata, "log1p_total_counts", 5)
        | is_outlier(anndata, "log1p_n_genes_by_counts", 5)
        | is_outlier(anndata, "pct_counts_in_top_20_genes", 5)
    )
    #filter out the outliers
    anndata = anndata[(~anndata.obs.outlier)].copy()

    return anndata


In [None]:
def generate_combined_anndata(folder):
    """
    Each of the raw count matrices from the indiviudal samples were accessed and preprocessed to remove doublets, outliers, and low quality cells. 
   
    Parameters:
    - folder (str): Folder path containing raw read count matrix files from the different samples.

    Returns:
    - anndata_list (list): A list of filtered Annotation Data objects generated from raw count matrices.
    """

    anndata_list=[]
    #Iterate through the raw count matrix files
    for filename in os.listdir(folder):
        #Print the filename
        print(filename, filename.split("_")[1][-3:])
        #Access the .csv raw count matrices 
        if filename.endswith(".csv"):
            #Load the data
            anndata = sc.read_csv(folder + "/" + filename).T
            #Print the initial number of cells
            print("initial number of cells:", len(anndata.obs))
            #Remove duplicate genes
            anndata.var_names_make_unique()
            

            #DOUBLET REMOVAL
            #Filter out low quality cells
            sc.pp.filter_genes(anndata, min_cells = 10)
            sc.pp.filter_cells(anndata, min_genes=200)
            #Subsetting the top 2000 highly variable genes
            sc.pp.highly_variable_genes(anndata, flavor = 'seurat_v3'
                                        ,n_top_genes = 2000, subset = True)
            #Train scvi model
            scvi.model.SCVI.setup_anndata(anndata)
            vae = scvi.model.SCVI(anndata)
            vae.train()
            #Train SOLO model for doublet prediction
            solo = scvi.external.SOLO.from_scvi_model(vae)
            solo.train()
            #Predict doublets
            doublet_predition_results_df = solo.predict()
            doublet_predition_results_df['prediction'] = solo.predict(soft = False)
            #Select doublets with a difference of more than 1 between the doublet and singlet prediction scores
            doublet_predition_results_df['dif'] = doublet_predition_results_df.doublet - doublet_predition_results_df.singlet
            doublets = doublet_predition_results_df[(doublet_predition_results_df.prediction == 'doublet') & (doublet_predition_results_df.dif > 1)]
            #Reload the data
            anndata = sc.read_csv(folder + "/" + filename).T
            anndata.obs_names_make_unique()
            #Remove the doublets
            anndata.obs['doublet'] = anndata.obs.index.isin(doublets.index)
            anndata = anndata[~anndata.obs.doublet]
            print("Cells after doublet removal:", len(anndata.obs))


            #ASSIGN SAMPLE NAME
            #Extract Sample name from the filename
            anndata.obs["Sample"]=filename.split("_")[1]

            
            #FILTER OUTLIERS
            anndata=outlier_removal(anndata)
            print("Cells after qc filtering removal:", len(anndata.obs))
            
            #Append the filtered AnnData object to the list
            anndata_list.append(anndata)

            print(f" successfully loaded {folder}/{filename}")
          
    return anndata_list

#Generate the combined AnnData object
combined_anndata = sc.concat(generate_combined_anndata("../Covid-19-sc"))

In [None]:
#ADDITIONAL FILTERING TO REMOVE LOW QUALITY CELLS

#Filter out cells with more than 10% mitochondrial genes
combined_anndata = combined_anndata[combined_anndata.obs['pct_counts_mt'] <= 10]
#Filter out cells with more than 10% ribosomal genes
combined_anndata = combined_anndata[combined_anndata.obs['pct_counts_ribo'] <= 10]
#Filter out cells with more than 15000 counts
combined_anndata = combined_anndata[combined_anndata.obs['total_counts'] < 15000].copy()
#Filter out cells with less than 200 genes counts
sc.pp.filter_cells(combined_anndata, min_genes=200)

In [None]:
#PLOTTING THE QC METRICS

p1 = sns.displot(combined_anndata.obs["total_counts"], bins=100, kde=False)
p2 = sc.pl.violin(combined_anndata, "pct_counts_mt")
p2 = sc.pl.violin(combined_anndata, "pct_counts_ribo")
p3 = sc.pl.scatter(combined_anndata, "total_counts", "n_genes_by_counts", color="pct_counts_mt")

Next we annotate the cell types using Cell-typist

In [None]:
#We store the raw counts in the "raw_counts" layer of the combined_anndata object for future use
combined_anndata.layers["raw_counts"]=combined_anndata.X.copy()

#Cell typist requires the data to be log normalized and scaled to a library size of 1e4

#normalize the data and scale it to 10000
sc.pp.normalize_total(combined_anndata, target_sum=1e4)
#log transform the data
sc.pp.log1p(combined_anndata)


#predict the cell types
cell_type_predictions = celltypist.annotate(combined_anndata, model = 'Lethal_COVID19_Lung.pkl', majority_voting = True)
#transfer the predictions to the annotation data object
predictions_adata = cell_type_predictions.to_adata()

#add biological condition to the cells using the information from the sample name
combined_anndata.obs['Condition'] = combined_anndata.obs['Sample'].apply(lambda x: 'covid' if x.endswith('cov') else 'control')

We now separate the Epithelial cells for further investigation

In [None]:
#Epithelial cell types In the Lung
cell_types = ["AT1","AT2",  "Airway goblet","Airway ciliated","Airway basal","Airway club", 'Airway mucous', "ECM-high epithelial","Cycling epithelial"]
#Keep only these cell types and filter our the rest
combined_anndata=combined_anndata[combined_anndata.obs["majority_voting"].isin(cell_types)]

In [None]:
# PLOT THE UMAP
#Compute the UMAP
sc.tl.umap(combined_anndata)
#Plot the UMAP highlighting the cell types
sc.pl.umap(combined_anndata, color=['majority_voting'], frameon=True, legend_fontsize=8)
#Plot the UMAP highlighting the samples
sc.pl.umap(combined_anndata, color="Sample",frameon=True)
sc.pl.umap(combined_anndata, color="Condition",frameon=True)

Next, we carry out differential expression analysis through pseudobulking and visualize the results


In [None]:
#We use raw counts for carrying out DGE
combined_anndata.X=combined_anndata.layers["raw_counts"].copy()

#the genes of interst we want to investigate
antiviral_zap=["ZC3HAV1"]
rna_editing_enzymes=['AICDA', 'APOBEC1', 'APOBEC2', 'APOBEC3A', 'APOBEC3B', 'APOBEC3C', 'APOBEC3D', 'APOBEC3F', 'APOBEC3G', 'APOBEC3H', 'APOBEC4','ADAR', 'ADARB1', 'ADARB2']
gene_list=antiviral_zap+rna_editing_enzymes


In [None]:
# Function to plot the volcano plot
def plot_volcano(dea_results, genes_of_interest, logfc_threshold=.75, padj_threshold=0.1, adjust=True):
    """
    Plot a volcano plot to visualize differential gene expression analysis results.

    Parameters:
    - dea_results (dataframe): The dataframe containing differential gene expression analysis results.
    - genes_of_interest (list): A list of genes of interest to be highlighted in the plot.
    - logfc_threshold (float, optional): The threshold for log2 fold change. Default is 1.
    - padj_threshold (float, optional): The significance threshold for adjusted p-value. Default is 0.1.
    - adjust (bool, optional): Whether to adjust the position of gene labels to avoid overlap. Default is True.

    """
    #plot size
    plt.figure(figsize=(10, 10))


    # Filter the genes based on the thresholds
    nonsig_genes = dea_results[(dea_results['padj'] > padj_threshold) | (abs(dea_results['log2FoldChange']) < logfc_threshold)]
    sig_genes = dea_results[(dea_results['padj'] <= padj_threshold) & (abs(dea_results['log2FoldChange']) >= logfc_threshold)]

    # Plot non-significant genes
    plt.scatter(nonsig_genes['log2FoldChange'], -np.log10(nonsig_genes['padj']), c='#DCDCDC', edgecolor='#E0E0E0', label='Not Statistically Significant', marker='o')

    # Plot significant genes
    plt.scatter(sig_genes['log2FoldChange'], -np.log10(sig_genes['padj']), c='#A9A9A9', edgecolor='#E0E0E0', label='Statistically Significant', marker='o')

    # Plot the genes of interest with a different marker
    genes_of_interest_points = dea_results[dea_results['gene_symbol'].isin(genes_of_interest)]
    plt.scatter(genes_of_interest_points['log2FoldChange'], -np.log10(genes_of_interest_points['padj']), c='b', edgecolor='none', label='Genes of Interest', marker='^')

    # Draw a horizontal line for padj threshold
    plt.axhline(y=-np.log10(padj_threshold), color='#909090', linestyle='--', linewidth=1)

    # Draw two vertical lines for log2foldchange threshold
    plt.axvline(x=logfc_threshold, color='#909090', linestyle='--', linewidth=1)
    plt.axvline(x=-logfc_threshold, color='#909090', linestyle='--', linewidth=1)

    # Label the genes of interest 
    texts = []
    for gene in genes_of_interest:
        gene_data = dea_results[dea_results['gene_symbol'] == gene]
        if not gene_data.empty:
            x = gene_data['log2FoldChange'].values[0]
            y = -np.log10(gene_data['padj'].values[0])
            txt = plt.text(x, y, gene, size=20, ha='center', va='center') 
            txt.set_path_effects([PathEffects.withStroke(linewidth=3, foreground='w')])
            texts.append(txt)

    plt.legend()
    plt.xlabel("$log_{2}$ fold change", size=15)
    plt.ylabel("-$log_{10}$ $p_{adj}$-value", size=15)

    # Adjust text to avoid overlap
    if adjust:
        adjust_text(texts, arrowprops=dict(arrowstyle="-", color='k', lw=0.5))

    # Set the size of x-ticks and y-ticks
    plt.xticks(size=12)
    plt.yticks(size=12)
    
    plt.show()

In [None]:

#Minimum number of cells per sample required to form pseudobulks
minimum_cell_population_per_sample=30
#Minimum number of pseudobulks per condition required to perform DGE analysis
minimum_sample_per_condition=3

#To record the log2foldchange data for each cell type
lfc_data=[]
#To record the cell types that qualified for pseudobulk-DGE analysis
cell_studied=[]

#iterate through each of the epithelial cell types
for cell_type in cell_types:
    #variables to record the number of control and covid pseudobulks
    no_of_control_pseudobulks=0
    no_of_covid_pseudobulks=0
    samples=[]
    #print the cell type under investigation
    print(f"Cell Type: {cell_type}")
    #To record the combined gene counts for the pseudobulks samples for this cell type
    combined_count_2d=[]
    #Subsetting the annotation data object for this cell type
    cell_subset=[]
    cell_subset = combined_anndata[combined_anndata.obs['majority_voting'] == cell_type]
    
    #Iterate through each of the samples(biological replicate) for this cell type to form pseudobulks
    for sample in cell_subset.obs['Sample'].unique():
        samp_cell_subset=[]
        samp_cell_subset = cell_subset[cell_subset.obs['Sample'] == sample]

        #check if the pseudobulk has the minimum cell population required to form a pseudobulk
        if(len(samp_cell_subset)>=minimum_cell_population_per_sample):
            print(f"{sample} satisfies minimum cell population criteria: {len(samp_cell_subset.obs_names)}")
            #record the counts for the pseudobulk
            total_counts = [int(count) for count in samp_cell_subset.X.sum(axis=0)]
            #record the pseudobulk sample name and biological condition for generating the metadata
            samples.append(sample)
            combined_count_2d.append(total_counts)
            if(sample[3:6] == "ctr"):
                no_of_control_pseudobulks+=1
            else:  
                no_of_covid_pseudobulks+=1
        else:
            #remove the sample if it does not have the minimum number of cells required to form a pseudobulk
            print(f"{sample} does not satisfy minimum cell population criteria: {len(samp_cell_subset.obs_names)}")
    #check if we have the minimum number of pseudobulks per condition to perform DEA for this cell type   
    if(no_of_control_pseudobulks>=minimum_sample_per_condition and no_of_covid_pseudobulks>=minimum_sample_per_condition):
        #PERFORM DEA analysis using DESeq2
        #create a counts dataframe using the combined gene counts from the pseudobulks
        counts_df = pd.DataFrame(combined_count_2d, index=samples, columns = combined_anndata.var_names) 
        #create metadata
        meta_data = pd.DataFrame([sample[3:6] for sample in samples], index=samples, columns = ["Condition"])
        print("Counts DataFrame shape:", counts_df.shape)
        print("Metadata DataFrame shape:", meta_data.shape)

        #keep genes that have at least a total count of 5
        genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= 5]
        counts_df = counts_df[genes_to_keep]

        #run DESeq2
        dds = DeseqDataSet(
            counts=counts_df,
            metadata=meta_data,
            design_factors="Condition",
            n_cpus=4, 
        )
        dds.deseq2()
        stat_ressults = DeseqStats(dds, contrast=('Condition', 'cov', 'ctr'))
        stat_ressults.summary()

        #Get the DEA analysis results
        dea_results  = stat_ressults.results_df
        dea_results["gene_symbol"]=dea_results.index

        # PLOT THE VOLCANO PLOT
        plot_volcano(dea_results, gene_list)

        #CARRY OUT GENE SET ENRICHMENT ANALYSIS
        #rank the genes based on the product of the negative logarithm of the adjusted p-value and log2foldchange
        dea_results['Rank'] = -np.log10(dea_results["padj"])*dea_results["log2FoldChange"]
        ranking = dea_results[['Rank']]
        ranking['Gene']= ranking.index
        ranking = ranking.reset_index(drop=True)
        ranking = ranking.reindex(columns=['Gene'] + [col for col in ranking.columns if col != 'Gene'])
        #path to the gene set for oxidative damage response pathway
        gene_set="WP_OXIDATIVE_DAMAGE_RESPONSE.v2023.2.Hs.gmt"
        #perform GSEA
        pre_res = gp.prerank(rnk = ranking, gene_sets =  gene_set, seed = 6, min_size=3)
        #plot the gene set enrichment analysis results
        axs = pre_res.plot(terms="WP_OXIDATIVE_DAMAGE_RESPONSE",figsize=(10,10))
        

        #RECORD THE LOG FOLD CHANGES OF THE GENES OF INTEREST
        #extract the Differntial Expression Analysis results for the genes of interest
        genes_of_interest_df=dea_results[dea_results.index.isin(gene_list)]
        #reorder the index to match the order in the gene list
        genes_of_interest_df=genes_of_interest_df.reindex(gene_list)
        #record the log2foldchange for this cell type
        lfc_data.append(genes_of_interest_df["log2FoldChange"].to_list())
        #record the cell type that was studied
        cell_studied.append(cell_type)

#PLOT THE HEATMAP LOG FOLD CHANGE OF THE GENES OF INTEREST ACROSS THE CELL TYPES
lfc_df = pd.DataFrame(np.transpose(lfc_data), index=gene_list, columns=cell_studied)
plt.figure(figsize=(4, 8))
ax = sns.heatmap(lfc_df, annot=True, cmap='coolwarm', linewidths=.5, fmt='.2f')  
# Rotate x ticks by 90 degrees
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=12)
# Decrease the overall font size
sns.set_context("paper", font_scale=1)  
# Set the title
ax.set_title(f"Log Fold Change(Mean) in Cell types", fontsize=20, fontweight='bold', pad=20) 
# Set the x-axis label
ax.set(xlabel="Cells", ylabel="Genes")
plt.show()

We save the Processes Annotated Data object

In [None]:
#Save the annotated adata object
combined_anndata.write("processed_adata.h5ad")