 Utils

In [7]:
# General
import scipy as sci
import numpy as np
import pandas as pd
import logging
import time
import pickle
from itertools import chain
import h5py
import scipy.sparse as sparse
import anndata as ad
import scipy.stats as stats
import gc

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rcParams
from matplotlib import cm
from matplotlib import colors
from matplotlib.pyplot import rc_context
import seaborn as sb
#from plotnine import *
from adjustText import adjust_text
#import pegasus as pg

# Analysis
import scanpy as sc

#import snapatac2 as snap
import pysam



#



In [None]:
def setup_R(R_path):
    import os
    os.environ['R_HOME'] = R_path
    ## R settings

    ### Ignore R warning messages
    #### Note: this can be commented out to get more verbose R output
    rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR)

    ### Automatically convert rpy2 outputs to pandas dataframes
    pandas2ri.activate()
    anndata2ri.activate()
    %load_ext rpy2.ipython

## read files

In [None]:
def read_h5_to_adata(path):
    adata = sc.read_10x_h5(path)
    return adata

In [None]:
def read_h5_to_mudata(path):
    adata = mu.read_10x_h5(path)
    return adata

In [None]:
def read_h5ad_to_adata(path):
    adata = sc.read(path)
    return adata

In [None]:
def read_h5mu_to_mudata(path):
    adata = mu.read(path)
    return adata

## add metadata

In [None]:
def read_excel_metadata(path, ix_col=None):
    metadata = pd.read_excel(path, index_col=ix_col)
    #print(metadata)
    return metadata              

In [None]:
def add_metadata(metadata_df,adata):
    sample_id = adata.obs['sample'][0]
    
    # condition with df.values property
    mask = metadata_df['Link_id'].values == sample_id

    # new dataframe
    df_new = metadata_df[mask].T
    #print(df_new)
    for col, value in df_new.iterrows():
        #print(col, value)
        adata.obs[col]=value.values[0]
    
    return adata
    

## ambient detection

In [None]:
def set_ambient_threshold(adata, threshold=0.0005, lower_limit=0.0001, upper_limit=0.002, bins=60, kde=True):
    sample = adata.obs['sample'][0]
    
    with rc_context({'figure.figsize': (8, 3)}):
        sb.distplot(adata.var['ambient_genes_values'][(adata.var['ambient_genes_values'] > lower_limit) & (adata.var['ambient_genes_values'] < upper_limit)], kde=kde, bins=bins)
        plt.axvline(threshold, 0, 1)
        plt.title(label=f'Ambient Genes Threshold {sample} (' + str(len(adata.var['ambient_genes_values'][adata.var['ambient_genes_values'] > threshold])) + ' Genes)', fontweight='bold')
        plt.show()
    
    #adata.var['is_ambient'] = pd.Categorical(adata.var['ambient_genes_values'] > threshold) already done in R

In [None]:
def prefilter_barcodes_mdata(mdata, barcodes=None, plot=True):
    sample = mdata.obs['sample'][0]
    if plot:
        sb_plot =sb.jointplot(
            data=mdata.mod['rna'].obs,
            x="log_counts",
            y="log_genes",
            kind="hist", bins=100, cmap="rocket_r", color="#f69c73", space=0
        )
        sb_plot.fig.suptitle(f'{sample}')
        sb_plot.fig.tight_layout()
        sb_plot.fig.subplots_adjust(top=0.95)
        
        #############################

        fig, ax1 = plt.subplots()
        ax1.scatter(x=mdata.mod['rna'].obs['n_counts_rank'], y=mdata.mod['rna'].obs['n_counts'], s=1, alpha=0.2, c='black', label='Total UMI Counts')
        ax1.scatter(x=mdata.mod['rna'].obs['n_counts_rank'], y=mdata.mod['rna'].obs['n_genes'], s=1, alpha=0.2, c='tab:green', label='Gene Counts')
        ax1.set(xscale='log', yscale='log')
        ax1.set_ylabel('Total UMI/Gene Counts')
        ax1.set_xlabel('Ranked Droplets')
        ax1.set_title(sample)
        #ax1.vlines(x=[max_rank], color="black", lw=0.5).set_linestyle("--")

        ax2 = ax1.twinx()
        ax2.scatter(x=mdata.mod['rna'].obs['n_counts_rank'], y=mdata.mod['rna'].obs['mt_frac']*100, s=1, alpha=0.2, c='tab:red', label='% Mito. Counts')
        ax2.set_ylabel('%')
        ax2.set_title(sample)

        fig.legend(loc='center left', fontsize='xx-small', bbox_to_anchor=(0.2, 0.35))

        plt.show()
        plt.close()

        ###################################

        cell_probs_key = 'log_cell_probs' #'log_cell_probs_' + mdata.mod['rna'].obs['sample'][0]
        fig, ax1 = plt.subplots()
        ax1.scatter(x=mdata.mod['rna'].obs['n_counts_rank'], y=mdata.mod['rna'].obs['n_counts'], s=1, alpha=0.2, c='black', label='Total UMI Counts')
        ax1.scatter(x=mdata.mod['rna'].obs['n_counts_rank'], y=mdata.mod['rna'].obs['n_genes'], s=1, alpha=0.2, c='tab:green', label='Gene Counts')
        ax1.set(xscale='log', yscale='log')
        ax1.set_ylabel('Total UMI/Gene Counts')
        ax1.set_xlabel('Ranked Droplets')
        ax1.set_title(sample)
        #ax1.vlines(x=[max_rank], color="black", lw=0.5).set_linestyle("--")

        ax2 = ax1.twinx()
        ax2.scatter(x=mdata.mod['rna'].obs['n_counts_rank'], y=mdata.mod['rna'].obs[cell_probs_key], s=1, alpha=0.2, c='tab:blue', label='Log Cell Probabilities')
        ax2.set_ylabel('Cell Probabilities')
        ax2.set_title(sample)

        fig.legend(loc='center left', fontsize='xx-small', bbox_to_anchor=(0.2, 0.35))

        plt.show()
        plt.close()

        ##############################

        sb.histplot(mdata.mod['rna'].obs[cell_probs_key][~np.isnan(list(mdata.mod['rna'].obs[cell_probs_key]))], kde=True, bins=60)
        plt.title(label=f'Log Cell Probabilities of {sample}', fontweight='bold')

        plt.show()
        plt.close()

## plot settings

#### plot titles
sb.jointplot:

sb.jointplot()
plt.suptitle()

umap:
fig = sc.plp.umap(return_fig=True)
ax = fig.axes[0]
ax.legend_.set_title()

histplot:
sb.histplot()
plt.title()

scatterplot:
p = sc.pl.scatter()
plt.suptitle()

### plot settings

In [None]:
#%matplotlib inline
def load_RdOrYl_cmap_settings(fig_height=6, fig_width =6, dpi = 150, save_dpi =300, transparent = True, vector_fr=False,fr_on=True):
    # Plot settings
    

    ## Plotting parameters
    rcParams['figure.figsize']=(fig_height,fig_width) #rescale figures
    #sc.set_figure_params(scanpy=True, frameon=False, vector_friendly=False, color_map='tab10' ,transparent=True, dpi=150, dpi_save=300)
    sc.set_figure_params(scanpy=True, frameon=fr_on, vector_friendly=vector_fr ,transparent=transparent, dpi=dpi, dpi_save=save_dpi)

    ## Grid & Ticks
    rcParams['grid.alpha'] = 0
    rcParams['xtick.bottom'] = True
    rcParams['ytick.left'] = True

    from matplotlib import colors
    plt.rcParams.update({
        "text.usetex": False,
        "font.family": "serif",
        "font.serif": "NewComputerModern10", #Computer Modern Roman fontsize 10
    })
    ## Define new default settings
    plt.rcParamsDefault = plt.rcParams

    ## Embed font
    plt.rc('pdf', fonttype=42)

    ## Define new default settings
    plt.rcParamsDefault = plt.rcParams

    # Color maps
    colors2 = plt.cm.YlOrRd(np.linspace(0.05, 1, 150)) 
    colors3 = plt.cm.Greys_r(np.linspace(0.8,0.9,1)) 
    colorsComb = np.vstack([colors3, colors2]) 
    mymap = colors.LinearSegmentedColormap.from_list('my_colormap', colorsComb)
    return mymap

In [None]:
%matplotlib inline
def load_YellowGnBl_cmap_settings(fig_height, fig_width):
    # Plot settings

    ## Plotting parameters
    rcParams['figure.figsize']=(fig_height,fig_width) #rescale figures
    #sc.set_figure_params(scanpy=True, frameon=False, vector_friendly=False, color_map='tab10' ,transparent=True, dpi=150, dpi_save=300)
    sc.set_figure_params(scanpy=True, frameon=False, vector_friendly=False ,transparent=True, dpi=150, dpi_save=300)

    ## Grid & Ticks
    rcParams['grid.alpha'] = 0
    rcParams['xtick.bottom'] = True
    rcParams['ytick.left'] = True

    from matplotlib import colors
    plt.rcParams.update({
        "text.usetex": False,
        "font.family": "serif",
        "font.serif": "NewComputerModern10", #Computer Modern Roman fontsize 10
    })
    ## Define new default settings
    plt.rcParamsDefault = plt.rcParams

    ## Embed font
    plt.rc('pdf', fonttype=42)

    ## Define new default settings
    plt.rcParamsDefault = plt.rcParams

   # Color maps
    colors2 = plt.cm.YlGnBu(np.linspace(0.05, 0.9, 150)) 
    colors3 = plt.cm.Greys_r(np.linspace(0.9,1,1)) 
    colorsComb = np.vstack([colors3, colors2]) 
    mymap = colors.LinearSegmentedColormap.from_list('my_colormap', colorsComb)
    return mymap


In [None]:
%matplotlib inline
def load_BluePurple_cmap_settings(fig_height, fig_width):
    # Plot settings

    ## Plotting parameters
    rcParams['figure.figsize']=(fig_height,fig_width) #rescale figures
    #sc.set_figure_params(scanpy=True, frameon=False, vector_friendly=False, color_map='tab10' ,transparent=True, dpi=150, dpi_save=300)
    sc.set_figure_params(scanpy=True, frameon=False, vector_friendly=False ,transparent=True, dpi=150, dpi_save=300)

    ## Grid & Ticks
    rcParams['grid.alpha'] = 0
    rcParams['xtick.bottom'] = True
    rcParams['ytick.left'] = True

    from matplotlib import colors
    plt.rcParams.update({
        "text.usetex": False,
        "font.family": "serif",
        "font.serif": "NewComputerModern10", #Computer Modern Roman fontsize 10
    })
    ## Define new default settings
    plt.rcParamsDefault = plt.rcParams

    ## Embed font
    plt.rc('pdf', fonttype=42)

    ## Define new default settings
    plt.rcParamsDefault = plt.rcParams

    # Color maps
    colors2 = plt.cm.BuPu(np.linspace(0.05, 0.9, 150)) 
    colors3 = plt.cm.Greys_r(np.linspace(0.9,1,1)) 
    colorsComb = np.vstack([colors3, colors2]) 
    mymap = colors.LinearSegmentedColormap.from_list('my_colormap', colorsComb)
    return mymap

In [None]:
def plot_embedding_density_kde(
    adata,
    groupby = 'batch',
    basis = 'umap',
    size = 10,
    kde_levels = 10,
    kde_linewidths = 0.5,
    kde_bw_adjust = 0.75,
    kde_thresh = 0.2,
    n_cols = 3,  # specify the desired number of columns
    save = False, cmap_kde = ''
):
    #ListedColormap(['#a8a8a8', '#939393', '#808080', '#6d6d6d', '#5a5a5a', '#484848', '#373737', '#262626', '#171717', '#000000'])
    from matplotlib.colors import LinearSegmentedColormap
    
    # Compute densities
    if not groupby+'_density' in adata.obs:
        sc.tl.embedding_density(adata, basis=basis, groupby=groupby, key_added=groupby+'_density')

    # Set the number of rows and columns for subplots
    #n_rows = n_rows  # specify the desired number of rows
    n_cols = n_cols  # specify the desired number of columns
    n_rows = int(np.ceil(len(adata.obs[groupby].cat.categories)/n_cols))

    # Calculate the total figure size for square subplots
    figsize = (n_cols * 4, n_rows * 4)

    # Create a figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)  # specify the figure size

    # Flatten the axes array to a 1D array
    axes = axes.flatten()

    # Loop over categories in groupby
    X = 'X_' + basis
    for i, sample in enumerate(adata.obs[groupby].cat.categories):
        # Create scatter plot on the current subplot
        ax = axes[i]
        x = pd.DataFrame(adata.obsm[X])[0]
        y = pd.DataFrame(adata.obsm[X])[1]
        sb.scatterplot(x=x, y=y, color='#cccccc', s=size, linewidth=0, alpha=0.2, ax=ax)

        # Create colored scatter & KDE plot on the current subplot
        x = pd.DataFrame(adata[adata.obs[groupby].isin([sample])].obsm[X])[0]
        y = pd.DataFrame(adata[adata.obs[groupby].isin([sample])].obsm[X])[1]
        c = adata[adata.obs[groupby].isin([sample])].obs[groupby+'_density']
        cmap = LinearSegmentedColormap.from_list("Custom", ['#bebebe',adata.uns[groupby + '_colors'][i]], N=100)
        sb.scatterplot(x=x, y=y, c=c, s=size, linewidth=0, alpha=0.5, cmap=cmap, ax=ax)
        sb.kdeplot(x=x, y=y, levels=kde_levels, cmap=cmap_kde, linewidths=kde_linewidths, bw_adjust=kde_bw_adjust, thresh=kde_thresh, ax=ax)

        # Customize plot appearance
        ax.set_frame_on(False)
        ax.set(xlabel='')
        ax.set(xticklabels=[])
        ax.set(ylabel='')
        ax.set(yticklabels=[])
        ax.tick_params(bottom=False, left=False)
        ax.set_title(sample)

    # Remove any extra empty subplots
    for i in range(len(adata.obs[groupby].cat.categories), n_rows * n_cols):
        fig.delaxes(axes[i])

    plt.tight_layout()  # Adjust subplot layout
    if save:
        plt.savefig(str(sc.settings.figdir) + "/embedding-density-kde_" + basis + "_" + groupby + ".pdf")
    plt.show()  # Display the plot


