# Comparing other feature selection methods with triku
In this notebook we will compare the performance of triku, compared to other methods. 

The methods that will be compared will be the following:
* Select genes with highest variance.   
* Scanpy's `sc.pp.highly_variable_genes`: It is based on Seurat's `vst` method, so they should return similar results.
* scry `devianceFeatureSelection()`. This method is featured as the feature selection for Irizarry's GLM-PCA paper (https://doi.org/10.1186/s13059-019-1861-6). From its description, it computes a deviance statistic for each row feature for count data based on a multinomial null model that assumes each feature has a constant rate. Features with large deviance are likely to be informative. Uninformative, low deviance features can be discarded to speed up downstream analyses and reduce memory footprint. The `fam`parameter will be set to `binomial`, the default.
* M3Drop, which has two main functions:
    * NBDrop: the NBDrop model assumes proportion of zeros follows a Michaelis-Menten model. Then the Michaelis-Menten parameter $K$ is fitted. For each gene, its parameter $K_i$ is compared to $K$ using a $Z$-test, which returns the selected genes.
    * NBUmi: The procedure is similar to above, although the equation to fit now is a negative binomial model,  and the selection of genes is then done using a $Z$-test.
* `BrenneckeGetVariableGenes` fits a function between CV$^2$ and mean expression. 
* Seurat's `FindVariableFeatures`
* SCTransform

With the exception of scanpy and triku, the rest of functions are set on $R$. We will use jupyter's `%%R` magic command, and `anndata2ri` to transform `annData` into `SingleCellExperiment` objects, and we will generate the functions to accept that annData and return the list of selected features. The functions have to be set up in notebook, and cannot be externalized. 

M3Drop requires a normalization step, which will be done in-situ.

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import triku as tk
import scanpy as sc
import pandas as pd
import numpy as np
import scipy.sparse as spr
import scipy.stats as sts
import os
import gc
from itertools import product
import pickle
import ray
import seaborn as sns
import itertools 

from IPython.display import display, HTML

from tqdm.notebook import tqdm

from bokeh.io import show, output_notebook, reset_output
from bokeh.plotting import figure
from bokeh.models import LinearColorMapper

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D

from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import adjusted_mutual_info_score as NMI
from sklearn.metrics import silhouette_score, davies_bouldin_score
from sklearn.decomposition import PCA

reset_output()
output_notebook()

In [None]:
list_methods = ['triku', 'm3drop', 'nbumi', 'scanpy', 'seurat', 'sct', 'scry', 'std', 'brennecke',]
list_methods_all = list_methods + ['all', 'random']

palette = [
        '#e91e63',  # triku
        '#81c784',  # m3drop
        '#388e3c',  # nbumi
        '#90caf9',  # scanpy
        '#2196f3',  # seurat
        '#1565c0',  # sctransform
        '#ff9800',  # std
        '#ff5722',  # scry
        '#ffca28',  # brennecke
]

palette_all =  palette +  [ 
        '#A5B1C2',  # all
        '#4B6584',  # random
    ]

In [None]:
!python setup.py install

In [None]:
import sys, os
sys.path.insert(0, os.getcwd() + '/code')

from triku_nb_code.comparing_feat_sel import plot_max_var_x_dataset, plot_max_var_x_method, create_dict_UMAPs_datasets, \
get_max_diff_gene, plot_ARI_x_method, plot_ARI_x_dataset, biological_silhouette_ARI_table, plot_lab_org_comparison_scores, \
clustering_binary_search, compare_rankings, compare_values
from triku_nb_code.comparing_feat_sel import create_UMAP_adataset_libprep_org, plot_UMAPs_datasets, plot_XY, biological_silhouette_ARI_table
from triku_nb_code.palettes_and_cmaps import magma, bold_and_vivid, prism
from triku_nb_code.GOEA_figs import scatter_enrichr, barplot_ontologies_all

In [None]:
%matplotlib inline

In [None]:
import anndata2ri
anndata2ri.activate()
%load_ext rpy2.ipython

In [None]:
%%R
# Load all the R libraries we will be using in the notebook
library(M3Drop) # Depends on r-foreing (conda-forge) and Hmisc and reldist (install.packages)
library(scry) # If R < 4, launch commit 9f0fc819
library(Seurat)
library(sctransform)
library(dplyr)

In [None]:
os.makedirs(os.getcwd() + '/exports/comparisons/', exist_ok=True)
os.makedirs(os.getcwd() + '/figures/comparison_figs/png', exist_ok=True)
os.makedirs(os.getcwd() + '/figures/comparison_figs/pdf', exist_ok=True)
data_dir = os.getcwd() + '/data/'

In the following 3 cells we will create the cells that obtain the most relevant features. Since some of the calls are to R, they have to be kept as separate cells. Also, we create the function `create_df_feature_ranking` which creates two dataframes: one with the evaluation values (p-value, emd distance, etc.) of each method, and the second one with the ranking of genes based on those values. These dataframes will be valuable so that we don't have to repeat the calling to the feature selection methods each time we do a graph. `create_df_feature_ranking` is also kept as a cell because it makes some calls to R.

In [None]:
%%R

run_scry <- function(sce){ #adata
    adata_ret = devianceFeatureSelection(sce, nkeep=dim(sce)[1], assay='X')
    return(adata_ret) #returns adata with stats on .var
} 


run_brennecke <- function(sce){ #df
    res_df <- BrenneckeGetVariableGenes(sce, suppress.plot=TRUE, fdr=100)
    return(res_df) # returns sorted df with genes and stats
}


run_M3Drop <- function(sce){
    norm <- M3DropConvertData(sce, is.counts=TRUE)
    DE_genes <- M3DropFeatureSelection(norm, suppress.plot=TRUE, mt_threshold=50)
    return(DE_genes) # returns sorted df with genes and stats
    
}

run_NBumi <- function(sce){
    count_mat <- NBumiConvertData(sce, is.counts=TRUE)
    DANB_fit <- NBumiFitModel(count_mat)
    NBDropFS <- NBumiFeatureSelectionCombinedDrop(DANB_fit, suppress.plot=TRUE, qval.thresh=10)
    return(NBDropFS)  # returns sorted df with genes and stats
    
}

run_seurat <- function(sce){ #adata
    sce <- FindVariableFeatures(sce, selection.method = "vst", nfeatures = dim(sce)[1])
    index <- c(1:dim(sce)[1])
    names <- VariableFeatures(sce)
    df_seurat <- HVFInfo(sce)[VariableFeatures(sce), ]
    return(df_seurat)
} 


run_sct <- function(sce){ #adata
    sce <- SCTransform(sce, verbose = FALSE)
    df_sct <- sce@assays$SCT@SCTModel.list$model1@feature.attributes
    return(df_sct)
} 


In [None]:
def run_scanpy(adata):
    adata_copy = adata.copy()
    if not 'log1p' in adata_copy.uns:
        sc.pp.log1p(adata_copy)
    ret = sc.pp.highly_variable_genes(adata_copy, n_top_genes=len(adata_copy), inplace=False)
    df = pd.DataFrame(ret)
    df =  df.set_index(adata_copy.var_names)
    del adata_copy; gc.collect()
    return df # returns df with stats

