## Find ESCC specific cancer EMT cells signature
The signature should score high for cancer cells in EMT and low for all other celltypes, i.e., Firboblasts and other cancer cells. 

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [None]:
import os 
import sys

sys.path.append('../../..')

import pandas as pd
import numpy as np
import scanpy as sc
import json 
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns
from matplotlib.pyplot import rc_context
from tqdm import tqdm
from statsmodels.stats.multitest import multipletests
from scipy.stats import mannwhitneyu
from matplotlib_venn import venn3

sys.path.append('../../..')
from data.load_data import load_datasets
from data.constants import BASE_PATH_DATA, BASE_PATH_EXPERIMENTS

from signaturescoring import score_signature
from signaturescoring.utils.utils import get_mean_and_variance_gene_expression, check_signature_genes

In [None]:
sc.settings.verbosity = 2

pl_size = 6
dataset = 'escc'

base_path_emt_signatures = os.path.join(BASE_PATH_DATA, 'annotations/emt')
base_path_barcodes = os.path.join(BASE_PATH_EXPERIMENTS, f'EMT_signature_scoring_case_study/{dataset}')
storing_path = os.path.join(base_path_barcodes, 'dataset_specific_emt_sig')

if not os.path.exists(storing_path):
    os.makedirs(storing_path)
    sc.logging.info(f'Creating new storing folder at {storing_path}')

save = True 

Load preprocessed dataset 

In [None]:
orig_adata = load_datasets(dataset, preprocessed=True, norm_method='mean')
if 'log1p' in orig_adata.uns_keys():
    orig_adata.uns['log1p']['base'] = None
else:
    orig_adata.uns['log1p'] = {'base': None}

Load cancer EMT barcodes

In [None]:
barcodes_cancer_emt_cells = pd.read_csv(os.path.join(base_path_barcodes, 'barcodes_cancer_emt.csv'))
barcodes_cancer_emt_cells = barcodes_cancer_emt_cells['0']
barcodes_cancer_emt_cells.name = 'cancer_emt_cells'

In [None]:
barcodes_caf_emt_mes_cells = orig_adata.obs[orig_adata.obs.celltype == 'Fibroblasts'].index.to_list() + barcodes_cancer_emt_cells.to_list()
barcodes_caf_emt_mes_cells = pd.Series(barcodes_caf_emt_mes_cells)

In [None]:
orig_adata.obs['celltype_broad'] = orig_adata.obs['celltype'].copy() 
orig_adata.obs['celltype_broad'] = orig_adata.obs['celltype_broad'].astype(str)

In [None]:
cells_not_cafs_and_cancer_emt = orig_adata.obs.index.isin(barcodes_caf_emt_mes_cells) == False

In [None]:
orig_adata.obs['celltype_broad'][cells_not_cafs_and_cancer_emt & (orig_adata.obs.celltype == 'Epi')] = 'Epi non Mes'
orig_adata.obs['celltype_broad'][cells_not_cafs_and_cancer_emt & (orig_adata.obs.celltype != 'Epi')] = 'rest'
orig_adata.obs['celltype_broad'].value_counts()

In [None]:
orig_adata.obs['celltype_broader'] = orig_adata.obs['celltype'].copy() 
orig_adata.obs['celltype_broader'] = orig_adata.obs['celltype_broader'].astype('str')

In [None]:
orig_adata.obs.loc[orig_adata.obs['celltype_broad']=='Epi', 'celltype_broader'] = 'Epi with Mes'
orig_adata.obs.loc[orig_adata.obs['celltype_broad']=='Epi non Mes', 'celltype_broader'] = 'Epi wo Mes'
orig_adata.obs['celltype_broader'].value_counts()

Prepare UMAPs

In [None]:
sc.tl.pca(orig_adata)
sc.pp.neighbors(orig_adata)
sc.tl.umap(orig_adata)

Load and score Hallmark EMT signature