In [None]:
def plot_composition(adata, 
x_key=None, 
y_key=None, 
x_labels = None,
y_labels = None,
y_colors = None,
title=None,                     
width = 0.85,       # the width of the bars: can also be len(x) sequence
x_rotation = 0,
y_lim_offset = 2.5,
x_lim_offset = 0.45,
figsize= (6, 4)):
    with rc_context({'figure.figsize': figsize}): #rcParams['figure.figsize']=(6,4)
        if (x_labels == None):
            x_labels = list(adata.obs[x_key].cat.categories)
        
        if (y_labels == None):
            y_labels = list(adata.obs[y_key].cat.categories)
        
        if (y_colors == None):
            y_colors = list(adata.uns[y_key + '_colors'])
            
        dic = {'x_labels':x_labels}
        
        for y_label in y_labels:
            x_values = []
            for x_label in x_labels:
                x_value = adata.obs[y_key][adata.obs[x_key]==x_label].value_counts()[y_label]/adata.obs[y_key][adata.obs[x_key]==x_label].value_counts().sum()*100
                x_values.append(x_value)
            dic[y_label] = x_values
        
        df = pd.DataFrame(data = dic)

        ax = df.plot(x='x_labels', kind='bar', stacked=True, width=width, edgecolor='0', linewidth=0.5, color=y_colors)

        ax.set_ylabel('%')
        ax.set_xlabel('')
        if (title == None):
            ax.set_title(y_key + ' by ' + x_key)
        else:
            ax.set_title(title)
            
        ax.axes.set_xticklabels(labels=x_labels, rotation=x_rotation)
        ax.legend(bbox_to_anchor=(1, .5),loc='center left', edgecolor='1')

        plt.ylim([-y_lim_offset,100+y_lim_offset])
        plt.xlim([-1+x_lim_offset,len(x_labels)-x_lim_offset])

        plt.show()
        
    return(df)


## gex QC

In [None]:
def qc_metrics(adata, ambient=True, plot=True, counts_per_gene=True, keep_dense=True):
    """\
    Calculate QC metrics.
    genome: {'auto','Mus_musculus','Homo_sapiens','Sus_scrofa'}
    mt_genes_path: Path to mitochondrial genes for sus scrofa. Tab delimited file without header and with gene symbols in column 2. default: '/mnt/ssd/Resources/sus_scrofa_mt_ens101_ext.txt'
    ambient: Requires adata.var['is_ambient'] = pd.Categorical(list(map(str,list(adata.var['ambient_genes'] > cut_off))))
    """
    
    if sci.sparse.issparse(adata.X):
        adata.X = adata.X.toarray()

    if counts_per_gene:
        # counts per gene
        adata.var['n_counts'] = adata.X.sum(0)

    # counts per cell
    adata.obs['n_counts'] = adata.X.sum(1)
    # log counts per cell
    adata.obs['log_counts'] = np.log(adata.obs['n_counts'])
    # rank by counts
    adata.obs['n_counts_rank'] = adata.obs['n_counts'].rank(method='first',ascending=False)
    # genes per cell
    adata.obs['n_genes'] = (adata.X > 0).sum(1)
    # log genes per cell
    adata.obs['log_genes'] = np.log(adata.obs['n_genes'])
    # fraction of mitochondrial genes
    mt_gene_mask = [gene.startswith('mt-') for gene in adata.var_names]
    adata.obs['mt_frac'] = adata.X[:, mt_gene_mask].sum(1)/adata.obs['n_counts']

    rp_gene_mask = [gene.startswith(('Rps','Rpl')) for gene in adata.var_names]
    adata.obs['rp_frac'] = adata.X[:,rp_gene_mask].sum(1) / adata.obs['n_counts']

    if ambient:
        adata.obs['ambi_frac'] = adata.X[:,adata.var['is_ambient']==True].sum(1) / adata.obs['n_counts']

    if plot:
        try:
            sb.jointplot(
                data=adata.obs,
                x="log_counts",
                y="log_genes",
                kind="hist", bins=100, cmap="rocket_r", color="#f69c73", space=0
            )
        except AttributeError:
            sb.jointplot(
                data=adata.obs,
                x="log_counts",
                y="log_genes",
                kind="hist", bins=100, space=0
            )

        fig, ax1 = plt.subplots()
        ax1.scatter(x=adata.obs['n_counts_rank'], y=adata.obs['n_counts'], s=1, alpha=0.2, c='black', label='Total UMI Counts')
        ax1.scatter(x=adata.obs['n_counts_rank'], y=adata.obs['n_genes'], s=1, alpha=0.2, c='tab:green', label='Gene Counts')
        ax1.set(xscale='log', yscale='log')
        ax1.set_ylabel('Total UMI/Gene Counts')
        ax1.set_xlabel('Ranked Droplets')
        #ax1.vlines(x=[max_rank], color="black", lw=0.5).set_linestyle("--")

        ax2 = ax1.twinx()
        ax2.scatter(x=adata.obs['n_counts_rank'], y=adata.obs['mt_frac']*100, s=1, alpha=0.2, c='tab:red', label='% Mito. Counts')
        ax2.set_ylabel('%')

        fig.legend(loc='center left', fontsize='xx-small', bbox_to_anchor=(0.2, 0.35))
        
    if not keep_dense:
        adata.X = sci.sparse.csr_matrix(adata.X)  

In [None]:
def get_umap_leiden(adata, resolution=0.5, exclude_highly_expressed=False):
       
    # preprocess adata, cluster and get umap
    adata_pp = adata.copy()
    sc.pp.normalize_total(adata_pp, target_sum=1e4, exclude_highly_expressed=exclude_highly_expressed, key_added='size_factor') #sc.pp.normalize_per_cell(adata_pp, counts_per_cell_after=1e6)
    #sc.pp.log1p(adata_pp)
    sc.pp.pca(adata_pp)
    sc.pp.neighbors(adata_pp, metric='correlation')
    sc.tl.leiden(adata_pp, resolution=resolution)
    sc.tl.umap(adata_pp)
    
    adata.obsm['X_umap'] = adata_pp.obsm['X_umap'].copy()
    adata.obs['leiden'] = adata_pp.obs['leiden'].copy()

In [None]:
def filter_genes(adata,threshold = 20):
    print(f"Total number of genes: {adata.n_vars}")

    # Min 20 cells - filters out 0 count genes
    sc.pp.filter_genes(adata, min_cells=threshold)
    print(f"Number of genes after cell filter: {adata.n_vars}")


In [None]:
def qc_filter_mdata(mdata, adata, modality=None, qc_filter=None): 
    """\
    mdata: mdata oject containing different modalities.
    adata: adata object containing the modality to filter
    modality: modality to filter
    qc_filter: array of booleans. E.g. genes_filter = adata.obs['n_genes'] > min_genes or (genes_filter & counts_filter)
    """
    if (modality is None):
        print('Specify modality (\'rna\' or \'atac\'.')
        return
    if (qc_filter is None):
        print('Specify QC filter.')
        return
    
    pre_filter_n_obs = mdata.mod[modality].n_obs
    mdata.mod[modality] = mdata.mod[modality][qc_filter]
    adata = mdata.mod[modality]
    pct = (pre_filter_n_obs - mdata.mod[modality].n_obs) / pre_filter_n_obs * 100
    print('Filtered out {:d}'.format(pre_filter_n_obs - mdata.mod[modality].n_obs),'cells ({:.1f}'.format(pct) ,'%).')
    print('Number of cells after filter: {:d}'.format(mdata.mod[modality].n_obs))
    return adata

## ATAC QC

In [None]:
def atac_qc_metrics(atac, mdata, n_tss=2000, plot=True, gtf_path=None):
    """\
    Calculate QC metrics.
    atac: adata object containing atac data. i.e. mdata.mod['atac']
    mdata: Corresponding mdata object
    tss: {'random','top'} random tss or the n tss with highest rna counts
    n_top: number of tss
    """
    import muon as mu
    from muon import atac as ac
    
    if (gtf_path is None):
        print('Provide path to annotation GTF file!')
        return
    
    if sci.sparse.issparse(atac.X):
        print('X is spare. Densifying X...')
        atac.X = atac.X.toarray()
    
    # counts per peak
    atac.var['n_counts'] = atac.X.sum(0)
    # counts per cell
    atac.obs['n_counts_ATAC'] = list(atac.X.sum(1))
    # log counts per cell
    atac.obs['log_counts_ATAC'] = np.log(atac.obs['n_counts_ATAC'] + 1)
    # rank by counts
    atac.obs['n_counts_rank_ATAC'] = atac.obs['n_counts_ATAC'].rank(method='first',ascending=False)
    # genes per cell
    atac.obs['n_peaks_ATAC'] = (atac.X > 0).sum(1)
    # log genes per cell
    atac.obs['log_peaks_ATAC'] = np.log(atac.obs['n_peaks_ATAC'] + 1)
    
    if plot:
        mu.pl.histogram(atac, ['n_counts_ATAC', 'n_peaks_ATAC'])
        
    # nucleosome signal
    ac.tl.nucleosome_signal(atac, n=1e6)
    
    if plot:
        # plot nucleosome signal
        ac.pl.fragment_histogram(atac, region='1:1-20000000')
        mu.pl.histogram(atac, "nucleosome_signal", kde=True)
        
    # TSS enrichment
    features = get_top_feature_pos_from_gtf(mdata.mod['rna'], gtf_path=gtf_path, n_top=n_tss)
    print(len(features))
    tss = ac.tl.tss_enrichment(mdata, features=features, n_tss=n_tss)  # by default, features=ac.tl.get_gene_annotation_from_rna(mdata)
    if plot:
        # plot TSS enrichment
        ac.pl.tss_enrichment(tss)

In [None]:
def get_feature_pos_from_gtf(gtf_path=None, random=True):
    # load Ensembl annotation file and reduce to genes
    annotation = pd.read_csv(gtf_path, header=None, skiprows=5, sep='\t')
    annotation = annotation.loc[annotation.iloc[:,2]=='gene',:]

    # get positions of + features
    features_p = annotation.loc[annotation.iloc[:,6]=='+',:].iloc[:,[0,3,4]]     
    features_p.columns = ['Chromosome','Start','End']

    # get positions of - features and switch start and end
    features_m = annotation.loc[annotation.iloc[:,6]=='-',:].iloc[:,[0,4,3]]     
    features_m.columns = ['Chromosome','Start','End']

    # concatenate
    features = pd.concat([features_p, features_m], ignore_index = True)
    
    # randomize order
    if random:
        features = features.sample(frac=1, random_state=420)

    return features

In [None]:
def get_top_feature_pos_from_gtf(gex_adata, gtf_path=None, n_top=2000):
    # load Ensembl annotation file and reduce to genes
    annotation = pd.read_csv(gtf_path, header=None, skiprows=5, sep='\t')
    annotation = annotation.loc[annotation.iloc[:,2]=='gene',:]
    
    # add gene names
    annotation.loc[:,9] = annotation.loc[:,8].str.split('\"', expand=True).loc[:,1]
    annotation.set_index = annotation[9]
   
    # add counts and sort annotation
    annotation = annotation.merge(gex_adata.var.loc[:,['gene_ids','n_counts']], left_on=9, right_on='gene_ids')
    annotation = annotation.sort_values(by=['n_counts'], ascending=False)
    
    # filter n_top genes
    annotation = annotation.iloc[0:n_top,:]
      
    # get positions of + features
    features_p = annotation.loc[annotation.iloc[:,6]=='+',:].iloc[:,[0,3,4]]     
    features_p.columns = ['Chromosome','Start','End']

    # get positions of - features and switch start and end
    features_m = annotation.loc[annotation.iloc[:,6]=='-',:].iloc[:,[0,4,3]]
    features_m.columns = ['Chromosome','Start','End']

    # concatenate
    features = pd.concat([features_p, features_m], ignore_index = True)

    return features