def run_variable(adata):
    if spr.issparse(adata.X):
        std = adata.X.power(2).mean(0) - np.power(adata.X.mean(0), 2) 
        std = np.asarray(std).flatten()        
    else:
        std = adata.X.std(0)
        
    return std #returns vector with order as var_names 

def run_triku(adata, seed, dist_conn='conn'):
    adata_copy = adata.copy()
    # We have seen that whitening the matrix (not only zero-centring) yielded the best results. This is not directly applycable 
    try:
        pca = PCA(n_components=30, whiten=True, svd_solver="auto", random_state=seed,).fit_transform(adata_copy.X.toarray())
    except: # the array is already dense
        pca = PCA(n_components=30, whiten=True, svd_solver="auto", random_state=seed,).fit_transform(adata_copy.X)
        
    adata_copy.obsm['X_pca'] = pca
    sc.pp.neighbors(adata_copy, random_state=seed, metric='cosine', n_neighbors=int(len(adata_copy) ** 0.5))
    tk.tl.triku(adata_copy, n_windows=100, verbose='error', dist_conn=dist_conn)
    
    if dist_conn == 'conn':
        print('n_HVG: ', adata_copy.var['highly_variable'].sum())
    d = adata_copy.var['triku_distance'] #pd series with distance
    del adata_copy; gc.collect()
    return d

In [None]:
def create_df_feature_ranking(adatax, title_prefix, apply_log=False):
    """
    Create a dataframe with the ranking of features, and another one with the feature values. The adata must be the raw
    adata. From that we will create a adata_df necessary for some R methods.
    
    After each method is run, we will fill the dataframe values, with the values of the metrics used for feature selection, 
    and the dataframe of rankings with the rankings based on the returned value (0, 1, 2, etc.). 
    We create two separate dataframes because the df with values might be reserved for other purposes. The rank dataframes is interesting
    because the values on the values dataframe have different argsort orders depending on the column (M3drop and NBumi direct, rest reverse).
    """
    
    adata = adatax.copy()
    sc.pp.filter_genes(adata, min_cells=1) 
    sc.pp.filter_cells(adata, min_genes=1)
    adata.layers['raw'] = adata.X.copy()
    
    if apply_log:
        sc.pp.log1p(adata)
    
    adata_df = pd.DataFrame(adata.X.T, index=adata.var_names, columns=adata.obs_names)
    adata_raw_df = pd.DataFrame(adata.layers['raw'].T, index=adata.var_names, columns=adata.obs_names)
    
    adata_short = sc.AnnData(X = adata.X[:,:]) # we have to create a clean adata because some column break Rpush
    adata_short.var_names, adata_short.obs_names = adata.var_names[:], adata.obs_names[:]

    adata_seurat = adata_short.copy()
    adata_seurat.obs['nCount_RNA'] = adata_raw_df.values.sum(0).astype(int)
    adata_seurat.obs['nFeature_RNA'] = (adata_raw_df.values > 0).sum(0)

    %Rpush adata_seurat
    %Rpush adata_short
    %Rpush adata_df
    %Rpush adata_raw_df

    %R assay(adata_seurat, "raw") <- data.matrix(adata_raw_df)
    %R adata_seurat <- as.Seurat(adata_seurat, counts = 'X', data = 'raw')
    
    print('Outside R', adata.shape, adata_short.shape)
    d = %R  dim(adata_short)
    print('Inside R', d)
       
    if 'Group' in adata.obs:
        adata_groups = [i.replace('Group', '') for i in adata.obs['Group']]
        adata.obs['groupn'] = adata_groups
    
    index, columns = adata.var_names, ['triku', 'triku_dist', 'scanpy', 'std', 'scry', 'brennecke', 'm3drop', 'nbumi', 'seurat', 'sct']

    df_values, df_ranks = pd.DataFrame(index=index, columns=columns), pd.DataFrame(index=index, columns=columns)
    
    
    df_emd_distance = run_triku(adata, seed=0, dist_conn='conn')
    df_values.loc[df_emd_distance.index, f'triku'] = df_emd_distance.values

    
    df_emd_distance = run_triku(adata, seed=0, dist_conn='dist')
    df_values.loc[df_emd_distance.index, f'triku_dist'] = df_emd_distance.values
    
    
    scanpy_ret = run_scanpy(adata)
    df_values.loc[scanpy_ret.index, 'scanpy'] = scanpy_ret['dispersions_norm'].values
    assert len(df_values.index) == len(adata.var_names)
    
    std_ret = run_variable(adata)
    df_values.loc[:, 'std'] = std_ret
    assert len(df_values.index) == len(adata.var_names)
    
    scry_ret = %R run_scry(adata_short)
    df_values.loc[scry_ret.var.index, 'scry'] = scry_ret.var['binomial_deviance'].values
    assert len(df_values.index) == len(adata.var_names)
    
    brennecke_ret = %R run_brennecke(adata_df)
    df_values.loc[brennecke_ret.index, 'brennecke'] = brennecke_ret['effect.size'].values
    assert len(df_values.index) == len(adata.var_names)
    
    M3Drop_ret = %R run_M3Drop(adata_df)
    df_values.loc[M3Drop_ret.index, 'm3drop'] = M3Drop_ret['q.value'].values
    assert len(df_values.index) == len(adata.var_names)
    
    NBumi_ret = %R run_NBumi(adata_df)
    df_values.loc[NBumi_ret.index, 'nbumi'] = NBumi_ret['q.value'].values
    assert len(df_values.index) == len(adata.var_names)
    
    seurat_ret = %R run_seurat(adata_seurat)
    df_values.loc[seurat_ret.index, 'seurat'] = seurat_ret['variance.standardized'].values
    assert len(df_values.index) == len(adata.var_names)
    
    sct_ret = %R run_sct(adata_seurat)
    df_values.loc[sct_ret.index, 'sct'] = sct_ret['residual_variance'].values
    assert len(df_values.index) == len(adata.var_names)  

    
    # Now we will fill df_ranks with an argsort !!!!! M3DROP and NBumi is not [::-1] because they are q-values 
    for col in columns:
        df_ranks[col] = df_values[col].values.argsort()[::-1].argsort()
    for col in ['m3drop', 'nbumi']:
        df_ranks[col] = df_values[col].values.argsort().argsort() # double argsort to return the rank!
    
    df_ranks.to_csv(os.getcwd() + '/exports/comparisons/' + title_prefix + '_feature_ranks.csv')
    df_values.to_csv(os.getcwd() + '/exports/comparisons/' + title_prefix + '_feature_values.csv')
    print('df_ranks', df_ranks.shape)
    
    del adata; gc.collect()
    return df_values, df_ranks

# Random datasets
For this section we will use the random datasets generated with splatter.
To evaluate the performance of the feature selection methods, we will use teo metrics, maximum deviation and ARI, explained below.

In [None]:
splatter_dir = os.getcwd() + '/data/splatter/'
list_deprobs = [0.0065, 0.008, 0.01, 0.016, 0.025, 0.05, 0.1, 0.3]

**THIS PROCESS TAKES ~ 4 HOURS!**

Also... this cell sometimes fails to load. Running it again makes it go fine. 

In [None]:
import logging
triku_logger = tk.logg.triku_logger
triku_logger.setLevel(logging.WARNING)

