# Notebook for figure creation of signature length experiment
During the experiment for robustness of scoring methods to limited signature genes, we computed the AUCROC and AUCPR of thescores and the true malignancy labels. The following notebook gathers the results and creates the result heatmaps for CRC, ESCC and LUAD datasets. 

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
import glob
import os
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import sys
sys.path.append('../..')
from data.constants import BASE_PATH_EXPERIMENTS

plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

### Global variables

In [None]:
name_map={
    'AUC_decreasing_log2fc_adjusted_neighborhood_scoring.csv':'ANS',
    'AUC_decreasing_log2fc_jasmine_scoring_lh.csv':'Jasmine_LH',
    'AUC_decreasing_log2fc_jasmine_scoring_or.csv':'Jasmine_OR',
    'AUC_decreasing_log2fc_scanpy_scoring.csv':'Scanpy',
    'AUC_decreasing_log2fc_seurat_ag_scoring.csv':'Seurat_AG',
    'AUC_decreasing_log2fc_seurat_lvg_scoring.csv':'Seurat_LVG',
    'AUC_decreasing_log2fc_seurat_scoring.csv':'Seurat',
    'AUC_decreasing_log2fc_ucell_scoring.csv':'UCell',
}
sc_name_map={
    'adjusted_neighborhood_scoring':'ANS',
    'seurat_scoring':'Seurat',
    'seurat_ag_scoring':'Seurat_AG', 
    'seurat_lvg_scoring':'Seurat_LVG',
    'scanpy_scoring':'Scanpy',
    'jasmine_scoring_lh':'Jasmine_LH',
    'jasmine_scoring_or':'Jasmine_OR',
    'ucell_scoring':'UCell',
}

sc_names = ['ANS','Seurat', 'Seurat_AG','Seurat_LVG','Scanpy', 'Jasmine_LH', 'Jasmine_OR','UCell']

base_path = os.path.join(BASE_PATH_EXPERIMENTS, 'signature_lengths_experiments')

save = True

### Helper functions

In [None]:
def get_full_data(list_fns):
    list_dfs = []
    for fn in list_fns:
        scoring_method = name_map[os.path.basename(fn)]
        tmp_df = pd.read_csv(fn)
        if 'Unnamed: 0' in tmp_df.columns:
            tmp_df = tmp_df.drop(columns=['Unnamed: 0'])
        tmp_df = tmp_df.set_index('signature_length')
        list_dfs.append(tmp_df)
        
    df = pd.concat(list_dfs, axis=1)

    df_aucroc = df[[x for x in df.columns if 'AUCROC' in x]]
    df_aucroc.columns = [sc_name_map['_'.join(x.split('_')[1:])] for x in df_aucroc.columns]

    df_aucpr = df[[x for x in df.columns if 'AUCPR' in x]]
    df_aucpr.columns = [sc_name_map['_'.join(x.split('_')[1:])] for x in df_aucpr.columns]

    

    return df_aucroc, df_aucpr


def create_lineplots(df, dataset, scnames):
    with plt.rc_context({'figure.figsize': (20,8)}):
        for group in scnames:
            sns.lineplot(data=df[df.scoring_method==group], x="signature_length", y="AUC",label=group)
        plt.axhline(0.99, c='r',ls=':',alpha=0.7, label='AUCROC 0.99')
        plt.axhline(0.95, c='orange',ls=':',alpha=0.7, label='AUCROC 0.95')
        plt.axhline(0.9, c='g',ls=':',alpha=0.7, label='AUCROC 0.90')
        plt.legend(fontsize=16)
        plt.title(f'Scoring for signatures of different lengths in decreasing log2FC order for DGEX genes ({dataset.upper()})', fontsize=18)
        plt.xlim([-0.001,250])
        plt.xticks((np.arange(0, 250, 10)), fontsize=16)
        plt.yticks(fontsize=16)
        plt.xlabel('Signature lengths', fontsize=16)
        plt.ylabel('AUCROC', fontsize=16)
    return plt.gcf()

