# Run with scib-pipeline-R4.0 conda environment¶


# AUTOMATED DATA QC

In [None]:
# Import dependencies
%matplotlib inline
import os
import scanpy as sc
import seaborn as sns
import collections
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import h5py
import anndata

# Print date and time:
import datetime
e = datetime.datetime.now()
print ("Current date and time = %s" % e)

# set a working directory
wdir = "/mnt/da8aa2c4-0136-465b-87a2-d12a59afec55/akurjan/analysis/notebooks"
os.chdir( wdir )

# folder structures
RESULTS_FOLDERNAME = "foetal/results/QC"
FIGURES_FOLDERNAME = "foetal/figures/QC"

if not os.path.exists(RESULTS_FOLDERNAME):
    os.makedirs(RESULTS_FOLDERNAME)
if not os.path.exists(FIGURES_FOLDERNAME):
    os.makedirs(FIGURES_FOLDERNAME)

# Set folder for saving figures into
sc.settings.figdir = FIGURES_FOLDERNAME

# Set other settings
sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()
sc.set_figure_params(dpi=150, fontsize=10, dpi_save=600)

In [None]:
def savesvg(fname: str, fig, folder: str=FIGURES_FOLDERNAME) -> None:
    """
    Save figure as vector-based SVG image format.
    """
    fig.savefig(os.path.join(folder, fname), format='svg')

# DATA DESCRIPTION

#### Files: CellRanger count aligned snRNA-seq data. Aligned with CellRanger (v 7.0.0), ambient RNA removed with CellBender (v 0.2.0)

# LOADING h5ad FILES

In [None]:
def load_adatafiles(folder_path, common_text):
    """
    Loads all anndata objects that contain a common text from a specified folder into a dictionary object.

    Args:
        folder_path: A string representing the folder path where the anndata objects are located.
        common_text: A string representing the common text in the names of the anndata objects to be loaded.

    Returns:
        A dictionary object containing the loaded anndata objects.
    """

    adata_dict = {}

    for file in os.listdir(folder_path):
        if file.endswith('.h5ad') and common_text in file:
            obj = sc.read_h5ad(os.path.join(folder_path, file))
            key = "_".join(file.split('_')[:2])
            adata_dict[key] = obj

    return adata_dict

In [None]:
adata_dict = load_adatafiles(RESULTS_FOLDERNAME, '_unfiltered')
adata_dict

In [None]:
df = pd.DataFrame(columns=['Unfiltered Cells', 'Unfiltered Genes'])

for sample, adata in adata_dict.items():
    # Apply your filtering steps on 'anndata_obj' and obtain the filtered results
    
    # Obtain the cell and gene counts after filtering
    cell_count = adata.n_obs  # Assuming cell count is obtained from 'obs'
    gene_count = adata.n_vars  # Assuming gene count is obtained from 'var'
    
    # Append the counts to the dataframe
    df = df.append({'Unfiltered Cells': cell_count, 'Unfiltered Genes': gene_count}, ignore_index=True)

df.index = list(adata_dict.keys())
df.index.name = 'Sample'
df

In [None]:
def convert_uint_to_int(adata_dict):
    """
    Iterate over a dictionary of anndata objects and convert uint32 and uint36 dtypes
    to int32 and int36 dtypes, respectively. Prints a message for each conversion.
    """
    for key, value in adata_dict.items():
        if not isinstance(value, anndata.AnnData):
            continue
                    
        # Update X and layers dtypes
        if value.X.dtype == 'uint32':
            value.X = value.X.astype('float32')
            print("Converted X from uint32 to float32.")
        elif value.X.dtype == 'uint64':
            value.X = value.X.astype('float64')
            print("Converted X from uint64 to float64.")
        for layer_key, layer_val in value.layers.items():
            if layer_val.dtype == 'uint32':
                value.layers[layer_key] = layer_val.astype('int32')
                print(f"Converted layer {layer_key} from uint32 to int32.")
            elif layer_val.dtype == 'uint36':
                value.layers[layer_key] = layer_val.astype('int36')
                print(f"Converted layer {layer_key} from uint36 to int36.")

In [None]:
convert_uint_to_int(adata_dict)

