# Downstream analyses healthy

Hélène Reich 05/08/2024

- EEC subclusters and DGE
- ISC,TA and Progenitors subclusters and DGE
- Paneth, Paneth-Goblet and Progenitors subclusters and DGE

In [None]:
# General
import scipy as sci
import numpy as np
import pandas as pd
import logging
import time
import pickle
from itertools import chain
import h5py
import scipy.sparse as sparse
import anndata as ad
import gc
import scipy.stats as stats
import torch

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as mcolors
from matplotlib import rcParams
from matplotlib.pyplot import rc_context
from matplotlib import cm
import seaborn as sb

# Analysis
import scanpy as sc
import scanpy.external as sce
import scvi


In [None]:
# Warnings
import warnings
warnings.filterwarnings('ignore') #(action='once')

## setup matplotlib

In [None]:
# Settings

## Directory
base_dir = '/mnt/hdd/Notebooks/Gut_project/'
sc.settings.figdir = base_dir + 'Figures'
sc.settings.cachedir = base_dir + 'Cache'

## Scanpy settings
sc.settings.verbosity = 3
sc.logging.print_header()
sc.logging.print_versions()

import warnings
warnings.filterwarnings('ignore')

In [None]:
%run utils.ipynb

In [None]:
mymap = load_RdOrYl_cmap_settings(transparent=False)

## Functions

In [None]:
def get_diff_exprs_DElegate(
    adata=None, 
    groupby=None, # groups/condtions to test (e.g stage, genotype, ...)
    groups_restrict=None, #restrict test to gives cell type or cluster
    restrict_to=None, #identity of cell type the should be restricted to. e.g Beta
    layer='raw_counts',
#     group_column = None, 
#     replicate_column = None, 
    method = "edger", 
    filter_ambient_genes=False, 
    rank_genes_groups_key=None, # rank genes group key with markers for groups_restrict
    get_marker=False, # run to rank_genes_groups to identify markers
    min_gene_score=0, # min score a cluster marker should have to be cluster-specific
    min_cluster_size = 100, 
    min_frac_cells = 0.05,
    sample_key=None, # key for samples/replicates
    #additional_variables=[],  # which metadata to keep, e.g. gender, age, etc.
    #replicates_per_sample=3, # number of pseudoreplicates/sample
    #min_cell_per_sample=30,
    #aggr_method='sum',
    plot=True,
    return_results='dict' # or 'top_table'
):
    # copy adata
    adata_temp = adata.copy()
    
#     # set selected layer to .X
#     if layer is not None:
#         adata_temp.X = adata_temp.layers[layer].copy()

    # create results dict and add parametes
    results = dict()
    results['method'] = 'DElegate_pseudobulk_' + method
    results['groupby'] = groupby
    results['groupby_categories'] = []
    results['groups_restrict'] = groups_restrict
    if groups_restrict is not None:
        results['groups_restrict_categories'] = list(adata_temp.obs[groups_restrict].cat.categories)
    if (groups_restrict is not None) & (restrict_to is not None):
        results['restrict_to'] = restrict_to
    else:
        results['restrict_to'] = ''
    results['layer'] = layer
    results['min_cluster_size'] = min_cluster_size
    results['min_frac_cells'] = min_frac_cells
    
    # check if cluster of interest (restrict_to) has enough cells
    if groups_restrict is not None:
        if adata_temp.obs[groups_restrict].value_counts()[restrict_to] < min_cluster_size:
            #print('Group has less than ' + str(min_cluster_size) + ' cells.')
            raise ValueError('Group has less than ' + str(min_cluster_size) + ' cells.') 
    
    # check if key for rank genes groups for the group containing the cluster of interest (groups_restrict) is provided -> rank_genes_group if not
    if (rank_genes_groups_key == None) & (get_marker):
        sc.tl.rank_genes_groups(adata_temp, groupby=groups_restrict)
        rank_genes_groups_key = 'rank_genes_groups'
    
    # subset adata to group provided in restrict_to
    if restrict_to is None:
        adata_temp_test = adata_temp.copy()
    else:
        adata_temp_test = adata_temp[adata_temp.obs[groups_restrict].isin([restrict_to])].copy()
    
    groupby_categories = list(adata_temp_test.obs[groupby].cat.categories)
    results['groupby_categories'] = groupby_categories
    
    groupby_colors = list(adata_temp_test.uns[groupby + '_colors'])
    results['groupby_colors'] = groupby_colors
    
    # filter genes expressed in few cells
    sc.pp.filter_genes(adata_temp_test, min_cells=adata_temp_test.shape[0]*min_frac_cells)
    
    # filter ambient genes
    if filter_ambient_genes:
        if rank_genes_groups_key == None:
            ambi_genes_remove = list(adata_temp.var_names[adata_temp.var_names.isin(list(adata_temp[:,adata_temp.var['is_ambient'] == True].var_names))])
            adata_temp_test = adata_temp_test[:,~adata_temp_test.var_names.isin(ambi_genes_remove)]
            print('\nRemoving ambient genes from analysis: ', ambi_genes_remove)
            results['ambient_genes_removed'] = ambi_genes_remove
        else:
            ambi_genes = list(adata_temp.var_names[adata_temp.var_names.isin(list(adata_temp[:,adata_temp.var['is_ambient'] == True].var_names))])
            marker_genes = list(adata_temp.uns[rank_genes_groups_key]['names'][restrict_to][adata_temp.uns[rank_genes_groups_key]['scores'][restrict_to] > min_gene_score])
            ambi_genes_remove = list(set(ambi_genes).difference(set(marker_genes)))
            adata_temp_test = adata_temp_test[:,~adata_temp_test.var_names.isin(ambi_genes_remove)]
            print('\nRemoving ambient genes from analysis: ', ambi_genes_remove)
            print('\nKeeping group-specific ambient genes: ', set(ambi_genes).difference(set(ambi_genes_remove)),'\n')
            results['ambient_genes_removed'] = ambi_genes_remove
            results['ambient_genes_kept'] = list(set(ambi_genes).difference(set(ambi_genes_remove)))
    
    results['background_genes'] = list(adata_temp_test.var_names)
    
    results['n_genes'] = adata_temp_test.shape[1]
    results['n_cells'] = adata_temp_test.shape[0]
    
    # run edgeR
    print('\nRunning DElegate...')
    top_table = run_DElegate_findDE(adata_temp_test, 
                                    layer = layer, 
                                    group_column=groupby, 
                                    replicate_column=sample_key, 
                                    compare=[groupby_categories[0], groupby_categories[1]], 
                                    method = "edger", 
                                    order_results = True, 
                                    verbosity = 1, 
                                    n_core = 64, 
                                    max_memory = 4)
    
    if return_results == 'dict':
        # convert results
        print('\nConverting results...')
        results = DElegate_to_results(top_table, 
                                   results_dict=results,
                                   ident_1=groupby_categories[0],
                                   ident_2=groupby_categories[1],
                                   ident_1_color=groupby_colors[0],
                                   ident_2_color=groupby_colors[1],
                                   plot=plot,
                                   plot_logfc_limit = 10,
                                   log_pvals_adj_limit = 300,
                                   z_logfc_cut_off=0.5,
                                   z_pval_cut_off=0.25)
    
    del adata_temp
    del adata_temp_test
    
    gc.collect()
    
    if return_results == 'dict':
        return results
    elif return_results == 'top_table':
        return top_table
    
    




##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################


    
    
    
def DElegate_to_results(results_table, 
                     results_dict=dict(),
                     ident_1=None,
                     ident_2=None,
                     ident_1_color='#1f77b4',
                     ident_2_color='#ff7f0e',
                     plot=True,
                     plot_logfc_limit = 10,
                     log_pvals_adj_limit = 300,
                     z_logfc_cut_off=0.5,
                     z_pval_cut_off=0.25
):
    results=results_dict
    names=list(results_table['feature'])
    logfc=np.array(results_table['log_fc'], dtype='float64')
    logexprs=np.array(results_table['ave_expr'], dtype='float64')
    pvals_adj=np.array(results_table['padj'], dtype='float64')
    log_pvals_adj = -np.log10(pvals_adj)
    log_pvals_adj[log_pvals_adj > log_pvals_adj_limit] = log_pvals_adj_limit
    logfc_limit = logfc.copy()
    logfc_limit[logfc_limit > plot_logfc_limit] = plot_logfc_limit
    logfc_limit[logfc_limit < -plot_logfc_limit] = -plot_logfc_limit

    table={'names': names, 'logfc': logfc, 'logexprs': logexprs, 'pvals_adj': pvals_adj, 'log_pvals_adj': log_pvals_adj, 'logfc_limit': logfc_limit}
    table = pd.DataFrame(data=table)
    table = table.sort_values(by=['pvals_adj'], ascending=True)
    table = table.sort_values(by=['logfc'], ascending=True)
    results[ident_1] = table #.loc[(abs(table['logfc']) >= min_logfc) & (table['pvals_adj'] <= max_pval),:]

    table={'names': names, 'logfc': -logfc, 'logexprs': logexprs, 'pvals_adj': pvals_adj, 'log_pvals_adj': log_pvals_adj, 'logfc_limit': -logfc_limit}
    table = pd.DataFrame(data=table)
    table = table.sort_values(by=['pvals_adj'], ascending=True)
    table = table.sort_values(by=['logfc'], ascending=True)
    results[ident_2] = table

    # find cut offs
    # To DO:
    # * avoid error when cut-off cannot be found. e.g. all p-val == 1. 
    # * set pval_cut_off to 0.05 if larger cut off is found  
    try:
        logfc_cut_off = round(min(abs(results[ident_1]['logfc'])[stats.zscore(abs(results[ident_1]['logfc'])) > z_logfc_cut_off]),1) 
    except:
        logfc_cut_off = 0.5
        
    try:
        pval_cut_off = round(min(results[ident_1]['log_pvals_adj'][stats.zscore(results[ident_1]['log_pvals_adj']) > z_pval_cut_off]),0)
    except:
        pval_cut_off = -np.log10(0.05)
        
    if pval_cut_off < -np.log10(0.05):
        pval_cut_off = -np.log10(0.05)
        
    #logfc_cut_off = round(min(abs(results[ident_1]['logfc'])[stats.zscore(abs(results[ident_1]['logfc'])) > z_logfc_cut_off]),1) 
    #pval_cut_off = round(min(results[ident_1]['log_pvals_adj'][stats.zscore(results[ident_1]['log_pvals_adj']) > z_pval_cut_off]),0)

    results['logfc_cut_off'] = logfc_cut_off
    results['pval_cut_off'] = pval_cut_off

    if plot:

        n_diff_logfc = sum(abs(results[ident_1]['logfc']) > logfc_cut_off)
        n_up_logfc = sum(results[ident_1]['logfc'] > logfc_cut_off)
        n_down_logfc = sum(results[ident_1]['logfc'] < -logfc_cut_off)

        with rc_context({'figure.figsize': (8, 2)}):
            sb.distplot(results[ident_1]['logfc'], kde=True, bins=100).set_xlabel('$log_2$ Fold Change')
            plt.axvline(logfc_cut_off, 0, 1)
            plt.axvline(-logfc_cut_off, 0, 1)
            plt.annotate('Down-regulated\n' + str(n_down_logfc), xy=(0.02, 0.92), xycoords='axes fraction', va="top", ha="left")
            plt.annotate('Up-regulated\n' + str(n_up_logfc), xy=(0.98, 0.92), xycoords='axes fraction', va="top", ha="right")
            plt.title(label='$log_2$ Fold Change (' + str(n_diff_logfc) + ' genes passing threshold of ' + str(logfc_cut_off) + ')', fontweight='bold')
            plt.show()

        #############################################################################################################
        #############################################################################################################

        n_diff_pval = sum(abs(results[ident_1]['log_pvals_adj']) > pval_cut_off)
        n_up_pval = sum(results[ident_1]['log_pvals_adj'] > pval_cut_off)
        n_down_pval = sum(results[ident_1]['log_pvals_adj'] < -pval_cut_off)

        with rc_context({'figure.figsize': (8, 2)}):
            sb.distplot(results[ident_1]['log_pvals_adj'], kde=True, bins=100).set_xlabel('$-log_{10}$ Adjusted p-Value')
            plt.axvline(pval_cut_off, 0, 1)
            plt.title(label='$-log_{10}$ Adjusted p-Value (' + str(n_diff_pval) + ' genes passing threshold of ' + str(pval_cut_off) + ')', fontweight='bold')
            plt.show()

        #############################################################################################################
        #############################################################################################################

        min_logfc = logfc_cut_off
        max_pval = 10**-pval_cut_off
        group_order = (0,1)
        y_max_ext_factor=1.13
        x_ext_factor=0.3
        x_max_ext_factor=1.1
        x_min_ext_factor=1.1
        fig_size=(7,6)
                
        results[ident_1].loc[:,'color'] = '#000000'
        results[ident_1].loc[(results[ident_1]['logfc'] > 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_1_color
        results[ident_1].loc[(results[ident_1]['logfc'] < 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_2_color

        n_diff = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (abs(results[ident_1]['logfc']) > logfc_cut_off))
        n_up = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] > logfc_cut_off))
        n_down = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] < -logfc_cut_off))

        fig, ax = plt.subplots(1, figsize=fig_size)


        # Make x & y axis longer to make gene name plotting easier
        y_max = max(results[ident_1]['log_pvals_adj'])*y_max_ext_factor
        x_ext = (max(results[ident_1]['logfc_limit']) - min(results[ident_1]['logfc_limit'])) * x_ext_factor
        x_max = max(results[ident_1]['logfc_limit']) + x_ext
        x_min = min(results[ident_1]['logfc_limit']) - x_ext

        x_max = max(results[ident_1]['logfc_limit']) * x_max_ext_factor
        x_min = min(results[ident_1]['logfc_limit']) * x_min_ext_factor

        ax.set_ylim((-1,y_max))
        ax.set_xlim((x_min,x_max))

        # normalize colormap
        vcenter = 0
        vmin, vmax = results[ident_1]['log_pvals_adj'].min(), results[ident_1]['log_pvals_adj'].max()
        #normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
        normalize = plt.Normalize(vmin, vmax)
        colormap = cm.RdBu_r

        # Scatter plot
        sb.scatterplot(y='log_pvals_adj', x='logfc_limit',
                        color='#000000',s=20,
                        linewidth=0,
                        data=results[ident_1])
        sb.scatterplot(y='log_pvals_adj', x='logfc_limit',
                        color='#cccccc',s=10,
                        linewidth=0,
                        data=results[ident_1], ax=ax)
        y = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'log_pvals_adj']
        x = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'logfc_limit']
        c = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color']
        
        sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)
        sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)

        # annotation
        ax.annotate('Down-regulated\n' + str(n_down), xy=(0.02, 0.98), xycoords='axes fraction', va="top", ha="left")
        ax.annotate('Up-regulated\n' + str(n_up), xy=(0.98, 0.98), xycoords='axes fraction', va="top", ha="right")
        ax.annotate(str(ident_2), xy=(0.02, 0.02), xycoords='axes fraction', va="bottom", ha="left")
        ax.annotate(str(ident_1), xy=(0.98, 0.02), xycoords='axes fraction', va="bottom", ha="right")

        # Lines
        plt.axvline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        plt.axvline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")


        # title & axis labels
        title = 'Combined p-value & fold change threshold\n('  + str(n_diff) + ' genes passing thresholds of ' + str(logfc_cut_off) + ' and ' + str(pval_cut_off) + ')'
        ax.set_title(title, fontweight='bold')
        ax.set_ylabel('$-log_{10}$ Adjusted p-Value')
        ax.set_xlabel('$log_2$ Fold Change')

        plt.show()

        

        #############################################################################################################
        #############################################################################################################

        min_logfc = logfc_cut_off
        max_pval = 10**-pval_cut_off
        group_order = (0,1)
        y_max_ext_factor=1.13
        x_ext_factor=0.3
        x_max_ext_factor=1.1
        x_min_ext_factor=1.1
        fig_size=(7,6)
                
        results[ident_1].loc[:,'color'] = '#000000'
        results[ident_1].loc[(results[ident_1]['logfc'] > 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_1_color
        results[ident_1].loc[(results[ident_1]['logfc'] < 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_2_color

        n_diff = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (abs(results[ident_1]['logfc']) > logfc_cut_off))
        n_up = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] > logfc_cut_off))
        n_down = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] < -logfc_cut_off))

        fig, ax = plt.subplots(1, figsize=fig_size)


        # Make x & y axis longer to make gene name plotting easier
        y_max = max(results[ident_1]['log_pvals_adj'])*y_max_ext_factor
#         x_ext = (max(results[ident_1]['logexprs']) - min(results[ident_1]['logexprs'])) * x_ext_factor
#         x_max = max(results[ident_1]['logexprs']) + x_ext
#         x_min = min(results[ident_1]['logexprs']) - x_ext

#         x_max = max(results[ident_1]['logexprs']) * x_max_ext_factor
#         x_min = min(results[ident_1]['logexprs']) * x_min_ext_factor

        ax.set_ylim((-1,y_max))
#         ax.set_xlim((x_min,x_max))

        # normalize colormap
        vcenter = 0
        vmin, vmax = results[ident_1]['log_pvals_adj'].min(), results[ident_1]['log_pvals_adj'].max()
        #normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
        normalize = plt.Normalize(vmin, vmax)
        colormap = cm.RdBu_r

        # Scatter plot
        sb.scatterplot(y='log_pvals_adj', x='logexprs',
                        color='#000000',s=20,
                        linewidth=0,
                        data=results[ident_1])
        sb.scatterplot(y='log_pvals_adj', x='logexprs',
                        color='#cccccc',s=10,
                        linewidth=0,
                        data=results[ident_1], ax=ax)
        y = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'log_pvals_adj']
        x = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'logexprs']
        c = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color']
        
        sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)
        sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)

        # annotation
        #ax.annotate('Down-regulated\n' + str(n_down), xy=(0.02, 0.98), xycoords='axes fraction', va="top", ha="left")
        ax.annotate('Up-regulated\n' + str(n_up) + '\nDown-regulated\n' + str(n_down), xy=(0.98, 0.98), xycoords='axes fraction', va="top", ha="right")
        #ax.annotate(str(ident_2), xy=(0.02, 0.02), xycoords='axes fraction', va="bottom", ha="left")
        #ax.annotate(str(ident_1), xy=(0.98, 0.02), xycoords='axes fraction', va="bottom", ha="right")

        # Lines
        #plt.axvline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        #plt.axvline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")


        # title & axis labels
        title = 'Combined p-value & fold change threshold\n('  + str(n_diff) + ' genes passing thresholds of ' + str(logfc_cut_off) + ' and ' + str(pval_cut_off) + ')'
        ax.set_title(title, fontweight='bold')
        ax.set_ylabel('$-log_{10}$ Adjusted p-Value')
        ax.set_xlabel('$log_2$ Expression')

        plt.show()
        

        #############################################################################################################
        #############################################################################################################

        min_logfc = logfc_cut_off
        max_pval = 10**-pval_cut_off
        group_order = (0,1)
        x_max_ext_factor=1.13
        y_ext_factor=0.3
        y_max_ext_factor=1.1
        y_min_ext_factor=1.1
        fig_size=(7,6)
                
        results[ident_1].loc[:,'color'] = '#000000'
        results[ident_1].loc[(results[ident_1]['logfc'] > 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_1_color
        results[ident_1].loc[(results[ident_1]['logfc'] < 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_2_color

        n_diff = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (abs(results[ident_1]['logfc']) > logfc_cut_off))
        n_up = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] > logfc_cut_off))
        n_down = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] < -logfc_cut_off))

        fig, ax = plt.subplots(1, figsize=fig_size)


        # Make x & y axis longer to make gene name plotting easier
        x_max = max(results[ident_1]['logexprs'])*x_max_ext_factor
        y_ext = (max(results[ident_1]['logfc_limit']) - min(results[ident_1]['logfc_limit'])) * y_ext_factor
        y_max = max(results[ident_1]['logfc_limit']) + y_ext
        y_min = min(results[ident_1]['logfc_limit']) - y_ext

        y_max = max(results[ident_1]['logfc_limit']) * y_max_ext_factor
        y_min = min(results[ident_1]['logfc_limit']) * y_min_ext_factor

