In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from ktest.utils_matplotlib import *
from ktest.tester import Ktest
from ktest.utils_courtine import *
import matplotlib as mpl


In [None]:
def colorFader(c1,c2,mix=0): 
    c1=np.array(mpl.colors.to_rgb(c1))
    c2=np.array(mpl.colors.to_rgb(c2))
    return mpl.colors.to_hex((1-mix)*c1 + mix*c2)

def center_in_input_space(self,center_by=None,center_non_zero_only=False):
    if center_by == 'replicate':
        data_r = self.get_data(condition='replicate',dataframe=True,in_dict=True)
        all_data_ = self.get_data(condition='replicate',dataframe=True,in_dict=False)
        dfc = all_data_[all_data_!=0].mean() if center_non_zero_only else all_data_.mean()
        data_ ={}
        for k,v in data_r.items():
            if center_non_zero_only:
                v[v!=0] = v[v!=0] - v[v!=0].mean() + dfc
                data_[k] = v
            else:
                data_[k] = v - v.mean() + dfc
    elif center_by == '#-replicate_+label':    
        data_r = self.get_data(condition='replicate',dataframe=True)
        data_l = self.get_data(condition='label',dataframe=True)
        if center_non_zero_only:
            mean_l = {k:v[v!=0].mean() for k,v in data_l.items()}
            mean_r = {k:v[v!=0].mean() for k,v in data_r.items()}
        else:
            mean_l = {k:v.mean() for k,v in data_l.items()}
            mean_r = {k:v.mean() for k,v in data_r.items()}
        metadata_l = self.get_metadata(condition='label',in_dict=True)
        r_in_l = {k:metadata_l[k]['replicate'].unique().to_list() for k in metadata_l.keys()}
        data_ = {}
        for l,ml in mean_l.items():
            for r,dr in data_r.items():
                if r in r_in_l[l]:
                    if center_non_zero_only:
                        dr[dr!=0] = dr[dr!=0] - mean_r[r] + ml
                        data_[r] = dr
                    else:
                        data_[r] =  dr - mean_r[r] + ml
    new_df = pd.concat(data_.values())
    new_meta = self.get_metadata(condition='label',in_dict=False)
    new_meta = new_meta.iloc[new_meta.index.get_indexer(new_df.index)].copy()
    return(new_df,new_meta)



def load_pvals_of_DE_method(dataset,de):
    path = '/home/data/squair/sc_results/'
    file = f'{dataset}-de_test={de}.csv'
    index_col = 7 if 'DESeq2' in de else 1
    df = pd.read_csv(f'{path}{file}',index_col=index_col)['p_val_adj']
    return(df)

def load_pvals_of_bulk_DE_method(dataset,de):
    path = '/home/data/squair/bulk_results/'
    if 'Reyfman2020' in dataset:
        file=  f'{dataset}_results.tsv.gz'
        df = pd.read_csv(f'{path}{file}',sep='\t')['padj']
    else:
        file = f'{dataset}-de_test={de}.rds'
        index_col = 6 if 'DESeq2' in de else 0
        col = 'adj.P.Val' if  'limma' in de else 'padj' if 'DESeq' in de else 'PValue' #if 'edgeR' in de
        df = pd.read_csv(f'{path}{file}',index_col=index_col)
        print(df.columns)
        df = df[col]
    return(df)