def create_barplot(df, dataset, scnames, mean_samples=True):
    if mean_samples:
        aggregated_performances = df.groupby(by= ['signature_length', 'scoring_method'])['AUC'].mean().reset_index()
    else:
        aggregated_performances = df.groupby(by= ['signature_length', 'scoring_method'])['AUC'].median().reset_index()
        
    list_interesting_data=[]
    new_index = []
    for group in aggregated_performances.groupby(by= ['scoring_method']):
        for auc in [0.8,0.85,0.9,0.95,0.99]:
            new_index.append((group[0],auc))
            list_interesting_data.append(group[1][group[1].AUC >= auc].signature_length.min())
    
    index = pd.MultiIndex.from_tuples(new_index, names=["scoring_method", 'AUC'])

    aggregated_performances = pd.DataFrame(list_interesting_data, index=index, columns=['signature_length']).reset_index()
    aggregated_performances["logFC order"]=' '
    g = sns.catplot(x="logFC order", y="signature_length",
                hue="scoring_method", col="AUC",
                hue_order=sc_names,
                data=aggregated_performances, kind="bar", height=5, aspect=0.5, alpha=1, legend=False);
    g.set_ylabels('Signature length', size=16)
    g.set_titles("AUCROC {col_name}", size=16)
    g.set_xticklabels(size=16)
    for splot in g.axes[0]:
        for p in splot.patches:
            splot.annotate(format(p.get_height(), '.0f'), 
                           (p.get_x() + p.get_width() / 2., p.get_height()), 
                           ha = 'center', va = 'center', 
                           xytext = (0, 9), 
                           textcoords = 'offset points',
                           fontsize=16)
    g.set(xlabel=None)
    g.add_legend(fontsize=16)
    g.axes[0][0].set_yticklabels(g.axes[0][0].get_yticks().astype('int32'), size = 16)
    plt.suptitle(f'Nr. of genes with largest log2FC required in signature to achieve desired AUCROC ({dataset.upper()})', size=18,y=1.075);
    return plt.gcf()
    

def create_heatmaps(df, dataset, scnames, mean_samples=True, short=False):
    if mean_samples:
        grouped_df = df.groupby(['scoring_method', 'signature_length'])['AUC'].aggregate('mean').unstack()
    else:
        grouped_df = df.groupby(['scoring_method', 'signature_length'])['AUC'].aggregate('median').unstack()
        
    if short:
        figsize = (8,20)
        data = grouped_df.iloc[:,0:50].loc[sc_names].T
    else:
        figsize = (10,50)
        data = grouped_df.loc[sc_names].T
        
    plt.figure(figsize=figsize)
    g = sns.heatmap(data, cmap='coolwarm', annot=True)
    g.set_title(f'Mean AUCs over {dataset.upper()} samples for signature scoring methods and different signature lengths', fontsize=18);
    g.set_ylabel("Signature lengths", size=16);
    g.set_xlabel("");
    g.set_yticklabels(g.get_yticklabels(),fontsize=16);
    g.figure.tight_layout()
    return plt.gcf()
        
        
        