In [None]:
# Notice we have lots of genes from the hallmark signature that are not available in the data
with open(os.path.join(BASE_PATH_DATA, '/annotations/emt/HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION.v7.5.1.json')) as f:
    hallmark_emt = json.load(f)
orig_hallmark_emt_sig = hallmark_emt['HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION']['geneSymbols']
hallmark_emt_sig = check_signature_genes(orig_adata.var_names, orig_hallmark_emt_sig)
hallmark_emt_sig.sort()
hallmark_emt_sig

In [None]:
%%time
score_signature(method="adjusted_neighborhood_scoring",
                adata=orig_adata,
                gene_list=hallmark_emt_sig,
                ctrl_size=100,
                score_name='hallmark_emt_scores'
                )

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
with rc_context({'figure.figsize': (15,8)}):
    for group in orig_adata.obs.groupby(by='celltype_broader'):
        group[1].hallmark_emt_scores.hist(bins=100, alpha=0.5, density=True, label=group[0])
plt.close()

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
with rc_context({'figure.figsize': (10,6)}):
    for group in orig_adata.obs.groupby(by='celltype_broader'):
        group[1].hallmark_emt_scores.hist(bins=100, alpha=0.5, density=True, label=group[0])
    plt.legend()
    plt.title(f'Hallmark EMT score dtiributions per celltype')
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        plt.savefig(os.path.join(curr_path, f'dist_scores_hallmark_emt.png'), dpi=600)

In [None]:
def get_nr_cells_over_quantile(adata, score_col = 'hallmark_emt_scores', quant=0.85, celltype_col = 'celltype_broader'):
    cut_off = adata.obs[score_col].quantile(quant)
    return adata.obs.groupby(celltype_col).apply(lambda x: ((x[score_col]>cut_off).sum(), len(x)))

vals = get_nr_cells_over_quantile(orig_adata)
for row in vals.items():
    print(row[0], row[1][0], row[1][1])

In [None]:
COL='hallmark_emt_scores' 
quant = 0.85
with rc_context({'figure.figsize': (6,6)}):
    orig_adata.obs[COL].hist(bins=100)
    
    quantile_85 = orig_adata.obs[COL].quantile(quant)
    plt.axvline(quantile_85, c='r', label=f'{quant} quantile')
    
    vals = get_nr_cells_over_quantile(orig_adata, score_col = COL, quant=quant, celltype_col = 'celltype_broader')
    title = f"Distribution scores for {COL} scores.\nCutoff {quant} quantile (={round(quantile_85, 2)}) "
    vals = get_nr_cells_over_quantile(orig_adata)
    for row in vals.items():
        title +=f'\n{row[0]}: {row[1][0]}/{row[1][1]} ({round(row[1][0]/row[1][1]*100, 3)}%)'
    plt.title(title)
    plt.tight_layout()
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        plt.savefig(os.path.join(curr_path, f'dist_scores_hallmark_emt_quant.png'), dpi=600)

### define subsets

In [None]:
caf_cancer_emt_cells = orig_adata[orig_adata.obs.index.isin(barcodes_caf_emt_mes_cells),:].copy()

In [None]:
print(caf_cancer_emt_cells.obs.celltype_broad.value_counts())
print(caf_cancer_emt_cells.obs.celltype.value_counts())

In [None]:
print(caf_cancer_emt_cells.obs.celltype_broad.value_counts())
print(caf_cancer_emt_cells.obs.celltype.value_counts())

In [None]:
cancer_cells = orig_adata[orig_adata.obs.celltype=='Epi'].copy()

In [None]:
print(cancer_cells.obs.celltype_broad.value_counts())
print(cancer_cells.obs.celltype.value_counts())

In [None]:
caf_and_all_cancer = orig_adata[orig_adata.obs.celltype.isin(['Fibroblasts', 'Epi'])].copy()

In [None]:
print(caf_and_all_cancer.obs.celltype_broad.value_counts())
print(caf_and_all_cancer.obs.celltype.value_counts())

### cancer emt vs rest 1 attempt