In [None]:
def signac_qc_metrics_list(mdata_filtered, 
                      cr_path=None,
                      sample=None,
                      aggregated=False,
                      species='Hsapiens',
                      genome='GRCh38',
                      ensembl_release='v111',
                      plot=True
                     ):
    """\
    Calculate QC metrics.
    atac: adata object containing atac data. i.e. mdata.mod['atac']
    cr_path: Path to CellRanger results (cr_count or cr_agg)
    sample: sample name
    aggregated: Aggregated CellRanger Results. If True, sample has to be None
    species='Hsapiens' or 'Mmusculus'
    genome='GRCh38' or GRCm39 or hg38 or mm10
    ensembl_release='v111'
    plot: plot results
    """
    
    # Load packages and annotations
    if genome not in ['GRCh38','GRCm38','hg19','hg38','mm10']:
        print('Unkown genome. Genome has to be one of the following: GRCh38, GRCm38, hg19, hg38, mm10')
        return None
    
    if genome.startswith('GRC'):
        seqStyle = 'NCBI'
    elif genome.startswith('hg') or genome.startswith('mm'):
        seqStyle = 'UCSC'
    
    #sample = atac.obs['sample'][0]

    print('Loading package BSgenome.' + species + '.' + seqStyle + '.' + genome)
    ro.globalenv['bs_genome'] = 'BSgenome.' + species + '.' + seqStyle + '.' + genome
    
    print('Loading package EnsDb.' + species + '.' + ensembl_release)
    ro.globalenv['ensdb'] = 'EnsDb.' + species + '.' + ensembl_release
    
    bl_dict = {'hg19':'blacklist_hg19',
               'hg38':'blacklist_hg38_unified',
               'GRCh38':'blacklist_hg38_unified',
               'mm10':'blacklist_mm10',
               'GRCm38':'blacklist_mm10'
              }
    
    print('Using Blacklist: ' + bl_dict[genome])
    ro.globalenv['blacklist_name'] = bl_dict[genome]
    ro.globalenv['genome'] = genome
    #workaround for blacklist to be renamed:
    seqStyle = 'NCBI'
    ro.globalenv['seqStyle'] = seqStyle
    
    
    ro.r('''
    library(Signac)
    library(Seurat)

    library(bs_genome,character.only=TRUE)
    library(ensdb,character.only=TRUE)
    library(stringr)

    # extract gene annotations from EnsDb
    annotations <- GetGRangesFromEnsDb(ensdb = get(ensdb))

    # rename blacklist
    blacklist = get(blacklist_name)
    if(seqStyle == "NCBI"){
        blacklist <- renameSeqlevels(blacklist, value = str_replace(str_replace(seqlevels(blacklist), pattern = "chr", replacement = ""), pattern = "M", replacement = "MT"))
    }
    ''')
    for folder_name, mdata in mdata_filtered.items():  
        sample = folder_name
        print('analysing sample: ' + sample)
        atac = mdata.mod['atac'] 
        
        # Prepare data
        if aggregated:
            frag_path = cr_path + '/outs/atac_fragments.tsv.gz'
            outs_path = cr_path + '/outs/'
        else:
            frag_path = cr_path + sample + '/outs/atac_fragments.tsv.gz'
            outs_path = cr_path + sample + '/outs/'
        
        print(frag_path)
            
        ro.globalenv['count_mat'] = atac.X.T
        ro.globalenv['obs_names'] = atac.obs_names
        ro.globalenv['var_names'] = atac.var_names
        ro.globalenv['obs'] = atac.obs
        ro.globalenv['var'] = atac.var
        ro.globalenv['fragments'] = frag_path
        ro.globalenv['outs_path'] = outs_path
        
        # Generate Seurat object
        ro.r('''
        rownames(count_mat) <- var_names
        colnames(count_mat) <- obs_names
        chrom_assay <- CreateChromatinAssay(counts = count_mat, sep = c(":", "-"), fragments = fragments, annotation = annotations)
        seurat <- CreateSeuratObject(counts = chrom_assay, assay = "atac")
        ''')
        
        # Calculate QC metrics
        ro.r('''
        ## claculate FRiP
        #frag_counts <- CountFragments(fragments = fragments)
        #rownames(frag_counts) <- frag_counts$CB
        #frag_counts$fraction_reads_in_peaks <- frag_counts$frequency_count / frag_counts$reads_count
        #seurat <- AddMetaData(seurat, metadata = frag_counts[rownames(seurat@meta.data), -1, drop = FALSE])
        
        # claculate FRiP and fraction of mitochondrial reads
        per_barcode_metrics <- read.csv(paste0(outs_path, "per_barcode_metrics.csv"))
        rownames(per_barcode_metrics) <- per_barcode_metrics$barcode
        per_barcode_metrics <- per_barcode_metrics[per_barcode_metrics$is_cell == 1,]
        colnames(per_barcode_metrics) <- paste0("cr_",colnames(per_barcode_metrics))
        per_barcode_metrics$cr_fraction_fragments_in_peaks = per_barcode_metrics$cr_atac_peak_region_fragments / per_barcode_metrics$cr_atac_fragments
        per_barcode_metrics$cr_fraction_tss_fragments = per_barcode_metrics$cr_atac_TSS_fragments / per_barcode_metrics$cr_atac_fragments
        per_barcode_metrics$cr_fraction_reads_in_mito = per_barcode_metrics$cr_atac_mitochondrial_reads / per_barcode_metrics$cr_atac_raw_reads
        seurat <- AddMetaData(seurat, metadata = per_barcode_metrics[rownames(seurat@meta.data),22:ncol(per_barcode_metrics)])
        
        # calculate fraction of cut-site (i.e. counts) in blacklist
        seurat$fraction_counts_in_blacklist <- FractionCountsInRegion(
        object = seurat,
        assay = "atac",
        regions = blacklist)
        
        # compute nucleosome signal score per cell
        seurat <- NucleosomeSignal(object = seurat)

        # compute TSS enrichment score per cell
        seurat <- TSSEnrichment(object = seurat, fast = FALSE)
        ''')
        
        # TO DO: Get fragment sizes
        
        # Get results
        ro.r('''
        meta_data <- seurat@meta.data
        tss_matrix <- GetAssayData(object = seurat, assay = "atac", slot = "positionEnrichment")[["TSS"]]
        ''')
        atac.uns['tss_matrix'] = ro.globalenv['tss_matrix']
        atac_meta_data = ro.globalenv['meta_data']

        try:
            with (ro.default_converter + pandas2ri.converter).context():
                atac_meta_data = ro.conversion.get_conversion().rpy2py(atac_meta_data)
        except AttributeError:
            pass
            
        columns = ['nCount_atac', 'nFeature_atac', 'cr_atac_raw_reads',
        'cr_atac_unmapped_reads', 'cr_atac_lowmapq', 'cr_atac_dup_reads',
        'cr_atac_chimeric_reads', 'cr_atac_mitochondrial_reads',
        'cr_atac_fragments', 'cr_atac_TSS_fragments',
        'cr_atac_peak_region_fragments', 'cr_atac_peak_region_cutsites',
        'cr_fraction_fragments_in_peaks', 'cr_fraction_reads_in_mito',
        'fraction_counts_in_blacklist', 'nucleosome_signal',
        'nucleosome_percentile', 'TSS.enrichment', 'TSS.percentile', 'cr_fraction_tss_fragments']

        atac.obs.loc[:,columns] = atac_meta_data.loc[atac.obs_names,columns]
        
        atac.obs['log_nCount_atac'] = np.log(atac.obs['nCount_atac'])
        atac.obs['log_nFeature_atac'] = np.log(atac.obs['nFeature_atac'])
        
        if plot:
            
            from matplotlib.colors import LinearSegmentedColormap
            from itertools import chain

            # Set the number of rows and columns for subplots
            #n_rows = n_rows  # specify the desired number of rows
            n_cols = 1 #n_cols  # specify the desired number of columns
            n_rows = 1 #int(np.ceil(len(adata.obs[groupby].cat.categories)/n_cols))

            # Calculate the total figure size for square subplots
            figsize = (n_cols * 4, n_rows * 4)

            # Create a figure with subplots
            fig, ax = plt.subplots(n_rows, n_cols, figsize=figsize)  # specify the figure size

            # Create colored scatter & KDE plot on the current subplot
            x = list(range(-1000,1001,1))
            y = list(chain.from_iterable(atac.uns['tss_matrix'].mean(axis=0).A))
            sb.lineplot(x=x, y=y, c='black', linewidth=0.5, alpha=1, ax=ax)

            # Customize plot appearance
            ax.set(xlabel='Distance to TSS')
            ax.set(ylabel='Mean TSS Enrichment Score')
            ax.set_title(f'TSS Enrichnemt {sample}')

            plt.tight_layout()  # Adjust subplot layout

            plt.show()  # Display the plot
            plt.close()
            
            ################################################################
            
            keys=['log_nCount_atac', 
            'cr_fraction_fragments_in_peaks', 'cr_fraction_reads_in_mito',
            'fraction_counts_in_blacklist', 'nucleosome_signal',
            'TSS.enrichment']

            n_cols = 6

            # Set the number of rows and keys for subplots
            #n_rows = n_rows  # specify the desired number of rows
            n_cols = n_cols  # specify the desired number of keys
            n_rows = int(np.ceil(len(keys)/n_cols))

            # Calculate the total figure size for square subplots
            figsize = (n_cols * 2.5, n_rows * 4)

            # Create a figure with subplots
            fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)  # specify the figure size

            # Flatten the axes array to a 1D array
            axes = axes.flatten()

            # Loop over categories in groupby
            for i, key in enumerate(keys):
                # Create scatter plot on the current subplot
                ax = axes[i]
                y = atac.obs[key]
                sb.violinplot(y=y, color='#f69c73', inner='quart', linewidth=1, linecolor='black', ax=ax)
                sb.stripplot(y=y, color='black', size=0.75, alpha=0.5, jitter=0.25,  ax=ax)

                ## Customize plot appearance
                ax.set(xlabel='')
                ax.set(xticklabels=[])
                ax.tick_params(bottom=False)

            # Remove any extra empty subplots
            for i in range(len(keys), n_rows * n_cols):
                fig.delaxes(axes[i])

            plt.tight_layout()  # Adjust subplot layout

            plt.show()  # Display the plot
            plt.close()
            
            #############################################################
            
            p = sb.jointplot(x=atac.obs['log_nCount_atac'], y=atac.obs['TSS.enrichment'], n_levels=15, thresh=0.05, kind="kde", space=0, fill=True, cmap="rocket_r", color="#f69c73")
            p.plot_joint(sb.scatterplot, c='#cccccc', s=10, linewidth=0, alpha=0.5)
            p.plot_joint(sb.kdeplot, levels=15, fill=True, cmap='rocket_r', thresh=0.05, alpha=0.75)
            p.plot_joint(sb.kdeplot, levels=15, fill=False, c='black', thresh=0.05, alpha=0.5, linewidths=0.25)
            plt.suptitle(sample)
            plt.show()
            plt.close()
            
            #############################################################
            
            p = sb.jointplot(x=atac.obs['log_nCount_atac'], y=atac.obs['nucleosome_signal'], n_levels=15, thresh=0.05, kind="kde", space=0, fill=True, cmap="rocket_r", color="#f69c73")
            p.plot_joint(sb.scatterplot, c='#cccccc', s=10, linewidth=0, alpha=0.5)
            p.plot_joint(sb.kdeplot, levels=15, fill=True, cmap='rocket_r', thresh=0.05, alpha=0.75)
            p.plot_joint(sb.kdeplot, levels=15, fill=False, c='black', thresh=0.05, alpha=0.5, linewidths=0.25)
            plt.suptitle(sample)
            plt.show()
            plt.close()
            #############################################################
            
            p = sb.jointplot(x=atac.obs['TSS.enrichment'], y=atac.obs['nucleosome_signal'], n_levels=15, thresh=0.05, kind="kde", space=0, fill=True, cmap="rocket_r", color="#f69c73")
            p.plot_joint(sb.scatterplot, c='#cccccc', s=10, linewidth=0, alpha=0.5)
            p.plot_joint(sb.kdeplot, levels=15, fill=True, cmap='rocket_r', thresh=0.05, alpha=0.75)
            p.plot_joint(sb.kdeplot, levels=15, fill=False, c='black', thresh=0.05, alpha=0.5, linewidths=0.25)
            plt.suptitle(sample)
            plt.show()
            plt.close()

    return mdata_filtered