#         ax.set_xlim((-1,x_max))
        ax.set_ylim((y_min,y_max))

        # normalize colormap
        vcenter = 0
        vmin, vmax = results[ident_1]['logfc_limit'].min(), results[ident_1]['logfc_limit'].max()
        #normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
        normalize = plt.Normalize(vmin, vmax)
        colormap = cm.RdBu_r

        # Scatter plot
        sb.scatterplot(y='logfc_limit', x='logexprs',
                        color='#000000',s=20,
                        linewidth=0,
                        data=results[ident_1])
        sb.scatterplot(y='logfc_limit', x='logexprs',
                        color='#cccccc',s=10,
                        linewidth=0,
                        data=results[ident_1], ax=ax)
        y = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'logfc_limit']
        x = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'logexprs']
        c = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color']
        
        sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)
        sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)

        # annotation
        ax.annotate(str(ident_1), xy=(0.02, 0.98), xycoords='axes fraction', va="top", ha="left")
        ax.annotate('Up-regulated\n' + str(n_up), xy=(0.98, 0.98), xycoords='axes fraction', va="top", ha="right")
        ax.annotate(str(ident_2), xy=(0.02, 0.02), xycoords='axes fraction', va="bottom", ha="left")
        ax.annotate('Down-regulated\n' + str(n_down), xy=(0.98, 0.02), xycoords='axes fraction', va="bottom", ha="right")

        # Lines
        plt.axhline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        plt.axhline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        #plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")


        # title & axis labels
        title = 'Combined p-value & fold change threshold\n('  + str(n_diff) + ' genes passing thresholds of ' + str(logfc_cut_off) + ' and ' + str(pval_cut_off) + ')'
        ax.set_title(title, fontweight='bold')
        ax.set_ylabel('$log_2$ Fold Change')
        ax.set_xlabel('$log_2$ Expression')

        plt.show()
        
    return results
    




##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################


def dot_plot_DElegate(
    adata,
    results_dict = None,
    keys = None,
    layer = 'sct_logcounts',
    cmap='RdBu_r'
):
    results = results_dict[results_dict['groupby_categories'][0]]

    # copy adata
    adata_temp = adata.copy()

    # set selected layer to .X
    if layer is not None:
        adata_temp.X = adata_temp.layers[layer].copy()

    # subset adata to group provided in restrict_to
    restrict_to = results_dict['restrict_to']
    groups_restrict = results_dict['groups_restrict']

    if restrict_to == None:
        adata_temp_test = adata_temp.copy()
    else:
        adata_temp_test = adata_temp[adata_temp.obs[groups_restrict].isin([restrict_to])].copy()

    # filter genes expressed in few cells
    adata_temp_test = adata_temp_test[:,results['names']]

    # filter keys
    de_genes = results['names'][(abs(results['logfc']) >= results_dict['logfc_cut_off']) & (results['log_pvals_adj'] >= results_dict['pval_cut_off'])]
    keys = [key for key in keys if key in list(de_genes)]

    # plot data
    ## var group pos
    if len(keys) > 0:
            var_group_positions=[(0,results[(results.names.isin(keys)) & (results['logfc'] < 0)].shape[0]-1),(results[(results.names.isin(keys)) & (results['logfc'] < 0)].shape[0],results[(results.names.isin(keys))].shape[0]-1)]
    
    ## colors
    index = pd.Index(results_dict['groupby_categories'], name='groupby')
    color_df = pd.DataFrame([results['logfc'],
                             -results['logfc']],
                           index=index).T
    color_df.index = results['names']
    color_df = color_df.T
    color_df = color_df.loc[:,keys]
    limit = abs(color_df).max().max()
    
    ## plot
    sc.pl.DotPlot(adata_temp_test, 
                  var_names=keys, 
                  groupby=results_dict['groupby'], 
                  dot_color_df=color_df, 
                  var_group_positions=var_group_positions, 
                  var_group_labels=results_dict['groupby_categories'][::-1],
                  vmin=-limit, 
                  vmax=limit, 
                  cmap=cmap).style(color_on='square', 
                                   dot_edge_lw=1, 
                                   grid=True, 
                                   dot_edge_color=None).legend(colorbar_title='log$_2$ Fold Change').show()

    del adata_temp_test
    del results
    gc.collect()


##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################


def run_DElegate_findDE(adata, 
                        layer = None, 
                        group_column = None, 
                        replicate_column = None, 
                        compare = "each_vs_rest", 
                        method = "edger", 
                        order_results = True, 
                        verbosity = 1, 
                        n_core = 64, 
                        max_memory = 4):
    '''
    adata: adata object to normalize
    layer: layer to use for normalization. Default = None -> use .X
    
    There are multiple ways the group comparisons can be specified based on the compare parameter. 
    The default, 'each_vs_rest', does multiple comparisons, one per group vs all remaining cells. 
    'all_vs_all', also does multiple comparisons, covering all group pairs. 
    If compare is set to a length two character vector, e.g. c('T-cells', 'B-cells'), one comparison between those two groups is done. 
    To put multiple groups on either side of a single comparison, use a list of length two. E.g. compare = list(c('cluster1', 'cluster5'), c('cluster3')).
    '''
    
    import rpy2
    import rpy2.robjects as ro
    import gc

       
    print('DE analysis with DElegate:')
    # load packages
    ro.globalenv['n_core'] = n_core
    ro.globalenv['max_memory'] = max_memory
    ro.r('''
    # Packages
    library(DElegate)
    library(Seurat)

    # Parallelization
    library(BiocParallel)
    register(MulticoreParam(n_core, progressbar = TRUE))

    library(future)
    plan("multicore", workers = n_core)
    options(future.globals.maxSize = max_memory * 1024^3)
    plan()
    ''')
    
    # transfer data & parameters
    if group_column is not None:
        ro.globalenv['group_column'] = group_column
    if replicate_column is not None:
        ro.globalenv['replicate_column'] = replicate_column
    ro.globalenv['compare'] = compare
    if type(compare) is list:
        ro.r('''
        compare <- unlist(compare)
        ''')
    
    ro.globalenv['method'] = method
    ro.globalenv['order_results'] = order_results
    ro.globalenv['verbosity'] = verbosity
    
    print('\tTransfer data...')
    if layer is None:
        print('\tUsing adata.X for differntial expression analysis...')
        ro.globalenv['counts'] = adata.X.T#.toarray()
        ro.globalenv['meta_data'] = adata.obs
        ro.globalenv['obs_names'] = adata.obs_names
        ro.globalenv['var_names'] = adata.var_names
    else:
        print('\tUsing layer \'', layer,'\' for differntial expression analysis...')
        ro.globalenv['counts'] = adata.layers[layer].T#.toarray()
        ro.globalenv['meta_data'] = adata.obs
        ro.globalenv['obs_names'] = adata.obs_names
        ro.globalenv['var_names'] = adata.var_names
        
    # generate seurat object
    ro.r('''
    rownames(counts) <- var_names
    colnames(counts) <- obs_names
    seurat <- CreateSeuratObject(counts = counts, meta.data = meta_data)
    ''')
    
    # run DElegate
    print('\tPerform differential gene expression analysis with method:', method,'...')
    
    # replace characters
    if group_column is not None:
        print('\tFixing characters in group_column:', group_column,'...')
        ro.r('''
        seurat@meta.data[group_column] <- gsub("[ -]", "_", get(group_column, seurat@meta.data))
        ''')
        
    if replicate_column is not None:
        print('\tFixing characters in replicate_column:', replicate_column,'...')
        ro.r('''
        seurat@meta.data[replicate_column] <- gsub("[ -]", "_", get(replicate_column, seurat@meta.data))
        ''')
    # run
    print('\tRunning', method,'...')
    ro.r('''
    de_res <- findDE(seurat,
                     meta_data = NULL,
                     group_column = group_column,
                     replicate_column = NULL,
                     compare = compare,
                     method = method,
                     order_results = order_results,
                     verbosity = verbosity)
    ''')
    
    # convert results
#     print('\tConverting results...')
#     ro.r('''
#     de_res <- as.data.frame(de_res)
#     ''')
    
    # transfer data
    print('\tTransfer data...')
    
    # convert results
    print('\tConvert results...')
    results = ro.globalenv['de_res']
    
    with (ro.default_converter + pandas2ri.converter).context():
        results = ro.conversion.get_conversion().rpy2py(results)
    
    #results.loc[:,'log10_padj'] = -np.log10(results.loc[:,'padj'])
    
    # delete
    print('\tClean up...')
    ro.r('''
    rm(list = ls())
    gc()
    ''')

    gc.collect()
    
    print('Done.')
    return results

def vulcano_plot_edger(results_dict=None, genes=[], annotate_top=True, n_top=10, title=None, min_logfc = 0.5, max_pval = 10**-2, group_order = (0,1), y_max_ext_factor=1.2, x_ext_factor=0.3, x_max_ext_factor=1.2, x_min_ext_factor=1.2, fig_size=(7,6), save=None):
    results = results_dict[results_dict['groupby_categories'][group_order[0]]]
    #ident_1 = results_dict['groupby_categories'][group_order[0]]
    #ident_2 = results_dict['groupby_categories'][group_order[1]]
    
    #n_up = sum((abs(results['pvals_adj']) > max_pval) & (results['logfc'] > min_logfc))
    #n_down = sum((abs(results['pvals_adj']) > max_pval) & (results['logfc'] < -min_logfc))
    n_up = sum((abs(results['pvals_adj']) < max_pval) & (results['logfc'] > min_logfc))
    n_down = sum((abs(results['pvals_adj']) < max_pval) & (results['logfc'] < -min_logfc))

    genes = genes + [gene for gene in results['names'] if not gene.startswith('ENSSSC')][0:n_top] + [gene for gene in results.sort_values(by=['log_pvals_adj'], ascending=False)['names'] if not gene.startswith('ENSSSC')][0:n_top]
    genes = genes + [gene for gene in results['names'][::-1] if not gene.startswith('ENSSSC')][0:n_top] + [gene for gene in results.loc[results['logfc']>0,:].sort_values(by=['log_pvals_adj'], ascending=False)['names'] if not gene.startswith('ENSSSC')][0:n_top]
    genes = list(set(genes))

    genes_up = [gene for gene in genes if gene in list(results.loc[(results['logfc'] > 0) & (abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'names'])]

    genes_down = [gene for gene in genes if gene in list(results.loc[(results['logfc'] < 0) & (abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'names'])]

    gene_2_plot = genes_up + genes_down

    results['log_pvals_adj'][results['log_pvals_adj'] > 300] = 300

    fig, ax = plt.subplots(1, figsize=fig_size)


    # Make x & y axis longer to make gene name plotting easier
    y_max = max(results['log_pvals_adj'])*y_max_ext_factor
    x_ext = (max(results['logfc_limit']) - min(results['logfc_limit'])) * x_ext_factor
    x_max = max(results['logfc_limit']) + x_ext
    x_min = min(results['logfc_limit']) - x_ext

    x_max = max(results['logfc_limit']) * x_max_ext_factor
    x_min = min(results['logfc_limit']) * x_min_ext_factor

    ax.set_ylim((-1,y_max))
    ax.set_xlim((x_min,x_max))

#     # normalize colormap
#     vcenter = 0
#     vmin, vmax = results['scores'].min(), results['scores'].max()
#     normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
#     colormap = cm.RdBu_r

    # Scatter plot
    sb.scatterplot(y='log_pvals_adj', x='logfc_limit',
                    color='#000000',s=20,
                    linewidth=0,
                    data=results)
    sb.scatterplot(y='log_pvals_adj', x='logfc_limit',
                    color='#cccccc',s=10,
                    linewidth=0,
                    data=results, ax=ax)
    y = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'log_pvals_adj']
    x = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'logfc_limit']
    c = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'color']

    sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                    #norm=normalize,
                    #cmap=colormap,
                    linewidth=0,
                    ax=ax)
    sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                    #norm=normalize,
                    #cmap=colormap,
                    linewidth=0,
                    ax=ax)

    # annotation
    an1 = ax.annotate(results_dict['groupby_categories'][group_order[1]] + '\n' + str(n_down) + ' genes', xy=(0.02, 0.98), xycoords='axes fraction',
                      va="top", ha="left")
    an2 = ax.annotate(results_dict['groupby_categories'][group_order[0]] + '\n' + str(n_up) + ' genes', xy=(0.98, 0.98), xycoords='axes fraction',  # (1, 0.5) of the an1's bbox
                      va="top", ha="right")

    # Lines
    plt.axvline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
    plt.axvline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
    plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")

    # Labels
    x_lim=ax.get_xlim()
    for gene_set,direction,ha in [(genes_down,-1,'right'),(genes_up,1,'left')]:
        labels = []
        for gene in gene_set:
            if gene.startswith('ENSSSC'):
                continue
            x=float(results.loc[results['names']==gene,'logfc_limit'])
            y=float(results.loc[results['names']==gene,'log_pvals_adj'])
            labels.append(plt.text(x, y, gene, color='#000000', fontsize=8))
        ax.set_xlim(sorted([x_lim[::direction][1],0.5*direction]))
        adjust_text(labels, expand_points=(1.5,1.5), expand_text=(2,2), expand_objects=(2,2), force_text=(0.75, 0.5), force_points=(0.75, 1), force_objects=(1, 0.5), ha=ha, precision=0.00001, lim=5000, autoalign='y', arrowprops=dict(arrowstyle="-",  color='k',  lw=0.5), ax=ax)
    ax.set_xlim(x_lim)

    # title & axis labels
    if title == None:
        title = 'Differential Gene Expression in ' + results_dict['restrict_to'] + ' Cells\n' + results_dict['groupby_categories'][group_order[0]] + ' vs ' + results_dict['groupby_categories'][group_order[1]]
    ax.set_title(title)
    ax.set_ylabel('$-log_{10}$ Adjusted p-Value')
    ax.set_xlabel('$log_2$ Fold Change')

    if save is not None:
        plt.savefig(save)




##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################



def run_DElegate_findMarkers(adata, 
                        layer = None, 
                        group_column = None, 
                        replicate_column = None, 
                        method = "edger", 
                        min_rate = 0.05,
                        min_fc = 1,
                        verbosity = 1, 
                        n_core = 64, 
                        max_memory = 4):
    '''
    adata: adata object to normalize
    layer: layer to use for normalization. Default = None -> use .X
    
    There are multiple ways the group comparisons can be specified based on the compare parameter. 
    The default, 'each_vs_rest', does multiple comparisons, one per group vs all remaining cells. 
    'all_vs_all', also does multiple comparisons, covering all group pairs. 
    If compare is set to a length two character vector, e.g. c('T-cells', 'B-cells'), one comparison between those two groups is done. 
    To put multiple groups on either side of a single comparison, use a list of length two. E.g. compare = list(c('cluster1', 'cluster5'), c('cluster3')).
    '''
    
    import rpy2
    import rpy2.robjects as ro
    import gc

       
    print('DE analysis with DElegate:')
    # load packages
    ro.globalenv['n_core'] = n_core
    ro.globalenv['max_memory'] = max_memory
    ro.r('''
    # Packages
    library(DElegate)
    library(Seurat)

    # Parallelization
    library(BiocParallel)
    register(MulticoreParam(n_core, progressbar = TRUE))

    library(future)
    plan("multicore", workers = n_core)
    options(future.globals.maxSize = max_memory * 1024^3)
    plan()
    ''')
    
    # transfer data & parameters
    if group_column is not None:
        ro.globalenv['group_column'] = group_column
    if replicate_column is not None:
        ro.globalenv['replicate_column'] = replicate_column
    
    ro.globalenv['method'] = method
    ro.globalenv['min_rate'] = min_rate
    ro.globalenv['min_fc'] = min_fc
    ro.globalenv['verbosity'] = verbosity
    
    print('\tTransfer data...')
    if layer is None:
        print('\tUsing adata.X for differntial expression analysis...')
        ro.globalenv['counts'] = adata.X.T#.toarray()
        ro.globalenv['meta_data'] = adata.obs
        ro.globalenv['obs_names'] = adata.obs_names
        ro.globalenv['var_names'] = adata.var_names
    else:
        print('\tUsing layer \'', layer,'\' for differntial expression analysis...')
        ro.globalenv['counts'] = adata.layers[layer].T#.toarray()
        ro.globalenv['meta_data'] = adata.obs
        ro.globalenv['obs_names'] = adata.obs_names
        ro.globalenv['var_names'] = adata.var_names
        
    # generate seurat object
    ro.r('''
    rownames(counts) <- var_names
    colnames(counts) <- obs_names
    seurat <- CreateSeuratObject(counts = counts, meta.data = meta_data)
    ''')
    
    # run DElegate
    print('\tPerform differential gene expression analysis with method:', method,'...')
    
    # replace characters
    if group_column is not None:
        print('\tFixing characters in group_column:', group_column,'...')
        ro.r('''
        seurat@meta.data[group_column] <- gsub("[ -]", "_", get(group_column, seurat@meta.data))
        ''')
        
    if replicate_column is not None:
        print('\tFixing characters in replicate_column:', replicate_column,'...')
        ro.r('''
        seurat@meta.data[replicate_column] <- gsub("[ -]", "_", get(replicate_column, seurat@meta.data))
        ''')
    # run
    print('\tRunning', method,'...')
    ro.r('''
    de_res <- FindAllMarkers2(seurat,
                     meta_data = NULL,
                     group_column = group_column,
                     replicate_column = NULL,
                     method = method,
                     min_rate = min_rate,
                     min_fc = min_fc,
                     verbosity = verbosity)
    ''')
    
    # convert results