In [None]:
for key, adata in adata_dict.items():
    print(f"Total number of genes: {adata.n_vars}")
    # Min 20 cells - filters out 0 count genes.
    sc.pp.filter_genes(adata, min_cells=20)
    print(f"Number of genes after gene filter: {adata.n_vars}")

    print(f"Total number of cells: {adata.n_obs}")
    # Removes cells with less than 200 UMI counts
    adata = adata[np.asarray(adata.X.sum(axis=1)).reshape(-1) > 200]
    print(f"Number of cells after 200 UMI minimum filter: {adata.n_obs}")

    # Removes genes with 0 umi counts
    adata = adata[:, np.asarray(adata.X.sum(axis=0)).reshape(-1) > 0]
    print(f"Number of genes after 0 umi count filter: {adata.n_vars}")
    print(f"Number of cells after 0 umi count filter: {adata.n_obs}")

    # Update the dictionary with the filtered 'adata' object
    adata_dict[key] = adata


In [None]:
for sample, adata in adata_dict.items():
    # Append the filtered counts to the dataframe
    df.loc[sample, 'Basic Filtered Cells'] = adata.n_obs
    df.loc[sample, 'Basic Filtered Genes'] = adata.n_vars

df

# MAD-BASED THRESHOLD QC AND FILTERING
https://www.sc-best-practices.org/preprocessing_visualization/quality_control.html:

In [None]:
for key, adata in adata_dict.items():
        # mitochondrial genes
        adata.var["mt"] = adata.var['gene_name'].str.startswith("MT-")
        # ribosomal genes
        adata.var["ribo"] = adata.var['gene_name'].str.startswith(("RPS", "RPL"))
        # hemoglobin genes.
        adata.var["hb"] = adata.var['gene_name'].str.contains(("^HB[^(P)]"))
        
        sc.pp.calculate_qc_metrics(adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20], log1p=True)

##### At gene levels (stored in var):
- **total_counts:** sum of counts for a gene
- **mean_counts:** mean expression for a gene over all cells
- **n_cells_by_counts:** number of cells with non-zero counts for a gene
- **pct_dropout_by_counts:** percentage of cells this gene does not appear in

##### At cell levels (stored in obs):
- **total_counts:** total number of counts for a cell
- **total_counts_mito:** total number of counts for the mitochondrial genes in a cell
- **n_genes_by_counts:** number of genes with non-zero counts
- **pct_counts_mt:** proportion of total counts for a cell which are from mitochondrial genes

In [None]:
def QC_plots(adata_dict, filename: str):
    n_plots = len(adata_dict)
    # Create a figure with N rows and 3 columns
    figure, axes = plt.subplots(nrows=n_plots, ncols=3, figsize=(16, 6*n_plots))
    
    # Loop over each object and create subplots for each row
    for i, key in enumerate(adata_dict.keys()):
        # Plot the first subplot
        sns.histplot(adata_dict[key].obs["total_counts"], bins=100, kde=False, ax=axes[i, 0])
        axes[i, 0].set_title("Total Counts for " + str(adata_dict[key].obs["sampletype"][0]))
        axes[i, 0].set_xlabel('Total Counts')
        axes[i, 0].set_ylabel('N cells')
        
        # Plot the second subplot
        sns.histplot(adata_dict[key].obs["n_genes_by_counts"], bins=100, kde=False, ax=axes[i, 1])
        axes[i, 1].set_title("Genes by Counts for " + str(adata_dict[key].obs["sampletype"][0]))
        axes[i, 1].set_xlabel('N genes')
        axes[i, 1].set_ylabel('N cells')
        
        # Plot the third subplot
        sc.pl.violin(adata_dict[key], keys="pct_counts_mt", ax=axes[i, 2], ylabel='Percent Mitochondrial Counts')
        axes[i, 2].set_title("% Mitochondrial Counts for " + str(adata_dict[key].obs["sampletype"][0]))
    
        # Adjust spacing between subplots
        #figure.tight_layout()
        #figure.subplots_adjust(hspace=0.5)
    
    # Show the plot
    plt.show()
    savesvg(filename, figure)

In [None]:
QC_plots(adata_dict, 'QC_plots1_minimal_filters.svg')