In [None]:
sc.tl.rank_genes_groups(caf_cancer_emt_cells,groupby='celltype_broad', method='wilcoxon', tie_correct=True)

In [None]:
sc.tl.rank_genes_groups(cancer_cells,groupby='celltype_broad', method='wilcoxon', tie_correct=True)

In [None]:
sc.tl.rank_genes_groups(caf_and_all_cancer, groupby='celltype_broad', method='wilcoxon', tie_correct=True)

In [None]:
fig = sc.pl.rank_genes_groups_dotplot(caf_cancer_emt_cells,min_logfoldchange=2, n_genes=15, return_fig=True)
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    fig.savefig(os.path.join(curr_path, f'top_marker_genes_caf_cancer_emt_cells.png'), dpi=600)

In [None]:
fig = sc.pl.rank_genes_groups_dotplot(cancer_cells,min_logfoldchange=2, n_genes=15, return_fig=True)
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    fig.savefig(os.path.join(curr_path, f'top_marker_genes_cancer_cells.png'), dpi=600)

In [None]:
fig = sc.pl.rank_genes_groups_dotplot(caf_and_all_cancer, min_logfoldchange=2, n_genes=20, return_fig=True)
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    fig.savefig(os.path.join(curr_path, f'top_marker_genes_caf_and_all_cancer.png'), dpi=600)

In [None]:
# tmp1  = sc.get.rank_genes_groups_df(caf_cancer_emt_cells, group='Epi', log2fc_min=2, pval_cutoff=0.001)
# tmp2  = sc.get.rank_genes_groups_df(caf_cancer_emt_cells, group='Fibroblasts', log2fc_min=1, pval_cutoff=0.001)
# tmp3  = sc.get.rank_genes_groups_df(cancer_cells, group='Epi', log2fc_min=2, pval_cutoff=0.001)
# tmp4  = sc.get.rank_genes_groups_df(cancer_cells, group='Epi non Mes', log2fc_min=0, pval_cutoff=0.001)

# A = set(tmp1.names.tolist()).union(set(tmp3.names.tolist()))
# B = set(tmp2.names.tolist())
# C = set(tmp4.names.tolist())

# D = (A.difference(B)).difference(C)

# tmp5 = pd.merge(tmp1, tmp3, on='names',how='outer')

# list(D)

# tmp5 = tmp5.set_index('names')

# tmp5.sort_index()

# tmp5 = tmp5.loc[list(D)]

# tmp5.sort_values(by=['logfoldchanges_x', 'logfoldchanges_y'], ascending=False)[0:50]

# tmp5[(~tmp5.logfoldchanges_x.isna())&(~tmp5.logfoldchanges_y.isna())]

In [None]:
dge_genes_cancer_emt_vs_cafs  = sc.get.rank_genes_groups_df(caf_cancer_emt_cells, group='Epi', log2fc_min=2, pval_cutoff=0.001)

In [None]:
dge_genes_cancer_emt_vs_cancer  = sc.get.rank_genes_groups_df(cancer_cells, group='Epi', log2fc_min=1.5, pval_cutoff=0.001)

In [None]:
dge_genes_cancer_vs_cancer_emt  = sc.get.rank_genes_groups_df(cancer_cells, group='Epi non Mes', log2fc_min=2, pval_cutoff=0.001)

In [None]:
venn3(
subsets=(
    set(dge_genes_cancer_emt_vs_cafs.names.to_list()),
    set(dge_genes_cancer_emt_vs_cancer.names.to_list()),
    set(dge_genes_cancer_vs_cancer_emt.names.to_list())
),
set_labels=(
    'dge_genes_cancer_emt_vs_cafs',
    'dge_genes_cancer_emt_vs_cancer',
    'dge_genes_cancer_vs_cancer_emt'
)
)
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'venn_dgex_genes_mal_emt_vs_caf_mal.png'), dpi=600)

In [None]:
dge_genes_cancer_emt_vs_caf_and_cancer = set(dge_genes_cancer_emt_vs_cafs.names.to_list()).intersection(set(dge_genes_cancer_emt_vs_cancer.names.to_list()))

