In [None]:
import os 
import sys
import pandas as pd
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import mannwhitneyu, shapiro, normaltest, levene, ttest_ind
from matplotlib.patches import Patch
import warnings
warnings.filterwarnings("ignore")

from signaturescoring import score_signature

from data_composition import wrap_labels

sys.path.append('../..')
from data.constants import BASE_PATH_EXPERIMENTS, BASE_PATH_DATA

plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 16})

In [None]:
data_fn = os.path.join(BASE_PATH_EXPERIMENTS,'data_composition_experiments/crc/mean_norm/dgex_on_pseudobulk/crc_adata_with_scores.h5ad')
dgex_genes = os.path.join(BASE_PATH_DATA,'dgex_genes/crc/mean_norm/on_pseudobulk/dgex_genes.csv')

In [None]:
adata = sc.read_h5ad(data_fn)

In [None]:
dgex_genes = pd.read_csv(dgex_genes)

In [None]:
def get_pvals_chemistry(adata):
    subnames = ['Scanpy', 'Tirosh', 'ANS', 'Jasmine', 'UCell']
    cols = [col for col in adata.obs.columns if any(map(col.__contains__, subnames))]
    res = adata.obs.groupby(by=['sample_id', 'malignant_key', 'SINGLECELL_TYPE'])[cols].mean().reset_index()
    res = res.dropna(axis=0, how='all', subset=res.columns.tolist()[3:])
    res = res.melt(id_vars=['sample_id', 'SINGLECELL_TYPE', 'malignant_key'],
                   var_name='scoring_method',
                   value_name='scores')
    
    name_mapping = {'all_samples': 'Scoring all samples together',
                    'si_ppas':'Scoring each sample individually (preprocessed together)',
                    'si_ppsi':'Scoring each sample individually (preprocessed independently)',
                    }
    
    res['scoring_mode'] = res.scoring_method.apply(lambda x: name_mapping['_'.join(x.split('_')[-2:])])
    res['scoring_method'] = res.scoring_method.apply(lambda x: '_'.join(x.split('_')[0:-4]))

    pvals = []
    score_means = []
    alpha = 0.05
    for (key, df) in res.groupby(['scoring_method', 'scoring_mode']):
        mal_p2 = df.scores[(df.SINGLECELL_TYPE == 'SC3Pv2') & (df.malignant_key == 'malignant')].values
        mal_p3 = df.scores[(df.SINGLECELL_TYPE == 'SC3Pv3') & (df.malignant_key == 'malignant')].values

        are_normal_shapiro = shapiro(mal_p3).pvalue > alpha and shapiro(mal_p2).pvalue > alpha
        are_normal_normaltest = normaltest(mal_p3).pvalue > alpha and normaltest(mal_p2).pvalue > alpha
        variance_test = levene(mal_p3, mal_p2).pvalue > alpha
        pvals.append({'scoring_method': key[0],
                      'scoring_mode': key[1],
                      'MannWhitneyU p-val': mannwhitneyu(mal_p3, mal_p2).pvalue,
                      'MannWhitneyU statistic': mannwhitneyu(mal_p3, mal_p2).statistic,
                      'are_normal_shapiro': are_normal_shapiro,
                      'are_normal_normaltest': are_normal_normaltest,
                      'equal_variance': variance_test,
                      'ttest p-val': ttest_ind(mal_p3, mal_p2, equal_var=variance_test).pvalue,
                      'ttest statistic': ttest_ind(mal_p3, mal_p2, equal_var=variance_test).statistic,
                      })
        score_means.append({
            'scoring_method': key[0],
            'scoring_mode': key[1],
            'mal_p2_mean': np.mean(mal_p2),
            'mal_p2_var': np.var(mal_p2),
            'mal_p2_min': np.min(mal_p2),
            'mal_p2_max': np.max(mal_p2),
            'mal_p3_mean': np.mean(mal_p3),
            'mal_p3_var': np.var(mal_p3),
            'mal_p3_min': np.min(mal_p3),
            'mal_p3_max': np.max(mal_p3),
        })

    pvals = pd.DataFrame(pvals)
    score_means = pd.DataFrame(score_means)
    return pvals, score_means

In [None]:
pvals, score_means = get_pvals_chemistry(adata)