In [None]:
def signac_qc_metrics(atac, 
                      cr_path=None,
                      sample=None,
                      aggregated=False,
                      species='Hsapiens',
                      genome='GRCh38',
                      ensembl_release='v111',
                      plot=True
                     ):
    """\
    Calculate QC metrics.
    atac: adata object containing atac data. i.e. mdata.mod['atac']
    cr_path: Path to CellRanger results (cr_count or cr_agg)
    sample: sample name
    aggregated: Aggregated CellRanger Results. If True, sample has to be None
    species='Hsapiens' or 'Mmusculus'
    genome='GRCh38' or GRCm39 or hg38 or mm10
    ensembl_release='v111'
    plot: plot results
    """
    
    # Load packages and annotations
    if genome not in ['GRCh38','GRCm38','hg19','hg38','mm10']:
        print('Unkown genome. Genome has to be one of the following: GRCh38, GRCm38, hg19, hg38, mm10')
        return None
    
    if genome.startswith('GRC'):
        seqStyle = 'NCBI'
    elif genome.startswith('hg') or genome.startswith('mm'):
        seqStyle = 'UCSC'
    
    #sample = atac.obs['sample'][0]
    print('analysing sample: ' + sample)
    print('Loading package BSgenome.' + species + '.' + seqStyle + '.' + genome)
    ro.globalenv['bs_genome'] = 'BSgenome.' + species + '.' + seqStyle + '.' + genome
    
    print('Loading package EnsDb.' + species + '.' + ensembl_release)
    ro.globalenv['ensdb'] = 'EnsDb.' + species + '.' + ensembl_release
    
    bl_dict = {'hg19':'blacklist_hg19',
               'hg38':'blacklist_hg38_unified',
               'GRCh38':'blacklist_hg38_unified',
               'mm10':'blacklist_mm10',
               'GRCm38':'blacklist_mm10'
              }
    
    print('Using Blacklist: ' + bl_dict[genome])
    ro.globalenv['blacklist_name'] = bl_dict[genome]
    ro.globalenv['genome'] = genome
    #workaround for blacklist to be renamed:
    seqStyle = 'NCBI'
    ro.globalenv['seqStyle'] = seqStyle
    
    
    ro.r('''
    library(Signac)
    library(Seurat)

    library(bs_genome,character.only=TRUE)
    library(ensdb,character.only=TRUE)
    library(stringr)

    # extract gene annotations from EnsDb
    annotations <- GetGRangesFromEnsDb(ensdb = get(ensdb))

    # rename blacklist
    blacklist = get(blacklist_name)
    if(seqStyle == "NCBI"){
        blacklist <- renameSeqlevels(blacklist, value = str_replace(str_replace(seqlevels(blacklist), pattern = "chr", replacement = ""), pattern = "M", replacement = "MT"))
    }
    ''')
    
    # Prepare data
    if aggregated:
        frag_path = cr_path + '/outs/atac_fragments.tsv.gz'
        outs_path = cr_path + '/outs/'
    else:
        frag_path = cr_path + sample + '/outs/atac_fragments.tsv.gz'
        outs_path = cr_path + sample + '/outs/'
    
    print(frag_path)
        
    ro.globalenv['count_mat'] = atac.X.T
    ro.globalenv['obs_names'] = atac.obs_names
    ro.globalenv['var_names'] = atac.var_names
    ro.globalenv['obs'] = atac.obs
    ro.globalenv['var'] = atac.var
    ro.globalenv['fragments'] = frag_path
    ro.globalenv['outs_path'] = outs_path
    
    # Generate Seurat object
    ro.r('''
    rownames(count_mat) <- var_names
    colnames(count_mat) <- obs_names
    chrom_assay <- CreateChromatinAssay(counts = count_mat, sep = c(":", "-"), fragments = fragments, annotation = annotations)
    seurat <- CreateSeuratObject(counts = chrom_assay, assay = "atac")
    ''')
    
    # Calculate QC metrics
    ro.r('''
    ## claculate FRiP
    #frag_counts <- CountFragments(fragments = fragments)
    #rownames(frag_counts) <- frag_counts$CB
    #frag_counts$fraction_reads_in_peaks <- frag_counts$frequency_count / frag_counts$reads_count
    #seurat <- AddMetaData(seurat, metadata = frag_counts[rownames(seurat@meta.data), -1, drop = FALSE])
    
    # claculate FRiP and fraction of mitochondrial reads
    per_barcode_metrics <- read.csv(paste0(outs_path, "per_barcode_metrics.csv"))
    rownames(per_barcode_metrics) <- per_barcode_metrics$barcode
    per_barcode_metrics <- per_barcode_metrics[per_barcode_metrics$is_cell == 1,]
    colnames(per_barcode_metrics) <- paste0("cr_",colnames(per_barcode_metrics))
    per_barcode_metrics$cr_fraction_fragments_in_peaks = per_barcode_metrics$cr_atac_peak_region_fragments / per_barcode_metrics$cr_atac_fragments
    per_barcode_metrics$cr_fraction_tss_fragments = per_barcode_metrics$cr_atac_TSS_fragments / per_barcode_metrics$cr_atac_fragments
    per_barcode_metrics$cr_fraction_reads_in_mito = per_barcode_metrics$cr_atac_mitochondrial_reads / per_barcode_metrics$cr_atac_raw_reads
    seurat <- AddMetaData(seurat, metadata = per_barcode_metrics[rownames(seurat@meta.data),22:ncol(per_barcode_metrics)])
    
    # calculate fraction of cut-site (i.e. counts) in blacklist
    seurat$fraction_counts_in_blacklist <- FractionCountsInRegion(
    object = seurat,
    assay = "atac",
    regions = blacklist)
    
    # compute nucleosome signal score per cell
    seurat <- NucleosomeSignal(object = seurat)

    # compute TSS enrichment score per cell
    seurat <- TSSEnrichment(object = seurat, fast = FALSE)
    ''')
    
    # TO DO: Get fragment sizes
    
    # Get results
    ro.r('''
    meta_data <- seurat@meta.data
    tss_matrix <- GetAssayData(object = seurat, assay = "atac", slot = "positionEnrichment")[["TSS"]]
    ''')
    atac.uns['tss_matrix'] = ro.globalenv['tss_matrix']
    atac_meta_data = ro.globalenv['meta_data']

    try:
        with (ro.default_converter + pandas2ri.converter).context():
            atac_meta_data = ro.conversion.get_conversion().rpy2py(atac_meta_data)
    except AttributeError:
        pass
        
    columns = ['nCount_atac', 'nFeature_atac', 'cr_atac_raw_reads',
       'cr_atac_unmapped_reads', 'cr_atac_lowmapq', 'cr_atac_dup_reads',
       'cr_atac_chimeric_reads', 'cr_atac_mitochondrial_reads',
       'cr_atac_fragments', 'cr_atac_TSS_fragments',
       'cr_atac_peak_region_fragments', 'cr_atac_peak_region_cutsites',
       'cr_fraction_fragments_in_peaks', 'cr_fraction_reads_in_mito',
       'fraction_counts_in_blacklist', 'nucleosome_signal',
       'nucleosome_percentile', 'TSS.enrichment', 'TSS.percentile', 'cr_fraction_tss_fragments']

    atac.obs.loc[:,columns] = atac_meta_data.loc[atac.obs_names,columns]
    
    atac.obs['log_nCount_atac'] = np.log(atac.obs['nCount_atac'])
    atac.obs['log_nFeature_atac'] = np.log(atac.obs['nFeature_atac'])
    
    if plot:
        
        from matplotlib.colors import LinearSegmentedColormap
        from itertools import chain

        # Set the number of rows and columns for subplots
        #n_rows = n_rows  # specify the desired number of rows
        n_cols = 1 #n_cols  # specify the desired number of columns
        n_rows = 1 #int(np.ceil(len(adata.obs[groupby].cat.categories)/n_cols))

        # Calculate the total figure size for square subplots
        figsize = (n_cols * 4, n_rows * 4)

        # Create a figure with subplots
        fig, ax = plt.subplots(n_rows, n_cols, figsize=figsize)  # specify the figure size

        # Create colored scatter & KDE plot on the current subplot
        x = list(range(-1000,1001,1))
        y = list(chain.from_iterable(atac.uns['tss_matrix'].mean(axis=0).A))
        sb.lineplot(x=x, y=y, c='black', linewidth=0.5, alpha=1, ax=ax)

        # Customize plot appearance
        ax.set(xlabel='Distance to TSS')
        ax.set(ylabel='Mean TSS Enrichment Score')
        ax.set_title(f'TSS Enrichnemt {sample}')

        plt.tight_layout()  # Adjust subplot layout

        plt.show()  # Display the plot
        plt.close()
        
        ################################################################
        
        keys=['log_nCount_atac', 
       'cr_fraction_fragments_in_peaks', 'cr_fraction_reads_in_mito',
       'fraction_counts_in_blacklist', 'nucleosome_signal',
       'TSS.enrichment']

        n_cols = 6

        # Set the number of rows and keys for subplots
        #n_rows = n_rows  # specify the desired number of rows
        n_cols = n_cols  # specify the desired number of keys
        n_rows = int(np.ceil(len(keys)/n_cols))

        # Calculate the total figure size for square subplots
        figsize = (n_cols * 2.5, n_rows * 4)

        # Create a figure with subplots
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)  # specify the figure size

        # Flatten the axes array to a 1D array
        axes = axes.flatten()

        # Loop over categories in groupby
        for i, key in enumerate(keys):
            # Create scatter plot on the current subplot
            ax = axes[i]
            y = atac.obs[key]
            sb.violinplot(y=y, color='#f69c73', inner='quart', linewidth=1, linecolor='black', ax=ax)
            sb.stripplot(y=y, color='black', size=0.75, alpha=0.5, jitter=0.25,  ax=ax)

            ## Customize plot appearance
            ax.set(xlabel='')
            ax.set(xticklabels=[])
            ax.tick_params(bottom=False)

        # Remove any extra empty subplots
        for i in range(len(keys), n_rows * n_cols):
            fig.delaxes(axes[i])

        plt.tight_layout()  # Adjust subplot layout

        plt.show()  # Display the plot
        plt.close()
        
        #############################################################
        
        p = sb.jointplot(x=atac.obs['log_nCount_atac'], y=atac.obs['TSS.enrichment'], n_levels=15, thresh=0.05, kind="kde", space=0, fill=True, cmap="rocket_r", color="#f69c73")
        p.plot_joint(sb.scatterplot, c='#cccccc', s=10, linewidth=0, alpha=0.5)
        p.plot_joint(sb.kdeplot, levels=15, fill=True, cmap='rocket_r', thresh=0.05, alpha=0.75)
        p.plot_joint(sb.kdeplot, levels=15, fill=False, c='black', thresh=0.05, alpha=0.5, linewidths=0.25)
        plt.suptitle(sample)
        plt.show()
        plt.close()
        
        #############################################################
        
        p = sb.jointplot(x=atac.obs['log_nCount_atac'], y=atac.obs['nucleosome_signal'], n_levels=15, thresh=0.05, kind="kde", space=0, fill=True, cmap="rocket_r", color="#f69c73")
        p.plot_joint(sb.scatterplot, c='#cccccc', s=10, linewidth=0, alpha=0.5)
        p.plot_joint(sb.kdeplot, levels=15, fill=True, cmap='rocket_r', thresh=0.05, alpha=0.75)
        p.plot_joint(sb.kdeplot, levels=15, fill=False, c='black', thresh=0.05, alpha=0.5, linewidths=0.25)
        plt.suptitle(sample)
        plt.show()
        plt.close()
        #############################################################
        
        p = sb.jointplot(x=atac.obs['TSS.enrichment'], y=atac.obs['nucleosome_signal'], n_levels=15, thresh=0.05, kind="kde", space=0, fill=True, cmap="rocket_r", color="#f69c73")
        p.plot_joint(sb.scatterplot, c='#cccccc', s=10, linewidth=0, alpha=0.5)
        p.plot_joint(sb.kdeplot, levels=15, fill=True, cmap='rocket_r', thresh=0.05, alpha=0.75)
        p.plot_joint(sb.kdeplot, levels=15, fill=False, c='black', thresh=0.05, alpha=0.5, linewidths=0.25)
        plt.suptitle(sample)
        plt.show()
        plt.close()

    
    return atac

## general

### spasify

In [None]:
def sparsify_mdata(mdata, modalities='all'):
   """
   Loop trough all modalities and make dense adata.X sparse.
   modalites: 'all' or list of modalities, e.g. ['rna','atac']
   """
   if modalities=='all':
       modalities = mdata.mod.keys()
       
   for mod in modalities:
       if not sci.sparse.issparse(mdata.mod[mod].X):
           density = sum(np.count_nonzero(mdata.mod[mod].X, axis=0))/(mdata.mod[mod].X.shape[0]*mdata.mod[mod].X.shape[1])
           if density < 2/3:
               print('Sparsify modality', mod)
               mdata.mod[mod].X = sci.sparse.csr_matrix(mdata.mod[mod].X)
           else:
               print('Modality', mod, 'is stored dense. Density is ', density*100, ' %.')
       else:
           print('Modality', mod, 'already sparse...')
           
###################################################################################################################
###################################################################################################################
###################################################################################################################
   
   
   
def sparsify_all_layers(adata):
   """
   Loop trough all layers and make dense matrices sparse.
   """
         
   if not sci.sparse.issparse(adata.X):
       density = sum(np.count_nonzero(adata.X, axis=0))/(adata.shape[0]*adata.X.shape[1])
       if density < 2/3:
           print('Sparsify .X...')
           adata.X = sci.sparse.csr_matrix(adata.X)
       else:
           print('.X is stored dense. Density is ', density*100, ' %.')
   else:
       print('.X already spase...')  
       
   for layer in list(adata.layers):
       if not sci.sparse.issparse(adata.layers[layer]):
           density = sum(np.count_nonzero(adata.layers[layer], axis=0))/(adata.shape[0]*adata.X.shape[1])
           if density < 2/3:
               print('Sparsify ', layer)
               adata.layers[layer] = sci.sparse.csr_matrix(adata.layers[layer])
           else:
               print('Layer', layer,' is stored dense. Density is ', density*100, ' %.')
       else:
           print('Layer', layer, 'already sparse...')
           
           
###################################################################################################################
###################################################################################################################
###################################################################################################################
   

def sparsify_all_layers_mdata(mdata, modalities='all'):
   """
   Loop trough all modalities and make dense adata.X sparse.
   modalites: 'all' or list of modalities, e.g. ['rna','atac']
   """
   if modalities=='all':
       modalities = mdata.mod.keys()
       
   for mod in modalities:
       if not sci.sparse.issparse(mdata.mod[mod].X):
           density = sum(np.count_nonzero(mdata.mod[mod].X, axis=0))/(mdata.mod[mod].X.shape[0]*mdata.mod[mod].X.shape[1])
           if density < 2/3:
               print('Sparsify .X in modality', mod)
               mdata.mod[mod].X = sci.sparse.csr_matrix(mdata.mod[mod].X)
           else:
               print('Modality', mod, 'is stored dense. Density is ', density*100, ' %.')
       else:
           print('.X in modality', mod, 'already sparse...')
           
       for layer in list(mdata.mod[mod].layers):
           if not sci.sparse.issparse(mdata.mod[mod].layers[layer]):
               density = sum(np.count_nonzero(mdata.mod[mod].layers[layer], axis=0))/(mdata.mod[mod].layers[layer].shape[0]*mdata.mod[mod].layers[layer].shape[1])
               if density < 2/3:
                   print('Sparsify ', layer, ' in modality', mod)
                   mdata.mod[mod].layers[layer] = sci.sparse.csr_matrix(mdata.mod[mod].layers[layer])
               else:
                   print('Layer', layer, ' in modality', mod, 'already sprase...')
           else:
               print('Layer', layer, ' in modality', mod, ' is stored dense. Density is ', density*100, ' %.')

## Doublet detection

#### def threshold

In [None]:
def threshold(
    clf,
    show=False,
    save=None,
    log10=True,
    log_p_grid=None,
    voter_grid=None,
    v_step=2,
    p_step=5,
):
    """Produce a plot showing number of cells called doublet across
       various thresholds
    Args:
        clf (BoostClassifier object): Fitted classifier
        show (bool, optional): If True, runs plt.show()
        save (str, optional): If provided, the figure is saved to this
            filepath.
        log10 (bool, optional): Use log 10 if true, natural log if false.
        log_p_grid (ndarray, optional): log p-value thresholds to use.
            Defaults to np.arange(-100, -1). log base decided by log10
        voter_grid (ndarray, optional): Voting thresholds to use. Defaults to
            np.arange(0.3, 1.0, 0.05).
        p_step (int, optional): number of xlabels to skip in plot
        v_step (int, optional): number of ylabels to skip in plot
    Returns:
        matplotlib figure
    """
    import warnings
    # Ignore numpy complaining about np.nan comparisons
    with np.errstate(invalid="ignore"):
        all_log_p_values_ = np.copy(clf.all_log_p_values_)
        if log10:
            all_log_p_values_ /= np.log(10)
        if log_p_grid is None:
            log_p_grid = np.arange(-100, -1)
        if voter_grid is None:
            voter_grid = np.arange(0.3, 1.0, 0.05)
        doubs_per_t = np.zeros((len(voter_grid), len(log_p_grid)))
        for i in range(len(voter_grid)):
            for j in range(len(log_p_grid)):
                voting_average = np.mean(
                    np.ma.masked_invalid(all_log_p_values_) <= log_p_grid[j], axis=0
                )
                labels = np.ma.filled((voting_average >= voter_grid[i]).astype(float), np.nan)
                doubs_per_t[i, j] = np.nansum(labels)

    # Ignore warning for convergence plot
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", module="matplotlib", message="^tight_layout")

        f, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=150)
        cax = ax.imshow(doubs_per_t, cmap="turbo", aspect="auto")
        ax.set_xticks(np.arange(len(log_p_grid))[::p_step])
        ax.set_xticklabels(np.around(log_p_grid, 1)[::p_step], rotation="vertical")
        ax.set_yticks(np.arange(len(voter_grid))[::v_step])
        ax.set_yticklabels(np.around(voter_grid, 2)[::v_step])
        cbar = f.colorbar(cax)
        cbar.set_label("Predicted Doublets")
        if log10 is True:
            ax.set_xlabel("Log10 p-value")
        else:
            ax.set_xlabel("Log p-value")
        ax.set_ylabel("Voting Threshold")
        ax.set_title("Threshold Diagnostics")

    if show is True:
        plt.show()
    if save:
        f.savefig(save, format="pdf", bbox_inches="tight")

    return f