In [None]:
dge_genes_and_hallmark = set(dge_genes_cancer_emt_vs_caf_and_cancer).intersection(hallmark_emt_sig)
len(dge_genes_and_hallmark)

In [None]:
score_signature(method="adjusted_neighborhood_scoring",
                adata=orig_adata,
                gene_list=dge_genes_cancer_emt_vs_caf_and_cancer,
                ctrl_size=100,
                score_name='dge_genes_cancer_emt_vs_caf_and_cancer'
                )

score_signature(method="adjusted_neighborhood_scoring",
                adata=orig_adata,
                gene_list=dge_genes_and_hallmark,
                ctrl_size=100,
                score_name='dge_genes_and_hallmark'
                )

In [None]:
sc.pl.violin(orig_adata, keys=['hallmark_emt_scores','dge_genes_cancer_emt_vs_caf_and_cancer','dge_genes_and_hallmark'], groupby='celltype_broader', rotation=90, show=False)
plt.gcf().tight_layout();
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'violin_plot_approach_1.svg'), dpi=600)
plt.show()

In [None]:
sc.tl.rank_genes_groups(orig_adata, groupby='celltype_broader', method='wilcoxon',reference='Epi with Mes', tie_correct=True)

In [None]:
dge_genes_to_avoid = set(sc.get.rank_genes_groups_df(orig_adata, group='Tcells', log2fc_min=2, pval_cutoff=0.001).names.tolist())
for celltype in [
#  'Fibroblasts',
#  'Bcells',
 'Myeloid',
#  'Endothelial',
 'Epi wo Mes',
#  'Pericytes',
#  'FRC'
]:
    if celltype == 'Epi wo Mes':
        dge_genes_to_avoid = dge_genes_to_avoid.union(set((sc.get.rank_genes_groups_df(orig_adata, group=celltype, log2fc_min=0, pval_cutoff=0.001)).names.tolist()))
    else:
        dge_genes_to_avoid = dge_genes_to_avoid.union(set((sc.get.rank_genes_groups_df(orig_adata, group=celltype, log2fc_min=2, pval_cutoff=0.001)).names.tolist()))

In [None]:
new_dge_genes_cancer_emt_vs_caf_and_cancer =  set(dge_genes_cancer_emt_vs_caf_and_cancer).difference(set(dge_genes_to_avoid))
len(new_dge_genes_cancer_emt_vs_caf_and_cancer)

In [None]:
score_signature(method="adjusted_neighborhood_scoring",
                adata=orig_adata,
                gene_list=new_dge_genes_cancer_emt_vs_caf_and_cancer,
                ctrl_size=100,
                score_name='new_dge_genes_cancer_emt_vs_caf_and_cancer'
                )

In [None]:
sc.pl.violin(orig_adata, keys=['hallmark_emt_scores','dge_genes_cancer_emt_vs_caf_and_cancer','new_dge_genes_cancer_emt_vs_caf_and_cancer','dge_genes_and_hallmark'], groupby='celltype_broader', rotation=90, show=False)
plt.gcf().tight_layout();
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'violin_plot_approach_2.svg'), dpi=600)
plt.show()

In [None]:
with rc_context({'figure.figsize': (15,8)}):
    for group in orig_adata.obs.groupby(by='celltype_broader'):
        group[1].new_dge_genes_cancer_emt_vs_caf_and_cancer.hist(bins=100, alpha=0.5, density=True, label=group[0])
    plt.legend()
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        plt.savefig(os.path.join(curr_path, f'dist_scores_new_dge_genes_cancer_emt_vs_caf_and_cancer.png'))