In [None]:
for deprob in tqdm(list_deprobs[2:]):
    print(f'Deprob {deprob}')
    adata_deprob = sc.read(splatter_dir + f'/splatter_deprob_{deprob}.loom', sparse=False)
    print(f'Adata {deprob} loaded: {adata_deprob.X.shape}')
    df_values, df_ranks = create_df_feature_ranking(adata_deprob, f'scatter_{deprob}')

## ARI / NMI
Using ARI on random datasets is a measure to assess the effectiveness of the feature selection. Random datasets were prepared with different degrees of differentially expressed gene probability, so that we can compare the leiden clusterign solution with the 9 populations. Triku can be run with different seeds, but the rest of methods are deterministic. However, leiden clustering in all cases can be run with a seed. Therefore, we are going to run all processes with 10 seeds (although the deterministic processes will be run once).

To apply the ARI we need to run leiden with as many clusters as scatter populations. Since leiden runs on resolution, we need to adjust the resolution parameter to match the number of clusters. To do that we are going to implement a binary search-like algorithm. We will start with resolutions 0.3 and 2 (may change in the future). If any of those yields the clusters, done. Else, find the midpoint, run the clustering, and if the clustering yields the number of populations, stop. Else, set the upper or lower resolution to the one that makes the desired number of clusters to be in the middle. This algorithm will try at most 5 times (it gets to resolution differences of ~0.05, which is fair).

To calculate the ARI, we need to load a dataset, select a number of features, and create the dataframe with seeds as rows (to see varation on clustering / triku) and the methods as columns. Because creating each dataframe take time (there are 70 cells to be filled), we will choose two datasets (DE = 0.01 and 0.025) and two number of features (100 and 500), which show good results in the previous sections. 

In [None]:
save_dir = os.getcwd() + '/exports/comparisons/'

In [None]:
min_res, max_res, max_depth = 0.3, 2, 6

In [None]:
@ray.remote
def leiden_adata_NMI_ARI(deprobx):
    print(deprobx)
    adata_all = sc.read(splatter_dir + f'/splatter_deprob_{deprobx}.loom', sparse=False)
    sc.pp.subsample(adata_all, 0.4) # We shorthen this to make the calculations not take 8 hours!
    sc.pp.filter_genes(adata_all, min_cells=1)
    sc.pp.filter_cells(adata_all, min_genes=1)
    sc.pp.log1p(adata_all)
    
    for n_features in [250, 500]:
        print(deprobx, n_features)
        if not os.path.exists(os.getcwd() + f'/exports/comparisons/NMI_scatter_{deprobx}_n_features_{n_features}.csv'):
            df_feature_ranks = pd.read_csv(os.getcwd() + '/exports/comparisons/' + f'scatter_{deprobx}' + '_feature_ranks.csv', index_col=0)

            list_methods = df_feature_ranks.columns.tolist() + ['all', 'random']

            df_NMI = pd.DataFrame(index=[f'seed_{i}' for i in range(10)], columns=list_methods)
            df_ARI = pd.DataFrame(index=[f'seed_{i}' for i in range(10)], columns=list_methods)

            for seed in range(10):
                print(deprobx, n_features, seed)
                for method in tqdm(list_methods):
                    if method == "all":
                        feats = df_feature_ranks[f'triku'].sort_values().index[:]
                    elif method == "random":
                        array_selection = np.array([False] * len(df_feature_ranks))
                        array_selection[np.random.choice(np.arange(len(df_feature_ranks)), n_features, replace=False)] = True
                        
                        feats = df_feature_ranks[f'triku'].sort_values().index[array_selection]
                    else:
                        feats = df_feature_ranks[method].sort_values().index[:n_features]
                    
                    adata_groups = [i.replace('Group', '') for i in adata_all.obs['Group']]
                    c_f, res = clustering_binary_search(adata_all, min_res, max_res, max_depth, seed, len(list(dict.fromkeys(adata_groups))), feats, apply_log=False)
                    NMS = NMI(c_f, adata_groups)
                    ARS = ARI(c_f, adata_groups)

                    df_NMI.loc[f'seed_{seed}', method] = NMS
                    df_ARI.loc[f'seed_{seed}', method] = ARS
            
            print(os.getcwd() + f'/exports/comparisons/NMI_scatter_{deprobx}_n_features_{n_features}.csv')
            df_NMI.to_csv(os.getcwd() + f'/exports/comparisons/NMI_scatter_{deprobx}_n_features_{n_features}.csv')
            df_ARI.to_csv(os.getcwd() + f'/exports/comparisons/ARI_scatter_{deprobx}_n_features_{n_features}.csv')

In [None]:
ray.init(num_cpus=min(os.cpu_count(), len(list_deprobs)), ignore_reinit_error=True)
ray_get = ray.get([leiden_adata_NMI_ARI.remote(deprobx) for deprobx in list_deprobs])
ray.shutdown()

In [None]:
help(plot_lab_org_comparison_scores)

#### Figure 3

In [None]:
for n_feats in ['250', '500']:
    list_files = [f'NMI_scatter_{deprob}_n_features_{n_feats}.csv' for deprob in list_deprobs[::-1]]
    fig = plot_lab_org_comparison_scores(f'NMI-{n_feats}', '', save_dir, [''], increasing=0, mode='ARI', list_files=list_files, 
                                       title=f'NMI on artificial datasets, {n_feats} features', 
                                       filename=f'NMI_{n_feats}-features', do_return=True, FS_methods=list_methods_all, palette=palette_all)
    _ = fig.axes[1].set_xticklabels([item.get_text().replace('scatter ', '') for item in fig.axes[1].get_xticklabels()], rotation=0, ha='center')
    for fmt in ["png", "pdf"]:
        fig.savefig(f"{os.getcwd()}/figures/comparison_figs/{fmt}/NMI_{n_feats}-features.{fmt}", bbox_inches="tight")

#### Figure S1

In [None]:
for n_feats in ['250', '500']:
    list_files = [f'ARI_scatter_{deprob}_n_features_{n_feats}.csv' for deprob in list_deprobs[::-1]]
    fig = plot_lab_org_comparison_scores(f'ARI-{n_feats}', '', save_dir, [''], increasing=0, mode='ARI', list_files=list_files, 
                                       title=f'ARI on artificial datasets, {n_feats} features', 
                                       filename=f'ARI_{n_feats}-features', do_return=True, FS_methods=list_methods_all, palette=palette_all)
    _ = fig.axes[1].set_xticklabels([item.get_text().replace('scatter ', '') for item in fig.axes[1].get_xticklabels()], rotation=0, ha='center')
    for fmt in ["png", "pdf"]:
        fig.savefig(f"{os.getcwd()}/figures/comparison_figs/{fmt}/ARI_{n_feats}-features.{fmt}", bbox_inches="tight")

For lower number of features (250) scanpy performs best at lower DE probabilities (up to 0.025) but performs worse at full resolution (0.1 or 0.3), with scry the best method for that, principally because the features that make smaller clusters separate are the ones with most expression, and those are the ones selected by scry. However, at smaller DE probabilities, the features that separate the dataset the most are the ones with mid expression levels, which are best picked by triku. 

In [None]:
df_values = pd.read_csv(f'{save_dir}/scatter_0.01_feature_values.csv', index_col=0)
df_ranks = pd.read_csv(f'{save_dir}/scatter_0.01_feature_ranks.csv', index_col=0)

In [None]:
df_ranks