In [None]:
def QC_plots2(adata_dict, filename: str):
    n_features = 5
    fig, axes = plt.subplots(len(adata_dict), n_features, figsize=(15, 3 * len(adata_dict)))
    
    for i, (adata_name, adata) in enumerate(adata_dict.items()):
        sampletype = adata.obs['sampletype'][0]
        row_title = f'{sampletype}'
        for j, feature in enumerate(['total_counts', 'n_genes_by_counts', 'pct_counts_mt', 'pct_counts_hb', 'pct_counts_ribo']):
            ax = axes[i, j]
            sc.pl.violin(adata, feature, ax=ax, show=False, jitter=0.4)
        axes[i, 0].set_ylabel(row_title, fontsize=10)
    
    fig.tight_layout()
    savesvg(filename, fig)

In [None]:
QC_plots2(adata_dict, 'QC_plots2_minimal_filtering.svg')

Automatic thresholding via MAD (median absolute deviations). Cells are marked as outliers if they differ by 5 MADs (relatively permissive filtering strategy). Also pct_counts_Mt is filtered with 3 MADs. Additionally, cells with a percentage of mitochondrial counts exceeding 10 % are filtered out.

In [None]:
def is_outlier(adata, metric: str, nmads: int):
    M = adata.obs[metric]
    outlier = (M < np.median(M) - nmads * M.mad()) | (
        np.median(M) + nmads * M.mad() < M
    )
    return outlier

def outlier_check(adata_dict):
    for key, adata in adata_dict.items():
        adata.obs["outlier"] = (is_outlier(adata, "log1p_total_counts", 5)
                                | is_outlier(adata, "log1p_n_genes_by_counts", 5)
                                | is_outlier(adata, "pct_counts_in_top_20_genes", 5))
        print(f"5 MAD Outliers in {adata.obs['sampletype'][0]}" + str(adata.obs.outlier.value_counts()))
    
        adata.obs["mt_outlier"] = is_outlier(adata, "pct_counts_mt", 3) | (adata.obs["pct_counts_mt"] > 10)
        print(f"5 MAD Outliers in {adata.obs['sampletype'][0]}" + str(adata.obs.mt_outlier.value_counts()))
        
def outlier_removal(adata_dict):
    outlier_check(adata_dict)
    filtered_adatafiles = {}
    for key, adata in adata_dict.items():
        print(f"Total number of cells: {adata.n_obs}")
        adata = adata[(~adata.obs.outlier) & (~adata.obs.mt_outlier)].copy()
        print(f"Number of cells after filtering of low quality cells: {adata.n_obs}")
        filtered_adatafiles[key] = adata.copy()
    return filtered_adatafiles
        

In [None]:
filtered_adatafiles = outlier_removal(adata_dict)
filtered_adatafiles

In [None]:
for sample, adata in filtered_adatafiles.items():
    # Append the filtered counts to the dataframe
    df.loc[sample, 'MAD Filtered Cells'] = adata.n_obs
    df.loc[sample, 'MAD Filtered Genes'] = adata.n_vars

df

# FILE SAVING

In [None]:
def savefile(filtered_adatafiles):
    for key in filtered_adatafiles.keys():
        sample_name = filtered_adatafiles[key].obs['sampletype'][0]
        filename = f"{sample_name}_filtered.h5ad"
        filepath = os.path.join(RESULTS_FOLDERNAME, filename)
        filtered_adatafiles[key].write(filepath)
        print(f"Saved file {filename} to {RESULTS_FOLDERNAME}.")

In [None]:
savefile(filtered_adatafiles)

# DOUBLET REMOVAL

In [None]:
import anndata2ri
import logging
import rpy2.rinterface_lib.callbacks as rcb
import rpy2.robjects as robjects

rcb.logger.setLevel(logging.ERROR)
robjects.pandas2ri.activate()
anndata2ri.activate()

#Loading the rpy2 extension enables cell/line magic to be used
%load_ext rpy2.ipython

In [None]:
# Create an ordered dictionary from the original dictionary
filtered_adatafiles = collections.OrderedDict(filtered_adatafiles)