In [None]:
COL='new_dge_genes_cancer_emt_vs_caf_and_cancer' 
quant = 0.95
with rc_context({'figure.figsize': (6,6)}):
    orig_adata.obs[COL].hist(bins=100)
    
    quantile_85 = orig_adata.obs[COL].quantile(quant)
    plt.axvline(quantile_85, c='r', label=f'{quant} quantile')
    
    vals = get_nr_cells_over_quantile(orig_adata, score_col = COL, quant=quant, celltype_col = 'celltype_broader')
    title = f"Distribution scores for {COL} scores.\nCutoff {quant} quantile (={round(quantile_85, 2)}) "
    for row in vals.items():
        title +=f'\n{row[0]}: {row[1][0]}/{row[1][1]} ({round(row[1][0]/row[1][1]*100, 3)}%)'
    plt.title(title)
    plt.tight_layout()
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        plt.savefig(os.path.join(curr_path, f'dist_scores_new_dge_genes_cancer_emt_vs_caf_and_cancer_quant.png'), dpi=600)

#### look at cells that score high and do dgex on those 

In [None]:
epi_cells = orig_adata.obs.celltype_broader.isin(['Epi with Mes', 'Epi wo Mes']) 

In [None]:
high_scores = orig_adata.obs['new_dge_genes_cancer_emt_vs_caf_and_cancer']>0.2

In [None]:
epi_with_high_scores = orig_adata[epi_cells&high_scores, :].copy()

In [None]:
sc.tl.rank_genes_groups(epi_with_high_scores, groupby='celltype_broader', method='wilcoxon',tie_correct=True)

In [None]:
epi_with_high_scores.obs.celltype_broader.value_counts()

In [None]:
fig = sc.pl.rank_genes_groups_dotplot(epi_with_high_scores, n_genes=10, return_fig=True)
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    fig.savefig(os.path.join(curr_path, f'top_marker_genes_high_scoring_mal_emt_vs_mal.png'), dpi=600)

In [None]:
dge_epi_wo_mes_with_high_scores = sc.get.rank_genes_groups_df(epi_with_high_scores, group='Epi wo Mes', log2fc_min=1, pval_cutoff=1e-5)

In [None]:
dge_epi_with_mes_with_high_scores = sc.get.rank_genes_groups_df(epi_with_high_scores, group='Epi with Mes', log2fc_min=1, pval_cutoff=1e-5)

In [None]:
venn3(
subsets=(
    set(dge_epi_wo_mes_with_high_scores.names.to_list()),
    set(dge_epi_with_mes_with_high_scores.sort_values(by='logfoldchanges', ascending=False).names.to_list()),
    set(new_dge_genes_cancer_emt_vs_caf_and_cancer)
),
set_labels=(
    'dge_epi_wo_mes_with_high_scores',
    'dge_epi_with_mes_with_high_scores',
    'new_dge_genes_cancer_emt_vs_caf_and_cancer'
)
)
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'venn_dgex_genes_mal_emt_vs_caf_mal_high_scores.png'), dpi=600)

In [None]:
genes_to_potentially_add = set(dge_epi_with_mes_with_high_scores.names.to_list()).difference(new_dge_genes_cancer_emt_vs_caf_and_cancer)

In [None]:
orig_adata.X = orig_adata.X.tocsc()

In [None]:
def get_scores_for_all_sig_genes(adata, sig_genes,ctrl_size = 100):

    gene_list = check_signature_genes(adata.var_names, sig_genes)
    df_mean_var = get_mean_and_variance_gene_expression(adata,
                                                    estim_var=False)
    gene_means = df_mean_var['mean'].copy()

    # computation of neighboring genes around each signature gene
    sorted_gene_means = gene_means.sort_values()
    ref_genes_means = sorted_gene_means[sorted_gene_means.index.isin(gene_list) == False]

    # use sliding window to compute for each window the mean
    rolled = ref_genes_means.rolling(ctrl_size, closed='right').mean()

    control_genes = []
    for sig_gene in gene_list:
        curr_sig_avg = sorted_gene_means.loc[sig_gene]
        min_val_idx = np.argmin(((rolled - curr_sig_avg).abs()))
        sig_gene_ctrl_genes = rolled.iloc[(min_val_idx - ctrl_size + 1):min_val_idx + 1]
        control_genes.append(list(sig_gene_ctrl_genes.index))
    
    
    list_scores_per_sig_genes = []
    for sig_gene, ctrl_genes in zip(gene_list,control_genes):
        curr_score = adata[:,sig_gene].X - adata[:,ctrl_genes].X.mean(axis=1)
        curr_score = (curr_score-curr_score.min())/(curr_score.max()-curr_score.min())