#     print('\tConverting results...')
#     ro.r('''
#     de_res <- as.data.frame(de_res)
#     ''')
    
    # transfer data
    print('\tTransfer data...')
    
    # convert results
    print('\tConvert results...')
    results = ro.globalenv['de_res']
    
    with (ro.default_converter + pandas2ri.converter).context():
        results = ro.conversion.get_conversion().rpy2py(results)
    
    #results.loc[:,'log10_padj'] = -np.log10(results.loc[:,'padj'])
    
    # delete
    print('\tClean up...')
    ro.r('''
    rm(list = ls())
    gc()
    ''')

    gc.collect()
    
    print('Done.')
    return results


In [None]:
def load_cell_cycle_genes(adata, genome='auto'):
    # Load cell cycle genes

    ## KEGG cell cycle genes
    cc_kegg = pd.read_table('/mnt/ssd/Resources/KEGG_mmu_Cell_Cycle.txt').iloc[:,0].tolist()

    ## Cell cycle genes Regev lab (Tirosh et al. 2016, DOI: 10.1126/science.aad0501)
    cc_genes_regev = [x.strip() for x in open('/mnt/ssd/Resources/regev_cell_cycle_genes.txt')]
        
    if genome=='auto':
        genome = '_'.join(adata.var.loc[:,'genome'][0].split('_')[0:2])
    
    print('Genome is', genome)
        
    if (genome == 'Homo_sapiens') | (genome == 'homo_sapiens'):

        s_genes_regev = adata.var_names[np.isin(adata.var_names, cc_genes_regev[:43])]
        g2m_genes_regev = adata.var_names[np.isin(adata.var_names, cc_genes_regev[43:])]

        cc_genes_regev = list(adata.var_names[np.isin(adata.var_names, cc_genes_regev)])

        ## Cell cycle genes Macosko et al. 2015, https://doi.org/10.1016/j.cell.2015.05.002
        cc_genes_macosko = pd.read_table('/mnt/ssd/Resources/Macosko_cell_cycle_genes.txt', delimiter='\t')

        s_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['S'].dropna())])
        g2m_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['G2.M'].dropna())])
        m_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['M'].dropna())])
        mg1_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['M.G1'].dropna())])
        g1s_genes_macosko = list(adata.var_names[np.isin(adata.var_names, cc_genes_macosko['IG1.S'].dropna())])

        cc_genes_macosko = s_genes_macosko + g2m_genes_macosko + m_genes_macosko + mg1_genes_macosko + g1s_genes_macosko

        ## Combine all
        all_cc_genes = list(set(cc_kegg + cc_genes_regev + cc_genes_macosko))
        
        return all_cc_genes, s_genes_regev, g2m_genes_regev, cc_genes_regev, cc_genes_macosko, s_genes_macosko, g2m_genes_macosko, m_genes_macosko, mg1_genes_macosko, g1s_genes_macosko

    elif (genome == 'Mus_musculus') | (genome == 'mus_musculus'):
        
        s_genes_regev = [gene.lower().capitalize() for gene in cc_genes_regev[:43]]
        g2m_genes_regev = [gene.lower().capitalize() for gene in cc_genes_regev[43:]]

        cc_genes_regev = [gene.lower().capitalize() for gene in cc_genes_regev]

        ## Cell cycle genes Macosko et al. 2015, https://doi.org/10.1016/j.cell.2015.05.002
        cc_genes_macosko = pd.read_table('/mnt/ssd/Resources/Macosko_cell_cycle_genes.txt', delimiter='\t')

        s_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['S'].dropna())]
        g2m_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['G2.M'].dropna())]
        m_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['M'].dropna())]
        mg1_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['M.G1'].dropna())]
        g1s_genes_macosko = [gene.lower().capitalize() for gene in list(cc_genes_macosko['IG1.S'].dropna())]

        cc_genes_macosko = s_genes_macosko + g2m_genes_macosko + m_genes_macosko + mg1_genes_macosko + g1s_genes_macosko

        ## Combine all
        all_cc_genes = list(set(cc_kegg + cc_genes_regev + cc_genes_macosko))
        
        return all_cc_genes, s_genes_regev, g2m_genes_regev, cc_genes_regev, cc_genes_macosko, s_genes_macosko, g2m_genes_macosko, m_genes_macosko, mg1_genes_macosko, g1s_genes_macosko

    elif (genome == 'Sus_scrofa') | (genome == 'sus_scrofa'):
        
        s_genes_regev = mdata.var_names[np.isin(mdata.var_names, cc_genes_regev[:43])]
        g2m_genes_regev = mdata.var_names[np.isin(mdata.var_names, cc_genes_regev[43:])]

        cc_genes_regev = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_regev)])

        ## Cell cycle genes Macosko et al. 2015, https://doi.org/10.1016/j.cell.2015.05.002
        cc_genes_macosko = pd.read_table('/mnt/ssd/Resources/Macosko_cell_cycle_genes.txt', delimiter='\t')

        s_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['S'].dropna())])
        g2m_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['G2.M'].dropna())])
        m_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['M'].dropna())])
        mg1_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['M.G1'].dropna())])
        g1s_genes_macosko = list(mdata.var_names[np.isin(mdata.var_names, cc_genes_macosko['IG1.S'].dropna())])

        cc_genes_macosko = s_genes_macosko + g2m_genes_macosko + m_genes_macosko + mg1_genes_macosko + g1s_genes_macosko

        ## Combine all
        all_cc_genes = list(set(cc_kegg + cc_genes_regev + cc_genes_macosko))
        
        return all_cc_genes, s_genes_regev, g2m_genes_regev, cc_genes_regev, cc_genes_macosko, s_genes_macosko, g2m_genes_macosko, m_genes_macosko, mg1_genes_macosko, g1s_genes_macosko


###################################################################################################################
###################################################################################################################
###################################################################################################################
    


def plot_composition(adata, 
x_key=None, 
y_key=None, 
x_labels = None,
y_labels = None,
y_colors = None,
width = 0.85,       # the width of the bars: can also be len(x) sequence
x_rotation = 0,
y_lim_offset = 2.5,
x_lim_offset = 0.45,
figsize= (6, 4),
save=None):
    with rc_context({'figure.figsize': figsize}): #rcParams['figure.figsize']=(6,4)
        if (x_labels == None):
            x_labels = list(adata.obs[x_key].cat.categories)
        
        if (y_labels == None):
            y_labels = list(adata.obs[y_key].cat.categories)
        
        if (y_colors == None):
            y_colors = list(adata.uns[y_key + '_colors'])
            
        dic = {'x_labels':x_labels}
        
        for y_label in y_labels:
            x_values = []
            for x_label in x_labels:
                x_value = adata.obs[y_key][adata.obs[x_key]==x_label].value_counts()[y_label]/adata.obs[y_key][adata.obs[x_key]==x_label].value_counts().sum()*100
                x_values.append(x_value)
            dic[y_label] = x_values
        
        df = pd.DataFrame(data = dic)

        ax = df.plot(x='x_labels', kind='bar', stacked=True, width=width, edgecolor='0', linewidth=0.5, color=y_colors)

        ax.set_ylabel('%')
        ax.set_xlabel('')
        ax.set_title(y_key + ' by ' + x_key)
        ax.axes.set_xticklabels(labels=x_labels, rotation=x_rotation)
        ax.legend(bbox_to_anchor=(1, .5),loc='center left', edgecolor='1')

        plt.ylim([-y_lim_offset,100+y_lim_offset])
        plt.xlim([-1+x_lim_offset,len(x_labels)-x_lim_offset])

        plt.show()

        if save is not None:
            plt.savefig(save)
        
    return(df)





##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
def round_to_5(x, base=5):
                return base * round(x/base)

def get_go_terms(results_dict=None, min_score=None, max_pval=0.01, min_logfc = 0.5, organism='hsapiens', selection_string=None, plot_all=True, plot_top=True, n_top=20, plot_select=True, n_select=20, cmap_up=mymap, cmap_down=mymap, cmap_dot='flare', width_factor=30, hight_factor=15, aspect_nominator=3, size_denominator=1.5):
    # GO terms upregulated
    cat = results_dict['groupby_categories'][0]
    results = results_dict[cat]
    results = results.sort_values(by=['log_pvals_adj'], ascending=False)
    
    if min_score == None:
        up_genes = list(results.loc[(results['logfc'] > 0) & (results['pvals_adj'] <= max_pval) & (abs(results['logfc']) >= min_logfc),'names'])
    else:
        up_genes = list(results.loc[(results['logfc'] > 0) & (abs(results['scores']) >= min_score) & (results['pvals_adj'] <= max_pval) & (abs(results['logfc']) >= min_logfc),'names'])
    print("Top 10 Upregulated in ", cat, ": ", up_genes[:10])
    print("Number of Upregulated Genes in ", cat, ": ", len(up_genes))
    up_enrich=sc.queries.enrich(up_genes, org=organism, gprofiler_kwargs={'no_evidences':False})
    up_enrich['-log_adj_p'] = -np.log10(up_enrich.loc[:,'p_value'])
    up_enrich['term'] = '(' + up_enrich['native'] + ') ' + up_enrich['name']
    up_enrich = up_enrich.assign(term_cat=pd.Categorical(up_enrich['term'], categories=up_enrich['term'][::-1]))
    up_enrich = up_enrich.loc[up_enrich['source']!='TF',:]
    up_enrich = up_enrich.loc[up_enrich['source']!='HPA',:]
    up_enrich = up_enrich.loc[up_enrich['source']!='GO:CC',:]
    up_enrich = up_enrich.loc[up_enrich['source']!='GO:MF',:]
    
    if plot_all:
        if up_enrich.size > 0:
           (ggplot(up_enrich.iloc[0:50,:])
            + aes(y='-log_adj_p', x='term_cat', fill='-log_adj_p')
            + geom_col(color='black')
            + scale_fill_continuous(cmap_name=cmap_up)
            + coord_flip()
            + labs(y='-log10 Adjusted p-Value', x='', title=cat + ' - ' + results_dict['restrict_to'])
            + theme_linedraw()
            + theme(aspect_ratio=2/1)
           ).draw(return_ggplot=False)

        up_enrich_kegg = up_enrich.loc[up_enrich['source']=='KEGG',:]
        if up_enrich_kegg.size > 0:
           (ggplot(up_enrich_kegg.iloc[0:50,:])
            + aes(y='-log_adj_p', x='term_cat', fill='-log_adj_p')
            + geom_col(color='black')
            + scale_fill_continuous(cmap_name=cmap_up)
            + coord_flip()
            + labs(y='-log10 Adjusted p-Value', x='', title=cat + ' - ' + results_dict['restrict_to'] + ' - KEGG')
            + theme_linedraw()
            + theme(aspect_ratio=2/1)
           ).draw(return_ggplot=False)

        up_enrich_reac = up_enrich.loc[up_enrich['source']=='REAC',:]
        if up_enrich_reac.size > 0:
           (ggplot(up_enrich_reac.iloc[0:50,:])
            + aes(y='-log_adj_p', x='term_cat', fill='-log_adj_p')
            + geom_col(color='black')
            + scale_fill_continuous(cmap_name=cmap_up)
            + coord_flip()
            + labs(y='-log10 Adjusted p-Value', x='', title=cat + ' - ' + results_dict['restrict_to'] + ' - Reactome')
            + theme_linedraw()
            + theme(aspect_ratio=2/1)
           ).draw(return_ggplot=False)

    # GO-Terms downregulated
    cat = results_dict['groupby_categories'][1]
    results = results_dict[cat]
    results = results.sort_values(by=['log_pvals_adj'], ascending=False)
    if min_score == None:
        down_genes = list(results.loc[(results['logfc'] > 0) & (results['pvals_adj'] <= max_pval) & (abs(results['logfc']) >= min_logfc),'names'])
    else:
        down_genes = list(results.loc[(results['logfc'] > 0) & (abs(results['scores']) >= min_score) & (results['pvals_adj'] <= max_pval) & (abs(results['logfc']) >= min_logfc),'names'])
    print("Top 10 Upregulated in ", cat, ": ", down_genes[:10])
    print("Number of Upregulated Genes in ", cat, ": ", len(down_genes))
    down_enrich=sc.queries.enrich(down_genes, org=organism, gprofiler_kwargs={'no_evidences':False})
    down_enrich['-log_adj_p'] = -np.log10(down_enrich.loc[:,'p_value'])
    down_enrich['term'] = '(' + down_enrich['native'] + ') ' + down_enrich['name']
    down_enrich = down_enrich.assign(term_cat=pd.Categorical(down_enrich['term'], categories=down_enrich['term'][::-1]))
    down_enrich = down_enrich.loc[down_enrich['source']!='TF',:]
    down_enrich = down_enrich.loc[down_enrich['source']!='HPA',:]
    down_enrich = down_enrich.loc[down_enrich['source']!='GO:CC',:]
    down_enrich = down_enrich.loc[down_enrich['source']!='GO:MF',:]
    
    if plot_all:
        if down_enrich.size > 0:
           (ggplot(down_enrich.iloc[0:50,:])
            + aes(y='-log_adj_p', x='term_cat', fill='-log_adj_p')
            + geom_col(color='black')
            + scale_fill_continuous(cmap_name=cmap_down)
            + coord_flip()
            + labs(y='-log10 Adjusted p-Value', x='', title=cat + ' - ' + results_dict['restrict_to'])
            + theme_linedraw()
            + theme(aspect_ratio=2/1)
           ).draw(return_ggplot=False)

        down_enrich_kegg = down_enrich.loc[down_enrich['source']=='KEGG',:]
        if down_enrich_kegg.size > 0:
           (ggplot(down_enrich_kegg.iloc[0:50,:])
            + aes(y='-log_adj_p', x='term_cat', fill='-log_adj_p')
            + geom_col(color='black')
            + scale_fill_continuous(cmap_name=cmap_down)
            + coord_flip()
            + labs(y='-log10 Adjusted p-Value', x='', title=cat + ' - ' + results_dict['restrict_to'] + ' - KEGG')
            + theme_linedraw()
            + theme(aspect_ratio=2/1)
           ).draw(return_ggplot=False)

        down_enrich_reac = down_enrich.loc[down_enrich['source']=='REAC',:]
        if down_enrich_reac.size > 0:
           (ggplot(down_enrich_reac.iloc[0:50,:])
            + aes(y='-log_adj_p', x='term_cat', fill='-log_adj_p')
            + geom_col(color='black')
            + scale_fill_continuous(cmap_name=cmap_down)
            + coord_flip()
            + labs(y='-log10 Adjusted p-Value', x='', title=cat + ' - ' + results_dict['restrict_to'] + ' - Reactome')
            + theme_linedraw()
            + theme(aspect_ratio=2/1)
           ).draw(return_ggplot=False)
        
    
    if plot_top:
        # Filter terms by key words
        down_enrich_select = down_enrich.iloc[0:n_top,:]
        up_enrich_select = up_enrich.iloc[0:n_top,:]

        # Mark cluster
        up_enrich_select['cluster'] = results_dict['groupby_categories'][0]
        down_enrich_select['cluster'] = results_dict['groupby_categories'][1]

        # Sort data frames by p-value
        up_enrich_select['sort'] = -np.log10(up_enrich_select.loc[:,'p_value'])
        up_enrich_select = up_enrich_select.sort_values(by=['sort'], ascending=False)

        down_enrich_select['sort'] = np.log10(down_enrich_select.loc[:,'p_value']) #reverse order
        down_enrich_select = down_enrich_select.sort_values(by=['sort'], ascending=True)

        # Join data frames
        joined_enrich_select = pd.concat([up_enrich_select.iloc[0:20,:],down_enrich_select.iloc[0:20,:]])

        # Calc gene ratio
        joined_enrich_select['ratio'] = joined_enrich_select.loc[:,'intersection_size']/joined_enrich_select.loc[:,'query_size']

        # Change order of terms in plot
        joined_enrich_select.index = [n for n in range(0,len(joined_enrich_select.index),1)]

        for term in joined_enrich_select['term'][joined_enrich_select['term'].duplicated()]:
            joined_enrich_select.loc[joined_enrich_select['term']==term,'sort'] = joined_enrich_select.iloc[abs(joined_enrich_select.loc[joined_enrich_select['term']==term,'sort']).idxmax(),:]['sort']

        joined_enrich_select = joined_enrich_select.sort_values(by=['sort'], ascending=True)
        joined_enrich_select['term'] = pd.Categorical(joined_enrich_select['term'], categories=joined_enrich_select['term'].drop_duplicates())

        # Plot
        if joined_enrich_select.size > 0:
            n_terms = len(joined_enrich_select.drop_duplicates(subset=['term']))
            

            # Define parameters
            cmap_dot = 'flare'
            width = 9
            height = len(joined_enrich_select) / 5
            fig_size = (width, height)
            x_ext = 0.65
            y_ext = 1
            y_ext_factor = 100
            size_factor = 18
            colorbar = True

            # Create the figure and axes
            fig, ax = plt.subplots(1, figsize=fig_size)

            # Set values for scatterplot
            x = joined_enrich_select['cluster']
            y = [name[0].upper() + name[1:] for name in joined_enrich_select['name']]
            size = joined_enrich_select['ratio'] * len(joined_enrich_select) * size_factor
            color = joined_enrich_select['-log_adj_p']

            # Adjust plot limits and normalize colormap
            x_max = 1 + x_ext
            x_min = 0 - x_ext
            ax.set_xlim((x_min, x_max))
            y_max = len(x)
            y_min = 0 - y_ext
            ax.set_ylim((y_min, y_max))
            vmin, vmax = color.min(), color.max()
            vcenter = (vmin + vmax) / 2
            normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)

            # Set scatterplot properties
            kwargs = {'edgecolor': "black", 'linewidth': 0.75, 'linestyle': '-'}

            # Create scatterplot
            sb.scatterplot(x=x, y=y, c=color, s=size, cmap=cmap_dot, **kwargs)

            # Set title and axis labels
            title = 'Top Terms'
            ax.set_title(title, fontweight='bold')
            ax.set_ylabel('')
            ax.set_xlabel('')
            plt.yticks(fontsize=12)
            plt.xticks(fontsize=12)

            # Create size legend
            min_size = size.min()
            max_size = size.max()
            size_ticks = np.linspace(round(size.max(), 0), round(size.min(), 0), 5)[::1].astype(int)
            l1 = plt.scatter([], [], s=round_to_5(size_ticks[0]), color='gray', edgecolor='black', linewidth=0.75)
            l2 = plt.scatter([], [], s=round_to_5(size_ticks[1]), color='gray', edgecolor='black', linewidth=0.75)
            l3 = plt.scatter([], [], s=round_to_5(size_ticks[2]), color='gray', edgecolor='black', linewidth=0.75)
            l4 = plt.scatter([], [], s=round_to_5(size_ticks[3]), color='gray', edgecolor='black', linewidth=0.75)
            l5 = plt.scatter([], [], s=round_to_5(size_ticks[4]), color='gray', edgecolor='black', linewidth=0.75)
            labels = [str(int(i)) for i in size_ticks]

            # Create size legend
            leg = ax.legend([l1, l2, l3, l4, l5],
                            labels,
                            ncol=1,
                            frameon=False,
                            fontsize=12,
                            loc='lower left',
                            bbox_to_anchor=(1.1, -0.50),
                            handlelength=0.75,
                            handleheight=1.2,
                            title='Gene Ratio',
                            scatterpoints=1,
                            facecolor='black')

            # Customize legend title position
            ax.get_legend().get_title().set_rotation(90)
            ax.get_legend().get_title().set_position((110, -350))

            # Add colorbar
            if colorbar:
                min_tick = round(min(color) + 0.5, 0)
                max_tick = round(max(color) - 0.5, 0)
                ticks = range(int(min_tick), int(max_tick) + 1, int(round((max_tick - min_tick) / 4, 0)))
                scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=cmap_dot)
                scalarmappaple.set_array(color)
                bar = fig.colorbar(scalarmappaple, ax=ax, shrink=0.25, pad=0.1, aspect=10, ticks=ticks, location='right', fraction=0.15)
                bar.set_ticklabels([str(tick) + '  ' for tick in list(ticks)], fontsize=12)
                bar.set_label(label='-log$_{10}$(Adjusted p-Value)')

            # Adjust the left margin to increase white space
            plt.subplots_adjust(left=0.65)  # Increase or decrease the value as needed

            plt.show() 


        pd.set_option('display.max_colwidth', None)
        display(pd.concat([up_enrich_select.iloc[:,[0,2,16,19,8,14,5]],down_enrich_select.iloc[:,[0,2,16,19,8,14,5]]]))
    
    
    
    # Filter terms by key words
    down_enrich_select = down_enrich[down_enrich['name'].str.contains(selection_string, na=False, case=False)]
    up_enrich_select = up_enrich[up_enrich['name'].str.contains(selection_string, na=False, case=False)]

    # Mark cluster
    up_enrich_select['cluster'] = results_dict['groupby_categories'][0]
    down_enrich_select['cluster'] = results_dict['groupby_categories'][1]

    # Sort data frames by p-value
    up_enrich_select['sort'] = -np.log10(up_enrich_select.loc[:,'p_value'])
    up_enrich_select = up_enrich_select.sort_values(by=['sort'], ascending=False)

    down_enrich_select['sort'] = np.log10(down_enrich_select.loc[:,'p_value']) #reverse order
    down_enrich_select = down_enrich_select.sort_values(by=['sort'], ascending=True)

    # Join data frames
    joined_enrich_select = pd.concat([up_enrich_select.iloc[0:n_select,:],down_enrich_select.iloc[0:n_select,:]])

    # Calc gene ratio
    joined_enrich_select['ratio'] = joined_enrich_select.loc[:,'intersection_size']/joined_enrich_select.loc[:,'query_size']

    # Change order of terms in plot
    joined_enrich_select.index = [n for n in range(0,len(joined_enrich_select.index),1)]

    for term in joined_enrich_select['term'][joined_enrich_select['term'].duplicated()]:
        joined_enrich_select.loc[joined_enrich_select['term']==term,'sort'] = joined_enrich_select.iloc[abs(joined_enrich_select.loc[joined_enrich_select['term']==term,'sort']).idxmax(),:]['sort']

    joined_enrich_select = joined_enrich_select.sort_values(by=['sort'], ascending=True)
    joined_enrich_select['term'] = pd.Categorical(joined_enrich_select['term'], categories=joined_enrich_select['term'].drop_duplicates())
    
    if plot_select:
        # Plot
        if joined_enrich_select.size > 0:
            n_terms = len(joined_enrich_select.drop_duplicates(subset=['term']))
                        # Define parameters
            cmap_dot = 'flare'
            width = 9
            height = len(joined_enrich_select) / 5
            fig_size = (width, height)
            x_ext = 0.65
            y_ext = 1
            y_ext_factor = 100
            size_factor = 18
            colorbar = True

            # Create the figure and axes
            fig, ax = plt.subplots(1, figsize=fig_size)

            # Set values for scatterplot
            x = joined_enrich_select['cluster']
            y = [name[0].upper() + name[1:] for name in joined_enrich_select['name']]
            size = joined_enrich_select['ratio'] * len(joined_enrich_select) * size_factor
            color = joined_enrich_select['-log_adj_p']

            # Adjust plot limits and normalize colormap
            x_max = 1 + x_ext
            x_min = 0 - x_ext
            ax.set_xlim((x_min, x_max))
            y_max = len(x)
            y_min = 0 - y_ext
            ax.set_ylim((y_min, y_max))
            vmin, vmax = color.min(), color.max()
            vcenter = (vmin + vmax) / 2
            normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)

            # Set scatterplot properties
            kwargs = {'edgecolor': "black", 'linewidth': 0.75, 'linestyle': '-'}

            # Create scatterplot
            sb.scatterplot(x=x, y=y, c=color, s=size, cmap=cmap_dot, **kwargs)

            # Set title and axis labels
            title = 'Selected Terms'
            ax.set_title(title, fontweight='bold')
            ax.set_ylabel('')
            ax.set_xlabel('')
            plt.yticks(fontsize=12)
            plt.xticks(fontsize=12)

            # Create size legend
            min_size = size.min()
            max_size = size.max()
            size_ticks = np.linspace(round(size.max(), 0), round(size.min(), 0), 5)[::1].astype(int)
            l1 = plt.scatter([], [], s=round_to_5(size_ticks[0]), color='gray', edgecolor='black', linewidth=0.75)
            l2 = plt.scatter([], [], s=round_to_5(size_ticks[1]), color='gray', edgecolor='black', linewidth=0.75)
            l3 = plt.scatter([], [], s=round_to_5(size_ticks[2]), color='gray', edgecolor='black', linewidth=0.75)
            l4 = plt.scatter([], [], s=round_to_5(size_ticks[3]), color='gray', edgecolor='black', linewidth=0.75)
            l5 = plt.scatter([], [], s=round_to_5(size_ticks[4]), color='gray', edgecolor='black', linewidth=0.75)
            labels = [str(int(i)) for i in size_ticks]

            # Create size legend
            leg = ax.legend([l1, l2, l3, l4, l5],
                            labels,
                            ncol=1,
                            frameon=False,
                            fontsize=12,
                            loc='lower left',
                            bbox_to_anchor=(1.1, -0.50),
                            handlelength=0.75,
                            handleheight=1.2,
                            title='Gene Ratio',
                            scatterpoints=1,
                            facecolor='black')

            # Customize legend title position
            ax.get_legend().get_title().set_rotation(90)
            ax.get_legend().get_title().set_position((110, -350))

            # Add colorbar
            if colorbar:
                min_tick = round(min(color) + 0.5, 0)
                max_tick = round(max(color) - 0.5, 0)
                ticks = range(int(min_tick), int(max_tick) + 1, int(round((max_tick - min_tick) / 4, 0)))
                scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=cmap_dot)
                scalarmappaple.set_array(color)
                bar = fig.colorbar(scalarmappaple, ax=ax, shrink=0.25, pad=0.1, aspect=10, ticks=ticks, location='right', fraction=0.15)
                bar.set_ticklabels([str(tick) + '  ' for tick in list(ticks)], fontsize=12)
                bar.set_label(label='-log$_{10}$(Adjusted p-Value)')

            # Adjust the left margin to increase white space
            plt.subplots_adjust(left=0.65)  # Increase or decrease the value as needed

            plt.show() 

        display(pd.concat([up_enrich_select.iloc[:,[0,2,16,19,8,14,5]],down_enrich_select.iloc[:,[0,2,16,19,8,14,5]]]))
    
    # add enrichment to results
    enrichment_results = dict()
    enrichment_results['term_selection_string'] = selection_string
    enrichment_results['organism_enrichment_analysis'] = organism
    enrichment_results['enriched_terms_' + results_dict['groupby_categories'][0]] = up_enrich
    enrichment_results['enriched_terms_' + results_dict['groupby_categories'][1]] = down_enrich
    enrichment_results['selected_terms_' + results_dict['groupby_categories'][0]] = up_enrich_select
    enrichment_results['selected_terms_' + results_dict['groupby_categories'][1]] = down_enrich_select
    results_dict['enrichment_analysis'] = enrichment_results
    
    return results_dict