#### def umap_plot

In [6]:
def umap_plot(
    raw_counts,
    labels,
    n_components=30,
    show=False,
    save=None,
    normalizer= None,
    random_state=None,
):
    """Produce a umap plot of the data with doublets in black.

        Count matrix is normalized and dimension reduced before plotting.

    Args:
        raw_counts (array-like): Count matrix, oriented cells by genes.
        labels (ndarray): predicted doublets from predict method
        n_components (int, optional): number of PCs to use prior to UMAP
        show (bool, optional): If True, runs plt.show()
        save (str, optional): filename for saved figure,
            figure not saved by default
        normalizer ((ndarray) -> ndarray, optional): Method to normalize
            raw_counts. Defaults to normalize_counts, included in this package.
            Note: To use normalize_counts with its pseudocount parameter changed
            from the default 0.1 value to some positive float `new_var`, use:
            normalizer=lambda counts: doubletdetection.normalize_counts(counts,
            pseudocount=new_var)
        random_state (int, optional): If provided, passed to PCA and UMAP

    Returns:
        matplotlib figure
        ndarray: umap reduction
    """
    import doubletdetection
    import os
    import warnings

    import matplotlib
    import numpy as np
    import umap
    from sklearn.decomposition import PCA
    from sklearn.utils import check_array
    try:
        raw_counts = check_array(
            raw_counts, accept_sparse=False, force_all_finite=True, ensure_2d=True
        )
    except TypeError:  # Only catches sparse error. Non-finite & n_dims still raised.
        warnings.warn("Sparse raw_counts is automatically densified.")
        raw_counts = raw_counts.toarray()
    norm_counts = doubletdetection.plot.normalize_counts(raw_counts)
    reduced_counts = PCA(
        n_components=n_components, svd_solver="randomized", random_state=random_state
    ).fit_transform(norm_counts)
    umap_dr = umap.UMAP(random_state=random_state, min_dist=0.5).fit_transform(
        reduced_counts
    )
    # Ensure only looking at positively identified doublets
    doublets = labels == 1

    fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=150)
    axes.scatter(
        umap_dr[:, 0],
        umap_dr[:, 1],
        c="grey",
        cmap=plt.cm.tab20,
        s=1,
        label="predicted singlets",
    )
    axes.scatter(
        umap_dr[:, 0][doublets],
        umap_dr[:, 1][doublets],
        s=3,
        c="black",
        label="predicted doublets",
    )
    axes.axis("off")
    axes.legend(frameon=False)
    axes.set_title(
        "{} doublets out of {} cells\n {}% cross-type doublet rate".format(
            np.sum(doublets),
            raw_counts.shape[0],
            np.round(100 * np.sum(doublets) / raw_counts.shape[0], 2),
        )
    )

    if show is True:
        plt.show()
    if isinstance(save, str):
        fig.savefig(save, format="pdf", bbox_inches="tight")

    return fig, umap_dr

#### def run_scDblFinder

In [None]:
def run_scDblFinder(adata, force_reload=False, layer=None, n_core=20, max_memory_gb=64):
    '''
    adata: adata object to normalize
    layer: layer to use for normalization. Default = None -> use .X
    force_reload: Force transfer of count data to R
    '''
    
    import rpy2
    import rpy2.robjects as ro
    import gc
    import doubletdetection

       
    print('Finding doublets with scDblFinder:')
    # load packages
    ro.globalenv['n_core'] = n_core
    ro.globalenv['max_memory'] = max_memory_gb #/n_core
    ro.r('''
    print(paste0("Cores: ", n_core))
    print(paste0("Memory: ", max_memory))
    ''')
    ro.r('''

    # Analysis
    library(Seurat)
    library(sctransform)
    library(scDblFinder)
    library(SingleCellExperiment)
    library(scater)
    library(pastecs)

    # Parallelization
    library(BiocParallel)
    register(MulticoreParam(n_core, progressbar = TRUE))

    library(future)
    plan(multicore, workers = n_core)
    options(future.globals.maxSize = max_memory * 1024^3) # for 50 Gb RAM
    plan()

    library(doParallel)
    registerDoParallel(n_core)
    ''')
    # transfer data
    print('\tTransfer data...')
    
    # check if data is in R workspace
    if ro.r('''exists('data_mat')''')[0] == 1:
        # check if data has same shape
        if ro.globalenv['data_mat'].shape == adata.X.T.shape:
            load_data = False
            print('\t\tFound data matrix of same shape. Skipping data transfer...')
        else:
            load_data = True
    else:
        load_data = True
        
    if force_reload:
        load_data = True
    
    if load_data:
        if layer is None:
            ro.globalenv['data_mat'] = adata.X.T#.toarray()
            ro.globalenv['obs_names'] = adata.obs_names
            ro.globalenv['var_names'] = adata.var_names
        else:
            print('\tUsing layer \'', layer,'\'...')
            ro.globalenv['data_mat'] = adata.layers[layer].T#.toarray()
            ro.globalenv['obs_names'] = adata.obs_names
            ro.globalenv['var_names'] = adata.var_names
        
        ro.r('''
        rownames(data_mat) <- var_names
        colnames(data_mat) <- obs_names
        ''') 
    # standart preprocessing
    ## create Seurat object    
    ro.r('''
    seurat <- CreateSeuratObject(counts = data_mat, project = "0", min.cells = 0, min.features = 0)
    ''')   
    ## preprocessing
    print('\tDoublet Detection with standard normalization...')
    print('\t\tPreprocessing...')
    ro.r('''
    seurat <- NormalizeData(seurat, verbose = FALSE)
    seurat <- FindVariableFeatures(seurat, selection.method = "vst", nfeatures = 5000, verbose = FALSE)
    seurat <- ScaleData(seurat, verbose = FALSE)
    seurat <- RunPCA(seurat, npcs = 50, verbose = FALSE)
    seurat <- RunUMAP(seurat, reduction = "pca", dims = 1:50)
    seurat <- FindNeighbors(seurat, dims = 1:50, verbose = FALSE)
    seurat <- FindClusters(seurat, verbose = FALSE, resolution = 0.5)
    #print(DimPlot(seurat, label = TRUE))
    
    #Conversion to SingleCellExperiment
    sce <- as.SingleCellExperiment(seurat)
    ''')
    ## run scDblFinder
    print('\t\tRunning scDblFinder...')
    ro.r('''
    #scDblFinder
    colData(sce)$scoresDoubletDensity <- computeDoubletDensity(sce)
    sce <- scDblFinder(sce, clusters = FALSE) #, dbr=0.1)
    ''')
    
    
    ## get results   
    print('\t\tCollect results...')
    ro.r('''
    results <- colData(sce)[,c("scDblFinder.class", "scDblFinder.score")]    
    ''')
    
    # sct preprocessing
    ## create Seurat object    
    ro.r('''
    seurat <- CreateSeuratObject(counts = data_mat, project = "0", min.cells = 0, min.features = 0)
    ''')   
    ## preprocessing
    print('\tDoublet Detection with SCT normalization...')
    print('\t\tPreprocessing...')
    ro.r('''
    seurat <- SCTransform(seurat, verbose = FALSE)
    seurat <- RunPCA(seurat, npcs = 50, verbose = FALSE)
    seurat <- RunUMAP(seurat, reduction = "pca", dims = 1:50)
    seurat <- FindNeighbors(seurat, dims = 1:50, verbose = FALSE)
    seurat <- FindClusters(seurat, verbose = FALSE, resolution = 0.5)
    #print(DimPlot(seurat, label = TRUE))
    
    #Conversion to SingleCellExperiment
    sce <- as.SingleCellExperiment(seurat)
    ''')
    ## run scDblFinder
    print('\t\tRunning scDblFinder...')
    ro.r('''
    #scDblFinder
    colData(sce)$scoresDoubletDensity <- computeDoubletDensity(sce)
    sce <- scDblFinder(sce, clusters = FALSE) #, dbr=0.1)
    ''')
    
    
    ## get results   
    print('\t\tCollect results...')
    ro.r('''
    results <- cbind(results, colData(sce)[,c("scDblFinder.class", "scDblFinder.score")])
    colnames(results) <- c("scDblFinder.class", "scDblFinder.score", "scDblFinder.class.sct", "scDblFinder.score.sct")
    ''')
    print('\t\tAdd results to anndata...')
    results = ro.globalenv['results']
    
    # check if results are already present in adata.obs and delete
    if 'scDblFinder.class' in adata.obs.columns:
        del adata.obs[["scDblFinder.class.sct", "scDblFinder.score.sct", "scDblFinder.class.sct", "scDblFinder.score.sct"]]
    
    adata.obs = pd.merge(adata.obs,results, left_index=True, right_index=True) #adata.obs[["scDblFinder.class.sct", "scDblFinder.score.sct"]] = results.copy()
    
    adata.obs.loc[:,'sdf_doublets'] = False
    adata.obs.loc[adata.obs.loc[:,'scDblFinder.class']=='doublet','sdf_doublets'] = True
    adata.obs.loc[adata.obs.loc[:,'scDblFinder.class.sct']=='doublet','sdf_doublets'] = True
    
    print('\n\n------------------------------------------------------------------------------------\n------------------------------------------------------------------------------------' )
    print('\nscDblFinder doublet rate:', adata.obs['sdf_doublets'].value_counts()[1]/adata.obs['sample'].value_counts()[0]*100, '% (',adata.obs['sdf_doublets'].value_counts()[1],' cells)' )
    
    #return adata


#### def run_DoubletFinder