#         curr_score = np.tanh(curr_score)
        list_scores_per_sig_genes.append(pd.DataFrame
                                         (curr_score,
                                          index=adata.obs_names,
                                          columns=[sig_gene+'_score']))
    df_signature_scores = pd.concat(list_scores_per_sig_genes, axis=1)
    
    
    columns_titles = df_signature_scores.columns.tolist()
    columns_titles.sort()
    df_signature_scores=df_signature_scores.reindex(columns=columns_titles)
    df_signature_scores['final_score_mean'] = df_signature_scores.mean(axis=1)
    df_signature_scores['final_score_median'] = df_signature_scores.median(axis=1)
    df_signature_scores['celltype'] = adata.obs['celltype']
    df_signature_scores['celltype_broad'] = adata.obs['celltype_broad']
    df_signature_scores['celltype_broader'] = adata.obs['celltype_broader']
    return df_signature_scores

In [None]:
from pandas.api.types import is_numeric_dtype

def plot_heatmap_with_celltype_anno(df, label_col, palette="tab10", bbox_to_anchor=(0.16,0.79),cat_title = 'celltypes',apply_tanh=False):
    if label_col not in df:
        raise KeyError(f'labelcol={label_col} is not a column of df')
        
    lut = dict(zip(df[label_col].unique(), sns.color_palette(palette)))
    
    row_colors = pd.DataFrame(df[label_col])[label_col].astype(str).map(lut)
    
    counts = dict(df[label_col].value_counts())
    
    categories = counts.keys()
    
    handles = [Patch(color=lut[category],label=category+' (%i)'%counts[category]) for category in sorted(categories)]
    
    tmp = df.sort_values(by=[label_col,'final_score_mean'])
    tmp = tmp[[x for x in tmp.columns if (x!= label_col) and  (is_numeric_dtype(tmp[x]))]]
    if apply_tanh:
        tmp = np.tanh(tmp)
    g = sns.clustermap(tmp, 
                   row_colors=row_colors, 
                   row_cluster=False, 
                   col_cluster=False,
                   figsize=(50,30),
                   cmap="viridis",
                   cbar_pos=(0.1, .1, .03, .6))
    legend = g.fig.legend(handles=handles,title='celltypes',bbox_to_anchor=(0.16,0.79),loc='center right',bbox_transform=g.fig.transFigure,borderaxespad=0.,fontsize=18,title_fontsize=20,ncol=1)
    g.ax_heatmap.set_title(f'Scored for each gene in hallmark_emt signature.', fontsize=22)

In [None]:
genes_to_potentially_add_scores = get_scores_for_all_sig_genes(orig_adata, list(genes_to_potentially_add))

In [None]:
genes = []
pvals_cancer_emt_caf = []
pvals_cancer_emt_cancer = []
pvals_cancer_emt_rest = []
pvals_cancer_rest = []
pvals_caf_rest = []
for col in tqdm(genes_to_potentially_add_scores.columns):
    if is_numeric_dtype(genes_to_potentially_add_scores[col]) and ('final' not in col):        
        caf_scores = genes_to_potentially_add_scores[col][genes_to_potentially_add_scores['celltype_broad']=='Fibroblasts']
        epi_emt_scores = genes_to_potentially_add_scores[col][genes_to_potentially_add_scores['celltype_broad']=='Epi']
        epi_non_emt_scores = genes_to_potentially_add_scores[col][genes_to_potentially_add_scores['celltype_broad']=='Epi non Mes']
        rest_scores = genes_to_potentially_add_scores[col][genes_to_potentially_add_scores['celltype_broad']=='rest']
        
        
        genes.append(col)
        pvals_cancer_emt_caf.append(mannwhitneyu(epi_emt_scores, caf_scores, alternative= 'greater').pvalue)
        pvals_cancer_emt_cancer.append(mannwhitneyu(epi_emt_scores, epi_non_emt_scores, alternative= 'greater').pvalue)
        pvals_cancer_emt_rest.append(mannwhitneyu(epi_emt_scores, rest_scores, alternative= 'greater').pvalue)