In [None]:
def plot_pval_heatmaps(pvals, storing_path = None):
    plt.figure(figsize=(6, 10))
    sns.heatmap(pvals[pvals.scoring_mode == 'Scoring all samples together'].set_index('scoring_method')[['MannWhitneyU p-val', 'ttest p-val']], annot=True)
    plt.yticks(rotation=0);
    plt.title('Scoring all samples together');
    plt.show()
    
    plt.figure(figsize=(6, 10))
    sns.heatmap(pvals[pvals.scoring_mode == 'Scoring each sample individually (preprocessed together)'].set_index('scoring_method')[['MannWhitneyU p-val', 'ttest p-val']], annot=True)
    plt.yticks(rotation=0);
    plt.title('Scoring each sample individually (preprocessed together)');
    plt.show()
    
    plt.figure(figsize=(6, 10))
    sns.heatmap(pvals[pvals.scoring_mode == 'Scoring each sample individually (preprocessed independently)'].set_index('scoring_method')[['MannWhitneyU p-val', 'ttest p-val']], annot=True)
    plt.yticks(rotation=0);
    plt.title('Scoring each sample individually (preprocessed independently)');
    plt.show()

In [None]:
plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 16})
plot_pval_heatmaps(pvals)

In [None]:
from statannotations.Annotator import Annotator