In [None]:
def run_DoubletFinder(adata, force_reload=False, layer=None, n_core=20, max_memory_gb=64):
    '''
    adata: adata object to normalize
    layer: layer to use for normalization. Default = None -> use .X
    force_reload: Force transfer of count data to R
    '''
    
    import rpy2
    import rpy2.robjects as ro
    import gc

       
    print('Finding doublets with scDblFinder:')
    # load packages
    ro.globalenv['n_core'] = n_core
    ro.globalenv['max_memory'] = max_memory_gb #/n_core
    ro.r('''
    print(paste0("Cores: ", n_core))
    print(paste0("Memory: ", max_memory))
    ''')
    ro.r('''

    # Analysis
    library(Seurat)
    library(sctransform)
    library(DoubletFinder)
    library(SingleCellExperiment)
    library(scater)
    library(pastecs)

    # Parallelization
    library(BiocParallel)
    register(MulticoreParam(n_core, progressbar = TRUE))

    library(future)
    plan(multicore, workers = n_core)
    options(future.globals.maxSize = max_memory * 1024^3) # for 50 Gb RAM
    plan()

    library(doParallel)
    registerDoParallel(n_core)
    
    # Adaption of original funtion to omit plot (https://github.com/chris-mcginnis-ucsf/DoubletFinder/blob/5dfd96b06365d7843adf3a72ffb6a30f42c74a01/R/find.pK.R)
    find.pK.noPlot <- function(sweep.stats) {

      ## Implementation for data without ground-truth doublet classifications 
      '%ni%' <- Negate('%in%')
      if ("AUC" %ni% colnames(sweep.stats) == TRUE) {
        ## Initialize data structure for results storage
        bc.mvn <- as.data.frame(matrix(0L, nrow=length(unique(sweep.stats$pK)), ncol=5))
        colnames(bc.mvn) <- c("ParamID","pK","MeanBC","VarBC","BCmetric")
        bc.mvn$pK <- unique(sweep.stats$pK)
        bc.mvn$ParamID <- 1:nrow(bc.mvn)

        ## Compute bimodality coefficient mean, variance, and BCmvn across pN-pK sweep results
        x <- 0
        for (i in unique(bc.mvn$pK)) {
          x <- x + 1
          ind <- which(sweep.stats$pK == i)
          bc.mvn$MeanBC[x] <- mean(sweep.stats[ind, "BCreal"])
          bc.mvn$VarBC[x] <- sd(sweep.stats[ind, "BCreal"])^2
          bc.mvn$BCmetric[x] <- mean(sweep.stats[ind, "BCreal"])/(sd(sweep.stats[ind, "BCreal"])^2)
        }

        return(bc.mvn)

      }

      ## Implementation for data with ground-truth doublet classifications (e.g., MULTI-seq, CellHashing, Demuxlet, etc.)
      if ("AUC" %in% colnames(sweep.stats) == TRUE) {
        ## Initialize data structure for results storage
        bc.mvn <- as.data.frame(matrix(0L, nrow=length(unique(sweep.stats$pK)), ncol=6))
        colnames(bc.mvn) <- c("ParamID","pK","MeanAUC","MeanBC","VarBC","BCmetric")
        bc.mvn$pK <- unique(sweep.stats$pK)
        bc.mvn$ParamID <- 1:nrow(bc.mvn)

        ## Compute bimodality coefficient mean, variance, and BCmvn across pN-pK sweep results
        x <- 0
        for (i in unique(bc.mvn$pK)) {
          x <- x + 1
          ind <- which(sweep.stats$pK == i)
          bc.mvn$MeanAUC[x] <- mean(sweep.stats[ind, "AUC"])
          bc.mvn$MeanBC[x] <- mean(sweep.stats[ind, "BCreal"])
          bc.mvn$VarBC[x] <- sd(sweep.stats[ind, "BCreal"])^2
          bc.mvn$BCmetric[x] <- mean(sweep.stats[ind, "BCreal"])/(sd(sweep.stats[ind, "BCreal"])^2)
        }

        return(bc.mvn)

      }
    }
    
    ''')
    # transfer data
    print('\tTransfer data...')
    
    # check if data is in R workspace
    if ro.r('''exists('data_mat')''')[0] == 1:
        # check if data has same shape
        if ro.globalenv['data_mat'].shape == adata.X.T.shape:
            load_data = False
            print('\t\tFound data matrix of same shape. Skipping data transfer...')
        else:
            load_data = True
    else:
        load_data = True
        
    if force_reload:
        load_data = True
    
    if load_data:
        if layer is None:
            ro.globalenv['data_mat'] = adata.X.T#.toarray()
            ro.globalenv['obs_names'] = adata.obs_names
            ro.globalenv['var_names'] = adata.var_names
        else:
            print('\tUsing layer \'', layer,'\'...')
            ro.globalenv['data_mat'] = adata.layers[layer].T#.toarray()
            ro.globalenv['obs_names'] = adata.obs_names
            ro.globalenv['var_names'] = adata.var_names
        
        ro.r('''
        rownames(data_mat) <- var_names
        colnames(data_mat) <- obs_names
        ''') 
    # standart preprocessing
    ## create Seurat object    
    ro.r('''
    seurat <- CreateSeuratObject(counts = data_mat, project = "0", min.cells = 0, min.features = 0)
    ''')   
    ## preprocessing
    print('\tDoublet Detection with standard normalization...')
    print('\t\tPreprocessing...')
    ro.r('''
    seurat <- NormalizeData(seurat, verbose = FALSE)
    seurat <- FindVariableFeatures(seurat, selection.method = "vst", nfeatures = 5000, verbose = FALSE)
    seurat <- ScaleData(seurat, verbose = FALSE)
    seurat <- RunPCA(seurat, npcs = 50, verbose = FALSE)
    seurat <- RunUMAP(seurat, reduction = "pca", dims = 1:50)
    seurat <- FindNeighbors(seurat, dims = 1:50, verbose = FALSE)
    seurat <- FindClusters(seurat, verbose = FALSE, resolution = 0.5)
    #print(DimPlot(seurat, label = TRUE))
    ''')
    ## run DoubletFinder
    print('\t\tRunning DoubletFinder...')
    ro.r('''
    ## pK Identification (no ground-truth) ---------------------------------------------------------------------------------------
    sweep.res.list <- paramSweep(seurat, PCs = 1:50, num.cores = n_core, sct = FALSE)
    sweep.stats <- summarizeSweep(sweep.res.list, GT = FALSE)
    bcmvn <- find.pK.noPlot(sweep.stats)
    
    ## Homotypic Doublet Proportion Estimate -------------------------------------------------------------------------------------
    homotypic.prop <- modelHomotypic(seurat@meta.data$seurat_clusters)           ## ex: annotations <- seurat.list[[1]]@meta.data$ClusteringResults
    nExp_poi <- round(0.1*length(seurat@meta.data$seurat_clusters))  # I guess that doublet formation rate is higher than the ~7.5% estimated from 10x if doublets are present in input cell suspension -> set to 10%  ## Assuming 7.5% doublet formation rate - tailor for your dataset
    nExp_poi.adj <- round(nExp_poi*(1-homotypic.prop))
    
    ## Run DoubletFinder with varying classification stringencies ----------------------------------------------------------------
    seurat <- doubletFinder(seurat, 
                                              PCs = 1:50, 
                                              pN = 0.25, 
                                              pK = as.numeric(as.character(bcmvn$pK[which.max(bcmvn$BCmetric)])), 
                                              nExp = nExp_poi, 
                                              reuse.pANN = FALSE, 
                                              sct = FALSE)
    
    seurat <- doubletFinder(seurat, 
                                              PCs = 1:50, 
                                              pN = 0.25, 
                                              pK = as.numeric(as.character(bcmvn$pK[which.max(bcmvn$BCmetric)])), 
                                              nExp = nExp_poi.adj, 
                                              reuse.pANN = paste0("pANN_0.25_",as.character(bcmvn$pK[which.max(bcmvn$BCmetric)]),"_",nExp_poi), 
                                              sct = FALSE)
    ''')
       
    
    ## get results   
    print('\t\tCollect results...')
    ro.r('''
    results <- seurat@meta.data[,6:8]
    colnames(results) <- c("pANN","DF_classifications_1","DF_classifications_2")
    ''')
    
    # sct preprocessing
    ## create Seurat object    
    ro.r('''
    seurat <- CreateSeuratObject(counts = data_mat, project = "0", min.cells = 0, min.features = 0)
    ''')   
    ## preprocessing
    print('\tDoublet Detection with SCT normalization...')
    print('\t\tPreprocessing...')
    ro.r('''
    seurat <- SCTransform(seurat, verbose = FALSE)
    seurat <- RunPCA(seurat, npcs = 50, verbose = FALSE)
    seurat <- RunUMAP(seurat, reduction = "pca", dims = 1:50)
    seurat <- FindNeighbors(seurat, dims = 1:50, verbose = FALSE)
    seurat <- FindClusters(seurat, verbose = FALSE, resolution = 0.5)
    #print(DimPlot(seurat, label = TRUE))
    ''')
    ## run DoubletFinder
    print('\t\tRunning DoubletFinder...')
    ro.r('''
    ## pK Identification (no ground-truth) ---------------------------------------------------------------------------------------
    sweep.res.list <- paramSweep(seurat, PCs = 1:50, num.cores = n_core, sct = TRUE)
    sweep.stats <- summarizeSweep(sweep.res.list, GT = FALSE)
    bcmvn <- find.pK.noPlot(sweep.stats)
    
    ## Homotypic Doublet Proportion Estimate -------------------------------------------------------------------------------------
    homotypic.prop <- modelHomotypic(seurat@meta.data$seurat_clusters)           ## ex: annotations <- seurat.list[[1]]@meta.data$ClusteringResults
    nExp_poi <- round(0.1*length(seurat@meta.data$seurat_clusters))  # I guess that doublet formation rate is higher than the ~7.5% estimated from 10x if doublets are present in input cell suspension -> set to 10%  ## Assuming 7.5% doublet formation rate - tailor for your dataset
    nExp_poi.adj <- round(nExp_poi*(1-homotypic.prop))
    
    ## Run DoubletFinder with varying classification stringencies ----------------------------------------------------------------
    seurat <- doubletFinder(seurat, 
                                              PCs = 1:50, 
                                              pN = 0.25, 
                                              pK = as.numeric(as.character(bcmvn$pK[which.max(bcmvn$BCmetric)])), 
                                              nExp = nExp_poi, 
                                              reuse.pANN = FALSE, 
                                              sct = TRUE)
    
    seurat <- doubletFinder(seurat, 
                                              PCs = 1:50, 
                                              pN = 0.25, 
                                              pK = as.numeric(as.character(bcmvn$pK[which.max(bcmvn$BCmetric)])), 
                                              nExp = nExp_poi.adj, 
                                              reuse.pANN = paste0("pANN_0.25_",as.character(bcmvn$pK[which.max(bcmvn$BCmetric)]),"_",nExp_poi), 
                                              sct = TRUE)
    ''')
    
    
    ## get results   
    print('\t\tCollect results...')
    ro.r('''
    results <- cbind(results, seurat@meta.data[,8:10])
    colnames(results) <- c("pANN","DF_classifications_1","DF_classifications_2", "pANN.sct","DF_classifications_1.sct","DF_classifications_2.sct")
    ''')
    print('\t\tAdd results to anndata...')
    results = ro.globalenv['results']
    with (ro.default_converter + pandas2ri.converter).context():
        results = ro.conversion.get_conversion().rpy2py(results) 
    
    # check if results are already present in adata.obs and delete
    if 'pANN' in adata.obs.columns:
        del adata.obs[["pANN","DF_classifications_1","DF_classifications_2", "pANN.sct","DF_classifications_1.sct","DF_classifications_2.sct"]]
    
    adata.obs = pd.merge(adata.obs,results, left_index=True, right_index=True) 
    
    adata.obs.loc[:,'df_doublets'] = False
    adata.obs.loc[adata.obs.loc[:,'DF_classifications_1']=='Doublet','df_doublets'] = True
    adata.obs.loc[adata.obs.loc[:,'DF_classifications_1.sct']=='Doublet','df_doublets'] = True
    
    print('\n\n------------------------------------------------------------------------------------\n------------------------------------------------------------------------------------' )
    print('\nDoubletFinder doublet rate: ', adata.obs['df_doublets'].value_counts()[1]/adata.obs['sample'].value_counts()[0]*100, '% (',adata.obs['df_doublets'].value_counts()[1],' cells)' )
    
    #return adata   

#### def run_scDblFinder_ATAC

In [None]:
def run_scDblFinder_ATAC(adata, repeats_file='/mnt/ssd/Genomes/mm10/Repeats/AMULET_Exclusion_List_Regions/AMULET_exclusion_regions_noChr.bed', nfeatures=25, dbr=0.1, force_reload=False, layer=None, n_core=20, max_memory_gb=64):
    '''
    adata: adata object to normalize
    repeats_file: Path to BED file with repeats and other exclusion regions e.g. '/mnt/ssd/Genomes/mm10/Repeats/AMULET_Exclusion_List_Regions/AMULET_exclusion_regions_noChr.bed'
    layer: layer to use for normalization. Default = None -> use .X
    force_reload: Force transfer of count data to R
    '''
    
    import rpy2
    import rpy2.robjects as ro
    import gc

       
    print('\nFinding scATAC-seq doublets with scDblFinder:')
    # load packages
    ro.globalenv['dbr'] = dbr
    ro.globalenv['nfeatures'] = nfeatures
    ro.globalenv['repeats_file'] = repeats_file
    ro.globalenv['n_core'] = n_core
    ro.globalenv['max_memory'] = max_memory_gb #/n_core
    ro.r('''
    print(paste0("Cores: ", n_core))
    print(paste0("Memory: ", max_memory))
    ''')
    ro.r('''

    # Analysis
    library(Seurat)
    library(sctransform)
    library(scDblFinder)
    library(SingleCellExperiment)
    library(scater)
    library(pastecs)
    library(GenomicRanges)

    # Parallelization
    library(BiocParallel)
    register(MulticoreParam(n_core, progressbar = TRUE))

    library(future)
    plan(multicore, workers = n_core)
    options(future.globals.maxSize = max_memory * 1024^3) # for 50 Gb RAM
    plan()

    library(doParallel)
    registerDoParallel(n_core)
    ''')
    # transfer data
    print('\tTransfer data...')
    
    # check if data is in R workspace
    if ro.r('''exists('data_mat')''')[0] == 1:
        # check if data has same shape
        if ro.globalenv['data_mat'].shape == adata.X.T.shape:
            load_data = False
            print('\t\tFound data matrix of same shape. Skipping data transfer...')
        else:
            load_data = True
    else:
        load_data = True
        
    if force_reload:
        load_data = True
    
    if load_data:
        if layer is None:
            ro.globalenv['data_mat'] = adata.X.T#.toarray()
            ro.globalenv['obs_names'] = adata.obs_names
            ro.globalenv['var_names'] = adata.var_names
        else:
            print('\tUsing layer \'', layer,'\'...')
            ro.globalenv['data_mat'] = adata.layers[layer].T#.toarray()
            ro.globalenv['obs_names'] = adata.obs_names
            ro.globalenv['var_names'] = adata.var_names
        
        ro.r('''
        rownames(data_mat) <- var_names
        colnames(data_mat) <- obs_names
        ''') 
        
    # prepare exclusion list
    ro.r('''
    repeats <- read.delim(repeats_file, header=FALSE)
    repeats <- makeGRangesFromDataFrame(repeats, seqnames.field = "V1", start.field = "V2", end.field = "V3")
    #repeats <- GRanges("6", IRanges(1000,2000))
    
    otherChroms <- GRanges(c("M","chrM","MT","X","Y","chrX","chrY"),IRanges(1L,width=10^8))
    
    toExclude <- suppressWarnings(c(repeats, otherChroms))
    ''')
    
    # get fragments file path
    ro.globalenv['fragments'] = adata.uns['files']['fragments']
#     ro.r('''
#     fragments <- system.file("extdata", "example_fragments.tsv.gz", package="scDblFinder")
#     ''')
    
    # create SingleCellExperiment object    
    ro.r('''
    sce <- SingleCellExperiment(assays=list(counts=data_mat))
    ''')

    ## run scDblFinder
    print('\tRunning scDblFinder...')
    ro.r('''
    #scDblFinder
    colData(sce)$scoresDoubletDensity <- computeDoubletDensity(sce)

    error <- 1
    while(error == 1){
        catch <- tryCatch(sce <- scDblFinder(sce, aggregateFeatures=TRUE, nfeatures=nfeatures, processing="normFeatures", dbr=dbr),
                error=function(error){
                warning(error)
                return(list(sce, catch = "FAILED"))
                })
        if (is.null(catch[["catch"]])) {
          error <- 0
          print(paste0("\t\tscDblFinder finished with ",nfeatures," features..."))
          rm(catch)
        } else if (catch[["catch"]] == "FAILED"){
            print(paste0("\t\tscDblFinder failed with ",nfeatures," features!"))
            nfeatures <- nfeatures + 5
        }
    }
    gc(verbose = TRUE, reset = FALSE, full = TRUE)
    ''')   
    
    ## run AMULET
    print('\tRunning AMULET...')
    ro.r('''
    #AMULET
    results <- amulet(fragments, regionsToExclude=toExclude, fullInMemory=TRUE)#, BPPARAM=MulticoreParam(n_core))
    colnames(results) <- paste0('amulet.', colnames(results))
    gc(verbose = TRUE, reset = FALSE, full = TRUE)
    ''')
    
    ## get results   
    print('\tCollect results...')
    ro.r('''
    results$atac.scDblFinder.p <- 1-colData(sce)[row.names(results), "scDblFinder.score"]
    results$atac.combined <- apply(results[,c("atac.scDblFinder.p", "amulet.p.value")], 1, FUN=function(x){
      x[x<0.001] <- 0.001 # prevent too much skew from very small or 0 p-values
      suppressWarnings(aggregation::fisher(x))
    })
    results$atac.combined.score <- -results$atac.combined + 1
    #results$atac.combined.score <- -results$amulet.p.value + 1
    
    results$atac.combined.class <- doubletThresholding(data.frame('score'=results$atac.combined.score), dbr=dbr)
    gc(verbose = TRUE, reset = FALSE, full = TRUE)
    ''')
    
    
    print('\tAdd results to anndata...')
    results = ro.globalenv['results']
    with (ro.default_converter + pandas2ri.converter).context():
        results = ro.conversion.get_conversion().rpy2py(results) 
    
    # check if results are already present in adata.obs and delete
    if results.columns[0] in adata.obs.columns:
        del adata.obs[results.columns]
    
    adata.obs = pd.merge(adata.obs,results, left_index=True, right_index=True) #adata.obs[["scDblFinder.class.sct", "scDblFinder.score.sct"]] = results.copy()
    