In [None]:
gene_cancer_emt_sig_larger_cafs = multipletests(pvals_cancer_emt_caf, alpha=1e-5,method='fdr_bh')[0].tolist()
gene_cancer_emt_sig_larger_cafs = [x[0].split('_')[0] for x in zip(genes, gene_cancer_emt_sig_larger_cafs) if x[1]]

In [None]:
gene_cancer_emt_sig_larger_cancer = multipletests(pvals_cancer_emt_cancer, alpha=1e-5,method='fdr_bh')[0].tolist()
gene_cancer_emt_sig_larger_cancer = [x[0].split('_')[0] for x in zip(genes, gene_cancer_emt_sig_larger_cancer) if x[1]]

In [None]:
gene_cancer_emt_sig_larger_rest = multipletests(pvals_cancer_emt_rest, alpha=1e-5,method='fdr_bh')[0].tolist()
gene_cancer_emt_sig_larger_rest = [x[0].split('_')[0] for x in zip(genes, gene_cancer_emt_sig_larger_rest) if x[1]]

In [None]:
venn3(
subsets=(
    set(gene_cancer_emt_sig_larger_cafs),
    set(gene_cancer_emt_sig_larger_cancer),
    set(gene_cancer_emt_sig_larger_rest)
),
set_labels=(
    'gene_cancer_emt_sig_larger_cafs',
    'gene_cancer_emt_sig_larger_cancer',
    'gene_cancer_emt_sig_larger_rest'
)
)
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'venn_dgex_genes_mal_emt_vs_caf_mal_to_add.png'), dpi=600)

In [None]:
refined_sig = (set(gene_cancer_emt_sig_larger_cafs).intersection(
    set(gene_cancer_emt_sig_larger_cancer)
).intersection(
    set(gene_cancer_emt_sig_larger_rest)
)).union(set(new_dge_genes_cancer_emt_vs_caf_and_cancer))
    

In [None]:
# refined_sig = set(gene_cancer_emt_sig_larger_cafs).intersection(set(gene_cancer_emt_sig_larger_cancer)).union(set(new_dge_genes_cancer_emt_vs_caf_and_cancer))

In [None]:
score_signature(method="adjusted_neighborhood_scoring",
                adata=orig_adata,
                gene_list=refined_sig,
                ctrl_size=100,
                score_name='refined_sig'
                )

In [None]:
sc.pl.violin(orig_adata, keys=['hallmark_emt_scores','dge_genes_cancer_emt_vs_caf_and_cancer','new_dge_genes_cancer_emt_vs_caf_and_cancer','refined_sig'], groupby='celltype_broader', rotation=90, show=False)
plt.gcf().tight_layout();
if save:
    curr_path = os.path.join(storing_path, 'plots')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'violin_plot_approach_3.svg'), dpi=600)
plt.show()

In [None]:
with rc_context({'figure.figsize': (10,6)}):
    for group in orig_adata.obs.groupby(by='celltype_broader'):
        group[1].new_dge_genes_cancer_emt_vs_caf_and_cancer.hist(bins=100, alpha=0.5, density=True, label=group[0])
    plt.legend()
    plt.title('new_dge_genes_cancer_emt_vs_caf_and_cancer')
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        plt.savefig(os.path.join(curr_path, f'dist_scores_new_dge_genes_cancer_emt_vs_caf_and_cancer.png'), dpi=600)