##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################

def export_results_edger(results_dict=None, path=None, file_base=None):
    import datetime
    date = datetime.datetime.now()

    method = results_dict['method'] + '_results_'
    suffix = results_dict['restrict_to'] + '_' + results_dict['groupby_categories'][0] + '-vs-' + results_dict['groupby_categories'][1]

    path = path + date.strftime('%Y-%m-%d') + '_' + file_base + method + suffix + '.xlsx'
    path = path.replace(' ','-')

    parameter_keys = ('method', 'cmd', 'groupby', 'groupby_categories', 'groups_restrict', 'groups_restrict_categories', 'restrict_to', 'layer', 'min_cluster_size', 'min_frac_cells', 'ambient_genes_removed', 'ambient_genes_kept', 'background_genes', 'n_genes', 'n_cells')

    parameter_df = pd.DataFrame.from_dict(dict((k,str(results_dict[k])) for k in parameter_keys if k in results_dict), orient ='index')

    if 'enrichment_analysis' in results_dict.keys():
        parameter_keys = ('term_selection_string', 'organism_enrichment_analysis')
        parameter_df = pd.concat([parameter_df, pd.DataFrame.from_dict(dict((k,str(results_dict['enrichment_analysis'][k])) for k in parameter_keys if k in results_dict['enrichment_analysis']), orient ='index')])

    # writing to Excel
    print('Exporting results to', path)
    excel = pd.ExcelWriter(path)

    # print parameters
    print('\tWriting parameters...')
    parameter_df.to_excel(excel, sheet_name='Parameters', index=True, header=False)

    for key in results_dict['groupby_categories']:

        print('\tWriting', key, '...')

        # replace invalid characters
        sheet_name=key.replace(' ','-') #.replace('$^{high}$/','h_').replace('$^{low}$/','l_').replace('$^{high}$','h').replace('$^{low}$','l')

        # write DataFrame to excel
        results_dict[key].sort_values(by=['logfc'], ascending=False).to_excel(excel, sheet_name=sheet_name, index=False, freeze_panes=(1,1))

    if 'enrichment_analysis' in results_dict.keys():
        for key in list(results_dict['enrichment_analysis'].keys())[2:]:

            print('\tWriting', key, '...')

            # replace invalid characters
            sheet_name=key.replace(' ','-') #.replace('$^{high}$/','h_').replace('$^{low}$/','l_').replace('$^{high}$','h').replace('$^{low}$','l')

            # write DataFrame to excel
            results_dict['enrichment_analysis'][key].to_excel(excel, sheet_name=sheet_name, index=False, freeze_panes=(1,1))

    excel.close()



##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################

def color_variant(hex_color, brightness_offset=90):
    # Taken from 
    # https://chase-seibert.github.io/blog/2011/07/29/python-calculate-lighterdarker-rgb-colors.html
    """ 
    Takes a color like #87c95f and produces a lighter or darker variant
    For lighter take positive, for darker take negative
    """
    if len(hex_color) != 7:
        raise Exception("Passed %s into color_variant(), needs to be in #87c95f format." % hex_color)
    rgb_hex = [hex_color[x:x+2] for x in [1, 3, 5]]
    new_rgb_int = [int(hex_value, 16) + brightness_offset for hex_value in rgb_hex]
    new_rgb_int = [min([255, max([0, i])]) for i in new_rgb_int] # make sure new values are between 0 and 255
    # hex() produces "0x88", we want just "88"
    return "#" + "".join([hex(i)[2:] for i in new_rgb_int])





##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################

def vulcano_plot_edger(results_dict=None, genes=[], annotate_top=True, n_top=10, title=None, min_logfc = 0.5, max_pval = 10**-2, group_order = (0,1), y_max_ext_factor=1.2, x_ext_factor=0.3, x_max_ext_factor=1.2, x_min_ext_factor=1.2, fig_size=(7,6), save=None):
    results = results_dict[results_dict['groupby_categories'][group_order[0]]]
    #ident_1 = results_dict['groupby_categories'][group_order[0]]
    #ident_2 = results_dict['groupby_categories'][group_order[1]]
    
    n_up = sum((abs(results['pvals_adj']) < max_pval) & (results['logfc'] > min_logfc))
    n_down = sum((abs(results['pvals_adj']) < max_pval) & (results['logfc'] < -min_logfc))

    genes = genes + [gene for gene in results['names'] if not gene.startswith('ENSSSC')][0:n_top]
    genes = genes + [gene for gene in results['names'][::-1] if not gene.startswith('ENSSSC')][0:n_top]
    genes = list(set(genes))

    genes_up = [gene for gene in genes if gene in list(results.loc[(results['logfc'] > 0) & (abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'names'])]

    genes_down = [gene for gene in genes if gene in list(results.loc[(results['logfc'] < 0) & (abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'names'])]

    gene_2_plot = genes_up + genes_down

    results['log_pvals_adj'][results['log_pvals_adj'] > 300] = 300

    fig, ax = plt.subplots(1, figsize=fig_size)


    # Make x & y axis longer to make gene name plotting easier
    y_max = max(results['log_pvals_adj'])*y_max_ext_factor
    x_ext = (max(results['logfc_limit']) - min(results['logfc_limit'])) * x_ext_factor
    x_max = max(results['logfc_limit']) + x_ext
    x_min = min(results['logfc_limit']) - x_ext

    x_max = max(results['logfc_limit']) * x_max_ext_factor
    x_min = min(results['logfc_limit']) * x_min_ext_factor

    ax.set_ylim((-1,y_max))
    ax.set_xlim((x_min,x_max))

#     # normalize colormap
#     vcenter = 0
#     vmin, vmax = results['scores'].min(), results['scores'].max()
#     normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
#     colormap = cm.RdBu_r

    # Scatter plot
    sb.scatterplot(y='log_pvals_adj', x='logfc_limit',
                    color='#000000',s=20,
                    linewidth=0,
                    data=results)
    sb.scatterplot(y='log_pvals_adj', x='logfc_limit',
                    color='#cccccc',s=10,
                    linewidth=0,
                    data=results, ax=ax)
    y = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'log_pvals_adj']
    x = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'logfc_limit']
    c = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'color']

    sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                    #norm=normalize,
                    #cmap=colormap,
                    linewidth=0,
                    ax=ax)
    sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                    #norm=normalize,
                    #cmap=colormap,
                    linewidth=0,
                    ax=ax)

    # annotation
    an1 = ax.annotate(results_dict['groupby_categories'][group_order[1]] + '\n' + str(n_down) + ' genes', xy=(0.02, 0.98), xycoords='axes fraction',
                      va="top", ha="left")
    an2 = ax.annotate(results_dict['groupby_categories'][group_order[0]] + '\n' + str(n_up) + ' genes', xy=(0.98, 0.98), xycoords='axes fraction',  # (1, 0.5) of the an1's bbox
                      va="top", ha="right")

    # Lines
    plt.axvline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
    plt.axvline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
    plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")

    # Labels
    x_lim=ax.get_xlim()
    for gene_set,direction,ha in [(genes_down,-1,'right'),(genes_up,1,'left')]:
        labels = []
        for gene in gene_set:
            if gene.startswith('ENSSSC'):
                continue
            x=float(results.loc[results['names']==gene,'logfc_limit'])
            y=float(results.loc[results['names']==gene,'log_pvals_adj'])
            labels.append(plt.text(x, y, gene, color='#000000', fontsize=8))
        ax.set_xlim(sorted([x_lim[::direction][1],0.5*direction]))
        adjust_text(labels, expand_points=(1.5,1.5), expand_text=(2,2), expand_objects=(2,2), force_text=(0.75, 0.5), force_points=(0.75, 1), force_objects=(1, 0.5), ha=ha, precision=0.00001, lim=5000, autoalign='y', arrowprops=dict(arrowstyle="-",  color='k',  lw=0.5), ax=ax)
    ax.set_xlim(x_lim)

    # title & axis labels
    if title == None:
        title = 'Differential Gene Expression in ' + results_dict['restrict_to'] + ' Cells\n' + results_dict['groupby_categories'][group_order[0]] + ' vs ' + results_dict['groupby_categories'][group_order[1]]
    ax.set_title(title)
    ax.set_ylabel('$-log_{10}$ Adjusted p-Value')
    ax.set_xlabel('$log_2$ Fold Change')

    if save is not None:
        plt.savefig(save)



##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################

def ma_plot_edger(results_dict=None, genes=[], annotate_top=True, n_top=10, title=None, min_logfc = 0.5, max_pval = 10**-2, group_order = (0,1), x_max_ext_factor=1.2, y_ext_factor=0.3, y_max_ext_factor=1.25, y_min_ext_factor=1.25, fig_size=(7,6), colorbar=True):
    results = results_dict[results_dict['groupby_categories'][group_order[0]]]
    #ident_1 = results_dict['groupby_categories'][group_order[0]]
    #ident_2 = results_dict['groupby_categories'][group_order[1]]
    
    n_up = sum((abs(results['pvals_adj']) > max_pval) & (results['logfc'] > min_logfc))
    n_down = sum((abs(results['pvals_adj']) > max_pval) & (results['logfc'] < -min_logfc))

    genes = genes + [gene for gene in results['names'] if not gene.startswith('ENSSSC')][0:n_top]
    genes = genes + [gene for gene in results['names'][::-1] if not gene.startswith('ENSSSC')][0:n_top]
    genes = list(set(genes))

    genes_up = [gene for gene in genes if gene in list(results.loc[(results['logfc'] > 0) & (abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'names'])]

    genes_down = [gene for gene in genes if gene in list(results.loc[(results['logfc'] < 0) & (abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'names'])]

    gene_2_plot = genes_up + genes_down

    results['log_pvals_adj'][results['log_pvals_adj'] > 300] = 300

    fig, ax = plt.subplots(1, figsize=fig_size)


    # Make x & y axis longer to make gene name plotting easier
    x_max = max(results['logexprs'])*x_max_ext_factor
    y_ext = (max(results['logfc_limit']) - min(results['logfc_limit'])) * y_ext_factor
    y_max = max(results['logfc_limit']) + y_ext
    y_min = min(results['logfc_limit']) - y_ext

    y_max = max(results['logfc_limit']) * y_max_ext_factor
    y_min = min(results['logfc_limit']) * y_min_ext_factor

    ax.set_ylim((y_min,y_max))
    
    # Scatter plot
    sb.scatterplot(y='logfc_limit', x='logexprs',
                    color='#000000',s=20,
                    linewidth=0,
                    data=results)
    sb.scatterplot(y='logfc_limit', x='logexprs',
                    color='#cccccc',s=10,
                    linewidth=0,
                    data=results, ax=ax)
    y = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'logfc_limit']
    x = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'logexprs']
    c = results.loc[(abs(results['logfc']) >= min_logfc) & (results['pvals_adj'] <= max_pval),'color']

    sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                    #norm=normalize,
                    #cmap=colormap,
                    linewidth=0,
                    ax=ax)
    sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                    #norm=normalize,
                    #cmap=colormap,
                    linewidth=0,
                    ax=ax)

    # annotation
    an1 = ax.annotate(results_dict['groupby_categories'][group_order[1]] + '\n' + str(n_down) + ' genes', xy=(0.98, 0.02), xycoords='axes fraction',
                      va="bottom", ha="right")
    an2 = ax.annotate(results_dict['groupby_categories'][group_order[0]] + '\n' + str(n_up) + ' genes', xy=(0.98, 0.98), xycoords='axes fraction',  # (1, 0.5) of the an1's bbox
                      va="top", ha="right")

    # Lines
    plt.axhline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
    plt.axhline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