In [None]:
def create_strip_plot_crc(adata, exclude_dge_si_ppsi=False, exclude_jasmine=False):
    plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 14})
    means_std_per_sample = []
    for (key, adata_obs) in adata.obs.groupby('sample_id'):
        curr_means = adata_obs.groupby(by=['SINGLECELL_TYPE', 'malignant_key']).mean()
        curr_means['sample_id'] = key
        curr_means['scoring_mode'] = 'Scoring on all samples' if exclude_dge_si_ppsi else 'Scoring on all samples (' \
                                                                                          'preprocessed together)'
        means_std_per_sample.append(curr_means)
    scmeans_per_chem_mal = pd.concat(means_std_per_sample).reset_index()
    scmeans_per_chem_mal = scmeans_per_chem_mal.drop(columns=['TumorSize', 'SizeQuantile', 'Age',
                                                              'n_counts', 'n_genes_by_counts',
                                                              'total_counts', 'total_counts_mt',
                                                              'pct_counts_mt', 'S_score', 'G2M_score',
                                                              'iCMS2_GT', 'iCMS3_GT',
                                                              'log_counts', ])
    scmeans_per_chem_mal.dropna(inplace=True)
    # change structure of data
    scmeans_per_chem_mal = scmeans_per_chem_mal.melt(
        id_vars=['sample_id', 'SINGLECELL_TYPE', 'malignant_key', 'scoring_mode'],
        var_name='scoring_method',
        value_name='scores'
    )
    if exclude_dge_si_ppsi:
        scmeans_per_chem_mal = scmeans_per_chem_mal[~scmeans_per_chem_mal.scoring_method.str.contains('si_ppsi')].copy()
        scmeans_per_chem_mal.loc[scmeans_per_chem_mal.scoring_method.str.contains(
            'si_ppas'), 'scoring_mode'] = 'Scoring samples individually'
    else:
        scmeans_per_chem_mal.loc[scmeans_per_chem_mal.scoring_method.str.contains(
            'si_ppsi'), 'scoring_mode'] = 'Scoring samples individually (preprocessed individually)'
        scmeans_per_chem_mal.loc[scmeans_per_chem_mal.scoring_method.str.contains(
            'si_ppas'), 'scoring_mode'] = 'Scoring samples individually (preprocessed together)'

    scmeans_per_chem_mal['mal_nd_chem'] = scmeans_per_chem_mal['malignant_key'].astype(str) + \
                                          ' with ' + scmeans_per_chem_mal['SINGLECELL_TYPE'].astype(str)
    scmeans_per_chem_mal['scmethod_nd_chem'] = scmeans_per_chem_mal['scoring_mode'].astype(
        str) + ' with ' + scmeans_per_chem_mal['SINGLECELL_TYPE'].astype(str)

    most_means = scmeans_per_chem_mal[
        scmeans_per_chem_mal.scoring_method.str.contains('most_dge')].copy()
    # random_means = sample_means_per_chem_nd_mal[
    #     sample_means_per_chem_nd_mal.scoring_method.str.contains('random_dge')].copy()

    most_means.scoring_method = most_means.scoring_method.apply(lambda x: '_'.join(x.split('_')[0:-4]))
    # random_means.scoring_method = random_means.scoring_method.apply(lambda x: '_'.join(x.split('_')[0:-4]))

    most_means = most_means.sort_values(by='scoring_method')
    # random_means = random_means.sort_values(by='scoring_method')

    mean_of_sid_means = most_means.groupby(
        by=['scoring_method', 'scmethod_nd_chem', 'malignant_key']).mean().reset_index()

    sns.set_style("ticks")
    if exclude_dge_si_ppsi:
        order_x = ['Scoring samples individually with SC3Pv2',
                   'Scoring samples individually with SC3Pv3',
                   'Scoring on all samples with SC3Pv2',
                   'Scoring on all samples with SC3Pv3']
        height = 5
        aspect = 0.6
        pairs = [(('Scoring samples individually with SC3Pv2', 'malignant'),
                  ('Scoring samples individually with SC3Pv3', 'malignant')),
                 (('Scoring on all samples with SC3Pv2', 'malignant'),
                  ('Scoring on all samples with SC3Pv3', 'malignant')),]
    else:
        order_x = ['Scoring samples individually (preprocessed individually) with SC3Pv2',
                   'Scoring samples individually (preprocessed individually) with SC3Pv3',
                   'Scoring samples individually (preprocessed together) with SC3Pv2',
                   'Scoring samples individually (preprocessed together) with SC3Pv3',
                   'Scoring on all samples (preprocessed together) with SC3Pv2',
                   'Scoring on all samples (preprocessed together) with SC3Pv3']
        height = 6
        aspect = 0.75
        pairs = [(('Scoring samples individually (preprocessed individually) with SC3Pv2', 'malignant'),
                  ('Scoring samples individually (preprocessed individually) with SC3Pv3', 'malignant')),
                 (('Scoring samples individually (preprocessed together) with SC3Pv2', 'malignant'),
                  ('Scoring samples individually (preprocessed together) with SC3Pv3', 'malignant')),
                 (('Scoring on all samples (preprocessed together) with SC3Pv2', 'malignant'),
                  ('Scoring on all samples (preprocessed together) with SC3Pv3', 'malignant'))]

    if exclude_jasmine:
        col_order = ['ANS', 'Tirosh', 'Tirosh_AG', 'Tirosh_LVG', 'Scanpy', 'UCell']
        col_wrap = 3
    else:
        col_order = ['ANS', 'Tirosh', 'Tirosh_AG', 'Tirosh_LVG', 'Scanpy', 'Jasmine_LH', 'Jasmine_OR', 'UCell']
        col_wrap = 4
        
    args = dict(
        x='scmethod_nd_chem',
        order=order_x,
        y='scores',
        hue='malignant_key',
        col='scoring_method',
        col_order=col_order,
        col_wrap=col_wrap,
        kind='strip',
        height=height,
        aspect=aspect,
        legend=True
    )

    g = sns.catplot(
        data=most_means,
        dodge=False,
        **args
    )
    
    for name, ax in g.axes_dict.items():
        # Subset the data based on the 'scoring_method' column
        subset_data = most_means.loc[most_means['scoring_method'] == name, :].copy()

        annot = Annotator(ax, pairs, **args, data=subset_data)
        annot.configure(test='Mann-Whitney', loc='inside', verbose=0)
        annot.apply_test().annotate() 
        
    g.add_legend(fontsize=16, loc='upper right', bbox_to_anchor=(1, 1))
    g.legend.set_title('Celltype', prop={'size': 18})

    palette = {'malignant': 'black', 'non-malignant': 'black'}
    for i, ax in enumerate(g.axes):
        ax.axvspan(xmin=-0.35, xmax=0.35, facecolor='grey', alpha=0.5, zorder=0)
        ax.axvspan(xmin=0.65, xmax=1.35, facecolor='tan', alpha=0.5, zorder=0)
        if not exclude_dge_si_ppsi:
            ax.axvline(1.5, c='black', ls=':')

        ax.axvspan(xmin=1.65, xmax=2.35, facecolor='grey', alpha=0.5, zorder=0)
        ax.axvspan(xmin=2.65, xmax=3.35, facecolor='tan', alpha=0.5, zorder=0)
        if not exclude_dge_si_ppsi:
            ax.axvline(3.5, c='black', ls=':')
            ax.axvspan(xmin=3.65, xmax=4.35, facecolor='grey', alpha=0.5, zorder=0)
            ax.axvspan(xmin=4.65, xmax=5.35, facecolor='tan', alpha=0.5, zorder=0)
        sns.scatterplot(y="scores", x="scmethod_nd_chem", hue='malignant_key',
                        data=mean_of_sid_means[
                            mean_of_sid_means.scoring_method == ax.title.get_text().split(' = ')[-1]], marker='_',
                        s=1000,
                        color='k', ax=ax, legend=False, palette=palette, zorder=2)

    g.set_axis_labels("", "Mean scores per sample and celltype", size=18)
    if exclude_dge_si_ppsi:
        g.set(xticks=([0.5, 2.5]))
        g.set_xticklabels(["Scoring samples individually", "Scoring on all samples"], size=16)
    else:
        g.set(xticks=([0.5, 2.5, 4.5]))
        g.set_xticklabels(["Scoring samples individually (preprocessed individually)",
                           "Scoring samples individually (preprocessed together)",
                           "Scoring on all samples (prepro-cessed together)"], size=14)
    for ax in g.axes[col_wrap:]:
        wrap_labels(ax, 11, break_long_words=False)

    g.set_yticklabels(np.round(g.axes[0].get_yticks(), decimals=2), size=16)
    g.set_titles("{col_name}", size=18, weight='normal')

    legend_elements = [Patch(facecolor='grey', edgecolor=None,
                             label='SC3Pv2'),
                       Patch(facecolor='tan', edgecolor=None,
                             label='SC3Pv3')]
    plt.legend(handles=legend_elements, title="Chemistry", fontsize=16, frameon=False, facecolor=None,
               title_fontsize=18, loc='upper right', bbox_to_anchor=(1.75, 1.1))
    return plt.gcf()