In [None]:
with rc_context({'figure.figsize': (10,6)}):
    for group in orig_adata.obs.groupby(by='celltype_broader'):
        group[1].refined_sig.hist(bins=100, alpha=0.5, density=True, label=group[0])
    plt.legend()
    plt.title('refined_sig')
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        plt.savefig(os.path.join(curr_path, f'dist_scores_refined_sig.png'), dpi=600)

In [None]:
COL='new_dge_genes_cancer_emt_vs_caf_and_cancer' 
quant = 0.95
with rc_context({'figure.figsize': (6,6)}):
    orig_adata.obs[COL].hist(bins=100)
    
    quantile_85 = orig_adata.obs[COL].quantile(quant)
    plt.axvline(quantile_85, c='r', label=f'{quant} quantile')
    
    vals = get_nr_cells_over_quantile(orig_adata, score_col = COL, quant=quant, celltype_col = 'celltype_broader')
    title = f"Distribution scores for {COL} scores.\nCutoff {quant} quantile (={round(quantile_85, 2)}) "
    for row in vals.items():
        title +=f'\n{row[0]}: {row[1][0]}/{row[1][1]} ({round(row[1][0]/row[1][1]*100, 3)}%)'
    plt.title(title)
    plt.tight_layout()
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        plt.savefig(os.path.join(curr_path, f'dist_scores_new_dge_genes_cancer_emt_vs_caf_and_cancer_quant.png'), dpi=600)

plt.show()

COL='refined_sig' 
quant = 0.95
with rc_context({'figure.figsize': (6,6)}):
    orig_adata.obs[COL].hist(bins=100)
    
    quantile_85 = orig_adata.obs[COL].quantile(quant)
    plt.axvline(quantile_85, c='r', label=f'{quant} quantile')
    
    vals = get_nr_cells_over_quantile(orig_adata, score_col = COL, quant=quant, celltype_col = 'celltype_broader')
    title = f"Distribution scores for {COL} scores.\nCutoff {quant} quantile (={round(quantile_85, 2)}) "
    for row in vals.items():
        title +=f'\n{row[0]}: {row[1][0]}/{row[1][1]} ({round(row[1][0]/row[1][1]*100, 3)}%)'
    plt.title(title)
    plt.tight_layout()
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        plt.savefig(os.path.join(curr_path, f'dist_scores_refined_sig_quant.png'), dpi=600)

In [None]:
quant_refined_sig = orig_adata.obs['refined_sig'].quantile(0.96)

In [None]:
orig_adata.obs[orig_adata.obs['refined_sig']>quant_refined_sig].celltype_broader.value_counts()

In [None]:
quant_some_new_dge = orig_adata.obs['new_dge_genes_cancer_emt_vs_caf_and_cancer'].quantile(0.96)

In [None]:
orig_adata.obs[orig_adata.obs['new_dge_genes_cancer_emt_vs_caf_and_cancer']>quant_some_new_dge].celltype_broader.value_counts()

In [None]:
with rc_context({'figure.figsize': (8,8)}):
    umap_celltypes = sc.pl.umap(orig_adata,
                            color=['hallmark_emt_scores','dge_genes_cancer_emt_vs_caf_and_cancer','new_dge_genes_cancer_emt_vs_caf_and_cancer','refined_sig', 'celltype_broader'],
                            ncols=3,
                            return_fig=True,
                            color_map = 'viridis'
                            )
    if save:
        curr_path = os.path.join(storing_path, 'plots')
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        umap_celltypes.savefig(os.path.join(curr_path, f'umap_celltypes.png'), dpi=600)

### Store signatures 

In [None]:
len(new_dge_genes_cancer_emt_vs_caf_and_cancer),len(refined_sig)

In [None]:
if save:
    pd.Series(list(new_dge_genes_cancer_emt_vs_caf_and_cancer)).to_csv(os.path.join(storing_path, 'ESOPHAG_CANCER_EMT_SIGNATURE_1.csv'))
    pd.Series(list(refined_sig)).to_csv(os.path.join(storing_path, 'ESOPHAG_CANCER_EMT_SIGNATURE_2.csv'))