#     adata.obs.loc[:,'atac_sdf_doublets'] = False
#     adata.obs.loc[adata.obs.loc[:,'atac.combined.class']=='doublet','sdf_doublets'] = True
    
#     print('\n\n------------------------------------------------------------------------------------\n------------------------------------------------------------------------------------' )
#     print('\nATAC scDblFinder doublet rate:', adata.obs['atac_sdf_doublets'].value_counts()[1]/adata.obs['sample'].value_counts()[0]*100, '% (',adata.obs['atac_sdf_doublets'].value_counts()[1],' cells)' )
    
#     return results

#### def run_SCDS

In [None]:
def run_SCDS(adata, force_reload=False, layer=None, n_core=20, max_memory_gb=64):
    '''
    adata: adata object to normalize
    layer: layer to use for normalization. Default = None -> use .X
    force_reload: Force transfer of count data to R
    '''
    
    import rpy2
    import rpy2.robjects as ro
    import gc

       
    print('Finding doublets with scDblFinder:')
    # load packages
    ro.globalenv['n_core'] = n_core
    ro.globalenv['max_memory'] = max_memory_gb #/n_core
    ro.r('''
    print(paste0("Cores: ", n_core))
    print(paste0("Memory: ", max_memory))
    ''')
    ro.r('''

    # Analysis
    library(Seurat)
    library(sctransform)
    library(scds)
    library(SingleCellExperiment)
    library(scater)
    library(pastecs)

    # Parallelization
    library(BiocParallel)
    register(MulticoreParam(n_core, progressbar = TRUE))

    library(future)
    plan(multicore, workers = n_core)
    options(future.globals.maxSize = max_memory * 1024^3) # for 50 Gb RAM
    plan()

    library(doParallel)
    registerDoParallel(n_core)
    ''')
    # transfer data
    print('\tTransfer data...')
    
    # check if data is in R workspace
    if ro.r('''exists('data_mat')''')[0] == 1:
        # check if data has same shape
        if ro.globalenv['data_mat'].shape == adata.X.T.shape:
            load_data = False
            print('\t\tFound data matrix of same shape. Skipping data transfer...')
        else:
            load_data = True
    else:
        load_data = True
        
    if force_reload:
        load_data = True
    
    if load_data:
        if layer is None:
            ro.globalenv['data_mat'] = adata.X.T#.toarray()
            ro.globalenv['obs_names'] = adata.obs_names
            ro.globalenv['var_names'] = adata.var_names
        else:
            print('\tUsing layer \'', layer,'\'...')
            ro.globalenv['data_mat'] = adata.layers[layer].T#.toarray()
            ro.globalenv['obs_names'] = adata.obs_names
            ro.globalenv['var_names'] = adata.var_names
        
        ro.r('''
        rownames(data_mat) <- var_names
        colnames(data_mat) <- obs_names
        ''') 
    # standart preprocessing
    ## create Seurat object    
    ro.r('''
    seurat <- CreateSeuratObject(counts = data_mat, project = "0", min.cells = 0, min.features = 0)
    ''')   
    ## preprocessing
    print('\tDoublet Detection with standard normalization...')
    print('\t\tPreprocessing...')
    ro.r('''
    seurat <- NormalizeData(seurat, verbose = FALSE)
    seurat <- FindVariableFeatures(seurat, selection.method = "vst", nfeatures = 5000, verbose = FALSE)
    seurat <- ScaleData(seurat, verbose = FALSE)
    seurat <- RunPCA(seurat, npcs = 50, verbose = FALSE)
    seurat <- RunUMAP(seurat, reduction = "pca", dims = 1:50)
    seurat <- FindNeighbors(seurat, dims = 1:50, verbose = FALSE)
    seurat <- FindClusters(seurat, verbose = FALSE, resolution = 0.5)
    #print(DimPlot(seurat, label = TRUE))
    
    #Conversion to SingleCellExperiment
    sce <- as.SingleCellExperiment(seurat)
    ''')
    ## run SCDS
    print('\t\tRunning SCDS...')
    ro.r('''
    # SCDS
    sce <- cxds(sce, retRes = TRUE)
    sce <- bcds(sce, retRes = TRUE, verb=TRUE)
    sce <- cxds_bcds_hybrid(sce)
    
    dens <- density(sce$hybrid_score)
    min_idx <- match(-1, extract(turnpoints(dens$y, calc.proba = TRUE)))
    cut_off <- dens$x[min_idx[length(min_idx)]]
    
    #print(ggplot(as.data.frame(colData(sce)), aes(x=hybrid_score)) + geom_density() + geom_vline(xintercept = cut_off, linetype=2))
    
    sce$hybrid_class <- "doublet"
    sce[,sce$hybrid_score < cut_off]$hybrid_class <- "singlet"
    ''')
    
    
    ## get results   
    print('\t\tCollect results...')
    ro.r('''
    results <- colData(sce)[,c("hybrid_class", "hybrid_score")]    
    ''')
    
    # sct preprocessing
    ## create Seurat object    
    ro.r('''
    seurat <- CreateSeuratObject(counts = data_mat, project = "0", min.cells = 0, min.features = 0)
    ''')   
    ## preprocessing
    print('\tDoublet Detection with SCT normalization...')
    print('\t\tPreprocessing...')
    ro.r('''
    seurat <- SCTransform(seurat, verbose = TRUE)
    seurat <- RunPCA(seurat, npcs = 50, verbose = FALSE)
    seurat <- RunUMAP(seurat, reduction = "pca", dims = 1:50)
    seurat <- FindNeighbors(seurat, dims = 1:50, verbose = FALSE)
    seurat <- FindClusters(seurat, verbose = FALSE, resolution = 0.5)
    #print(DimPlot(seurat, label = TRUE))
    
    #Conversion to SingleCellExperiment
    sce <- as.SingleCellExperiment(seurat)
    ''')
    ## run SCDS
    print('\t\tRunning SCDS...')
    ro.r('''
    # SCDS
    sce <- cxds(sce, retRes = TRUE)
    sce <- bcds(sce, retRes = TRUE, verb=TRUE)
    sce <- cxds_bcds_hybrid(sce)
    
    dens <- density(sce$hybrid_score)
    min_idx <- match(-1, extract(turnpoints(dens$y, calc.proba = TRUE)))
    cut_off <- dens$x[min_idx[length(min_idx)]]
    
    #print(ggplot(as.data.frame(colData(sce)), aes(x=hybrid_score)) + geom_density() + geom_vline(xintercept = cut_off, linetype=2))
    
    sce$hybrid_class <- "doublet"
    sce[,sce$hybrid_score < cut_off]$hybrid_class <- "singlet"
    ''')
    
    
    ## get results   
    print('\t\tCollect results...')
    ro.r('''
    results <- cbind(results, colData(sce)[,c("hybrid_class", "hybrid_score")])
    colnames(results) <- c("hybrid_class", "hybrid_score", "hybrid_class_sct", "hybrid_score_sct")
    ''')
    print('\t\tAdd results to anndata...')
    results = ro.globalenv['results']
    
    # check if results are already present in adata.obs and delete
    if 'hybrid_class' in adata.obs.columns:
        del adata.obs[["hybrid_class", "hybrid_score", "hybrid_class_sct", "hybrid_score_sct"]]
    
    adata.obs = pd.merge(adata.obs,results, left_index=True, right_index=True) 
    
    adata.obs.loc[:,'scds_doublets'] = False
    adata.obs.loc[adata.obs.loc[:,'hybrid_class']=='doublet','scds_doublets'] = True
    adata.obs.loc[adata.obs.loc[:,'hybrid_class_sct']=='doublet','scds_doublets'] = True
    
    print('\n\n------------------------------------------------------------------------------------\n------------------------------------------------------------------------------------' )
    print('\nScds doublet rate:', adata.obs['scds_doublets'].value_counts()[1]/adata.obs['sample'].value_counts()[0]*100, '% (',adata.obs['scds_doublets'].value_counts()[1],' cells)' )
    
#     return adata
 

## cell cycle genes

In [None]:
def load_cell_cycle_genes(adata, genome='mus musculus'):
    # Load cell cycle genes

    ## KEGG cell cycle genes
    cc_kegg = pd.read_table('/mnt/hdd/data/KEGG_mmu_Cell_Cycle.txt').iloc[:,0].tolist()

    ## Cell cycle genes Regev lab (Tirosh et al. 2016, DOI: 10.1126/science.aad0501)
    cc_genes_regev = [x.strip() for x in open('/mnt/hdd/data/regev_cell_cycle_genes.txt')]
        
    if genome=='auto':
        genome = '_'.join(adata.var.loc[:,'genome'][0].split('_')[0:2])
    
    print('Genome is', genome)
        
    if (genome == 'Homo_sapiens') | (genome == 'homo_sapiens'):

        s_genes_regev = adata.var_names[np.isin(adata.var_names, cc_genes_regev[:43])]
        g2m_genes_regev = adata.var_names[np.isin(adata.var_names, cc_genes_regev[43:])]

        cc_genes_regev = list(adata.var_names[np.isin(adata.var_names, cc_genes_regev)])

        ## Cell cycle genes Macosko et al. 2015, https://doi.org/10.1016/j.cell.2015.05.002
        cc_genes_macosko = pd.read_table('/mnt/ssd/Resources/Macosko_cell_cycle_genes.txt', delimiter='\t')

        s_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['S'].dropna())])
        g2m_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['G2.M'].dropna())])
        m_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['M'].dropna())])
        mg1_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['M.G1'].dropna())])
        g1s_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['IG1.S'].dropna())])

        cc_genes_macosko = s_genes_macosko + g2m_genes_macosko + m_genes_macosko + mg1_genes_macosko + g1s_genes_macosko

        ## Combine all
        all_cc_genes = list(set(cc_kegg + cc_genes_regev + cc_genes_macosko))
        
        return all_cc_genes, s_genes_regev, g2m_genes_regev, cc_genes_regev, cc_genes_macosko, s_genes_macosko, g2m_genes_macosko, m_genes_macosko, mg1_genes_macosko, g1s_genes_macosko

    elif (genome == 'Mus_musculus') | (genome == 'mus_musculus'):
        
        s_genes_regev = [gene.lower().capitalize() for gene in cc_genes_regev[:43]]
        g2m_genes_regev = [gene.lower().capitalize() for gene in cc_genes_regev[43:]]

        cc_genes_regev = [gene.lower().capitalize() for gene in cc_genes_regev]

        ## Cell cycle genes Macosko et al. 2015, https://doi.org/10.1016/j.cell.2015.05.002
        cc_genes_macosko = pd.read_table('/mnt/hdd/data/Macosko_cell_cycle_genes.txt', delimiter='\t')

        s_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['S'].dropna())]
        g2m_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['G2.M'].dropna())]
        m_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['M'].dropna())]
        mg1_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['M.G1'].dropna())]
        g1s_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['IG1.S'].dropna())]

        cc_genes_macosko = s_genes_macosko + g2m_genes_macosko + m_genes_macosko + mg1_genes_macosko + g1s_genes_macosko

        ## Combine all
        all_cc_genes = list(set(cc_kegg + cc_genes_regev + cc_genes_macosko))
        
        return all_cc_genes, s_genes_regev, g2m_genes_regev, cc_genes_regev, cc_genes_macosko, s_genes_macosko, g2m_genes_macosko, m_genes_macosko, mg1_genes_macosko, g1s_genes_macosko

    elif (genome == 'Sus_scrofa') | (genome == 'sus_scrofa'):
        
        s_genes_regev = mdata.var_names[np.isin(mdata.var_names, cc_genes_regev[:43])]
        g2m_genes_regev = mdata.var_names[np.isin(mdata.var_names, cc_genes_regev[43:])]

        cc_genes_regev = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_regev)])

        ## Cell cycle genes Macosko et al. 2015, https://doi.org/10.1016/j.cell.2015.05.002
        cc_genes_macosko = pd.read_table('/mnt/ssd/Resources/Macosko_cell_cycle_genes.txt', delimiter='\t')

        s_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['S'].dropna())])
        g2m_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['G2.M'].dropna())])
        m_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['M'].dropna())])
        mg1_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['M.G1'].dropna())])
        g1s_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['IG1.S'].dropna())])

        cc_genes_macosko = s_genes_macosko + g2m_genes_macosko + m_genes_macosko + mg1_genes_macosko + g1s_genes_macosko

        ## Combine all
        all_cc_genes = list(set(cc_kegg + cc_genes_regev + cc_genes_macosko))
        
        return all_cc_genes, s_genes_regev, g2m_genes_regev, cc_genes_regev, cc_genes_macosko, s_genes_macosko, g2m_genes_macosko, m_genes_macosko, mg1_genes_macosko, g1s_genes_macosko


## Nomalization

In [None]:
import rpy2
import rpy2.robjects as ro
import gc

