In [None]:
import os 
import sys
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sys.path.append('../..')

from data.constants import BASE_PATH_EXPERIMENTS, BASE_PATH_DRIVE

In [None]:
crc_scores = pd.read_csv(os.path.join(BASE_PATH_EXPERIMENTS, 'data_composition_experiments/crc/mean_norm/dgex_on_pseudobulk/crc_adata_obs.csv'))

In [None]:
escc_scores = pd.read_csv(os.path.join(BASE_PATH_EXPERIMENTS, 'data_composition_experiments/escc/mean_norm/dgex_on_pseudobulk/escc_adata_obs.csv'))

In [None]:
subnames = ['Scanpy', 'Tirosh', 'ANS', 'Jasmine', 'UCell']
crc_scores_cols = [col for col in crc_scores.columns if any(map(col.__contains__, subnames))]
escc_scores_cols = [col for col in escc_scores.columns if any(map(col.__contains__, subnames))]

In [None]:
crc_scores.SINGLECELL_TYPE.value_counts()

In [None]:
crc_SC3Pv2_scores = crc_scores[crc_scores.SINGLECELL_TYPE == 'SC3Pv2' ][['malignant_key'] + crc_scores_cols].copy()
crc_SC3Pv2_scores['dataset'] = 'CRC_SC3Pv2'
crc_SC3Pv3_scores = crc_scores[crc_scores.SINGLECELL_TYPE == 'SC3Pv3' ][['malignant_key'] + crc_scores_cols].copy()
crc_SC3Pv3_scores['dataset'] = 'CRC_SC3Pv3'
escc_scores_cols = escc_scores[['malignant_key'] + escc_scores_cols].copy()
escc_scores_cols['dataset'] = 'ESCC'

In [None]:
data = pd.concat([crc_SC3Pv2_scores, crc_SC3Pv3_scores, escc_scores_cols], axis=0)

In [None]:
data

In [None]:
data = data.groupby(['dataset', 'malignant_key']).var()

In [None]:
data = data.reset_index()

In [None]:
data = data.melt(id_vars=['dataset', 'malignant_key'],
                 var_name='scoring_method',
                 value_name='score')

In [None]:
# name_mapping = {'all_samples': 'Scoring all samples together',
#                     'si_ppas': 'Scoring each sample individually (preprocessed together)',
#                     'si_ppsi': 'Scoring each sample individually (preprocessed independently)',
#                     }
name_mapping = {'all_samples': 'Scoring all samples together',
                    'si_ppas': 'not_used',
                    'si_ppsi': 'Scoring each sample individually',
                    }

data['scoring_mode'] = data.scoring_method.apply(lambda x: name_mapping['_'.join(x.split('_')[-2:])])
data['scoring_method'] = data.scoring_method.apply(lambda x: '_'.join(x.split('_')[0:-4]))
data = data[data.scoring_mode!='not_used'].copy()

In [None]:
data.scoring_method.unique()

In [None]:
sc_method_name_mapping = {
    'ANS':'ANS', 
    'Tirosh':'Seurat',
    'Tirosh_AG':'Seurat_AG',
    'Tirosh_LVG':'Seurat_LVG',
    'Scanpy':'Scanpy',
    'Jasmine_LH':'Jasmine_LH',
    'Jasmine_OR':'Jasmine_OR',
    'UCell':'UCell',
}
data['scoring_method'] = data['scoring_method'].map(sc_method_name_mapping)

In [None]:
data['ds_method'] = data.apply(lambda x: f"{x.dataset}_{x.scoring_method}", axis=1)

In [None]:
data.scoring_mode.unique()

In [None]:
# order=['Scoring each sample individually (preprocessed independently)',
#                        'Scoring each sample individually (preprocessed together)',
#                        'Scoring all samples together']
order=['Scoring each sample individually',
       'Scoring all samples together']

In [None]:
data.groupby(['dataset', 'scoring_method', 'malignant_key']).apply(print)

In [None]:
tmp = data.groupby(['dataset', 'scoring_method', 'malignant_key']).apply(lambda x: x.score.diff().values[1])

In [None]:
display(tmp)

In [None]:
import textwrap

def wrap_labels(ax, width, break_long_words=False):
    """
    Method to wrap ticklabels to a certain length.
    Args:
        ax: Figure axis
        width: Desired max width of a label before breaking.
        break_long_words: Indicate whether long words should be broken.
    """
    labels = []
    for label in ax.get_xticklabels():
        text = label.get_text()
        labels.append(textwrap.fill(text, width=width,
                                    break_long_words=break_long_words))
    ax.set_xticklabels(labels, rotation=0)

In [None]:
row_order = ['ANS', 'Seurat', 'Seurat_AG','Seurat_LVG','Scanpy' ,'Jasmine_LH','Jasmine_OR', 'UCell']
cm = 1 / 2.54  # centimeters in inches

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':10})

g = sns.catplot(data=data, 
                x="scoring_mode", 
                y="score", 
                hue="dataset", 
                row = 'scoring_method',
                row_order = row_order,
                col='malignant_key', 
                kind="point",
                order=order,
                height=3*cm,
                aspect=1.5,
               )
for ax in g.axes[-1,:]:
        wrap_labels(ax, 15, break_long_words=False)
        
g.set_xlabels("")
# g.set_ylabels("Variance")
g.set_titles("")
g.axes[0,0].set_title('Malignant cells',fontsize=10)
g.axes[0,1].set_title('Non-malignant cells',fontsize=10)

for ax, title in zip(g.axes[:,0], row_order):
    ax.set_ylabel(f"{title} \nscore std", fontsize=10, rotation=90)
#     ax.yaxis.set_label_position("right")
    
plt.subplots_adjust(wspace=0.4)
plt.subplots_adjust(hspace=0.6)

g.fig.tight_layout()

g.fig.savefig(os.path.join(BASE_PATH_DRIVE,'figures/supplementary/benchmark/score_variances.pdf', format='pdf'))