In [None]:
df_values

In [None]:
adata = sc.read(splatter_dir + f'/splatter_deprob_0.01.loom', sparse=False)
sc.pp.subsample(adata, 0.4) # We shorthen this to make the calculations not take 8 hours!

adata.raw = adata
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)

In [None]:
df_bokeh = pd.DataFrame({'m': np.log10(adata.X.mean(0)), 
                         'z': df_values['triku'].loc[adata.var_names].values, 
                         'n': df_values.index.values})
                   
p = figure(tools="box_zoom,hover,reset", plot_height=600, plot_width=600, tooltips=[("Gene","@n")])
p.scatter('m', 'z', source=df_bokeh, alpha=0.7, line_color=None)
show(p)

#### Figure S2

In [None]:
triku_hvg = df_ranks['triku'].values < 250
scry_hvg = df_ranks['scry'].values < 250
std_hvg = df_ranks['std'].values < 250
scanpy_hvg = df_ranks['scanpy'].values < 250

fig, axs = plt.subplots(1, 4, figsize=(12, 3))
names = ['triku', 'scanpy', 'std', 'scry',]
hvgs = [triku_hvg, scanpy_hvg, std_hvg, scry_hvg]
colors = ['#e73f74', '#7f3c8d', '#11a579', '#3969ac']

for i in range(4):
    axs[i].scatter(np.log10(adata.X.mean(0))[~hvgs[i]][::5], 
                   df_values['triku'].loc[adata.var_names].values[~hvgs[i]][::5], c="#dedede", s=2, alpha=0.7)
    axs[i].scatter(np.log10(adata.X.mean(0))[hvgs[i]], 
                   df_values['triku'].loc[adata.var_names].values[hvgs[i]], c=colors[i], s=2, label=names[i])
    axs[i].legend()

fig.text(0.0, 0.5, 'Wasserstein distance', va='center', rotation='vertical')
fig.text(0.5, 0.0, 'log$_{10}$ mean expression', va='center', rotation='horizontal')
plt.tight_layout()
plt.savefig(os.getcwd() + f'/figures/comparison_figs/pdf/barplots_scatter.pdf', fmt='pdf')


list_list_genes = [['Gene1118', 'Gene8599', 'Gene1513', 'Gene1479'],     # triku only
                   ['Gene6723', 'Gene6625', 'Gene9796', 'Gene935'],      # triku + scanpy
                   ['Gene12841', 'Gene10739', 'Gene6729', 'Gene12240'],  # scanpy
                   ['Gene9545', 'Gene4459', 'Gene383', 'Gene12455'],     # all
                   ['Gene1633', 'Gene10792', 'Gene2496', 'Gene12497']]   # std + scry

list_bar_colors = ['#94346E', '#E17C05', '#0F8554', '#1D6996']

for lg_idx, list_genes in enumerate(list_list_genes):
    fig, axs = plt.subplots(1, 5, figsize=(3*5, 3))
    
    hvg = np.isin(adata.var_names, list_genes)
    axs[0].scatter(np.log10(adata.X.mean(0))[~hvg][::5], 
                   df_values['triku'].loc[adata.var_names].values[~hvg][::5], c="#bcbcbc", s=2, alpha=0.7)
    
    for i in range(1, 5):
        for group in range(10):
            axs[0].scatter(np.log10(adata.X.mean(0))[np.isin(adata.var_names, list_genes[i - 1])][::5], 
                           df_values['triku'].loc[adata.var_names].values[np.isin(adata.var_names, list_genes[i - 1])][::5], c=list_bar_colors[i - 1], s=7)

                
            data_values = adata[adata.obs['Group'] == 'Group' + str(group + 1)].X[:, np.argwhere(adata.var_names == list_genes[i-1])[0]].flatten()
            mean, std = np.mean(data_values), np.std(data_values)

            axs[i].bar(group + 1, mean, color=list_bar_colors[i - 1])
            
    axs[0].set_ylabel('Wasserstein distance')
    axs[0].set_xlabel('log$_{10}$ mean expression')
    axs[1].set_ylabel('Mean group expression')
    axs[1].set_xlabel('Group')

    plt.tight_layout()

    plt.savefig(os.getcwd() + f'/figures/comparison_figs/pdf/barplots_{lg_idx}.pdf', fmt='pdf')

# Ding et al. / Mereu et al. datasets
Now that we have seen that triku outperforms other methods in artificial datasets, at least when there is intrinsic noisiness, we are going to apply similar metrics to biological datasets. We are first going to use Mereu's and Ding's human + mouse benchmarking datasets. They will help us see biases on performance of all the methods, and also it will act as a validation of the results from the original papers.

In this part, due to the large amount of datasets, and also due to the heterogeneity of genes, we will not apply use different number of features. Instead, we will run triku with seed 0, and select the default number of features that is automatically generated to select the features on the rest of methods. This will mean that different datasets will have different number of features, although each dataset will have the same number of features across methods. 

The two main methods that we will use to evaluate the feature selection are NMI and Silhouette scores.
* NMI uses the assigned cell types from the paper (Mereu et al. use MatchSCore2 and Ding et al. uses a custom algorithm) and applies the same binary search for resolution.
* Silhouette. It is used in two forms:
    * Apply the same resolution to all datasets and all methods using the binary search, and get the Silhouette from there.
    * Apply Silhouette to the benchmark-assigned cell types.

## Create feature ranking dataframes

**This process takes ~3 hours (> 12 with scmer)**

In [None]:
save_dir = os.getcwd() + '/exports/comparisons/'
mereu_dir = os.getcwd() + '/data/Mereu_2020/'
ding_dir = os.getcwd() + '/data/Ding_2020/'

In [None]:
for libprep in tqdm(['10X', 'CELseq2', 'ddSEQ', 'Dropseq', 'inDrop', 'QUARTZseq', 'SingleNuclei', 'SMARTseq2']):
    for org in ['human', 'mouse']:
        print(libprep, org)
        if os.path.exists(save_dir + f'mereu_{libprep}_{org}-log_feature_values.csv'):
            print(f'{libprep}, {org} exists!')
        else:
            if not os.path.exists(mereu_dir + f'{libprep}_{org}.h5ad'):
                print(print(libprep, org, 'is not available'))
            else:
                adata_libprep = sc.read(mereu_dir + f'{libprep}_{org}.h5ad')
                create_df_feature_ranking(adata_libprep, f'mereu_{libprep}_{org}-log', apply_log=True)

## Calculate scores

In [None]:
@ray.remote
def run_ARI_silhouette_rem(lib_prep, org, seed, lab, adata_dir, save_dir):
    if os.path.exists(adata_dir + f'{lib_prep}_{org}.h5ad'):
        if os.path.exists(save_dir + f'{lab}_{lib_prep}-log_{org}_comparison-scores_seed-{seed}.csv'):
            print(f'{lib_prep}, {org}, {seed} exists!')
        else:
            adata = sc.read_h5ad(adata_dir + f'{lib_prep}_{org}.h5ad')
            print(adata)
            cell_type = 'cell_types' if 'cell_types' in adata.obs else 'CellType' # Somwhere I've fucked up with column name. Don't care where honestly.
            df_rank = pd.read_csv(os.getcwd() + f'/exports/comparisons/{lab}_{lib_prep}_{org}-log_feature_ranks.csv', index_col=0)

            biological_silhouette_ARI_table(adata, df_rank, outdir=save_dir, file_root=f'{lab}_{lib_prep}_{org}-log', seed=seed, 
                                                        cell_types_col=cell_type, n_procs=1)   
    else:
        print(adata_dir + f'{lib_prep}_{org}.h5ad does not exist!')