def load_pvals_of_KFDA_method(dataset,de):
    path_courtine = '/home/data/squair/ktest_results/'
    t = de.split(sep='_t')[-1]
    model = 'standard_sn_' if dataset == 'Angelidis2019_alvmac' else 'ny_lmrandom_m499_basisw_sn_'
    if 'kgq20' in de:
        model = model.replace('_sn','_kgq0.2sn')
    if 'kgq80' in de:
        model = model.replace('_sn','_kgq0.8sn')
    if 'klin' in de:
        model = model.replace('_sn','_lin_sn')
        
    folder = 'cbrlis_cbnz_' if 'Angelidis' in dataset or 'Reyfman' in dataset else 'cbris_cbnz_'
    spec = 'crlis_cnz_' if 'Angelidis' in dataset or 'Reyfman' in dataset else 'cris_cnz_'
    for kernel in ['kfzig','kgq20','kgq80','klin']:
        if kernel in de:
            folder+=f'{kernel}_'
    path = f'{path_courtine}Courtine_KFDA_{folder}results_adj/'
    file = f'DEA_{model}_{dataset}{spec}label_univariate.csv'
    col = f'DEA_{model}_{dataset}{spec}label_t{t}_pvalBH'        
    col = f'DEA_{model}_{dataset}{spec}label_t{t}'        
    cols = [f'{col}_pvalBH',f'{col}_kfda']
    if file in os.listdir(path):
        df=  pd.read_csv(f'{path}{file}',index_col=0)
        df = df[cols]
        df.columns = [de,f'{de}_kfda'] 
        return(df)
    else:
        print(f' not found : {file} \n in {path} ')

def names(de):
    de = de.replace('sn_','').replace('kfzig_','-ZI-kernel')
    de = de.replace('kgq','-G-kernel-bw=q')
    de = de.replace('q20_','q20')
    de = de.replace('q80_','q80')
    de = de.replace('klin_','-linear-kernel')
    de = de.replace('cbis_','')
    de = de.replace('cbisreplicate_','')
    de = de.replace('kfda_','ktest')
    de = de.replace(',mode?','-').replace(',test?','-').replace('bulk_','bulk-')
    de = de.replace('t4','').replace('t1','')
    de = de.replace('pseudobulk-','')
    return(de)

def filter_df(df,column,min_value=None,max_value=None):
    df_= df.copy()
    print(column)
    if min_value is not None:
        df_ = df_[df_[column]>min_value]
    if max_value is not None:
        df_ = df_[df_[column]<max_value]
    return(df_)

def get_accepted_by(df,de_methods,threshold=.05):
    df_ = df.copy()
    for de in de_methods:
        df_ = filter_df(df_,de,min_value=threshold)
    return(df_)
def get_rejected_by(df,de_methods,threshold=.05):
    df_ = df.copy()
    for de in de_methods:
        df_ = filter_df(df_,de,max_value=threshold)
    return(df_)   


def get_pvals_of(dataset,de,ntop=None):
    if dataset not in dfs:
        print(f'dataset {dataset} not in dfs')
    else:
        if de not in dfs[dataset]:
            print(f"DE method {de} not in dfs['{dataset}']")
        else:
            df = dfs[dataset][de]
            if ntop is not None:
                df = df.sort_values()[:ntop]
            return(df)

def points_in_boxplot(df,ax,colors=None,vert=False):

    # ajouter les points 
    vals, ys = [], [] 
    for i,c in enumerate(df.columns):
        vals.append(df[c])
        ys.append(np.random.normal(i+1, 0.04, len(df)))
    ngroup = len(vals)
    clevels = np.linspace(0., 1., ngroup)
    colors = list(colors.values()) if colors is not None else [None]*len(ys)
    for x, val, color in zip(ys, vals, colors):
        if vert :
            ax.scatter(x,val, c=color, alpha=1)  
        else:
            ax.scatter(val,x,c=color,alpha=1)
def filled_boxplot(df,ax,colors=None,alpha=.5,vert=False):
    bp = df.boxplot(ax=ax,return_type='both',patch_artist=True,vert=vert)

    for i,patch in enumerate(bp[1]['boxes']):
        if colors is not None:
            color=list(colors.values())[i]
            patch.set(facecolor=color,edgecolor=color)

        patch.set(alpha=alpha,
                 fill=True,
                 linewidth=0)