#     plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")

    # Labels
    y_lim=ax.get_ylim()
    for gene_set,direction,va in [(genes_down,-1,'top'),(genes_up,1,'bottom')]:
        labels = []
        for gene in gene_set:
            if gene.startswith('ENSSSC'):
                continue
            x=float(results.loc[results['names']==gene,'logexprs'])
            y=float(results.loc[results['names']==gene,'logfc_limit'])
            labels.append(plt.text(x, y, gene, color='#000000', fontsize=8))
        ax.set_ylim(sorted([y_lim[::direction][1],0.5*direction]))
        adjust_text(labels, expand_points=(1.5,1.5), expand_text=(2,2), expand_objects=(2,2), force_text=(0.75, 0.5), force_points=(0.75, 1), force_objects=(1, 0.5), va=va, precision=0.00001, lim=5000, autoalign='x', arrowprops=dict(arrowstyle="-",  color='k',  lw=0.5), ax=ax)
    ax.set_ylim(y_lim)

    # title & axis labels
    if title == None:
        title = 'Differential Gene Expression in ' + results_dict['restrict_to'] + ' Cells\n' + results_dict['groupby_categories'][group_order[0]] + ' vs ' + results_dict['groupby_categories'][group_order[1]]
    ax.set_title(title)
    ax.set_xlabel('$-log_{2}$ Expression')
    ax.set_ylabel('$log_2$ Fold Change')



##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################

def generate_pseudobulk(
    adata,
    groupby, # groups/condtions to test (e.g stage, genotype, ...)
    sample_key, # key for samples/replicates
    identity, # cluster/cell type to test in
    identity_key, # key for clusters/cell types
    obs_to_keep=[],  # which metadata to keep, e.g. gender, age, etc.
    replicates_per_sample=3, # number of pseudoreplicates/sample
    min_cell_per_sample=30,
    aggr_method='sum'
):
    # TO DO:
    # * Plot/show different samples/group and number of cells in each pseudo-replicate
    import random
    # subset adata to the given cell identity
    adata_identity = adata[adata.obs[identity_key] == identity].copy()
    # check which samples to keep according to the number of cells specified with min_cell_per_sample
    size_by_sample = adata_identity.obs.groupby([sample_key]).size()
    samples_to_drop = [ sample for sample in size_by_sample.index if size_by_sample[sample] <= min_cell_per_sample]
    if len(samples_to_drop) > 0:
        print("Dropping the following samples:")
        print(samples_to_drop)
    df = pd.DataFrame(columns=[*adata_identity.var_names, *obs_to_keep])

    adata_identity.obs[sample_key] = adata_identity.obs[sample_key].astype("category")
    for i, sample in enumerate(adata_identity.obs[sample_key].cat.categories):
        print(f"\tProcessing sample {i+1} out of {len(adata_identity.obs[sample_key].cat.categories)}...", end="\r")
        if sample not in samples_to_drop:
            adata_sample = adata_identity[adata_identity.obs[sample_key] == sample]
            # create replicates for each sample
            indices = list(adata_sample.obs_names)
            random.seed(12345)
            random.shuffle(indices)
            indices = np.array_split(np.array(indices), replicates_per_sample)
            for i, rep_idx in enumerate(indices):
                adata_replicate = adata_sample[rep_idx]
                # specify how to aggregate: sum gene expression for each gene for each sample and also keep the condition information
                agg_dict = {gene: aggr_method for gene in adata_replicate.var_names}
                for obs in obs_to_keep:
                    agg_dict[obs] = "first"
                # create a df with all genes, sample and group info
                df_sample = pd.DataFrame(adata_replicate.X.A)
                df_sample.index = adata_replicate.obs_names
                df_sample.columns = adata_replicate.var_names
                df_sample = df_sample.join(adata_replicate.obs[obs_to_keep])
                # aggregate
                df_sample = df_sample.groupby(sample_key).agg(agg_dict)
                df_sample[sample_key] = sample
                df.loc[f"sample_{sample}_{i}"] = df_sample.loc[sample]
    print("\n")
    # create AnnData object from the df
    adata_identity = sc.AnnData(
        df[adata_identity.var_names], obs=df.drop(columns=adata_identity.var_names)
    )
    return adata_identity




##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################

def get_diff_exprs_edgeR(
    adata=None, 
    groupby=None, # groups/condtions to test (e.g stage, genotype, ...)
    groups_restrict=None, #restricht test to gives cell type or cluster
    restrict_to=None, #identity of cell type the should be restricted to. e.g Beta
    layer='raw_counts', 
    filter_ambient_genes=True, 
    rank_genes_groups_key=None, # rank genes group key with markers for groups_restrict
    get_marker=False, # run to rank_genes_groups to identify markers
    min_gene_score=0, # min score a cluster marker should have to be cluster-specific
    min_cluster_size = 100, 
    min_frac_cells = 0.05,
    sample_key=None, # key for samples/replicates
    additional_variables=[],  # which metadata to keep, e.g. gender, age, etc.
    replicates_per_sample=3, # number of pseudoreplicates/sample
    min_cell_per_sample=30,
    aggr_method='sum',
    plot=True,
    return_results='dict' # or 'top_table'
):
    # copy adata
    adata_temp = adata.copy()
    
    # set selected layer to .X
    if layer is not None:
        adata_temp.X = adata_temp.layers[layer].copy()

    # create results dict and add parametes
    results = dict()
    results['method'] = 'edgeR_pseudobulks'
    results['groupby'] = groupby
    results['groupby_categories'] = []
    results['groups_restrict'] = groups_restrict
    results['groups_restrict_categories'] = list(adata_temp.obs[groups_restrict].cat.categories)
    results['restrict_to'] = restrict_to
    results['layer'] = layer
    results['min_cluster_size'] = min_cluster_size
    results['min_frac_cells'] = min_frac_cells
    
    # check if cluster of interest (restrict_to) has enough cells
    if adata_temp.obs[groups_restrict].value_counts()[restrict_to] < min_cluster_size:
        #print('Group has less than ' + str(min_cluster_size) + ' cells.')
        raise ValueError('Group has less than ' + str(min_cluster_size) + ' cells.') 
    
    # check if key for rank genes groups for the group containing the cluster of interest (groups_restrict) is provided -> rank_genes_group if not
    if (rank_genes_groups_key == None) & (get_marker):
        sc.tl.rank_genes_groups(adata_temp, groupby=groups_restrict)
        rank_genes_groups_key = 'rank_genes_groups'
    
    # subset adata to group provided in restrict_to
    if restrict_to == None:
        adata_temp_test = adata_temp.copy()
    else:
        adata_temp_test = adata_temp[adata_temp.obs[groups_restrict].isin([restrict_to])].copy()
    
    groupby_categories = list(adata_temp_test.obs[groupby].cat.categories)
    results['groupby_categories'] = groupby_categories
    
    groupby_colors = list(adata_temp_test.uns[groupby + '_colors'])
    results['groupby_colors'] = groupby_colors
    
    # filter genes expressed in few cells
    sc.pp.filter_genes(adata_temp_test, min_cells=adata_temp_test.shape[0]*min_frac_cells)
    
    # filter ambient genes
    if filter_ambient_genes:
        if rank_genes_groups_key == None:
            ambi_genes_remove = list(adata_temp.var_names[adata_temp.var_names.isin(list(adata_temp[:,adata_temp.var['is_ambient'] == True].var_names))])
            adata_temp_test = adata_temp_test[:,~adata_temp_test.var_names.isin(ambi_genes_remove)]
            print('\nRemoving ambient genes from analysis: ', ambi_genes_remove)
            results['ambient_genes_removed'] = ambi_genes_remove
        else:
            ambi_genes = list(adata_temp.var_names[adata_temp.var_names.isin(list(adata_temp[:,adata_temp.var['is_ambient'] == True].var_names))])
            marker_genes = list(adata_temp.uns[rank_genes_groups_key]['names'][restrict_to][adata_temp.uns[rank_genes_groups_key]['scores'][restrict_to] > min_gene_score])
            ambi_genes_remove = list(set(ambi_genes).difference(set(marker_genes)))
            adata_temp_test = adata_temp_test[:,~adata_temp_test.var_names.isin(ambi_genes_remove)]
            print('\nRemoving ambient genes from analysis: ', ambi_genes_remove)
            print('\nKeeping group-specific ambient genes: ', set(ambi_genes).difference(set(ambi_genes_remove)),'\n')
            results['ambient_genes_removed'] = ambi_genes_remove
            results['ambient_genes_kept'] = list(set(ambi_genes).difference(set(ambi_genes_remove)))
    
    results['background_genes'] = list(adata_temp_test.var_names)
    
    # generate pseudobulk
    print('\nGenerating pseudobulks...')
    obs_to_keep = [groupby,groups_restrict] + [sample_key] + additional_variables
    adata_temp_test_bulk = generate_pseudobulk(
        adata=adata_temp_test,
        groupby=groupby, # groups/condtions to test (e.g stage, genotype, ...)
        sample_key=sample_key, # key for samples/replicates
        identity=restrict_to, # cluster/cell type to test in
        identity_key=groups_restrict, # key for clusters/cell types
        obs_to_keep=obs_to_keep,  # which metadata to keep, e.g. gender, age, etc.
        replicates_per_sample=replicates_per_sample, # number of pseudoreplicates/sample
        min_cell_per_sample=min_cell_per_sample,
        aggr_method=aggr_method
    )
    
    #adata_temp_test_bulk.obs["lib_size"] = np.sum(adata_temp_test_bulk.X, axis=1)
    #adata_temp_test_bulk.obs["log_lib_size"] = np.log(adata_temp_test_bulk.obs["lib_size"])
    
    results['n_genes'] = adata_temp_test_bulk.shape[1]
    results['n_cells'] = adata_temp_test.shape[0]
    results['n_pseudobulks'] = adata_temp_test_bulk.shape[0]
    
    # run edgeR
    print('\nRunning edgeR...')
    top_table = run_edgeR(adata_temp_test_bulk, groupby=groupby, ident_1=groupby_categories[0], ident_2=groupby_categories[1], restrict_to=restrict_to, groups_restrict=groups_restrict, additional_variables=additional_variables)
    
    if return_results == 'dict':
        # convert results
        print('\nConverting results...')
        results = edgeR_to_results(top_table, 
                                   results_dict=results,
                                   ident_1=groupby_categories[0],
                                   ident_2=groupby_categories[1],
                                   ident_1_color=groupby_colors[0],
                                   ident_2_color=groupby_colors[1],
                                   plot=plot,
                                   plot_logfc_limit = 10,
                                   log_pvals_adj_limit = 300,
                                   z_logfc_cut_off=0.5,
                                   z_pval_cut_off=0.25)
    
    del adata_temp
    del adata_temp_test
    del adata_temp_test_bulk
    
    gc.collect()
    
    if return_results == 'dict':
        return results
    elif return_results == 'top_table':
        return top_table

##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################

def run_edgeR(adata, layer=None, groupby=None, ident_1=None, ident_2=None, restrict_to=None, groups_restrict=None, additional_variables=None, min_cells=None, n_core=64, max_memory=128):
    '''
    adata: adata object to normalize
    layer: layer to use for normalization. Default = None -> use .X
    '''
    
    import rpy2
    import rpy2.robjects as ro
    import gc

       
    print('DE analysis with edgeR:')
    # load packages
    ro.globalenv['n_core'] = n_core
    ro.globalenv['max_memory'] = max_memory
    ro.r('''
    # Packages
    library(edgeR)
    library(SingleCellExperiment)

    # Parallelization
    library(BiocParallel)
    register(MulticoreParam(n_core, progressbar = TRUE))

    library(future)
    plan("multicore", workers = n_core)
    options(future.globals.maxSize = max_memory * 1024^2)
    plan()
    ''')
    # transfer data
    print('\tTransfer data...')
    ro.globalenv['adata'] = adata
    ro.globalenv['groupby'] = groupby
    ro.globalenv['ident_1'] = ident_1
    ro.globalenv['ident_2'] = ident_2
    ro.globalenv['additional_variables'] = additional_variables
    ro.globalenv['groups_restrict'] = groups_restrict
    if restrict_to is not None:
        ro.globalenv['restrict_to'] = restrict_to
        
     
    # perform analysis
    print('\tPerform differential gene expression analysis with edgeR...')
    ro.r('''
    # edgeR
    # replace characters
    colData(adata)[groupby] <- gsub("[ -]", "_", get(groupby,as.data.frame(colData(adata))))
    ident_1 <- gsub("[ -]", "_", ident_1)
    ident_2 <- gsub("[ -]", "_", ident_2)
    
    # create an edgeR object with counts and grouping factor
    y <- DGEList(assay(adata, "X"), group = get(groupby, as.data.frame(colData(adata))))
    
    # filter out genes with low counts
    #print("\t\tDimensions before subsetting:")
    #print(paste0("\t\t",dim(y)))
    #print("")
    keep <- filterByExpr(y)
    y <- y[keep, , keep.lib.sizes=FALSE]
    #print("\t\tDimensions after subsetting:")
    #print(paste0("\t\t",dim(y)))
    #print("")
    
    # normalize
    y <- calcNormFactors(y)
    
    # create a vector that is concatentation of condition and cell type that we will later use with contrasts
    if (exists('restrict_to')){
        group <- paste0(get(groupby, as.data.frame(colData(adata))), ".", get(groups_restrict, as.data.frame(colData(adata))))
    } else {
        group <- get(groupby, as.data.frame(colData(adata)))
    }
    
    
    # create a design matrix
    if (length(additional_variables) > 0){
        additional_terms <- paste(unlist(additional_variables), collapse=" + ")
        design <- model.matrix(formula(paste("~ 0",additional_terms,"group", sep=" + ")))
    } else {
        design <- model.matrix(~ 0 + group)
    }

    # estimate dispersion
    y <- estimateDisp(y, design = design)
    # fit the model
    fit <- glmQLFit(y, design)
    
    # make contrasts
    if (exists('restrict_to')){
        myContrast <- makeContrasts(paste0("group",ident_1,".",restrict_to,"-group",ident_2,".",restrict_to), levels = y$design)
    } else {
        myContrast <- makeContrasts(paste0("group",ident_1,"-group",ident_2), levels = y$design)
    }
    qlf <- glmQLFTest(fit, contrast=myContrast)
    
    # get all of the DE genes and calculate Benjamini-Hochberg adjusted FDR
    results <- topTags(qlf, n = Inf)
    results <- results$table
    ''')   
    
    # transfer data
    print('\tTransfer data...')
    
    # add to andata.obs
    results = ro.globalenv['results']
    with (ro.default_converter + pandas2ri.converter).context():
        results = ro.conversion.get_conversion().rpy2py(results)
    
    # delete
    ro.r('''
    rm(list = ls())
    gc()
    ''')

    gc.collect()
    
    return results



##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################