In [None]:
def doublet_removal(new_dict):
    """
    Creates a SCE object, runs scDblFinder, updates adata with scDblFinder results, 
    filters data accordingly.
    """
    for key, adata in new_dict.items():
        temp_counts = new_dict[key].X.T
        robjects.globalenv['temp_counts'] = temp_counts
        robjects.r('library(SingleCellExperiment)')
        robjects.r('library("scDblFinder")')
        robjects.r('set.seed(123)')
        robjects.r('sce <- SingleCellExperiment(list(counts=temp_counts))')
        robjects.r('counts <- assay(sce, "counts")')
        robjects.r('print(sce)')
        robjects.r('sce <- scDblFinder(sce)')
        droplet_class = robjects.r('sce$scDblFinder.class')
        
        new_dict[key].obs["scDblFinder_class"] = droplet_class
        print(f'Singlets and Doublets in {key}:')
        print(new_dict[key].obs.scDblFinder_class.value_counts())

        print(f"Total cells in {key}: {new_dict[key].n_obs}")
        
        new_dict[key] = new_dict[key][new_dict[key].obs.scDblFinder_class == 1].copy()
        print(f"Number of cells after filtering of doublets:{new_dict[key].n_obs}")

In [None]:
doublet_removal(filtered_adatafiles)

In [None]:
for sample, adata in filtered_adatafiles.items():
    # Append the filtered counts to the dataframe
    df.loc[sample, 'Doublet Removed Cells'] = adata.n_obs
    df.loc[sample, 'Doublet Removed Genes'] = adata.n_vars

df

In [None]:
QC_plots(filtered_adatafiles, 'QC_plots1_filtered.svg')
QC_plots2(filtered_adatafiles, 'QC_plots2_filtered.svg')

# ADDITIONAL FILTERING

In [None]:
# manually defined minimal n_genes_by_counts thresholds for the different samples (based on QC plots)
min_ngenes_threshold = {
    'OMB1556_Ach': 400,
    'DEV15984_Ach': 300,
    'DEV16135DEV16171_Ach': 400,
    'OMB1250_Quad': 200,
    'OMB0785_Ach': 400,
    'DEV15984_Quad': 500,
    'DEV16127_Ach': 500,
    'DEV16136_Ach': 200,
    'DEV16134_Ach': 400,
    'DEV15985_Quad': 400,
    'DEV16134_Quad': 300,
    'DEV16569_Ach': 300,
    'DEV16136_Quad': 500,
    'OMB1266_Quad': 300,
    'DEV16127_Quad': 300,
    'DEV16135DEV16171_Quad': 400,
    'DEV16569_Quad': 400,
    'DEV15983_Ach': 400,
    'DEV15985_Ach': 400}


def mfilter_adatafiles_by_counts(filtered_adatafiles, min_ngenes_threshold):
    mfiltered_adatafiles = {}
    for adata_name, min_ngenes in min_ngenes_threshold.items():
        if adata_name in filtered_adatafiles:
            adata = filtered_adatafiles[adata_name]
            adata_filtered = adata[adata.obs['n_genes_by_counts'] >= min_ngenes, :]
            print(f"Filtered out {adata.n_obs - adata_filtered.n_obs} cells in {adata.obs['sampletype'][0]}")
            mfiltered_adatafiles[adata_name] = adata_filtered
        else:
            print(f"Anndata object {adata_name} not found in filtered_adatafiles")
    return mfiltered_adatafiles

In [None]:
mfiltered_adatafiles = mfilter_adatafiles_by_counts(filtered_adatafiles, min_ngenes_threshold)
mfiltered_adatafiles

In [None]:
for sample, adata in mfiltered_adatafiles.items():
    # Append the filtered counts to the dataframe
    df.loc[sample, 'Manually Filtered Cells'] = adata.n_obs
    df.loc[sample, 'Manually Filtered Genes'] = adata.n_vars

df

In [None]:
QC_plots(mfiltered_adatafiles, 'QC_plots_fully_filtered.svg')
QC_plots2(mfiltered_adatafiles, 'QC_plots2_fully_filtered.svg')

# CONCATENATION

In [None]:
# """Version 1 of concatenation, seems inferior to the concat() offered by anndata:"""

# def concatenate_adatafiles(filtered_adatafiles):
#     """
#     Concatenates all anndata objects within a dictionary.

#     Args:
#         filtered_adatafiles: A dictionary object containing the anndata objects to be concatenated.

#     Returns:
#         A concatenated anndata object.
#     """

#     # Separate the first object from the dictionary
#     adata_list = list(filtered_adatafiles.values())
#     adata = adata_list[0]
    
#     # Concatenate the remaining objects in the dictionary
#     adata_concat = adata.concatenate(adata_list, join='outer', index_unique=None)
#     del adata, adata_list
    
#     return adata_concat