def creat_heatmap(data, measurement='AUCROC', block_code=(1,1,1)):
    plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':16})
    asp = 2
    figw = 7
    figh = figw*asp

    cmap = plt.cm.coolwarm
    if sum(block_code)==3:
        height_ratios = [9,1,1]
    elif sum(block_code)==2:
        height_ratios = [9,1]
    else:
        height_ratios = [1]
    gridspec_kw = {"height_ratios":height_ratios, "width_ratios" : [1]}
    heatmapkws = dict(square=False, cbar=False, cmap = cmap, linewidths=0.0, vmin= data.min().min(), vmax= data.max().max() ) 
    tickskw =  dict(xticklabels=False, yticklabels=False)

    left = 0.07; right=0.87
    bottom = 0.1; top = 0.9
    
    fig, axes = plt.subplots(ncols=1, nrows=sum(block_code), figsize=(figw, figh), gridspec_kw=gridspec_kw)
    plt.subplots_adjust(left=left, right=right,bottom=bottom, top=top, wspace=0.1, hspace=0.01*asp)
    i=0
    if block_code[0]==1:
        curr_ax = axes[i] if sum(block_code)>1 else axes
        sns.heatmap(data.iloc[:-4,:], ax=curr_ax, xticklabels=False if sum(block_code)!=(i+1) else True, yticklabels=True, annot=True, **heatmapkws)
        curr_ax.set_ylabel('')
        curr_ax.set_yticklabels(curr_ax.get_yticklabels(), rotation=0)
        if sum(block_code)==(i+1):
            curr_ax.set_xticklabels(curr_ax.get_xticklabels(), rotation=45, fontsize=18)
        i+=1
    if block_code[1]==1:
        curr_ax = axes[i] if sum(block_code)>1 else axes
        sns.heatmap(data.iloc[-4:-2,:], ax=curr_ax, xticklabels=False if sum(block_code)!=(i+1) else True, yticklabels=True,annot=True, **heatmapkws)
        curr_ax.set_ylabel('')
        curr_ax.set_yticklabels(curr_ax.get_yticklabels(), rotation=0)
        if sum(block_code)==(i+1):
            curr_ax.set_xticklabels(curr_ax.get_xticklabels(), rotation=45, fontsize=18)
        i+=1
    if block_code[2]==1:
        curr_ax = axes[i] if sum(block_code)>1 else axes
        sns.heatmap(data.iloc[-2:,:], ax=curr_ax, xticklabels=True, yticklabels=True,annot=True, **heatmapkws)
        curr_ax.set_ylabel('')
        curr_ax.set_yticklabels(curr_ax.get_yticklabels(), rotation=0)
        curr_ax.set_xticklabels(curr_ax.get_xticklabels(), rotation=45, fontsize=18)

    plt.suptitle(f'Signature length robustness ({measurement})', x=0.47, y=0.92)
    return plt.gcf()  


def get_idx(x):
    try:
        idx = x[round(x,2)==1].index[0]
    except:
        idx = x.index[-1]
    return idx


## CRC

In [None]:
dataset = 'crc'

In [None]:
st_path_dec = os.path.join(base_path, dataset, 'decreasing_log2fc', 'AUCROCS')
storing_path = os.path.join(base_path, dataset, 'decreasing_log2fc')

In [None]:
AUC_fns = glob.glob(os.path.join(st_path_dec,'*.csv'))
AUC_fns.sort()
AUC_fns

In [None]:
df_aucroc, df_aucpr = get_full_data(AUC_fns)

In [None]:
df_aucroc = df_aucroc[sc_names]
df_aucroc = df_aucroc.fillna(method='ffill')
indexes_aucroc = df_aucroc.apply(get_idx , axis=0)
df_aucroc

In [None]:
df_aucpr = df_aucpr[sc_names]
df_aucpr = df_aucpr.fillna(method='ffill')
indexes_aucpr = df_aucpr.apply(get_idx , axis=0)
df_aucpr

In [None]:
indexes_aucroc, indexes_aucpr

In [None]:
small_df_aucroc = df_aucroc.iloc[list(np.arange(15))+[indexes_aucroc.Jasmine_OR-2, indexes_aucroc.Jasmine_OR-1,
                                                      indexes_aucroc.UCell-2, indexes_aucroc.UCell-1]]


small_df_aucpr  = df_aucpr.iloc[list(np.arange(15))+[indexes_aucpr.Jasmine_OR-2, indexes_aucpr.Jasmine_OR-1,
                                                     indexes_aucpr.UCell-2, indexes_aucpr.UCell-1]]

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
fig = creat_heatmap(small_df_aucroc, measurement='AUCROC')
plt.close(fig)

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
fig = creat_heatmap(small_df_aucroc, measurement='AUCROC', block_code=(1,1,0))
if save:
    fig.savefig(os.path.join(storing_path, 'aucroc_heatmap.svg'), format='svg')