def contours_boxplot(df,ax,colors=None,lw=3,vert=False):
    bp = df.boxplot(ax=ax,return_type='both',patch_artist=True,vert=vert)
    for i,patch in enumerate(bp[1]['boxes']):
        if colors is not None:
            color=list(colors.values())[i]
            patch.set(edgecolor=color)

        patch.set(alpha=1,
                 fill=False,
                 linewidth=lw)
    if colors is not None:
        for i,cap in enumerate(bp[1]['caps']):
            cap.set(color=list(colors.values())[i//2],linewidth=lw)

        for i,whisker in enumerate(bp[1]['whiskers']):
            whisker.set(color=list(colors.values())[i//2],linewidth=lw)
        
def custom_boxplot(df,colors=None,alpha=.5,lw=3,scatter=True,fig=None,ax=None,vert=True):
    if fig is None:
        fig,ax = plt.subplots(figsize=(20,7))
    
    if colors is not None:
        colors = {s:c for s,c in colors.items() if s in df} 
        df = df[list(colors.keys())]
    
    filled_boxplot(df=df,colors=colors,alpha=alpha,ax=ax,vert=vert)
    contours_boxplot(df=df,colors=colors,lw=lw,ax=ax,vert=vert)
    if scatter: 
        points_in_boxplot(df=df,colors=colors,ax=ax,vert=vert)
    return(fig,ax)

    
def boxplot_of_top_de_genes(de_methods,metric="mean",ntop=200,fig=None,ax=None,colors=None,order=True):        
    output = {}
    for de in de_methods:
        output[de] = {}
        for dataset in datasets:
            if de in dfs[dataset]:
                top_genes = get_pvals_of(dataset,de,ntop=ntop).index
                measure = dfs[dataset][metric].loc[top_genes].mean()
                output[de][dataset] = measure
    
    df_output = pd.DataFrame(output) 
    if order:
        de_methods = df_output.mean().sort_values(ascending=False).index
    for c in df_output.columns:
        df_output[names(c)] = df_output[c]
    de_methods = [names(de) for de in de_methods]
    fig,ax = custom_boxplot(df_output[de_methods],vert=False,colors=colors,fig=fig,ax=ax)
    return(fig,ax)



In [3]:
dfs = {}
datasets =[
 'Reyfman2020_pneumo', 'Reyfman2020_alvmac',
 'Angelidis2019_alvmac', 'Angelidis2019_pneumo',
 'Hagai2018_mouse-pic', 'Hagai2018_rat-lps',
 'Hagai2018_rabbit-lps', 'Hagai2018_rat-pic',
 'Hagai2018_mouse-lps',  'Hagai2018_pig-lps',
'CanoGamez2020_Memory-Th0','CanoGamez2020_Naive-Th2', 'CanoGamez2020_Naive-iTreg',
 'CanoGamez2020_Naive-Th0', 'CanoGamez2020_Naive-Th17',
 'CanoGamez2020_Memory-iTreg', 'CanoGamez2020_Memory-Th2',
 'CanoGamez2020_Memory-Th17',]
nd = len(datasets)

sc = ['MAST','wilcox', 'poisson', 'bimod', 'negbinom','LR','t', ]
bulk = ['bulk_limma,mode?voom', 'bulk_limma,mode?trend', 'bulk_DESeq2,test?Wald', 'bulk_edgeR,test?QLF', 'bulk_edgeR,test?LRT', 'bulk_DESeq2,test?LRT']
pseudobulk_no_batch = ['pseudobulk_limma,mode?voom,replicate?cells', 'pseudobulk_limma,mode?trend,replicate?cells', 'pseudobulk_DESeq2,test?Wald,replicate?cells', 'pseudobulk_edgeR,test?QLF,replicate?cells', 'pseudobulk_edgeR,test?LRT,replicate?cells', 'pseudobulk_DESeq2,test?LRT,replicate?cells']
pseudobulk = ['pseudobulk_limma,mode?voom', 'pseudobulk_limma,mode?trend', 'pseudobulk_DESeq2,test?Wald', 'pseudobulk_edgeR,test?QLF', 'pseudobulk_edgeR,test?LRT', 'pseudobulk_DESeq2,test?LRT']
courtine = sc+pseudobulk


kfdas = ['kfda_cbis_t4', 'kfda_cbis_kfzig_t4', 'kfda_cbis_klin_sn_t1']
kfda4 = ['kfda_cbisreplicate_sn_t4', 'kfda_cbisreplicate_kfzig_sn_t4', 'kfda_cbisreplicate_klin_sn_t1']

groups = {"AUCC based hierarchical clustering":bulk+pseudobulk+sc+kfda4,}
de_boxplot = kfdas+courtine
nm = len(de_boxplot)



bulk = ['bulk_limma,mode?voom', 'bulk_limma,mode?trend', 'bulk_DESeq2,test?Wald', 'bulk_edgeR,test?QLF', 'bulk_edgeR,test?LRT', 'bulk_DESeq2,test?LRT']
pseudobulk_no_batch = ['pseudobulk_limma,mode?voom,replicate?cells', 'pseudobulk_limma,mode?trend,replicate?cells', 'pseudobulk_DESeq2,test?Wald,replicate?cells', 'pseudobulk_edgeR,test?QLF,replicate?cells', 'pseudobulk_edgeR,test?LRT,replicate?cells', 'pseudobulk_DESeq2,test?LRT,replicate?cells']
pseudobulk = ['pseudobulk_limma,mode?voom', 'pseudobulk_limma,mode?trend', 'pseudobulk_DESeq2,test?Wald', 'pseudobulk_edgeR,test?QLF', 'pseudobulk_edgeR,test?LRT', 'pseudobulk_DESeq2,test?LRT']


In [None]:
colors= {}
lws = {}

for dataset in datasets: 
    if 'Hagai' in dataset:
        colors[dataset] = 'xkcd:periwinkle'
    if 'Angelidis' in dataset:
        colors[dataset] = 'xkcd:azure'
    if 'CanoGamez' in dataset:
        colors[dataset] = 'xkcd:light orange'
    if 'Reyfman' in dataset:
        colors[dataset] = 'xkcd:brownish'

for de in kfdas:
    t = int(de.split(sep='_t')[-1])
    lws[de] = 3
    if 'cbis_kfzig' in de:
        colors[de]= colorFader('xkcd:pale pink','xkcd:bright pink',t/7)
    elif 'cbis_klin' in de:
        colors[de]= 'xkcd:tan'
    elif 'cbis_kgq' in de:
        q = int(de.split(sep='kgq')[1].split(sep="_")[0])
        colors[de] = colorFader('xkcd:faded green','xkcd:emerald',q/100)
    elif 'cbis' in de:
        colors[de] = colorFader('xkcd:light green','xkcd:forest green',t/7)
        colors[de] = colorFader('xkcd:faded green','xkcd:emerald',50/100)
    elif 'cbfs' in de:
        colors[de]= colorFader('xkcd:pale purple','xkcd:royal purple',t/7)
    else:
        colors[de]= colorFader('xkcd:light yellow','xkcd:burnt sienna',t/7)
for i,de in enumerate(sc):
    colors[de] = colorFader('xkcd:baby blue','xkcd:royal blue',i/len(sc))
    lws[de] = .5
for i,de in enumerate(pseudobulk):
        colors[de] =  colorFader('xkcd:peach','xkcd:brick red',i/len(pseudobulk))
        lws[de] = .5
for i,de in enumerate(bulk):
        colors[de] =  colorFader('xkcd:tomato','xkcd:lipstick red',i/len(bulk))
        lws[de] = .5
colors = {names(de):color for de,color in colors.items()}
colors

In [None]:
path_var = '/home/data/squair/ktest_results/multivariate/'
for dataset in datasets:
    if dataset not in dfs:
        for file in os.listdir(path_var):
            if dataset in file and 'var.csv' in file and 'cb_replicate' not in file:
                print(file)
                dfs[dataset] = pd.read_csv(path_var+file,index_col=0)
for dataset in datasets:
    for de in courtine:
        if de not in dfs[dataset]:
            df = load_pvals_of_DE_method(dataset=dataset,de=de)
            if df is not None:
                dfs[dataset][de] = df.copy()
for dataset in datasets:
    for de in kfdas:
        if de not in dfs[dataset]:
            print(de)
            df =  load_pvals_of_KFDA_method(dataset=dataset,de=de)
            if df is not None:
                dfs[dataset] = pd.concat([dfs[dataset],df],axis=1)
for dataset in datasets:
    print(dataset,end=' ')
    nl3  = (dfs[dataset]['ne']<3).sum()
    df = dfs[dataset] 
    df = df[df['ne']>=3]
    dfs[dataset] = df
for dataset in datasets:
    df = dfs[dataset]
    df = df[df['nz'].isna()]
    if len(df)>0:
        print('\n\n',dataset,len(df))
    else:
        print('####',dataset)

In [None]:
path = '/home/data/squair/AUCC/couples/'
de_methods = []
for file in os.listdir(path):
    if file != 'old_sn':
        de1 = file.split(sep='de1_')[1].split(sep='_de2')[0]
        de2 = file.split(sep='de2_')[1].split(sep='_k')[0]
#         print(file)
        k = file.split(sep='_k')[1].split(sep='.csv')[0]
        k = int(file.replace('kfda','').split(sep='_k')[-1].split(sep='.csv')[0])
        de_methods += [de1,de2]
de_methods = list(set(de_methods))
print(len(de_methods),de_methods)



In [None]:

c1 = colorFader('xkcd:pale pink','xkcd:bright pink',4/7)
c2 = colorFader('xkcd:light green','xkcd:forest green',4/7)
fig = plt.figure(figsize=(18,18), layout="constrained")
spec = fig.add_gridspec(2, 2)
ax0 = fig.add_subplot(spec[0, :])
ax1 = fig.add_subplot(spec[1, 0])
ax2 = fig.add_subplot(spec[1, 1])
tree_colors = ['xkcd:brown','xkcd:brick red','xkcd:royal blue','black']
for group_name,de_methods in groups.items() : 
    m = np.zeros([len(de_methods),len(de_methods)],np.float64)
    for i,de1 in enumerate(de_methods):
        for j,de2 in enumerate(de_methods):
            for file in os.listdir(path):
                if f'de1_{de1}_de2_{de2}_k500' in file:
                    df = pd.read_csv(path+file,index_col=0)
                    aucc = df.mean(axis=1).values[0]
                    m[i,j] = np.max([aucc,m[j,i]])
                    m[j,i] = np.max([aucc,m[i,j]])
    m.mean(axis=1)
    ax = ax0
    d,fig,ax=plot_dendrogram_from_distance_matrix(1-m,de_methods,fig=fig,ax=ax)
    xticklabels =[names(de._text) for de in ax.get_xticklabels()]
    ax.set_xticklabels(xticklabels,fontsize=25,rotation=90)
    [de.set_color(colors[names(de._text)]) for de in ax.get_xticklabels()]
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
    yticklabels = [1-float(l._text) for l in ax.get_yticklabels()]
    ax.set_yticklabels(yticklabels,fontsize=20)
    for i,c in enumerate(ax.__dict__['_children']):
        c.set_edgecolors('black')#tree_colors[i])#'xkcd:royal blue')

ntops = [500]
for de_methods in [de_boxplot,]:
    for i,metric in enumerate(['mean','pz']):
        for _,ntop in enumerate(ntops):
            ax = ax1 if metric == 'mean' else ax2
            boxplot_of_top_de_genes(de_methods,colors=colors,fig=fig,ax=ax,ntop=ntop,metric=metric)
            title = 'average gene expression' if metric == 'mean' else 'proportion of zeros'
            ax.set_xlabel(title,fontsize=25)
            xticklabels = [l._text for l in ax.get_xticklabels()]
            ax.set_xticklabels(xticklabels,fontsize=20)
            if i ==1:
                a = 0
                ax.set_yticklabels([])
            else:
                yticklabels = [names(de._text) for de in ax.get_yticklabels()]
                ytickcolors = {names(de._text):colors[de._text] for de in ax.get_yticklabels()}
                ax.set_yticklabels([names(de._text) for de in ax.get_yticklabels()],fontsize=25)
                [de.set_color(ytickcolors[de._text]) for de in ax.get_yticklabels()]
plt.show()
fig.savefig('/home/figures/squair/main_squair.pdf',bbox_inches='tight')