In [None]:
# Mereu's datasets
save_dir = os.getcwd() + '/exports/comparisons/'
adata_dir = data_dir + 'Mereu_2020/'


lib_preps = ['SingleNuclei', 'Dropseq', 'inDrop', '10X', 'SMARTseq2', 'CELseq2', 'QUARTZseq', 'ddSEQ'] 
orgs = ['mouse', 'human'] 
result = list(product(*[lib_preps, orgs, range(5)]))

ray.init(ignore_reinit_error=True, num_cpus=min(len(result), os.cpu_count()))

list_id = [run_ARI_silhouette_rem.remote(lib_prep, org, seed, 'mereu', adata_dir, save_dir) for lib_prep, org, seed in result]
list_results = ray.get(list_id)

ray.shutdown()

In [None]:
# Ding's datasets
save_dir = os.getcwd() + '/exports/comparisons/'
adata_dir = data_dir + 'Ding_2020/'


lib_preps = ['10X', 'CELseq2', 'Dropseq', 'inDrop', 'sci-RNAseq', 'Seq-Well', 'SingleNuclei', 'SMARTseq2']
orgs = ['mouse', 'human'] 
result = list(product(*[lib_preps, orgs, range(5)]))

ray.init(ignore_reinit_error=True, num_cpus=min(len(result), os.cpu_count()))

list_id = [run_ARI_silhouette_rem.remote(lib_prep, org, seed, 'ding', adata_dir, save_dir) for lib_prep, org, seed in result]
list_results = ray.get(list_id)

ray.shutdown()

#### Figure 4A and 5A

In [None]:
for lab in ['ding', 'mereu']:
    save_dir = os.getcwd() + '/exports/comparisons/'
    fig = plot_lab_org_comparison_scores(lab, org='-log', read_dir=save_dir, variables=['NMI'], figsize=(16, 4), title=f'NMI on {lab} datasets (log)', 
                                  filename=f'{lab}-NMI-log', do_return=True, sort_values='descending', FS_methods=list_methods_all, palette=palette_all)
    _ = fig.axes[1].set_xticklabels([item.get_text().replace('-log', '') for item in fig.axes[1].get_xticklabels()])
    for fmt in ["png", "pdf"]:
        fig.savefig(f"{os.getcwd()}/figures/comparison_figs/{fmt}/{lab}-NMI-log.{fmt}", bbox_inches="tight")

#### Figure S3A and S4A

In [None]:
for lab in ['ding', 'mereu']:
    save_dir = os.getcwd() + '/exports/comparisons/'
    fig = plot_lab_org_comparison_scores(lab, org='-log', read_dir=save_dir, variables=['ARI'], figsize=(16, 4), title=f'ARI on {lab} datasets (log)', 
                                  filename=f'{lab}-ARI-log', do_return=True, sort_values='descending', FS_methods=list_methods_all, palette=palette_all)
    _ = fig.axes[1].set_xticklabels([item.get_text().replace('-log', '') for item in fig.axes[1].get_xticklabels()])
    for fmt in ["png", "pdf"]:
        fig.savefig(f"{os.getcwd()}/figures/comparison_figs/{fmt}/{lab}-ARI-log.{fmt}", bbox_inches="tight")

#### Figure 4B,5B

In [None]:
for lab in ['ding', 'mereu']:
    save_dir = os.getcwd() + '/exports/comparisons/'
    fig = plot_lab_org_comparison_scores(lab, org='-log', read_dir=save_dir, variables=['Sil_bench_all_hvg'], figsize=(16, 4), 
                                       title=f'Silhouette on {lab} datasets, cell types on selected features (log)',
                                       filename=f'{lab}-silhouette_selected features_celltypes-log', do_return=True, sort_values='descending', 
                                         FS_methods=list_methods_all, palette=palette_all)
    _ = fig.axes[1].set_xticklabels([item.get_text().replace('-log', '') for item in fig.axes[1].get_xticklabels()])
    for fmt in ["png", "pdf"]:
        fig.savefig(f"{os.getcwd()}/figures/comparison_figs/{fmt}/{lab}-silhouette_selected features_celltypes-log.{fmt}", bbox_inches="tight")

#### Figure S3B and S4B

In [None]:
for lab in ['ding', 'mereu']:
    save_dir = os.getcwd() + '/exports/comparisons/'
    fig = plot_lab_org_comparison_scores(lab, org='-log', read_dir=save_dir, variables=['Sil_leiden_all_hvg'], figsize=(16, 4), 
                                   title=f'Silhouette on {lab} datasets, leiden clusters on selected features (log)', 
                                  filename=f'{lab}-silhouette_selected features_leiden-log', do_return=True, sort_values='descending', 
                                  FS_methods=list_methods_all, palette=palette_all)
    _ = fig.axes[1].set_xticklabels([item.get_text().replace('-log', '') for item in fig.axes[1].get_xticklabels()])
    for fmt in ["png", "pdf"]:
        fig.savefig(f"{os.getcwd()}/figures/comparison_figs/{fmt}/{lab}-silhouette_selected features_leiden-log.{fmt}", bbox_inches="tight")

## Explainability of results
Although we see that triku has promising results, we were striked at how std, scry and brennecke have such a big gap of scores with respect to triku, nbumi and m3drop. In this section we are going to apply some comprobation measures to see if we can know why the difference is so big. 

### Overlap heatmaps

In [None]:
def plot_heatmaps_jaccard(df_ranks, n_HVG, fig_save_dir='', title='', ax=None, lab='ding'):
    df_heatmap = pd.DataFrame(np.NaN, index=df_ranks.columns, columns=df_ranks.columns)
    
    for row_idx, row in enumerate(df_ranks.columns):
        for col_idx, col in enumerate(df_ranks.columns):
            if row_idx >= col_idx:
                row_names = set(df_ranks.sort_values(by=row).index[:n_HVG].values)
                col_names = set(df_ranks.sort_values(by=col).index[:n_HVG].values)
                
                jaccard = len(row_names & col_names)/len(row_names | col_names)
                df_heatmap.loc[row, col] = jaccard
    
    h = sns.heatmap(df_heatmap, cbar=False, ax=ax, annot=True, fmt='.1g')
    h.set_title(title)
    
    plt.tight_layout()
    for fmt in ['png', 'pdf']:
        plt.savefig(f'{fig_save_dir}/{lab}_heatmap_overlap_features.{fmt}', bbox_inches='tight')
    

#### Figure 6

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(4.3*4, 4.3))

list_files = ['ding_10X_mouse-log_comparison-scores', 'ding_Dropseq_human-log_comparison-scores',
              'ding_Seq-Well_human-log_comparison-scores', ]