In [None]:
fig = creat_heatmap(small_df_aucpr, measurement='AUCPR', block_code=(1,1,0))
if save:
    fig.savefig(os.path.join(storing_path, 'aucpr_heatmap.svg'), format='svg')

In [None]:
#fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=True, short=False)
#fig.savefig(os.path.join(storing_path, 'auc_heat_mean_long.svg'), format='svg')
#plt.close(fig)
#fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=True, short=True)
#fig.savefig(os.path.join(storing_path, 'auc_heat_mean_short.svg'), format='svg')
#plt.close(fig)
#fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=False, short=False)
#fig.savefig(os.path.join(storing_path, 'auc_heat_median_long.svg'), format='svg')
#plt.close(fig)
#fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=False, short=True)
#fig.savefig(os.path.join(storing_path, 'auc_heat_median_short.svg'), format='svg')
#plt.close(fig)
#fig = create_barplot(all_performances, dataset, sc_names, mean_samples=True)
#fig.savefig(os.path.join(storing_path, 'auc_bar_mean.svg'), format='svg')
#plt.close(fig)
#fig = create_barplot(all_performances, dataset, sc_names, mean_samples=False)
#fig.savefig(os.path.join(storing_path, 'auc_bar_median.svg'), format='svg')
#plt.close(fig)

## ESCC

In [None]:
dataset = 'escc'

In [None]:
st_path_dec = os.path.join(base_path, dataset, 'decreasing_log2fc', 'AUCROCS')
storing_path = os.path.join(base_path, dataset, 'decreasing_log2fc')

In [None]:
AUC_fns = glob.glob(os.path.join(st_path_dec,'*.csv'))
AUC_fns.sort()

In [None]:
df_aucroc, df_aucpr = get_full_data(AUC_fns)

In [None]:
df_aucroc = df_aucroc[sc_names]
df_aucroc = df_aucroc.fillna(method='ffill')
indexes_aucroc = df_aucroc.apply(get_idx , axis=0)
df_aucroc

In [None]:
df_aucpr = df_aucpr[sc_names]
df_aucpr = df_aucpr.fillna(method='ffill')
indexes_aucpr = df_aucpr.apply(get_idx , axis=0)
df_aucpr

In [None]:
indexes_aucroc, indexes_aucpr

In [None]:
small_df_aucroc = df_aucroc.iloc[list(np.arange(12))+[indexes_aucroc.Jasmine_OR-2, indexes_aucroc.Jasmine_OR-1,
                                                      indexes_aucroc.UCell-2, indexes_aucroc.UCell-1]]


small_df_aucpr  = df_aucpr.iloc[list(np.arange(14))+[indexes_aucpr.Jasmine_OR-2, indexes_aucpr.Jasmine_OR-1,
                                                     indexes_aucpr.UCell-2, indexes_aucpr.UCell-1]]

In [None]:
fig = creat_heatmap(small_df_aucroc, measurement='AUCROC', block_code=(1,0,0))
if save:
    fig.savefig(os.path.join(storing_path, 'aucroc_heatmap.svg'), format='svg')

In [None]:
fig = creat_heatmap(small_df_aucpr, measurement='AUCPR',block_code=(1,0,1))
if save:
    fig.savefig(os.path.join(storing_path, 'aucpr_heatmap.svg'), format='svg')

In [None]:
#fig = create_lineplots(all_performances, dataset, sc_names)
#fig = create_lineplots(all_performances, dataset, sc_names)
#fig.savefig(os.path.join(storing_path, 'line_plot.svg'), format='svg')
#plt.close(fig)
#fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=True, short=False)
#fig.savefig(os.path.join(storing_path, 'auc_heat_mean_long.svg'), format='svg')
#plt.close(fig)
#fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=True, short=True)
#fig.savefig(os.path.join(storing_path, 'auc_heat_mean_short.svg'), format='svg')
#plt.close(fig)
#fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=False, short=False)
#fig.savefig(os.path.join(storing_path, 'auc_heat_median_long.svg'), format='svg')
#plt.close(fig)
#fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=False, short=True)
#fig.savefig(os.path.join(storing_path, 'auc_heat_median_short.svg'), format='svg')
#plt.close(fig)
#fig = create_barplot(all_performances, dataset, sc_names, mean_samples=True)
#fig.savefig(os.path.join(storing_path, 'auc_bar_mean.svg'), format='svg')
#plt.close(fig)
#fig = create_barplot(all_performances, dataset, sc_names, mean_samples=False)
#fig.savefig(os.path.join(storing_path, 'auc_bar_median.svg'), format='svg')
#plt.close(fig)