def edgeR_to_results(top_table, 
                     results_dict=dict(),
                     ident_1=None,
                     ident_2=None,
                     ident_1_color='#1f77b4',
                     ident_2_color='#ff7f0e',
                     plot=True,
                     plot_logfc_limit = 10,
                     log_pvals_adj_limit = 300,
                     z_logfc_cut_off=0.5,
                     z_pval_cut_off=0.25
):
    results=results_dict
    names=list(top_table.index)
    logfc=np.array(top_table['logFC'], dtype='float64')
    logexprs=np.array(top_table['logCPM'], dtype='float64')
    pvals_adj=np.array(top_table['FDR'], dtype='float64')
    log_pvals_adj = -np.log10(pvals_adj)
    log_pvals_adj[log_pvals_adj > log_pvals_adj_limit] = log_pvals_adj_limit
    logfc_limit = logfc.copy()
    logfc_limit[logfc_limit > plot_logfc_limit] = plot_logfc_limit
    logfc_limit[logfc_limit < -plot_logfc_limit] = -plot_logfc_limit

    table={'names': names, 'logfc': logfc, 'logexprs': logexprs, 'pvals_adj': pvals_adj, 'log_pvals_adj': log_pvals_adj, 'logfc_limit': logfc_limit}
    table = pd.DataFrame(data=table)
    table = table.sort_values(by=['pvals_adj'], ascending=True)
    table = table.sort_values(by=['logfc'], ascending=True)
    results[ident_1] = table #.loc[(abs(table['logfc']) >= min_logfc) & (table['pvals_adj'] <= max_pval),:]

    table={'names': names, 'logfc': -logfc, 'logexprs': logexprs, 'pvals_adj': pvals_adj, 'log_pvals_adj': log_pvals_adj, 'logfc_limit': -logfc_limit}
    table = pd.DataFrame(data=table)
    table = table.sort_values(by=['pvals_adj'], ascending=True)
    table = table.sort_values(by=['logfc'], ascending=True)
    results[ident_2] = table

    # find cut offs
    # To DO:
    # * avoid error when cut-off cannot be found. e.g. all p-val == 1. 
    # * set pval_cut_off to 0.05 if larger cut off is found  
    logfc_cut_off = round(min(abs(results[ident_1]['logfc'])[stats.zscore(abs(results[ident_1]['logfc'])) > z_logfc_cut_off]),1) 
    pval_cut_off = round(min(results[ident_1]['log_pvals_adj'][stats.zscore(results[ident_1]['log_pvals_adj']) > z_pval_cut_off]),0)

    results['logfc_cut_off'] = logfc_cut_off
    results['pval_cut_off'] = pval_cut_off

    if plot:

        n_diff_logfc = sum(abs(results[ident_1]['logfc']) > logfc_cut_off)
        n_up_logfc = sum(results[ident_1]['logfc'] > logfc_cut_off)
        n_down_logfc = sum(results[ident_1]['logfc'] < -logfc_cut_off)

        with rc_context({'figure.figsize': (8, 2)}):
            sb.distplot(results[ident_1]['logfc'], kde=True, bins=100).set_xlabel('$log_2$ Fold Change')
            plt.axvline(logfc_cut_off, 0, 1)
            plt.axvline(-logfc_cut_off, 0, 1)
            plt.annotate('Down-regulated\n' + str(n_down_logfc), xy=(0.02, 0.92), xycoords='axes fraction', va="top", ha="left")
            plt.annotate('Up-regulated\n' + str(n_up_logfc), xy=(0.98, 0.92), xycoords='axes fraction', va="top", ha="right")
            plt.title(label='$log_2$ Fold Change (' + str(n_diff_logfc) + ' genes passing threshold of ' + str(logfc_cut_off) + ')', fontweight='bold')
            plt.show()

        #############################################################################################################
        #############################################################################################################

        n_diff_pval = sum(abs(results[ident_1]['log_pvals_adj']) > pval_cut_off)
        n_up_pval = sum(results[ident_1]['log_pvals_adj'] > pval_cut_off)
        n_down_pval = sum(results[ident_1]['log_pvals_adj'] < -pval_cut_off)

        with rc_context({'figure.figsize': (8, 2)}):
            sb.distplot(results[ident_1]['log_pvals_adj'], kde=True, bins=100).set_xlabel('$-log_{10}$ Adjusted p-Value')
            plt.axvline(pval_cut_off, 0, 1)
            plt.title(label='$-log_{10}$ Adjusted p-Value (' + str(n_diff_pval) + ' genes passing threshold of ' + str(pval_cut_off) + ')', fontweight='bold')
            plt.show()

        #############################################################################################################
        #############################################################################################################

        min_logfc = logfc_cut_off
        max_pval = 10**-pval_cut_off
        group_order = (0,1)
        y_max_ext_factor=1.13
        x_ext_factor=0.3
        x_max_ext_factor=1.1
        x_min_ext_factor=1.1
        fig_size=(7,6)
                
        results[ident_1].loc[:,'color'] = '#000000'
        results[ident_1].loc[(results[ident_1]['logfc'] > 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_1_color
        results[ident_1].loc[(results[ident_1]['logfc'] < 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_2_color

        n_diff = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (abs(results[ident_1]['logfc']) > logfc_cut_off))
        n_up = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] > logfc_cut_off))
        n_down = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] < -logfc_cut_off))

        fig, ax = plt.subplots(1, figsize=fig_size)


        # Make x & y axis longer to make gene name plotting easier
        y_max = max(results[ident_1]['log_pvals_adj'])*y_max_ext_factor
        x_ext = (max(results[ident_1]['logfc_limit']) - min(results[ident_1]['logfc_limit'])) * x_ext_factor
        x_max = max(results[ident_1]['logfc_limit']) + x_ext
        x_min = min(results[ident_1]['logfc_limit']) - x_ext

        x_max = max(results[ident_1]['logfc_limit']) * x_max_ext_factor
        x_min = min(results[ident_1]['logfc_limit']) * x_min_ext_factor

        ax.set_ylim((-1,y_max))
        ax.set_xlim((x_min,x_max))

        # normalize colormap
        vcenter = 0
        vmin, vmax = results[ident_1]['log_pvals_adj'].min(), results[ident_1]['log_pvals_adj'].max()
        #normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
        normalize = plt.Normalize(vmin, vmax)
        colormap = cm.RdBu_r

        # Scatter plot
        sb.scatterplot(y='log_pvals_adj', x='logfc_limit',
                        color='#000000',s=20,
                        linewidth=0,
                        data=results[ident_1])
        sb.scatterplot(y='log_pvals_adj', x='logfc_limit',
                        color='#cccccc',s=10,
                        linewidth=0,
                        data=results[ident_1], ax=ax)
        y = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'log_pvals_adj']
        x = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'logfc_limit']
        c = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color']
        
        sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)
        sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)

        # annotation
        ax.annotate('Down-regulated\n' + str(n_down), xy=(0.02, 0.98), xycoords='axes fraction', va="top", ha="left")
        ax.annotate('Up-regulated\n' + str(n_up), xy=(0.98, 0.98), xycoords='axes fraction', va="top", ha="right")
        ax.annotate(str(ident_2), xy=(0.02, 0.02), xycoords='axes fraction', va="bottom", ha="left")
        ax.annotate(str(ident_1), xy=(0.98, 0.02), xycoords='axes fraction', va="bottom", ha="right")

        # Lines
        plt.axvline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        plt.axvline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")


        # title & axis labels
        title = 'Combined p-value & fold change threshold\n('  + str(n_diff) + ' genes passing thresholds of ' + str(logfc_cut_off) + ' and ' + str(pval_cut_off) + ')'
        ax.set_title(title, fontweight='bold')
        ax.set_ylabel('$-log_{10}$ Adjusted p-Value')
        ax.set_xlabel('$log_2$ Fold Change')

        plt.show()

        

        #############################################################################################################
        #############################################################################################################

        min_logfc = logfc_cut_off
        max_pval = 10**-pval_cut_off
        group_order = (0,1)
        y_max_ext_factor=1.13
        x_ext_factor=0.3
        x_max_ext_factor=1.1
        x_min_ext_factor=1.1
        fig_size=(7,6)
                
        results[ident_1].loc[:,'color'] = '#000000'
        results[ident_1].loc[(results[ident_1]['logfc'] > 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_1_color
        results[ident_1].loc[(results[ident_1]['logfc'] < 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_2_color

        n_diff = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (abs(results[ident_1]['logfc']) > logfc_cut_off))
        n_up = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] > logfc_cut_off))
        n_down = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] < -logfc_cut_off))

        fig, ax = plt.subplots(1, figsize=fig_size)


        # Make x & y axis longer to make gene name plotting easier
        y_max = max(results[ident_1]['log_pvals_adj'])*y_max_ext_factor
#         x_ext = (max(results[ident_1]['logexprs']) - min(results[ident_1]['logexprs'])) * x_ext_factor
#         x_max = max(results[ident_1]['logexprs']) + x_ext
#         x_min = min(results[ident_1]['logexprs']) - x_ext

#         x_max = max(results[ident_1]['logexprs']) * x_max_ext_factor
#         x_min = min(results[ident_1]['logexprs']) * x_min_ext_factor

        ax.set_ylim((-1,y_max))
#         ax.set_xlim((x_min,x_max))

        # normalize colormap
        vcenter = 0
        vmin, vmax = results[ident_1]['log_pvals_adj'].min(), results[ident_1]['log_pvals_adj'].max()
        #normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
        normalize = plt.Normalize(vmin, vmax)
        colormap = cm.RdBu_r

        # Scatter plot
        sb.scatterplot(y='log_pvals_adj', x='logexprs',
                        color='#000000',s=20,
                        linewidth=0,
                        data=results[ident_1])
        sb.scatterplot(y='log_pvals_adj', x='logexprs',
                        color='#cccccc',s=10,
                        linewidth=0,
                        data=results[ident_1], ax=ax)
        y = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'log_pvals_adj']
        x = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'logexprs']
        c = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color']
        
        sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)
        sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)

        # annotation
        #ax.annotate('Down-regulated\n' + str(n_down), xy=(0.02, 0.98), xycoords='axes fraction', va="top", ha="left")
        ax.annotate('Up-regulated\n' + str(n_up) + '\nDown-regulated\n' + str(n_down), xy=(0.98, 0.98), xycoords='axes fraction', va="top", ha="right")
        #ax.annotate(str(ident_2), xy=(0.02, 0.02), xycoords='axes fraction', va="bottom", ha="left")
        #ax.annotate(str(ident_1), xy=(0.98, 0.02), xycoords='axes fraction', va="bottom", ha="right")

        # Lines
        #plt.axvline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        #plt.axvline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")


        # title & axis labels
        title = 'Combined p-value & fold change threshold\n('  + str(n_diff) + ' genes passing thresholds of ' + str(logfc_cut_off) + ' and ' + str(pval_cut_off) + ')'
        ax.set_title(title, fontweight='bold')
        ax.set_ylabel('$-log_{10}$ Adjusted p-Value')
        ax.set_xlabel('$log_2$ Expression')

        plt.show()
        

        #############################################################################################################
        #############################################################################################################

        min_logfc = logfc_cut_off
        max_pval = 10**-pval_cut_off
        group_order = (0,1)
        x_max_ext_factor=1.13
        y_ext_factor=0.3
        y_max_ext_factor=1.1
        y_min_ext_factor=1.1
        fig_size=(7,6)
                
        results[ident_1].loc[:,'color'] = '#000000'
        results[ident_1].loc[(results[ident_1]['logfc'] > 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_1_color
        results[ident_1].loc[(results[ident_1]['logfc'] < 0) & (abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color'] = ident_2_color

        n_diff = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (abs(results[ident_1]['logfc']) > logfc_cut_off))
        n_up = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] > logfc_cut_off))
        n_down = sum((abs(results[ident_1]['log_pvals_adj']) > pval_cut_off) & (results[ident_1]['logfc'] < -logfc_cut_off))

        fig, ax = plt.subplots(1, figsize=fig_size)


        # Make x & y axis longer to make gene name plotting easier
        x_max = max(results[ident_1]['logexprs'])*x_max_ext_factor
        y_ext = (max(results[ident_1]['logfc_limit']) - min(results[ident_1]['logfc_limit'])) * y_ext_factor
        y_max = max(results[ident_1]['logfc_limit']) + y_ext
        y_min = min(results[ident_1]['logfc_limit']) - y_ext

        y_max = max(results[ident_1]['logfc_limit']) * y_max_ext_factor
        y_min = min(results[ident_1]['logfc_limit']) * y_min_ext_factor

#         ax.set_xlim((-1,x_max))
        ax.set_ylim((y_min,y_max))

        # normalize colormap
        vcenter = 0
        vmin, vmax = results[ident_1]['logfc_limit'].min(), results[ident_1]['logfc_limit'].max()
        #normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
        normalize = plt.Normalize(vmin, vmax)
        colormap = cm.RdBu_r

        # Scatter plot
        sb.scatterplot(y='logfc_limit', x='logexprs',
                        color='#000000',s=20,
                        linewidth=0,
                        data=results[ident_1])
        sb.scatterplot(y='logfc_limit', x='logexprs',
                        color='#cccccc',s=10,
                        linewidth=0,
                        data=results[ident_1], ax=ax)
        y = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'logfc_limit']
        x = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'logexprs']
        c = results[ident_1].loc[(abs(results[ident_1]['logfc']) >= min_logfc) & (results[ident_1]['pvals_adj'] <= max_pval),'color']
        
        sb.scatterplot(y=y, x=x, color='#ffffff', s=10, alpha=1,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)
        sb.scatterplot(y=y, x=x, c=c, s=10, alpha=0.5,
                        norm=normalize,
                        cmap=colormap,
                        linewidth=0,
                        ax=ax)

        # annotation
        ax.annotate(str(ident_1), xy=(0.02, 0.98), xycoords='axes fraction', va="top", ha="left")
        ax.annotate('Up-regulated\n' + str(n_up), xy=(0.98, 0.98), xycoords='axes fraction', va="top", ha="right")
        ax.annotate(str(ident_2), xy=(0.02, 0.02), xycoords='axes fraction', va="bottom", ha="left")
        ax.annotate('Down-regulated\n' + str(n_down), xy=(0.98, 0.02), xycoords='axes fraction', va="bottom", ha="right")

        # Lines
        plt.axhline(min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        plt.axhline(-min_logfc, 0, 1, color='#666666', lw=1).set_linestyle("--")
        #plt.axhline(-np.log10(max_pval), 0, 1, color='#666666', lw=1).set_linestyle("--")


        # title & axis labels
        title = 'Combined p-value & fold change threshold\n('  + str(n_diff) + ' genes passing thresholds of ' + str(logfc_cut_off) + ' and ' + str(pval_cut_off) + ')'
        ax.set_title(title, fontweight='bold')
        ax.set_ylabel('$log_2$ Fold Change')
        ax.set_xlabel('$log_2$ Expression')

        plt.show()
        
    return results
    




##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################


def dot_plot_edger(
    adata,
    results_dict = None,
    keys = None,
    layer = 'sct_logcounts',
    cmap='RdBu_r'
):
    results = results_dict[results_dict['groupby_categories'][0]]

    # copy adata
    adata_temp = adata.copy()

    # set selected layer to .X
    if layer is not None:
        adata_temp.X = adata_temp.layers[layer].copy()

    # subset adata to group provided in restrict_to
    restrict_to = results_dict['restrict_to']
    groups_restrict = results_dict['groups_restrict']

    if restrict_to == None:
        adata_temp_test = adata_temp.copy()
    else:
        adata_temp_test = adata_temp[adata_temp.obs[groups_restrict].isin([restrict_to])].copy()

    # filter genes expressed in few cells
    adata_temp_test = adata_temp_test[:,results['names']]

    # filter keys
    de_genes = results['names'][(abs(results['logfc']) >= results_dict['logfc_cut_off']) & (results['log_pvals_adj'] >= results_dict['pval_cut_off'])]
    keys = [key for key in keys if key in list(de_genes)]

    # plot data
    ## var group pos
    if len(keys) > 0:
            var_group_positions=[(0,results[(results.names.isin(keys)) & (results['logfc'] < 0)].shape[0]-1),(results[(results.names.isin(keys)) & (results['logfc'] < 0)].shape[0],results[(results.names.isin(keys))].shape[0]-1)]
    
    ## colors
    index = pd.Index(results_dict['groupby_categories'], name='groupby')
    color_df = pd.DataFrame([results['logfc'],
                             -results['logfc']],
                           index=index).T
    color_df.index = results['names']
    color_df = color_df.T
    color_df = color_df.loc[:,keys]
    limit = abs(color_df).max().max()
    
    ## plot
    sc.pl.DotPlot(adata_temp_test, 
                  var_names=keys, 
                  groupby=results_dict['groupby'], 
                  dot_color_df=color_df, 
                  var_group_positions=var_group_positions, 
                  var_group_labels=results_dict['groupby_categories'][::-1],
                  vmin=-limit, 
                  vmax=limit, 
                  cmap=cmap).style(color_on='square', 
                                   dot_edge_lw=1, 
                                   grid=True, 
                                   dot_edge_color=None).legend(colorbar_title='log$_2$ Fold Change').show()

    del adata_temp_test
    del results
    gc.collect()


##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################
##################################################################################################################################################################################



## load Adata

In [None]:
EECs = ['Goblet/EEC prog. (early)',
'K-cell (Gip+)',
'EC (mature)',
'EC (immature)',
'EEC (Peptide/immature)' ,
'L/I-cell (Glp1+/Cck+)' ,
'EEC prog. (mid)',
'EC prog. (late)',
'D-cell (Sst+)',
'EEC prog. (late/Peptide)',
'EC 2' ,
'X-cell (Ghrl+)']
Paneth = ['Paneth','Paneth prog.', 'Goblet-Paneth-like', 'Goblet-Paneth-like(cycling)']
Progenitors = ['Goblet/EEC prog. (early)','Paneth prog.', 'Tuft prog.','Tuft prog. 2']

In [None]:
adata = sc.read_h5ad('joint_diseased_healthy_with_layers_metadata_corrected_anno_updated.h5ad')

In [None]:
adata.obs['enrichment proportion'] =adata.obs['enrichment proportion'].astype('category')

In [None]:
adata.obs['enrichment proportion'].cat.categories

## healthy only

In [None]:
adata = adata[adata.obs['atlas'].isin(['reference'])].copy()

In [None]:
gc.collect()

In [None]:
adata.obs['cell_type_annotation_lv1'].value_counts()

## EECs

In [None]:
adata_EEC = adata[adata.obs['cell_type_annotation_lv1'].isin(EECs)].copy()

In [None]:
adata_EEC.X = adata_EEC.layers['sct_logcounts']

In [None]:
adata_EEC

In [None]:
del adata
gc.collect()

In [None]:
sc.pp.neighbors(adata_EEC, use_rep='X_scarches_emb')
sc.tl.leiden(adata_EEC, resolution=1.5)

In [None]:
sc.tl.paga(adata_EEC, groups='cell_type_annotation_lv1')

In [None]:
sc.pl.paga(adata_EEC,  fontsize=5, save = 'paga_healthy_EEC_subs.png')#, fontoutline=True, threshold=0.05, max_edge_width=3, min_edge_width=0.01, node_size_scale=3,

In [None]:
sc.tl.umap(adata_EEC, init_pos='paga')

In [None]:
sc.pl.umap(adata_EEC, color=['pretty name','cell_type_annotation_lv1','leiden'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap,wspace=1, frameon= True)

### run DElegate

In [None]:
results = run_DElegate_findMarkers(adata_EEC, 
                        layer = 'raw_counts', 
                        group_column = 'leiden', 
                        replicate_column = None, 
                        method = "edger", 
                        min_rate = 0.05,
                        min_fc = 1,
                        verbosity = 1, 
                        n_core = 20, 
                        max_memory = 4)

In [None]:
for cluster in set(adata_EEC.obs['leiden']):
    print(cluster)    
    sc.pl.umap(adata_EEC, color=list(results.loc[results['group1']==cluster.replace(" ","_"),"feature"][0:10]), layer='log_dca_counts', size=10, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=5)

In [None]:
gc.collect()
de_results = get_diff_exprs_DElegate(
    adata = adata_EEC, 
    groupby ='leiden', # groups/condtions to test (e.g stage, sample, ...)
    #groups_restrict = 'all_cells', #restrict test to gives cell type or cluster
    #restrict_to = 'all', #identity of cell type the should be restricted to. e.g Beta
    layer = 'log_dca_counts',
#     group_column = None, 
#     replicate_column = None, 
    method = "edger", 
    filter_ambient_genes = False, 
    rank_genes_groups_key = None, # rank genes group key with markers for groups_restrict
    get_marker = False, # run to rank_genes_groups to identify markers
    min_gene_score = 0, # min score a cluster marker should have to be cluster-specific
    min_cluster_size = 100, 
    min_frac_cells = 0.05,
    sample_key = 'sample', # key for samples/replicates
    #additional_variables=[],  # which metadata to keep, e.g. gender, age, etc.
    #replicates_per_sample=3, # number of pseudoreplicates/sample
    #min_cell_per_sample=30,
    #aggr_method='sum',
    plot = True,
    return_results = 'dict' # or 'top_table'
)

In [None]:
pd.set_option('display.max_colwidth', 20)
term_key_words = 'golgi|stress|ER|localization|autopha|cilium|cytoskeleton|Wnt|PCP|planar|polarity|mTOR|cell-cell|adhesion|junction|integrin|oxidative phos|mitochondria|electron|translation|ribosome|microtubule|signal|insulin|Insulin|IGF|cycle|mito|hormone|peptide|secretion|transcription'
results_enrichment = get_go_terms(results_dict=de_results, min_score=None, max_pval=10**-de_results['pval_cut_off'], min_logfc=de_results['logfc_cut_off'], selection_string=term_key_words, plot_select=True, plot_top=True, plot_all=False, n_select=20, width_factor=30)

### rank genes

In [None]:
adata_EEC

### scran counts (already log)

In [None]:
sc.tl.rank_genes_groups(adata_EEC, 'leiden', method='wilcoxon', layer = 'scran_counts', use_raw=False)

# to visualize the results

sc.pl.rank_genes_groups(adata_EEC)

In [None]:
sc.tl.dendrogram(adata_EEC, groupby='leiden')

In [None]:
del adata_EEC.raw

In [None]:
adata_EEC.X= adata_EEC.layers['scran_counts']

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_EEC, n_genes=5, key="rank_genes_groups", groupby="leiden")


### sct counts

In [None]:
sc.tl.rank_genes_groups(adata_EEC, 'leiden', method='wilcoxon', layer = 'sct_logcounts', use_raw=False)

# to visualize the results

sc.pl.rank_genes_groups(adata_EEC)

In [None]:
adata_EEC.X= adata_EEC.layers['sct_logcounts']

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_EEC, n_genes=5, key="rank_genes_groups", groupby="leiden")