for libprep_idx, libprep in enumerate(list_files):
    pre = '' if libprep_idx == 0 else '_'
    for method_idx, method in enumerate(list_methods_all):
        list_y = []
        for seed in range(5):
            df = pd.read_csv(os.getcwd() + f'/exports/comparisons/{libprep}_seed-{seed}.csv', index_col=0)
            list_y.append(df.loc['NMI', method])

        axs[0].bar(
            libprep_idx + (method_idx - len(list_methods_all) // 2) * 0.08,
            np.mean(list_y), width=0.08, yerr=np.std(list_y), color=palette_all[method_idx], 
            label=pre + method,
        )
        
axs[0].set_xticks([0, 1, 2])
axs[0].set_xticklabels(['10X mouse', 'Dropseq human', 'Seq-Well human'], rotation=45, ha='right')
axs[0].set_title('NMI on ding datasets')
axs[0].legend(ncol=2, handleheight=0.3, labelspacing=0.05, prop={'size': 8}, frameon=False)
axs[0].set(frame_on=False)

# Next are the heatmaps of gene overlap
df_ranks = pd.read_csv(os.getcwd() + '/exports/comparisons/ding_10X_mouse-log_feature_ranks.csv', index_col=0)
df_ranks = df_ranks[list_methods]
plot_heatmaps_jaccard(df_ranks, n_HVG=1623, fig_save_dir=os.getcwd() + '/figures/comparison_figs', 
                      title='10X mouse', ax=axs[1])

df_ranks = pd.read_csv(os.getcwd() + '/exports/comparisons/ding_Dropseq_human-log_feature_ranks.csv', index_col=0)
df_ranks = df_ranks[list_methods]
plot_heatmaps_jaccard(df_ranks, n_HVG=1294, fig_save_dir=os.getcwd() + '/figures/comparison_figs', 
                      title='Dropseq human', ax=axs[2])

df_ranks = pd.read_csv(os.getcwd() + '/exports/comparisons/ding_Seq-Well_human-log_feature_ranks.csv', index_col=0)
df_ranks = df_ranks[list_methods]
plot_heatmaps_jaccard(df_ranks, n_HVG=1256, fig_save_dir=os.getcwd() + '/figures/comparison_figs', 
                      title='Seq-Well human', ax=axs[3])

fig.savefig(f"{os.getcwd()}/figures/comparison_figs/pdf/ding_heatmap_overlap_features.pdf", bbox_inches="tight")

# Enrichment of ribosomal and mitochondrial genes

In [None]:
def barplot_mt_rbp(lab, org, method, n_features=[100, 250, 500, 1000], mode=0):   
    fig, axs = plt.subplots(2, 1, figsize=(10, 6))
    
    for n_feature_idx, n_feature in enumerate(n_features):
        for FS_idx, FS in enumerate(list_methods):
            df = pd.read_csv(os.getcwd() + f'/exports/comparisons/{lab}_{method}_{org}-log_feature_ranks.csv', index_col=0)

            set_rbp = set([i for i in df.index if (i.upper().startswith('RPS')) | (i.upper().startswith('RPL'))])
            set_mt = set([i for i in df.index if (i.upper().startswith('MT-'))])
            
            set_FS = set(df.sort_values(by=FS).index.tolist()[:n_feature])
            
            if mode == 0:
                axs[0].bar(n_feature_idx + (FS_idx - len(list_FS) // 2) / (len(list_FS) + 3) , 100 * len(set_rbp & set_FS)/len(set_FS), 
                        width = 0.1, color=palette_all[FS_idx])
                axs[1].bar(n_feature_idx + (FS_idx - len(list_FS) // 2) / (len(list_FS) + 3) , 100 * len(set_mt & set_FS)/len(set_FS), 
                        width = 0.1, color=palette_all[FS_idx])
            else:
                axs[0].bar(n_feature_idx + (FS_idx - len(list_FS) // 2) / (len(list_FS) + 3) , 100 * len(set_rbp & set_FS)/len(set_rbp), 
                        width = 0.1, color=palette_all[FS_idx])
                axs[1].bar(n_feature_idx + (FS_idx - len(list_FS) // 2) / (len(list_FS) + 3) , 100 * len(set_mt & set_FS)/len(set_rbp), 
                        width = 0.1, color=palette_all[FS_idx])
                
    for ax in axs:
        ax.set_xticks(range(len(n_features)))
        ax.set_xticklabels(n_features)
    
    if mode == 0:
        axs[0].set_ylabel('% ribosomal genes\n(from selected features)')
        axs[1].set_ylabel('% mitochondrial genes\n(from selected features)')
    else:
        axs[0].set_ylabel('% ribosomal genes\n(from all ribosomal genes)')
        axs[1].set_ylabel('% mitochondrial genes\n(from all mitochondrial genes)')
        
    legend_elements = [mpl.lines.Line2D([0], [0], marker="o", color=palette[0], label='triku')] + [
        mpl.lines.Line2D(
            [0], [0], marker="o", color=palette_all[j], label=list_methods[j]
        )
        for j in range(1, len(list_methods))
    ]
    axs[0].legend(handles=legend_elements, bbox_to_anchor=(1.2, 0.9))
    
    
def heatmap_mt_rbp(labs, orgs, methods, n_features=500):        
    dict_info = {}
    
    for lab in labs:
        for org in orgs:
            for method in methods:
                for FS_idx, FS in enumerate(list_methods):
                    if not os.path.exists(os.getcwd() + f'/exports/comparisons/{lab}_{method}_{org}-log_feature_ranks.csv'):
                        continue
                        
                    df = pd.read_csv(os.getcwd() + f'/exports/comparisons/{lab}_{method}_{org}-log_feature_ranks.csv', index_col=0)

                    set_rbp = set([i for i in df.index if (i.upper().startswith('RPS')) | (i.upper().startswith('RPL'))])
                    set_mt = set([i for i in df.index if (i.upper().startswith('MT-'))])

                    set_FS = set(df.sort_values(by=FS).index.tolist()[:n_features])
                    
                    for opt in [f'{FS}_per_rbp_all_features', f'{FS}_per_mt_all_features',]:
                        if opt not in dict_info:
                            dict_info[opt] = []
                    
                    dict_info[f'{FS}_per_rbp_all_features'].append(100 * len(set_rbp & set_FS) / len(set_FS))
                    dict_info[f'{FS}_per_mt_all_features'].append(100 * len(set_mt & set_FS) / len(set_FS))
    
    df = pd.DataFrame(index=list_methods, columns=[
        'Percentage RBPs in selected features', 'Percentage MTs in selected features', ])
    
    for FS_idx, FS in enumerate(list_methods):
        df.iloc[FS_idx, 0] = '%.3f' % np.nanmean(dict_info[f'{FS}_per_rbp_all_features'])
        df.iloc[FS_idx, 1] = '%.3f' % np.nanmean(dict_info[f'{FS}_per_mt_all_features'])
                                
                                     
    return df.astype(float)

#### Table 1

In [None]:
df = heatmap_mt_rbp(['mereu'], ['human', 'mouse'], ['SingleNuclei', 'Dropseq', 'inDrop', '10X', 'SMARTseq2', 
                                              'CELseq2', 'QUARTZseq', 'sci-RNAseq', 'Seq-Well'], n_features=2000)

display(df)

In [None]:
df = heatmap_mt_rbp(['ding'], ['human', 'mouse'], ['SingleNuclei', 'Dropseq', 'inDrop', '10X', 'SMARTseq2', 
                                              'CELseq2', 'QUARTZseq', 'sci-RNAseq', 'Seq-Well'], n_features=2000)

display(df)

# Gene ontology analysis

To see which method is better, a possible idea is to run Enrichr with the selected features, and use it to compare the FS methods. If the ontologies from one method have better p-values/scores, it is likely that they are more representative of the dataset.


In [None]:
import gseapy

In [None]:
os.makedirs(os.getcwd() + f'/exports/enrichr/', exist_ok=True)

In [None]:
# list_onto_mouse = ['KEGG_2019_Mouse', 'WikiPathways_2019_Mouse', 'GO_Biological_Process_2018', 'GO_Cellular_Component_2018', 
#                    'GO_Molecular_Function_2018',]

# list_onto_human = ['KEGG_2019_Human', 'WikiPathways_2019_Human', 'GO_Biological_Process_2018', 'GO_Cellular_Component_2018', 
#                    'GO_Molecular_Function_2018', ]


list_onto_mouse = ['GO_Biological_Process_2018']
list_onto_human = ['GO_Biological_Process_2018']

In [None]:
@ray.remote
def call_enrichr(lab, org, method, n_features, FS):
    if os.path.exists(os.getcwd() + f'/exports/enrichr/{lab}_{method}_{org}_{n_features}_{FS}.csv'):
        print(os.getcwd() + f'/exports/enrichr/{lab}_{method}_{org}_{n_features}_{FS}.csv EXISTS!')
        return None
    
    if not os.path.exists(os.getcwd() + f'/exports/comparisons/{lab}_{method}_{org}-log_feature_ranks.csv'):
        return None
    
    df_file = pd.read_csv(os.getcwd() + f'/exports/comparisons/{lab}_{method}_{org}-log_feature_ranks.csv', index_col=0)
    
    list_genes = df_file.sort_values(by=FS).index.tolist()[:n_features]
    list_onto = list_onto_mouse if org == 'mouse' else list_onto_human
    
    n_trials = 0
    
    while n_trials < 5:
        try:
            result_df = gseapy.enrichr(list_genes, list_onto, cutoff=1, organism=org).results
            result_df.to_csv(os.getcwd() + f'/exports/enrichr/{lab}_{method}_{org}_{n_features}_{FS}.csv', index=None)
            n_trials += 10
        except:
#             raise
            print(f'TRIAL {n_trials}')
            
        

    

In [None]:
list_comb = list(product(*[['ding', 'mereu'], 
                           ['human', 'mouse'], 
                           ['SingleNuclei', 'Dropseq', 'inDrop', '10X', 'SMARTseq2', 'CELseq2', 'QUARTZseq', 'sci-RNAseq', 'Seq-Well'], 
                           [100, 250, 500, 1000, 1250, 1500], 
                           list_methods]))


ray.init(ignore_reinit_error=True)

list_id = [call_enrichr.remote(lab, org, method, n_features, FS) for lab, org, method, n_features, FS in list_comb]
list_results = ray.get(list_id)

ray.shutdown()

In [None]:
enrichr_figs_dir = os.getcwd() + '/figures/enrichr_figs/'
os.makedirs(enrichr_figs_dir, exist_ok=True)

#### Figure 7

In [None]:
def scatter_enrichr(lab, org, method, n_features, list_FS, palette, n_ontologies=30, column_sort='Adjusted P-value', plot_type='bar', 
                    list_onto=['KEGG_2019_Mouse', 'WikiPathways_2019_Mouse', 'KEGG_2019_Human', 'WikiPathways_2019_Human',
                               'GO_Biological_Process_2018', 'GO_Cellular_Component_2018', 'GO_Molecular_Function_2018',], save=True, ):
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 3))
    
    dict_dfs = {}
    
    for n_feature_idx, n_feature in enumerate(n_features):
        for FS_idx, FS in enumerate(list_FS):
            df = pd.read_csv(os.getcwd() + f'/exports/enrichr/{lab}_{method}_{org}_{n_feature}_{FS}.csv')
            df = df[df['Gene_set'].isin(list_onto)]
            
            if column_sort == 'Adjusted P-value':
                df = df.sort_values(by=column_sort).iloc[:n_ontologies]
                y_vals = df[column_sort].values
                y_vals = - np.log10(y_vals)

            elif column_sort == 'Combined Score':
                df = df.sort_values(by=column_sort, ascending=False).iloc[:n_ontologies]
                y_vals = df[column_sort].values
            
            elif column_sort == 'division':
                table_vals = df['Overlap'].values
                df['divided'] = [int(i.split('/')[0]) / int(i.split('/')[1]) for i in table_vals]
                df = df.sort_values(by='divided', ascending=False).iloc[:n_ontologies]
                y_vals = df['divided'].values
                
            
            x_pos = n_feature_idx + (FS_idx - len(list_FS) // 2) / (len(list_FS) + 3)
            
            if plot_type == 'bar':
                plt.bar(x_pos , np.mean(y_vals), 
                        width = 0.1, yerr=np.std(y_vals), color=palette[FS_idx])
            elif plot_type == 'scatter':
                plt.scatter([x_pos] * len(y_vals), y_vals, c=palette[FS_idx], alpha=0.8)
            
            dict_dfs[f'{n_feature}_{FS}'] = df
    
    legend_elements = [
        mpl.lines.Line2D(
            [0], [0], marker="o", color=palette[j], label=list_FS[j]
        )
        for j in range(len(list_FS))
    ]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.2, 0.9))
    ax.set_xticks(range(len(n_features)))
    ax.set_xticklabels(n_features)
    ax.set_ylabel(column_sort)
    ax.set_xlabel('Number of features')
    
    plt.tight_layout()
    
    if save:
        plt.savefig(save)
    
    return dict_dfs

def barplot_ontologies_individual(df, axis=None, color="#ababab", column='Adjusted P-value', ascending=False, log=True, y_text=''):
    if axis is None:
        fig, axis = plt.subplots(1, 1, figsize=(10, 7))
    
    vals = df.sort_values(by=column, ascending=ascending)[column].values
    names = [i.split(' (')[0] for i in df.sort_values(by=column, ascending=ascending)['Term'].values]
    names = [i[: 42] + '...' if len(i) > 42 else i for i in names]

    if log:
        vals = - np.log10(df.sort_values(by=column, ascending=ascending)[column].values)
    
    if column == 'Adjusted P-value':
        if log:
            axis.plot(-np.log10([0.05, 0.05]), [-1.5, len(names) + 0.5], c="#ababab", alpha=0.8, linewidth=3, zorder=0)
        else:
            axis.plot([0.05, 0.05], [-1.5, len(names) + 0.5], c="#ababab", alpha=0.8, linewidth=3, zorder=0)
        
    axis.barh(range(len(df)), vals, color=color, zorder=5, alpha=0.7)
    
    for y in range(len(df)):
        axis.text(0.05 * np.max(axis.get_xlim()), y - 0.2, names[y], zorder=10, fontsize=12)
        
    axis.set_yticks([])
    axis.spines['right'].set_visible(False)
    axis.spines['top'].set_visible(False)

    x_text = column if not log else column + ' (log)'
    axis.set_xlabel(x_text)
    axis.set_ylabel(y_text)
    
    return axis

def barplot_ontologies_all(dict_dfs, n_features=1000, list_FSs=['triku', 'std', 'scry', 'scanpy', 'm3drop', 'nbumi'], 
                           list_colors=["#E73F74", "#11A579","#3969AC", "#7F3C8D", "#80BA5A","#E68310"], figsize=(17, 17), save=''):
    
    mpl.rcParams.update({'font.size':17})
    fig, axis = plt.subplots(3, 3, figsize=figsize)
    
    for i in range(len(list_FSs)):
        barplot_ontologies_individual(dict_dfs[f'{n_features}_{list_FSs[i]}'], axis=axis.ravel()[i], 
                                      color=list_colors[i], column='Adjusted P-value', ascending=False, log=True, y_text=list_FSs[i])
    
    plt.tight_layout()
    
    if save:
        plt.savefig(save)
        
    mpl.rcParams.update(mpl.rcParamsDefault)

In [None]:
lab, org, method, n_features = 'ding', 'human', 'Dropseq', [100, 250, 500, 1000, 1250, 1500]
list_dfs_ding_human_dropseq = []
for x in ['Adjusted P-value']:  # ['Combined Score', 'Adjusted P-value', 'division']:
    dict_df = scatter_enrichr(lab, org, method, n_features, list_FS = list_methods, palette=palette, n_ontologies=25, column_sort=x, plot_type='scatter', 
                    list_onto=[ 'GO_Biological_Process_2018',], save=enrichr_figs_dir + f'scatter_{lab}_{org}_{method}_{x}.pdf')
    list_dfs_ding_human_dropseq.append(dict_df)
    
barplot_ontologies_all(list_dfs_ding_human_dropseq[0], save=enrichr_figs_dir + f'barplots_{lab}_{org}_{method}_{x}.pdf', 
                      list_FSs=list_methods, list_colors=palette)

#### Figure S6

In [None]:
lab, org, method, n_features = 'ding', 'human', '10X', [100, 250, 500, 1000, 1250, 1500]
list_dfs_ding_human_dropseq = []
for x in ['Adjusted P-value']:  # ['Combined Score', 'Adjusted P-value', 'division']:
    dict_df = scatter_enrichr(lab, org, method, n_features, list_FS = list_methods, palette=palette, n_ontologies=25, column_sort=x, plot_type='scatter', 
                    list_onto=[ 'GO_Biological_Process_2018',], save=enrichr_figs_dir + f'scatter_{lab}_{org}_{method}_{x}.pdf')
    list_dfs_ding_human_dropseq.append(dict_df)
    
barplot_ontologies_all(list_dfs_ding_human_dropseq[0], save=enrichr_figs_dir + f'barplots_{lab}_{org}_{method}_{x}.pdf', 
                      list_FSs=list_methods, list_colors=palette)

# Distribution of expression per cluster
In this section we are going to study the distribution of expression of HVG across clusters. We are expected that the features selected by better FS methods are specific to fewer clusters. On the other hand, *worse* features are expressed across more clusters. To study this effect, we distribute, for each gene, its expression to sum one, and see what percentage of the total expression is located within the most expressed cluster, the 2nd most expressed cluster, etc. We observe that HVG selected by triku are more biased to be expressed in fewer clusters, compared to the results of other FS methods.

In [None]:
mpl.rc('font', **{'size': 13})

In [None]:
def return_mean_per(matrix):
    # Returns the mean counts per gene, and the proportion of zeros
    n_reads_per_gene = matrix.sum(0).astype(int)
    n_zeros = (matrix == 0).sum(0)

    return n_reads_per_gene/matrix.shape[0], n_zeros/matrix.shape[0]

In [None]:
def get_norm_exp_cluster(adata, gene):
    expression_vals = adata[:, gene].X.ravel()
    expression_vals /= np.sum(expression_vals)
    exp_by_cluster = sorted([sum(expression_vals[adata.obs['leiden'] == str(i)]) for i 
                      in range(np.max(adata.obs['leiden'].astype(int)) + 1)])[::-1]
    return(exp_by_cluster)

#### Figure S5

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(12, 12))
combos = product(*[['mereu', 'ding'], ['human', 'mouse']])
datasets = ['10X', 'SMARTseq2', 'CELseq2', 'inDrop']

cutoff, res = 1750, 1.2

for col_idx, combo in enumerate(list(combos)):
    for row_idx, dataset in enumerate(list(datasets)):
        print('|||', combo, dataset, '|||')
        try:
            adata = sc.read(os.getcwd() + f'/data/{combo[0].capitalize()}_2020/{dataset}_{combo[1]}.h5ad')
            df_ranks = pd.read_csv(os.getcwd() + f'/exports/comparisons/{combo[0]}_{dataset}_{combo[1]}-log_feature_ranks.csv', index_col=0)
        except:
            print(f'Combo {combo} {dataset} does not exist!')
            axs[row_idx, col_idx].axis('off')
            continue
        
        label_pre = '' if col_idx == 0 and row_idx == 0 else '_'  # This is to make the legend only appear for 
        
        
        combined_names = np.intersect1d(df_ranks.index.values, adata.var_names)
        adata = adata[:, combined_names]
        df_ranks = df_ranks.loc[combined_names]

        sc.pp.log1p(adata)
        sc.pp.pca(adata, random_state=seed)
        sc.pp.neighbors(adata, random_state=seed, knn=int(0.5 * (len(adata) ** 0.5)), metric='cosine')
        sc.tl.umap(adata, random_state=seed)
        sc.tl.leiden(adata, resolution=res, random_state=seed)

        for col_rest_idx, col_rest in enumerate(list_methods_all):
            list_mean_exp, list_p_zeros = return_mean_per(adata.X)
            list_genes = adata.var_names

            if col_rest not in ['all', 'random']:
                list_genes = df_ranks[df_ranks[col_rest] < cutoff].index
            elif col_rest == 'all':
                list_genes = df_ranks.index
            elif col_rest == 'random':
                list_genes = np.random.choice(df_ranks.index, cutoff) 

            list_clust = []
            for gene in list_genes:
                exp_clust = get_norm_exp_cluster(adata, gene)
                list_clust.append(exp_clust)

            arr = np.array(list_clust)

            axs[row_idx, col_idx].plot(np.arange(len(exp_clust)), 100 * np.mean(arr, 0), color=palette_all[col_rest_idx], 
                                       alpha=1, label=col_rest, zorder=len(list_methods_all) - col_rest_idx)
            
        
        axs[row_idx, col_idx].set_xlim([0, min(len(exp_clust), 8)])
        axs[row_idx, col_idx].set_xticks(np.arange(min(len(exp_clust), 8) + 1))
        axs[row_idx, col_idx].set_xticklabels(np.arange(min(len(exp_clust), 8) + 1))
        
        if row_idx == 0:
            axs[row_idx, col_idx].set_title(' '.join(combo).capitalize())
        if col_idx == 0:
            axs[row_idx, col_idx].set_ylabel(dataset)
            
            
# add a big axis, hide frame
fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axis
plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)
plt.xlabel("Cluster (most to least expressed)")

plt.ylabel("% of expression")
plt.gca().yaxis.set_label_position("right")

axs[0, 0].legend(ncol=5, frameon=False, bbox_to_anchor=(4, -3.9), labelspacing=0.3)

plt.tight_layout()
plt.savefig(os.getcwd() + f'/figures/comparison_figs/comparison_clusters_benchmarking.pdf')