def normalise_scran(adata, r = 0.5):
    print('Normalization with Scran:')
    print('\tPreprocess data...\n\t-----------------------------------\n ')
    adata_pp = adata.copy()
    sc.pp.normalize_total(adata_pp)#, exclude_highly_expressed=True) #sc.pp.normalize_per_cell(adata_pp, counts_per_cell_after=1e6)
    sc.pp.log1p(adata_pp)
    sc.pp.pca(adata_pp)
    sc.pp.neighbors(adata_pp)
    sc.tl.leiden(adata_pp, key_added='groups', resolution=r)

    ro.globalenv['data_mat'] = adata_pp.X.T
    ro.globalenv['input_groups'] = adata_pp.obs['groups']

    print('\tCalculate size factors...')
    ro.r('library("scran")')
    # calculate size factors
    ro.r('''
    size_factors = calculateSumFactors(data_mat, clusters=input_groups, min.mean=0.1)
    ''')

    print('\tTransfer data...')
    # add to andata.obs
    adata.obs['size_factors'] = ro.r['size_factors']

    print('\tPlot results...')
    # plot results
    rcParams['figure.figsize']=(5,5)
    sc.pl.scatter(adata, 'size_factors', 'n_counts')
    sc.pl.scatter(adata, 'size_factors', 'n_genes')

    sb.histplot(adata.obs['size_factors'], bins=100, kde=True)
    plt.show()

    print('\tAdd results to anndata...')
    #Keep the count data in a counts layer
    if not 'raw_counts' in adata.layers.keys():
        adata.layers['raw_counts'] = adata.X.copy()

    #Logarithmize raw counts
    if not 'log_raw_counts' in adata.layers.keys():
        adata.layers['log_raw_counts'] = sc.pp.log1p(adata.layers['raw_counts'], copy=True)

    #Normalize adata 
    adata.X /= adata.obs['size_factors'].values[:,None]
    sc.pp.log1p(adata)

    #Keep the normalized count data in a counts layer
    adata.layers['scran_counts'] = adata.X.copy()

    # delete
    print('\tClean up...')
    del adata_pp
    gc.collect()

In [None]:
def normalize_sct(adata, layer=None, results_to_X=None, min_cells=None, n_core=20, max_memory=64):
    '''
    adata: adata object to normalize
    layer: layer to use for normalization. Default = None -> use .X
    results_to_X: Set results layer to adata.X (e.g. 'sct_logcounts')
    '''
    
    import rpy2
    import rpy2.robjects as ro
    import gc

       
    print('Normalization with SCT:')
    # load packages
    ro.globalenv['n_core'] = n_core
    ro.globalenv['max_memory'] = max_memory
    ro.r('''
    # Packages
    library(Seurat)
    library(sctransform)
    library(SingleCellExperiment)

    # Parallelization
    library(BiocParallel)
    register(MulticoreParam(n_core, progressbar = TRUE))

    library(future)
    plan(multisession, workers = n_core) #change from multicore hoping it would work better
    options(future.globals.maxSize = max_memory * 1024^3)
    options(future.globals.onReference = "error")

    library(doParallel)
    registerDoParallel(n_core)
    ''')
    # transfer data
    print('\tTransfer data...')
    if layer is None:
        ro.globalenv['data_mat'] = adata.X.T#.toarray()
        ro.globalenv['obs_names'] = adata.obs_names
        ro.globalenv['var_names'] = adata.var_names
    else:
        print('\tNormalizing layer \'', layer,'\'...')
        ro.globalenv['data_mat'] = adata.layers[layer].T#.toarray()
        ro.globalenv['obs_names'] = adata.obs_names
        ro.globalenv['var_names'] = adata.var_names
        
    ro.r('''
    rownames(data_mat) <- var_names
    colnames(data_mat) <- obs_names
    seurat <- CreateSeuratObject(counts = data_mat, project = "0", min.cells = 0, min.features = 0)
    gc()
    ''')
    gc.collect() 
    # perform sct
    print('\tPerform SCT...')
    ro.r('''
    # SCTransform
    seurat <- SCTransform(seurat, verbose = FALSE, return.only.var.genes = FALSE, variable.features.n = NULL, vst.flavor = "v2")
    gc()
    ''')
    gc.collect() 
    # convert to singleCellExperiment
    print('\tConvert data...')
    ro.r('''
    # Add feature meta data (since Seurat v4 -> will be fixed?)
    var <- c('detection_rate','gmean', 'variance', 'residual_variance')
    seurat[["SCT"]]@meta.features <- SCTResults(seurat[["SCT"]], slot = "feature.attributes")[, var]
    seurat[["SCT"]]@meta.features$variable <- FALSE
    seurat[["SCT"]]@meta.features[VariableFeatures(seurat[["SCT"]] ), "variable"] <- TRUE
    colnames(seurat[["SCT"]]@meta.features) <- paste0("sct.", colnames(seurat[["SCT"]]@meta.features) )

    # Convert to SingleCellExperiment
    sce <- as.SingleCellExperiment(seurat)

    # Add feature meta data (since Seurat v4 -> will be fixed?)
    rowData(sce) <- seurat[["SCT"]]@meta.features

    # Rename and add layers
    SummarizedExperiment::assay(sce, i = 1) <- seurat[["SCT"]]@counts
    SummarizedExperiment::assay(sce, i = 2) <- seurat[["SCT"]]@data
    SummarizedExperiment::assay(sce, i = 3) <- seurat[["SCT"]]@scale.data
    #SummarizedExperiment::assay(sce, i = 4) <- seurat[["RNA"]]@counts
    SummarizedExperiment::assayNames(sce) <- c("sct_counts", "sct_logcounts", "sct_scale_data")#, "raw_counts")
    gc()
    ''')
    
    # transfer data
    print('\tTransfer data...')
    
    # add to andata.obs
    adata_sct = ro.globalenv['sce']
    adata_sct.layers['sct_counts'] = adata_sct.X.copy()
    
    gc.collect()
    
    # Harmonize var_names
    ## Remove underscores
    adata.var_names = ['-'.join(var_name.split('_')) for var_name in adata.var_names]
    var_adata = set(adata.var_names)
    var_sct = set(adata_sct.var_names)
    var_intersect = list(var_adata.intersection(var_sct))
    # Subset adata
    adata = adata[:,var_intersect]
    adata_sct = adata_sct[:,var_intersect]
    
    # Add SCT data
    print('\tAdd results to anndata...')
    adata.layers['sct_counts'] = adata_sct.layers['sct_counts'].copy()
    adata.layers['sct_logcounts'] = adata_sct.layers['sct_logcounts'].copy()
    adata.layers['sct_scale_data'] = adata_sct.layers['sct_scale_data'].copy()
    adata.var[['sct.detection_rate', 'sct.gmean', 'sct.variance', 'sct.residual_variance', 'sct.variable']] = adata_sct.var[['sct.detection_rate', 'sct.gmean', 'sct.variance', 'sct.residual_variance', 'sct.variable']].copy()

    if results_to_X is not None:
        print('\tSet',results_to_X,' anndata.X...')
        adata.X = adata.layers[results_to_X].copy()
        
    # Set HVGs
    print('\tSet HVGs...')
    adata.var.loc[:,'highly_variable'] = [bool(i) for i in adata_sct.var['sct.variable']]
    #hvgs = pd.Series(adata.var['sct.variable'][adata.var['sct.variable'] > 0].index) # use HVGs from sct
    #adata.var['highly_variable']= False
    #adata.var.loc[hvgs,'highly_variable'] = True
    
    if min_cells is not None:
        # Filter genes: Min 20 cells - filters out 0 count genes
        print('\tFilter genes...')
        sc.pp.filter_genes(adata, min_cells=min_cells)
    
    # delete
    ro.r('''
    rm(list = ls())
    gc()
    ''')
      
    del adata_sct
    gc.collect()
    
    return adata


In [None]:
def import_sct(adata, sct_results_path=None, keep_raw_qc=False, set_adata_raw=True, min_cells=None):
    
    import rpy2
    import rpy2.robjects as ro
    import gc
    
    if sct_results_path is None:
        Print('Specify \'sct_results_path\'...')
        return
    
    print('Importing SCT results from',sct_results_path)
    print('\tReading file...')
    ro.globalenv['sct_results_path'] = sct_results_path
    ro.r('library(SingleCellExperiment)')
    ro.r('sct <- readRDS(sct_results_path)')
    
    print('\tTransfer data...')
    adata_sct = ro.r['sct']
    
    adata_sct.layers['sct_counts'] = adata_sct.X.copy()
    
    #return adata_sct
    
    print('\tAdd SCT results to AnnData...')
    # Harmonize var_names
    ## Remove underscores
    adata.var_names = ['-'.join(var_name.split('_')) for var_name in adata.var_names]
    var_adata = set(adata.var_names)
    var_sct = set(adata_sct.var_names)
    var_intersect = list(var_adata.intersection(var_sct))
    # Subset adata
    adata = adata[:,var_intersect]
    adata_sct = adata_sct[:,var_intersect]
    # Add SCT data
    adata.layers['sct_counts'] = adata_sct.layers['sct_counts']
    adata.layers['sct_logcounts'] = adata_sct.layers['sct_logcounts']
    adata.layers['sct_scale_data'] = adata_sct.layers['sct_scale_data']
    adata.var[['sct.detection_rate', 'sct.gmean', 'sct.variance', 'sct.residual_variance', 'sct.variable']] = adata_sct.var[['sct.detection_rate', 'sct.gmean', 'sct.variance', 'sct.residual_variance', 'sct.variable']]
    
    # Put X in a layer to keep it after merging
    adata_sct.layers['sct_counts'] = adata_sct.X.copy()
    
    if keep_raw_qc:
        print('\tSave raw QC metrics...')
        # Keep raw QC metrics & counts
        adata.obs['mt_frac_raw'] = adata.obs['mt_frac']
        adata.obs['rp_frac_raw'] = adata.obs['rp_frac']
        adata.obs['n_genes_raw'] = adata.obs['n_genes']
        adata.obs['log_genes_raw'] = adata.obs['log_genes']
        adata.obs['n_counts_raw'] = adata.obs['n_counts']
        adata.obs['log_counts_raw'] = adata.obs['log_counts']
        
    if 'raw_counts' not in list(adata.layers):
        print('\tSave AnnData.X to AnnData.layers[\'raw_counts\']...')
        adata.layers['raw_counts'] = adata.X

    print('\tRecalculate QC metrics...')
    # Set normalized counts as X for QC metrics
    adata.X = adata.layers['sct_counts']
    qc_metrics(adata, ambient=False, make_dense=True)
    
    print('\tSet SCT log counts as AnnData.X...')
    # Set log-normalized counts as X
    adata.X = adata.layers['sct_logcounts'].copy()
    
    print('\tSet highly variable genes from SCT...')
    # Set HVGs from SCT
    hvgs = pd.Series(adata.var['sct.variable'][adata.var['sct.variable'] > 0].index) # use HVGs from sct
    adata.var['highly_variable']= False
    adata.var.loc[hvgs,'highly_variable'] = True
    print('\n','\tNumber of highly variable genes: {:d}'.format(np.sum(adata.var['highly_variable'])))
    
    if set_adata_raw:
        print('\tStore full AnnData in AnnData.raw...')
        # Store the full data set in 'raw' as log-normalized data for statistical testing
        adata.raw = adata
    
    if min_cells is not None:
        print('\tFilter genes detected in less than',min_cells,'cells...')
        # Filter genes: Min 20 cells - filters out 0 count genes
        sc.pp.filter_genes(adata, min_cells=min_cells)
        print('Number of genes after filter: {:d}'.format(adata.n_vars))


In [None]:
def normalize_tfidf(atac, hvg=False, hvg_min_mean=0.05, hvg_max_mean=1.5, hvg_min_disp=0.5, remove_1st_lsi=True):
    
    from muon import atac as ac
    print('Normalization with SCT:')
    
    print('\tSave raw counts to .layers[\'atac_raw_counts\']...')
    # Save original counts
    if 'atac_raw_counts' not in list(atac.layers):
        print('\tSave AnnData.X to AnnData.layers[\'atac_raw_counts\']...')
        atac.layers['atac_raw_counts'] = atac.X
    
    # TF-IDF normalization
    print('\tTF-IDF normalization...')
    ac.pp.tfidf(atac, scale_factor=1e4, log_tf=False, log_idf=False, log_tfidf=True)
    
    if hvg:
        # Feature selection
        sc.pp.highly_variable_genes(atac, min_mean=hvg_min_mean, max_mean=hvg_max_mean, min_disp=hvg_min_disp)
        sc.pl.highly_variable_genes(atac)
        print('\t\tNumber of variable features: ', np.sum(atac.var.highly_variable))
    
    # Save to .raw
    print('\tSave to .raw...')
    atac.raw = atac
    
    # LSI
    print('\tLSI...')
    ac.tl.lsi(atac)
    
    if remove_1st_lsi:
        # 1st dimension is often associated with number peaks/counts and should be removed
        
        # plot 1st lsi against counts/peaks
        lims_x = []
        lims_y = []
        lims_line = []

        fig, axs = plt.subplots(1, 2, constrained_layout=True, figsize=(8, 4))
        # Plots
        axs[0].scatter(atac.obsm['X_lsi'][:,0], y=atac.obs['log_counts_ATAC'], s=2, alpha=0.2, c=atac.obs['n_peaks_ATAC'], cmap='rocket')
        axs[1].scatter(atac.obsm['X_lsi'][:,0], y=atac.obs['log_peaks_ATAC'], s=2, alpha=0.2, c=atac.obs['n_counts_ATAC'], cmap='rocket')

        # Aesthetics
        for i,ax in enumerate(axs):
            lims_x.append(ax.get_xlim())
            lims_y.append(ax.get_ylim())

        axs[0].set_xlabel('LSI Dim. 1')
        axs[0].set_ylabel('Counts')
        axs[0].set_xlim(lims_x[0])
        axs[0].set_ylim(lims_y[0])

        axs[1].set_xlabel('LSI Dim. 1')
        axs[1].set_ylabel('Peaks')
        axs[1].set_xlim(lims_x[1])
        axs[1].set_ylim(lims_y[1])
        
        # remove 1st component
        atac.obsm['X_lsi'] = atac.obsm['X_lsi'][:,1:]
        atac.varm["LSI"] = atac.varm["LSI"][:,1:]
        atac.uns["lsi"]["stdev"] = atac.uns["lsi"]["stdev"][1:]