In [None]:
def concat_filtered_adatafiles(filtered_adatafiles):
    # Extract the values (anndata objects) from the dictionary
    adata_list = list(filtered_adatafiles.values())
    
    # Concatenate the anndata objects
    adata_concat = anndata.concat(
        adata_list,
        join='outer',
        index_unique=None      # Optional: specify a custom index unique function
    )
    
    return adata_concat

In [None]:
adata_concat = concat_filtered_adatafiles(mfiltered_adatafiles)
adata_concat

In [None]:
adata_concat.obs['age'] = adata_concat.obs['age'].astype('category')
adata_concat.obs['sampletype'].value_counts()

In [None]:
def create_violin_plots(adata_concat, filename: str):
    # Define the parameters to plot
    parameters = ['total_counts', 'n_genes_by_counts', 'pct_counts_mt', 
                  #'pct_counts_hb', 'pct_counts_ribo'
                 ]
    
    # Create a figure with 5 subplots
    fig, axs = plt.subplots(nrows=len(parameters), figsize=(12, len(parameters)*4))
    
    # Define the age palette
    age_colors = sns.color_palette('dark', len(adata_concat.obs['age'].cat.categories))
    age_palette = dict(zip(adata_concat.obs['age'].cat.categories, age_colors))
    
    # Define a dictionary mapping sample names to age categories
    sample_to_age = dict(zip(adata_concat.obs['sampletype'].values, adata_concat.obs['age']))
    
    # Initialize lists to store legend handles and labels
    legend_handles = []
    legend_labels = []

    # Loop over the parameters and create a violin plot for each one
    for i, param in enumerate(parameters):
        ax = sc.pl.violin(adata_concat, keys=[param], groupby='sampletype', ax=axs[i], show=False, color='white')
        axs[i].set_xlabel('')
        axs[i].set_title(param, fontsize=12)
        #axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=65, ha='right')
        #axs[i].tick_params(axis='x', labelrotation=90, labelsize=8)
        if i == len(parameters)-1:
            axs[i].tick_params(axis='x', labelrotation=90, labelsize=8)
        else:
            axs[i].tick_params(axis='x', which='both', labelbottom=False, labeltop=False)

        # Color the samples based on their age
        for tick, label in zip(ax.get_xticks(), ax.get_xticklabels()):
            sample_name = label.get_text()
            age_category = sample_to_age[sample_name]
            color = age_palette[age_category]
            label.set_color(color)
            
            # Add handles and labels for the current age category to the legend lists
            if age_category not in legend_labels:
                legend_handles.append(plt.Rectangle((0,0),1,1, color=color))
                legend_labels.append(age_category)

    # Add a legend for the 'age' variable - vertical
    if 'age' in adata_concat.obs:
        fig.legend(legend_handles, legend_labels, loc='center left')
        
    # Add a legend for the 'age' variable - horizontal below the plot
    #if 'age' in adata_concat.obs:
    #    legend_ax = fig.add_axes([0, -0.2, 1, 0.1])  # create a new axis below the plot axes
    #    legend_ax.axis('off')  # turn off the axis so it doesn't interfere with the plot
    #    legend_ax.legend(legend_handles, legend_labels, loc='center', ncol=len(legend_labels))  # add the legend to the new axis

    # Add a legend for the 'age' variable - horizontal above the plot
    #if 'age' in adata_concat.obs:
    #    legend_ax = fig.add_axes([0, 1.1, 1, 0.1])  # create a new axis above the plot axes
    #    legend_ax.axis('off')  # turn off the axis so it doesn't interfere with the plot
    #    legend_ax.legend(legend_handles, legend_labels, loc='upper center', ncol=len(legend_labels))  # add the legend to the new axis
    
    
    # Adjust the spacing between the subplots
    plt.subplots_adjust(hspace=0.2)
    
    # Show the figure
    plt.show()
    savesvg(filename, fig)

In [None]:
create_violin_plots(adata_concat, 'concat_violins_filtered.svg')

In [None]:
print(adata_concat.X[1:5, 1:5])
adata_concat.layers['counts'] = adata_concat.X.copy()

In [None]:
adata_concat.write(os.path.join(RESULTS_FOLDERNAME, 'adata_concat_filtered.h5ad'))

In [None]:
df = df.apply(lambda x: x.astype(int))
df

In [None]:
df.to_csv(os.path.join(RESULTS_FOLDERNAME,'filtering_cells_genes_table.csv'), index=True)