### metadata

In [None]:
adata_EEC.obs['doublet_calls'] = adata_EEC.obs['doublet_calls'].astype('category')

In [None]:
adata_EEC.uns['doublet_calls_colors'] = np.array([mpl.colors.to_hex(color, keep_alpha=True) for color in mymap(np.linspace(0,2,8))])

In [None]:
sc.pl.umap(adata_EEC, color=['Project','enriched','phase','kit','line','strain', 'doublet_calls', 'enrichment proportion'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=4, legend_fontsize=8, color_map=mymap,wspace = 0.4, save = 'EEC_subs_healthy_metadata.png', frameon= True)

In [None]:
plot_composition(adata_EEC, y_key='doublet_calls', x_key='leiden', x_rotation=90)

In [None]:
adata_EEC = adata_EEC[~adata_EEC.obs['leiden'].isin(['15','16'])].copy()

In [None]:
sc.pl.umap(adata_EEC, color=['leiden'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=4, legend_fontsize=8, color_map=mymap,wspace = 0.4)

In [None]:
sc.tl.rank_genes_groups(adata_EEC, 'leiden', method='wilcoxon', layer = 'sct_logcounts', use_raw=False)

In [None]:
adata_EEC.X= adata_EEC.layers['sct_logcounts']

In [None]:
sc.tl.dendrogram(adata_EEC, groupby='leiden')

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_EEC, n_genes=5, key="rank_genes_groups", groupby="leiden")


### TFs

In [None]:
with open('TF_mouse_all.txt', 'r') as file:
    TFs = file.read().splitlines()

In [None]:
ranked_genes = adata_EEC.uns['rank_genes_groups']['names']

In [None]:
ranked_genes.dtype.names

In [None]:
ranked_genes

In [None]:
n = -1
differentially_expressed_tfs = {}
N =10
for group in ranked_genes.dtype.names:
    n+=1
    differentially_expressed_tfs[n] = []
    for i, gene in enumerate(ranked_genes[group]):
        if gene in TFs and i < N:
            if gene not in differentially_expressed_tfs[n]:
                differentially_expressed_tfs[n].append(gene)

print(differentially_expressed_tfs)

In [None]:
# Flatten the list of genes
all_tfs = [gene for genes in differentially_expressed_tfs.values() for gene in genes]
all_tfs = list(set(all_tfs))  # Remove duplicates

In [None]:
np.max(adata_EEC.X)

In [None]:
#sc.pl.rank_genes_groups_dotplot(adata_EEC, var_names = all_tfs)
sc.pl.dotplot(adata_EEC, all_tfs, groupby='leiden',dendrogram=True, layer='sct_logcounts',use_raw=False)

In [None]:
sc.pl.umap(adata_EEC, color=['leiden'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap,legend_loc='on data', frameon= True)

In [None]:
sc.pl.umap(adata_EEC, color=['cell_type_annotation_lv1','leiden', 'doublet_calls'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap,wspace=0.6, frameon= True)

In [None]:
adata_EEC

In [None]:
sc.pl.umap(adata_EEC, color=all_tfs, size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, cmap=mymap, layer='log_dca_counts', save= 'umap_healthy_TFs_EEC.png', frameon= True)

In [None]:
sc.pl.umap(adata_EEC, color=['Ghrl','Sst','Gcg','Gip','Cck','Sct','Tac1','Tph1', 'Spdef','Reg4'],layer= 'log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, cmap=mymap, save = 'EEC_subs_healthy_EECs_hormones_expr.png', frameon= True)

# fine anntoation level 2 EEC

first exclude Goblet signed cells in progenitors
then, recluster and annotate clusters in appropriate resolution

In [None]:
sc.tl.leiden(adata_EEC, resolution=1.5)

In [None]:
sc.tl.umap(adata_EEC, init_pos='paga')

In [None]:
adata_EEC.X =adata_EEC.layers['log_dca_counts']

In [None]:
sc.pl.umap(adata_EEC, color=['Spdef','Neurog3','leiden'], use_raw=False,size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap)

In [None]:
sc.tl.leiden(adata_EEC, restrict_to=('leiden', ['7','5','12','13']), resolution=1.5, key_added='leiden_sub_goblet')

In [None]:
sc.pl.umap(adata_EEC, color=['Spdef','Neurog3','leiden_sub_goblet'], use_raw=False,size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap)

In [None]:

gene_of_interest = 'Spdef' 

with rc_context({'figure.figsize': (6, 4)}):
    sc.pl.violin(adata_EEC, use_raw=False, keys=['Spdef'], groupby='leiden_sub_goblet', rotation=90, show=False)
    plt.title(f'Expression of {gene_of_interest} per Leiden cluster')
    plt.xlabel('Leiden Cluster')
    plt.ylabel('Expression Level')
    plt.show()

In [None]:
Goblet_clusters = ['7-5-12-13,0','7-5-12-13,3','7-5-12-13,4','7-5-12-13,5','7-5-12-13,7','7-5-12-13,10','7-5-12-13,11']
adata_EEC = adata_EEC[~adata_EEC.obs['leiden_sub_goblet'].isin(Goblet_clusters)]
adata_EEC

In [None]:
sc.pp.neighbors(adata_EEC, use_rep='X_scarches_emb')
sc.tl.leiden(adata_EEC, resolution=1.5)

In [None]:
sc.tl.paga(adata_EEC, groups='cell_type_annotation_lv1')

In [None]:
sc.pl.paga(adata_EEC,  fontsize=4)#, fontoutline=True, threshold=0.05, max_edge_width=3, min_edge_width=0.01, node_size_scale=3,

In [None]:
sc.tl.umap(adata_EEC, init_pos='paga')

In [None]:
sc.pl.umap(adata_EEC, color=['pretty name','cell_type_annotation_lv1','leiden'], size=10, add_outline=True,legend_fontsize=9, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap,wspace=1, frameon=True)

In [None]:
sc.pl.umap(adata_EEC, color=['cell_type_annotation_lv1'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap,wspace=0.6, save = 'EEC_subs_healthy_without_GC_anno.png', legend_fontsize = 9, frameon=True)


In [None]:
sc.pl.umap(adata_EEC, color=['leiden'], size=8, add_outline=True, alpha=1, outline_width=(0.3, 0.0), legend_loc='on data', save = 'EEC_subs_healthy_without_GC_leiden.png', legend_fontsize = 16, frameon=True)


In [None]:
del adata_EEC.raw
gc.collect()

In [None]:
sc.tl.rank_genes_groups(adata_EEC, 'leiden', method='wilcoxon', layer = 'sct_logcounts', use_raw=False)

# to visualize the results

sc.pl.rank_genes_groups(adata_EEC)

In [None]:
sc.tl.dendrogram(adata_EEC, groupby='leiden')

In [None]:
adata_EEC.X = adata_EEC.layers['sct_logcounts']

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_EEC, n_genes=5, key="rank_genes_groups", groupby="leiden")


In [None]:
plot_composition(adata_EEC, y_key='pretty name', x_key='leiden', x_rotation=90)

In [None]:
sc.pl.umap(adata_EEC, color=['Project','enriched','phase','kit','line','strain', 'doublet_calls', 'enrichment proportion'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=4, legend_fontsize=8,cmap=mymap, wspace = 0.4, save = 'EEC_subs_without_GCX_metadata.png', frameon=True)

In [None]:
plot_composition(adata_EEC, y_key='enriched', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_EEC, y_key='line', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_EEC, y_key='strain', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_EEC, y_key='enrichment proportion', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_EEC, y_key='Project', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_EEC, y_key='kit', x_key='leiden', x_rotation=90)

## TFs

In [None]:
with open('TF_mouse_all.txt', 'r') as file:
    TFs = file.read().splitlines()

In [None]:
ranked_genes = adata_EEC.uns['rank_genes_groups']['names']

In [None]:
ranked_genes

In [None]:
n = -1
differentially_expressed_tfs = {}
N =10
for group in ranked_genes.dtype.names:
    n+=1
    differentially_expressed_tfs[n] = []
    for i, gene in enumerate(ranked_genes[group]):
        if gene in TFs and i < N:
            if gene not in differentially_expressed_tfs[n]:
                differentially_expressed_tfs[n].append(gene)

print(differentially_expressed_tfs)

In [None]:
# Flatten the list of genes
all_tfs = [gene for genes in differentially_expressed_tfs.values() for gene in genes]
all_tfs = list(set(all_tfs))  # Remove duplicates

In [None]:
#sc.pl.rank_genes_groups_dotplot(adata_EEC, var_names = all_tfs)
sc.pl.dotplot(adata_EEC, all_tfs, groupby='leiden',dendrogram=True, layer='sct_logcounts',use_raw=False)

In [None]:
sc.pl.umap(adata_EEC, color=all_tfs, size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=6, cmap=mymap, layer='log_dca_counts', save= 'umap_healthy_TFs_withoutGC.png', frameon=True)

In [None]:
sc.pl.violin(adata_EEC,groupby='kit',keys=['Hmgn3', 'Glis3', 'Peg3'], rotation=90)

In [None]:
sc.pl.umap(adata_EEC, color=['Ghrl','Sst','Gcg','Gip','Pyy','Cck','Nts','Sct','Tac1','Tph1','Reg4'],layer= 'log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=6, cmap=mymap, save = 'EEC_subs_healthy_EECs_hormones_expr_wo_GC.png', frameon= True)

In [None]:
sc.pl.umap(adata, color=['Gata4', 'Prrx1', 'Sp5', 'Nr1i3', 'Creb3l3','Creb3l2','Ada','Klf6', 'Bex1','Slc2a2','Reg1'],layer= 'log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=6, cmap=mymap)#, save = 'EEC_subs_healthy_EECs_hormones_expr_wo_GC.png', frameon= True)

In [None]:
markergenes = {}
markergenes['Stem'] = ['Lgr5', 'Ascl2', 'Slc12a2', 'Axin2', 'Olfm4', 'Gkn3']
markergenes['Enterocyte (Proximal)'] = ['Gsta1','Rbp2','Adh6a','Apoa4','Reg3a','Creb3l3','Cyp3a13','Cyp2d26','Ms4a10','Ace','Aldh1a1','Rdh7','H2-Q2', 'Hsd17b6','Gstm3','Gda','Apoc3','Gpd1','Fabp1','Slc5a1','Mme','Cox7a1','Gsta4','Lct','Khk','Mttp','Xdh','Sult1b1', 'Treh','Lpgat1','Dhrs1','Cyp2c66','Ephx2','Cyp2c65','Cyp3a25','Slc2a2','Ugdh','Gstm6','Retsat','Acsl5', 'Cyb5r3','Cyb5b','Ckmt1','Aldob','Ckb','Scp2','Prap1']
markergenes['Enterocyte (Distal)'] = ['Tmigd1','Fabp6','Slc51b','Slc51a','Mep1a','Fam151a','Naaladl1','Slc34a2','Plb1','Nudt4','Dpep1','Pmp22','Xpnpep2','Muc3','Neu1','Clec2h','Phgr1','2200002D01Rik','Prss30','Cubn','Plec','Fgf15','Crip1','Krt20','Dhcr24','Myo15b','Amn','Enpep','Anpep','Slc7a9','Ocm','Anxa2','Aoc1','Ceacam20','Arf6','Abcb1a','Xpnpep1','Vnn1','Cndp2','Nostrin','Slc13a1','Aspa','Maf','Myh14']
markergenes['Goblet'] = ['Agr2', 'Fcgbp', 'Tff3', 'Clca1', 'Zg16', 'Tpsg1', 'Muc2', 'Galnt12', 'Atoh1', 'Rep15', 'S100a6', 'Pdia5', 'Klk1', 'Pla2g10', 'Spdef', 'Lrrc26', 'Ccl9', 'Bace2', 'Bcas1', 'Slc12a8', 'Smim14', 'Tspan13', 'Txndc5', 'Creb3l4', 'C1galt1c1', 'Creb3l1', 'Qsox1', 'Guca2a', 'Scin', 'Ern2', 'AW112010', 'Fkbp11', 'Capn9', 'Stard3nl', 'Slc50a1', 'Sdf2l1', 'Galnt7', 'Hpd', 'Ttc39a', 'Tmed3', 'Pdia6', 'Uap1', 'Gcnt3', 'Tnfaip8', 'Dnajc10', 'Ergic1', 'Tsta3', 'Kdelr3', 'Foxa3', 'Tpd52', 'Tmed9', 'Spink4', 'Nans', 'Cmtm7', 'Creld2', 'Tm9sf3', 'Wars', 'Smim6', 'Manf', 'Oit1', 'Tram1', 'Kdelr2', 'Xbp1', 'Serp1', 'Guk1', 'Sh3bgrl3', 'Cmpk1', 'Tmsb10', 'Dap', 'Ostc', 'Ssr4', 'Sec61b', 'Pdia3', 'Gale', 'Klf4', 'Krtcap2', 'Arf4', 'Sep15', 'Ssr2', 'Ramp1', 'Calr', 'Ddost']
markergenes['Paneth'] = ['Defa17', 'Defa22', 'Mptx2', 'Ang4']
markergenes['Enteroendocrine'] = ['Chgb', 'Gfra3', 'Cck', 'Vwa5b2', 'Neurod1', 'Fev', 'Aplp1', 'Scgn', 'Neurog3', 'Resp18', 'Trp53i11', 'Bex2', 'Rph3al', 'Scg5', 'Pcsk1', 'Isl1', 'Maged1', 'Fabp5', 'Celf3', 'Pcsk1n', 'Fam183b', 'Prnp', 'Tac1', 'Gpx3', 'Cplx2', 'Nkx2-2', 'Olfm1', 'Vim', 'Rimbp2', 'Anxa6', 'Scg3', 'Ngfrap1', 'Insm1', 'Gng4', 'Pax6', 'Cnot6l', 'Cacna2d1', 'Tox3', 'Slc39a2', 'Riiad1']
markergenes['Tuft'] = ['Alox5ap', 'Lrmp', 'Hck', 'Avil', 'Rgs13', 'Ltc4s', 'Trpm5', 'Dclk1', 'Spib', 'Fyb', 'Ptpn6', 'Matk', 'Snrnp25', 'Sh2d7', 'Ly6g6f', 'Kctd12', '1810046K07Rik', 'Hpgds', 'Tuba1a', 'Pik3r5', 'Vav1', 'Tspan6', 'Skap2', 'Pygl', 'Ccdc109b', 'Ccdc28b', 'Plcg2', 'Ly6g6d', 'Alox5', 'Pou2f3', 'Gng13', 'Bmx', 'Ptpn18', 'Nebl', 'Limd2', 'Pea15a', 'Tmem176a', 'Smpx', 'Itpr2', 'Il13ra1', 'Siglecf', 'Ffar3', 'Rac2', 'Hmx2', 'Bpgm', 'Inpp5j', 'Ptgs1', 'Aldh2', 'Pik3cg', 'Cd24a', 'Ethe1', 'Inpp5d', 'Krt23', 'Gprc5c', 'Reep5', 'Csk', 'Bcl2l14', 'Tmem141', 'Coprs', 'Tmem176b', '1110007C09Rik', 'Ildr1', 'Galk1', 'Zfp428', 'Rgs2', 'Inpp5b', 'Gnai2', 'Pla2g4a', 'Acot7', 'Rbm38', 'Gga2', 'Myo1b', 'Adh1', 'Bub3', 'Sec14l1', 'Asah1', 'Ppp3ca', 'Agt', 'Gimap1', 'Krt18', 'Pim3', '2210016L21Rik', 'Tmem9', 'Lima1', 'Fam221a', 'Nt5c3', 'Atp2a3', 'Mlip', 'Vdac3', 'Ccdc23', 'Tmem45b', 'Cd47', 'Lect2', 'Pla2g16', 'Mocs2', 'Arpc5', 'Ndufaf3']

In [None]:
sc.pl.umap(adata, color=markergenes['Goblet'],layer= 'log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=6, cmap=mymap)#, save = 'EEC_subs_healthy_EECs_hormones_expr_wo_GC.png', frameon= True)

In [None]:
sc.pl.umap(adata, color=markergenes['Stem'],layer= 'log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=6, cmap=mymap)#, save = 'EEC_subs_healthy_EECs_hormones_expr_wo_GC.png', frameon= True)

In [None]:
sc.pl.umap(adata, color=markergenes['Enterocyte (Proximal)'],layer= 'log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=6, cmap=mymap)#, save = 'EEC_subs_healthy_EECs_hormones_expr_wo_GC.png', frameon= True)

In [None]:
sc.pl.umap(adata, color=markergenes['Enterocyte (Distal)'],layer= 'log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=6, cmap=mymap)#, save = 'EEC_subs_healthy_EECs_hormones_expr_wo_GC.png', frameon= True)

In [None]:
sc.pl.umap(adata, color=markergenes['Paneth'],layer= 'log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=6, cmap=mymap)#, save = 'EEC_subs_healthy_EECs_hormones_expr_wo_GC.png', frameon= True)

## Goblet and EECs

In [None]:
EECs = ['Goblet/EEC prog. (early)', 'Goblet', 'Goblet prog. (late)',
'K-cell (Gip+)',
'EC (mature)',
'EC (immature)',
'EEC (Peptide/immature)' ,
'L/I-cell (Glp1+/Cck+)' ,
'EEC prog. (mid)',
'EC prog. (late)',
'D-cell (Sst+)',
'EEC prog. (late/Peptide)',
'EC 2' ,
'X-cell (Ghrl+)']
Paneth = ['Paneth','Paneth prog.', 'Goblet-Paneth-like', 'Goblet-Paneth-like(cycling)']
Progenitors = ['Goblet/EEC prog. (early)','Paneth prog.', 'Tuft prog.','Tuft prog. 2']

In [None]:
adata = sc.read_h5ad('joint_diseased_healthy_with_layers_metadata_corrected_anno_updated.h5ad')

In [None]:
adata.obs['enrichment proportion'].cat.categories

### healthy only

In [None]:
adata = adata[adata.obs['atlas'].isin(['reference'])].copy()

In [None]:
gc.collect()

In [None]:
adata.obs['cell_type_annotation_lv1'].value_counts()

In [None]:
adata_EEC = adata[adata.obs['cell_type_annotation_lv1'].isin(EECs)].copy()

In [None]:
adata_EEC.X = adata_EEC.layers['sct_logcounts']

In [None]:
adata_EEC

In [None]:
del adata
gc.collect()

In [None]:
sc.pp.neighbors(adata_EEC, use_rep='X_scarches_emb')
sc.tl.leiden(adata_EEC, resolution=1.5)

In [None]:
sc.tl.paga(adata_EEC, groups='cell_type_annotation_lv1')

In [None]:
sc.pl.paga(adata_EEC,  fontsize=5, save = 'paga_healthy_EEC_GC_subs.png')#, fontoutline=True, threshold=0.05, max_edge_width=3, min_edge_width=0.01, node_size_scale=3,

In [None]:
sc.tl.umap(adata_EEC)#, init_pos='paga', min_dist =0.3)

In [None]:
sc.pl.umap(adata_EEC, color=['pretty name','cell_type_annotation_lv1','leiden','doublet_calls'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=2, cmap=mymap,wspace=1, legend_fontsize=9)

### rank genes

In [None]:
adata_EEC

### scran counts (already log)

In [None]:
sc.tl.rank_genes_groups(adata_EEC, 'leiden', method='wilcoxon', layer = 'scran_counts', use_raw=False)

# to visualize the results

sc.pl.rank_genes_groups(adata_EEC)

In [None]:
del adata_EEC.raw

In [None]:
adata_EEC.X= adata_EEC.layers['scran_counts']

In [None]:
sc.tl.dendrogram(adata_EEC, groupby='leiden')

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_EEC, n_genes=5, key="rank_genes_groups", groupby="leiden")


### sct counts

In [None]:
sc.tl.rank_genes_groups(adata_EEC, 'leiden', method='wilcoxon', layer = 'sct_logcounts', use_raw=False)

# to visualize the results

sc.pl.rank_genes_groups(adata_EEC)

In [None]:
adata_EEC.X= adata_EEC.layers['sct_logcounts']

In [None]:
sc.tl.dendrogram(adata_EEC, groupby='leiden')

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_EEC, n_genes=5, key="rank_genes_groups", groupby="leiden")


### metadata

In [None]:
adata_EEC.obs['doublet_calls'] = adata_EEC.obs['doublet_calls'].astype('category')

In [None]:
adata_EEC.uns['doublet_calls_colors'] = np.array([mpl.colors.to_hex(color, keep_alpha=True) for color in mymap(np.linspace(0,2,8))])

In [None]:
sc.pl.umap(adata_EEC, color=['Project','enriched','phase','kit','line','strain', 'doublet_calls', 'enrichment proportion'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=4, legend_fontsize=8, color_map=mymap,wspace = 0.4, save = 'EEC_GC_subs_healthy_metadata.png')

In [None]:
plot_composition(adata_EEC, y_key='doublet_calls', x_key='leiden', x_rotation=90)

In [None]:
adata_EEC = adata_EEC[~adata_EEC.obs['leiden'].isin(['20'])].copy()

In [None]:
sc.pl.umap(adata_EEC, color=['leiden'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=4, legend_fontsize=8, color_map=mymap,wspace = 0.4)

In [None]:
sc.tl.rank_genes_groups(adata_EEC, 'leiden', method='wilcoxon', layer = 'sct_logcounts', use_raw=False)

In [None]:
adata_EEC.X= adata_EEC.layers['sct_logcounts']

In [None]:
sc.tl.dendrogram(adata_EEC, groupby='leiden')

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_EEC, n_genes=5, key="rank_genes_groups", groupby="leiden")


### TFs

In [None]:
with open('TF_mouse_all.txt', 'r') as file:
    TFs = file.read().splitlines()

In [None]:
ranked_genes = adata_EEC.uns['rank_genes_groups']['names']

In [None]:
ranked_genes.dtype.names

In [None]:
ranked_genes

In [None]:
n = -1
differentially_expressed_tfs = {}
N =15
for group in ranked_genes.dtype.names:
    n+=1
    differentially_expressed_tfs[n] = []
    for i, gene in enumerate(ranked_genes[group]):
        if gene in TFs and i < N:
            if gene not in differentially_expressed_tfs[n]:
                differentially_expressed_tfs[n].append(gene)

print(differentially_expressed_tfs)

In [None]:
# Flatten the list of genes
all_tfs = [gene for genes in differentially_expressed_tfs.values() for gene in genes]
all_tfs = list(set(all_tfs))  # Remove duplicates

In [None]:
np.max(adata_EEC.X)

In [None]:
#sc.pl.rank_genes_groups_dotplot(adata_EEC, var_names = all_tfs)
sc.pl.dotplot(adata_EEC, all_tfs, groupby='leiden',dendrogram=True, layer='sct_logcounts',use_raw=False)

In [None]:
sc.pl.umap(adata_EEC, color=['leiden'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap,legend_loc='on data')

In [None]:
sc.pl.umap(adata_EEC, color=['cell_type_annotation_lv1','leiden', 'doublet_calls'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=3, cmap=mymap,wspace=0.6)

In [None]:
adata_EEC

In [None]:
sc.pl.umap(adata_EEC, color=all_tfs, size=5, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, cmap=mymap, layer='log_dca_counts', save= 'umap_healthy_TFs_EEC_GC.png')

In [None]:
genes_m = [gene for gene in adata_EEC.var_names if str(gene).startswith('Mt')]

In [None]:
genes_m

In [None]:
sc.pl.umap(adata_EEC, color=['Ghrl','Sst','Gcg','Gip','Pyy','Cck','Nts','Sct','Tac1','Tph1','Reg4', 'Spdef','Muc2','Tff3','Lyz1'],layer= 'log_dca_counts',size=5, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, cmap=mymap, save = 'EEC_subs_healthy_EECs_GCs_hormones_expr.png')

In [None]:
plot_composition(adata_EEC, y_key='cell_type_annotation_lv1', x_key='sequencing machine', x_rotation=90,figsize= (6, 4))

## paneth subclusters

In [None]:
EECs = ['Goblet/EEC prog. (early)',
'K-cell (Gip+)',
'EC (mature)',
'EEC (Peptide/immature)' ,
'L/I-cell (Glp1+/Cck+)' ,
'EEC prog. (mid)',
'EC prog. (late)',
'D-cell (Sst+)',
'EEC prog. (late/Peptide)',
'EC 2' ,
'X-cell (Ghrl+)']
Paneth = ['Paneth','Paneth prog.', 'Goblet-Paneth-like', 'Goblet-Paneth-like(cycling)']
Progenitors = ['Goblet/EEC prog. (early)','Paneth prog.', 'Tuft prog.','Tuft prog. 2']

In [None]:
adata = sc.read_h5ad('joint_diseased_healthy_with_layers_metadata_corrected_anno_updated.h5ad')

In [None]:
adata = adata[adata.obs['atlas'].isin(['reference'])].copy()

In [None]:
gc.collect()

### metadata actualisation

In [None]:
## add metadata
metadata_df =read_excel_metadata(f'/mnt/hdd/data/metadata_mouse_gut.xlsx')
# Ensure folder name is the index in metadata for easier access
metadata_df.drop(metadata_df[metadata_df['kit'] == 'Multiome_ATAC_v1'].index, inplace=True)
#metadata_df.drop(metadata_df[metadata_df['condition'].isin(['Ctr','Ctr/WT'])].index, inplace=True)
metadata_df.set_index('folder name', inplace=True)
metadata_df.drop(['Sample Pooling - confounded with Project?','date','Project Name','Link_id','sample name','Cell Count [cells/µl]','Viable Cells [%]','Lib. Concentration [ng/µl]','Lib. Molarity [nM]','Average Lib. Size [bp]','cDNA Cycles','Lib. Cycles','10x Sample Index','Sequencing Depth [reads/cell]','exclusion, reason'], axis=1, inplace=True)

In [None]:
# Function to update adata.obs with metadata using a lambda function
for col in metadata_df.columns:
    try:
        adata.obs[col] = adata.obs['sample'].apply(lambda x: metadata_df.at[x, col])
    except KeyError as err:
        print(f'no such key: {err} in col {col}')

In [None]:
adata.obs.drop(['sample number Minas'],axis=1,inplace=True)

In [None]:
adata.uns['cell_type_annotation_lv1' + '_colors'] =['#d0d0d0',  # ISC
 '#eebcbc',  # TA
 '#fee0d2',  # TA prox
 '#c67a84',  # early Enterocyte
 '#bb4353',  # Enterocyte
 '#eca4d0',  # Tuft prog.
 '#df65b0',  # Tuft prog. 2
 '#e7298a',  # Tuft
 '#e1f3bf',  # Goblet/EEC prog.
 '#d9edf7',  # EEC prog
 '#85c6e6',  # EEC prog. (late/Peptide)
 '#46a8d9',  # EEC (peptide/immature)
 '#339a98',  # X-cell (Ghrl+)
 '#368cbf',  # K-cell (Gip+)
 '#5a72dd',  # L/I-cell (Glp1+/Cck+)
 '#243dae',  # D-cell (Sst+)
 '#d0d1e6',  # EC prog.
 '#aa9dce',  # EC (imm.)
 '#594495',  # EC (mature)
 '#725dae',  # EC 2
 '#fec44f',  # Goblet prog.
 '#dd894e',  # Goblet
 '#7BB98F',  # Paneth prog.
 '#238b45',  # Paneth
 '#ac9470'   # unknown0
]

In [None]:
annotation_key = 'cell_type_annotation_lv1'

In [None]:
adata.obs[annotation_key] = adata.obs[annotation_key].cat.reorder_categories(['ISC', 'TA', 'TA (prox.))', 'early Enterocyte', 'Enterocyte', 
'Tuft prog.', 'Tuft prog. 2', 'Tuft', 
'Goblet/EEC prog. (early)', 'EEC prog. (mid)', 'EEC prog. (late/Peptide)', 'EEC (Peptide/immature)', 'X-cell (Ghrl+)',  'K-cell (Gip+)', 'L/I-cell (Glp1+/Cck+)', 'D-cell (Sst+)',
'EC prog. (late)', 'EC (immature)', 'EC (mature)','EC 2', 
 'Goblet prog. (late)', 'Goblet',  'Paneth prog.', 'Paneth', 'unknown0' ])

In [None]:
adata.obs['cell_type_annotation_lv1'].value_counts()

### get normalized counts

In [None]:
# add dca imputed counts
from anndata._io.specs import read_elem
with h5py.File('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_rmDoublets_integrated_all_imputed.h5ad', 'r') as f:
    # Read specific columns from `obs`
    #sample_column = f['obs/sample'][:]
    #n_counts_column = f['obs/n_counts'][:]
    #https://github.com/scverse/anndata/issues/436:
    #cell_types = read_elem(f["obs/celltype"])
    #umap = read_elem(f["obsm/X_umap"])
    logsct = read_elem(f["layers/sct_logcounts"])
    scran = read_elem(f["layers/scran_counts"])

In [None]:
adata.layers['sct_logcounts'] = logsct
adata.layers['scran_counts'] = scran

### subset to paneth

In [None]:
adata_Paneth = adata[adata.obs['cell_type_annotation_lv1'].isin(Paneth)].copy()

In [None]:
adata_Paneth.X = adata_Paneth.layers['sct_logcounts']

In [None]:
adata_Paneth

In [None]:
del adata
gc.collect()

In [None]:
sc.pp.neighbors(adata_Paneth, use_rep='X_scarches_emb')
sc.tl.leiden(adata_Paneth, resolution=0.4)

In [None]:
sc.tl.paga(adata_Paneth, groups='cell_type_annotation_lv1') #changed from leiden to anno key

In [None]:
sc.pl.paga(adata_Paneth,  fontsize=4,save='paga_paneth_subs.png')#, fontoutline=True, threshold=0.05, max_edge_width=3, min_edge_width=0.01, node_size_scale=3,

In [None]:
sc.tl.umap(adata_Paneth, init_pos='paga')

In [None]:
sc.pl.umap(adata_Paneth, color=['cell_type_annotation_lv1','Project','enriched','phase', 'enrichment proportion','kit','line','strain','leiden','pretty name'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, cmap=mymap, wspace = 0.6,save='paneth_subs_metadata_healthy.png')

In [None]:
sc.tl.rank_genes_groups(adata_Paneth, 'leiden', method='wilcoxon', layer = 'sct_logcounts', use_raw=False)

# to visualize the results

sc.pl.rank_genes_groups(adata_Paneth)

In [None]:
sc.tl.dendrogram(adata_Paneth, groupby='leiden')

In [None]:
adata_Paneth.X = adata_Paneth.layers['sct_logcounts']

In [None]:
del adata_Paneth.raw

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_Paneth, n_genes=5, key="rank_genes_groups", groupby="leiden",save='dotplot_leiden_paneth_subs_healthy.png')


In [None]:
plot_composition(adata_Paneth, y_key='strain', x_key='leiden', x_rotation=90, save='comboplot_paneth_healthy_by_leiden.png')

In [None]:
plot_composition(adata_Paneth, y_key='enriched', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_Paneth, y_key='line', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_Paneth, y_key='strain', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_Paneth, y_key='enrichment proportion', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_Paneth, y_key='Project', x_key='leiden', x_rotation=90)

In [None]:
plot_composition(adata_Paneth, y_key='kit', x_key='leiden', x_rotation=90)

In [None]:
Paneth_markers = [
    "Lyz1",    # Lysozyme
    "Lyz2",
    "Defa27",  # Denfensin a
    "Defa34",   # Defensin alpha 34
    "Defa38",   # Defensin alpha 38
    "Defa43",   # Defensin alpha 43
    "Tff3",    # 
    "Muc2", #Mucin
    "Rps5",    # 
    "Tcf4",    # Transcription Factor 4
    "Sox9",    # SRY-box transcription factor 9
    "Malat1",    #
    "Gfi1"     # Growth Factor Independent 1 Transcription Repressor
]


In [None]:
sc.pl.umap(adata_Paneth, color=['cell_type_annotation_lv1', 'doublet_calls'] + Paneth_markers, layer='log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, cmap=mymap,wspace=0.7,save = 'umap_paneth_markers_anno.png')

## Prog. vs mature

In [None]:
sc.tl.rank_genes_groups(adata_Paneth, 'cell_type_annotation_lv1', method='wilcoxon', layer = 'sct_logcounts', use_raw=False)

# to visualize the results

sc.pl.rank_genes_groups(adata_Paneth)

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_Paneth, n_genes=15, key="rank_genes_groups", groupby="leiden",save = 'dotplot_paneth_meta_markers_prog_vs_mature.png')

In [None]:
Paneth_markers = [
    "Lyz1",    # Lysozyme
    "Defa27",  # Denfensin a
    "Defa34",   # Defensin alpha 34
    "Defa38",   # Defensin alpha 38
    "Defa43",   # Defensin alpha 43
    "Tff3",    # Goblet marker
    "Muc2", #Mucin goblet marker
    "Rps5",    # Ribosomal Protein S5
    "Prap1",        #Lipid-binding protein which promotes lipid absorption by facilitating MTTP-mediated lipid transfer (mainly triglycerides and phospholipids) and MTTP-mediated apoB lipoprotein assembly and secretion (By similarity). Protects the gastrointestinal epithelium from irradiation-induced apoptosis (By similarity). May play an important role in maintaining normal growth homeostasis in epithelial cells (PubMed:14583459). Involved in p53/TP53-dependent cell survival after DNA damage 
    "Olfm4",        # TA marker
    "Plac8",        #PLAC8 (Placenta Associated 8) is a Protein Coding gene. Diseases associated with PLAC8 include Epilepsy, Familial Adult Myoclonic, 2. Among its related pathways are Innate Immune System and Kidney development. Gene Ontology (GO) annotations related to this gene include chromatin binding.
    "Actg1",        #Actin Gamma 1, Actins are highly conserved proteins that are involved in various types of cell motility and are ubiquitously expressed in all eukaryotic cells.
    "Cyc1", # This gene encodes a subunit of the cytochrome bc1 complex, which plays an important role in the mitochondrial respiratory chain by transferring electrons from the Rieske iron-sulfur protein to cytochrome c.
    "Npm1", # Nucleophosmin 1, Involved in diverse cellular processes such as ribosome biogenesis, centrosome duplication, protein chaperoning, histone assembly, cell proliferation, and regulation of tumor suppressors p53/TP53 and ARF. Binds ribosome presumably to drive ribosome nuclear export.
    "Tcf4",    # Transcription Factor 4
    "Sox9",    # SRY-box transcription factor 9
    "Malat1",    #MALAT1 (Metastasis Associated Lung Adenocarcinoma Transcript 1) is an RNA Gene, This transcript is retained in the nucleus where it is thought to form molecular scaffolds for ribonucleoprotein complexes. It may act as a transcriptional regulator for numerous genes, including some genes involved in cancer metastasis and cell migration, and it is involved in cell cycle regulation.
    "Gfi1"     # Growth Factor Independent 1 Transcription Repressor
]


In [None]:
sc.pl.umap(adata_Paneth, color=['cell_type_annotation_lv1', 'doublet_calls'] + Paneth_markers, layer='log_dca_counts',size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, cmap=mymap,wspace=0.7,save = 'umap_paneth_markers_anno.png')

## TFs

In [None]:
with open('TF_mouse_all.txt', 'r') as file:
    TFs = file.read().splitlines()

In [None]:
ranked_genes = adata_Paneth.uns['rank_genes_groups']['names']

In [None]:
ranked_genes

In [None]:
n = -1
differentially_expressed_tfs = {}
N =150
for group in ranked_genes.dtype.names:
    n+=1
    differentially_expressed_tfs[n] = []
    for i, gene in enumerate(ranked_genes[group]):
        if gene in TFs and i < N:
            if gene not in differentially_expressed_tfs[n]:
                differentially_expressed_tfs[n].append(gene)

print(differentially_expressed_tfs)

In [None]:
# Flatten the list of genes
all_tfs = [gene for genes in differentially_expressed_tfs.values() for gene in genes]
all_tfs = list(set(all_tfs))  # Remove duplicates

In [None]:
#sc.pl.rank_genes_groups_dotplot(adata_EEC, var_names = all_tfs)
sc.pl.dotplot(adata_Paneth, all_tfs, groupby='leiden',dendrogram=True, layer='sct_logcounts',use_raw=False)

In [None]:
sc.pl.umap(adata_Paneth, color=all_tfs, size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, cmap=mymap, layer='log_dca_counts', save= 'umap_healthy_TFs_Paneth.png')

### get ambient genes info