In [None]:
plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 16})

fig = create_strip_plot_crc(adata, exclude_dge_si_ppsi=True, exclude_jasmine=True)

In [None]:
raise ValueError()

In [None]:
adata = sc.read_h5ad(os.path.join(BASE_PATH_DATA, 'preprocessed/pp_luad.h5ad'))

In [None]:
data_fn = os.path.join(BASE_PATH_EXPERIMENTS,'data_composition_experiments/escc/mean_norm/dgex_on_pseudobulk/escc_adata_with_scores.h5ad')

adata = sc.read_h5ad(data_fn)

In [None]:
adata.obs.columns

In [None]:
means_std_per_sample = []
for (key, adata_obs) in adata.obs.groupby('sample_id'):
    curr_means = adata_obs.groupby(by='malignant_key').mean()
    curr_means['sample_id'] = key
    curr_means['scoring_mode'] = 'Scoring on all samples' if True else 'Scoring on all samples (' \
                                                                                      'preprocessed together)'
    means_std_per_sample.append(curr_means)
scmeans_per_mal = pd.concat(means_std_per_sample).reset_index()

In [None]:
scmeans_per_mal.columns

In [None]:
cols_crc = ['TumorSize', 'SizeQuantile', 'Age',
          'n_counts', 'n_genes_by_counts',
          'total_counts', 'total_counts_mt',
          'pct_counts_mt', 'S_score', 'G2M_score',
          'iCMS2_GT', 'iCMS3_GT',
          'log_counts', ]

cols_luad = ['age', 'n_genes_by_counts', 'total_counts',
           'total_counts_mito', 'pct_counts_mito', 'is_primary_data', 'n_genes',
           'total_counts_mt', 'pct_counts_mt', 'S_score', 'G2M_score',
           'log_counts',]

cols_escc = ['n_counts', 'n_genes_by_counts', 'total_counts',
                'total_counts_mt', 'pct_counts_mt', 'S_score', 'G2M_score', 'AP_GT',
                'Cycling_GT', 'Epi1_GT', 'Epi2_GT', 'Mes_GT', 'Mucosal_GT', 'Oxd_GT',
                'Stress_GT', 'log_counts',]

scmeans_per_mal = scmeans_per_mal.drop(columns=cols_luad)

In [None]:
scmeans_per_mal.columns

In [None]:
scmeans_per_mal.dropna(inplace=True)
# change structure of data
scmeans_per_mal = scmeans_per_mal.melt(
    id_vars=['sample_id', 'malignant_key', 'scoring_mode'],
    var_name='scoring_method',
    value_name='scores'
)

In [None]:
scmeans_per_mal