## LUAD

In [None]:
dataset = 'luad'

In [None]:
st_path_dec = os.path.join(base_path, dataset, 'decreasing_log2fc', 'AUCROCS')
storing_path = os.path.join(base_path, dataset, 'decreasing_log2fc')

In [None]:
AUC_fns = glob.glob(os.path.join(st_path_dec,'*.csv'))
AUC_fns.sort()

In [None]:
df_aucroc, df_aucpr = get_full_data(AUC_fns)

In [None]:
df_aucroc = df_aucroc[sc_names]
df_aucroc = df_aucroc.fillna(method='ffill')
indexes_aucroc = df_aucroc.apply(get_idx , axis=0)
df_aucroc

In [None]:
df_aucpr = df_aucpr[sc_names]
df_aucpr = df_aucpr.fillna(method='ffill')
indexes_aucpr = df_aucpr.apply(get_idx , axis=0)
df_aucpr

In [None]:
indexes_aucroc, indexes_aucpr

In [None]:
small_df_aucroc = df_aucroc.iloc[list(np.arange(20))+[indexes_aucroc.Jasmine_OR-2, indexes_aucroc.Jasmine_OR-1,
                                                      indexes_aucroc.UCell-2, indexes_aucroc.UCell-1]]


small_df_aucpr  = df_aucpr.iloc[list(np.arange(20))+[indexes_aucpr.Jasmine_OR-2, indexes_aucpr.Jasmine_OR-1,
                                                     indexes_aucpr.UCell-2, indexes_aucpr.UCell-1]]

In [None]:
fig = creat_heatmap(small_df_aucroc, measurement='AUCROC', block_code=(1,0,1))
if save:
    fig.savefig(os.path.join(storing_path, 'aucroc_heatmap.svg'), format='svg')

In [None]:
fig = creat_heatmap(small_df_aucpr, measurement='AUCPR',block_code=(1,0,1))
if save:
    fig.savefig(os.path.join(storing_path, 'aucpr_heatmap.svg'), format='svg')

In [None]:
# fig = create_lineplots(all_performances, dataset, sc_names)
# fig = create_lineplots(all_performances, dataset, sc_names)
# fig.savefig(os.path.join(storing_path, 'line_plot.svg'), format='svg')
# plt.close(fig)
# fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=True, short=False)
# fig.savefig(os.path.join(storing_path, 'auc_heat_mean_long.svg'), format='svg')
# plt.close(fig)
# fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=True, short=True)
# fig.savefig(os.path.join(storing_path, 'auc_heat_mean_short.svg'), format='svg')
# plt.close(fig)
# fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=False, short=False)
# fig.savefig(os.path.join(storing_path, 'auc_heat_median_long.svg'), format='svg')
# plt.close(fig)
# fig = create_heatmaps(all_performances, dataset, sc_names, mean_samples=False, short=True)
# fig.savefig(os.path.join(storing_path, 'auc_heat_median_short.svg'), format='svg')
# plt.close(fig)
# fig = create_barplot(all_performances, dataset, sc_names, mean_samples=True)
# fig.savefig(os.path.join(storing_path, 'auc_bar_mean.svg'), format='svg')
# plt.close(fig)
# fig = create_barplot(all_performances, dataset, sc_names, mean_samples=False)
# fig.savefig(os.path.join(storing_path, 'auc_bar_median.svg'), format='svg')
# plt.close(fig)