## Get cancer EMT cells from ESCC
The goal of this jupyter notebook is to get all malignant cells from ESCC that are in EMT, i.e., that express epithelial markers as well as mesenchymal markers. 

We test two approaches:
1. Score the entire dataset for the mesenchymal signature and then analyze the scores for the malignant cells
2. Score only on malignant cells and analze their scores

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
import os 
import sys 


import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib_venn import venn2

import scipy

sys.path.append('../../..')
from data.load_data import load_datasets
from data.constants import BASE_PATH_DATA, BASE_PATH_EXPERIMENTS

from signaturescoring import score_signature

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

In [None]:
save = True

In [None]:
storing_path = os.path.join(BASE_PATH_EXPERIMENTS, 'EMT_signature_scoring_case_study/escc')

In [None]:
adata = load_datasets('escc', preprocessed=True, norm_method='mean')
if 'log1p' in adata.uns_keys():
    adata.uns['log1p']['base'] = None
else:
    adata.uns['log1p'] = {'base': None}

In [None]:
sc.tl.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)

In [None]:
mes_sig = pd.read_csv(os.path.join(BASE_PATH_DATA,'annotations/escc/genesig_Mes.csv')).Mes.tolist()

### 1. Score on entire dataset and get cancer emt barcodes 

In [None]:
score_signature(method="adjusted_neighborhood_scoring",
                adata=adata,
                gene_list=mes_sig,
                ctrl_size=100,
                score_name='mes_sig'
                )

In [None]:
mal_cells = adata[adata.obs.malignant_key =='malignant', :].copy()

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
quantile_cutoff = 0.9

nr_above_thresh  = len(mal_cells.obs.mes_sig[mal_cells.obs.mes_sig>=mal_cells.obs.mes_sig.quantile
(quantile_cutoff)])

mal_cells.obs.mes_sig.hist(bins=100)
plt.axvline(mal_cells.obs.mes_sig.quantile(quantile_cutoff), label=f'{quantile_cutoff} quantile', c='r')
plt.title(f'Distribution Mes scores (on all data),\n{nr_above_thresh} cells above {quantile_cutoff} quantile')
plt.legend()
if save:
    curr_path = os.path.join(storing_path, 'cancer_emt_barcode_selection')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'dist_scores_on_all_data_w_quant.png'))

In [None]:
mal_sig_above_quant_1 = mal_cells.obs.mes_sig[mal_cells.obs.mes_sig>=mal_cells.obs.mes_sig.quantile
(quantile_cutoff)].index.to_list()

### 2. Score on cancer cells only and get cancer emt barcodes 

In [None]:
score_signature(method="adjusted_neighborhood_scoring",
                adata=mal_cells,
                gene_list=mes_sig,
                ctrl_size=100,
                score_name='mes_sig_mal_cells'
                )

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
quantile_cutoff = 0.9

nr_above_thresh  = len(mal_cells.obs.mes_sig_mal_cells[mal_cells.obs.mes_sig_mal_cells>=mal_cells.obs.mes_sig_mal_cells.quantile
(quantile_cutoff)])

mal_cells.obs.mes_sig_mal_cells.hist(bins=100)
plt.axvline(mal_cells.obs.mes_sig_mal_cells.quantile(quantile_cutoff), label=f'{quantile_cutoff} quantile', c='r')
plt.title(f'Distribution Mes scores (only on malignant cells),\n{nr_above_thresh} cells above {quantile_cutoff} quantile')
plt.legend()
if save:
    curr_path = os.path.join(storing_path, 'cancer_emt_barcode_selection')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'dist_scores_on_malig_cells_w_quant.png'))

In [None]:
mal_sig_above_quant_2 = mal_cells.obs.mes_sig_mal_cells[mal_cells.obs.mes_sig_mal_cells>=mal_cells.obs.mes_sig_mal_cells.quantile
(quantile_cutoff)].index.to_list()

### Compare barcodes 

In [None]:
mal_cells.obs['diff'] = np.abs(mal_cells.obs.mes_sig_mal_cells-mal_cells.obs.mes_sig)

In [None]:
mal_cells.obs[['mes_sig', 'mes_sig_mal_cells','diff']].describe()

In [None]:
mal_cells.obs[['mes_sig', 'mes_sig_mal_cells','diff']].hist()

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
mal_cells.obs['mes_sig'].hist(bins=100, alpha=0.5, label='mes_sig scores')
mal_cells.obs['mes_sig_mal_cells'].hist(bins=100, alpha=0.5, label='mes_sig_mal_cells')
plt.legend()
plt.title('Scoring mesenchymal signature\non all the data vs.only on malignant cells.')
if save:
    curr_path = os.path.join(storing_path, 'cancer_emt_barcode_selection')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'diff_dist_scores_on_all_vs_on_malig_cells.png'))

In [None]:

scipy.stats.ks_2samp(mal_cells.obs['mes_sig'], mal_cells.obs['mes_sig_mal_cells'], alternative='greater', mode='auto')

In [None]:
venn2(subsets=(
    set(mal_sig_above_quant_1),
    set(mal_sig_above_quant_2)
    ),
     set_labels=(
     'mal cells with\nscore above 0.9 quantile\non all data',
     'mal cells with\nscore above 0.9 quantile\non mal cells',
     ))
if save:
    curr_path = os.path.join(storing_path, 'cancer_emt_barcode_selection')
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
    plt.savefig(os.path.join(curr_path, f'venn_barcodes_selection.png'))

In [None]:
set_1 = set(mal_sig_above_quant_1).difference(set(mal_sig_above_quant_2))

In [None]:
set_2 = set(mal_sig_above_quant_2).difference(set(mal_sig_above_quant_1))

In [None]:
set_3 = set(mal_sig_above_quant_2).intersection(set(mal_sig_above_quant_1))

In [None]:
mal_cells[list(set_1),:].obs.mes_sig.hist(bins=50, alpha=0.5, density=True, label='barcodes from all data')
mal_cells[list(set_2),:].obs.mes_sig.hist(bins=50, alpha=0.5, density=True, label='barcodes from cancer data')
# mal_cells[list(set_3),:].obs.mes_sig.hist(bins=50, alpha=0.5, density=True, label='barcodes intersection')
plt.legend()
plt.title(f'distribution scores of scoring on all\nthe data for nonoverlapping barcodes.')

In [None]:
mal_cells[list(set_1),:].obs.mes_sig_mal_cells.hist(bins=50, alpha=0.5, density=True, label='barcodes from all data')
mal_cells[list(set_2),:].obs.mes_sig_mal_cells.hist(bins=50, alpha=0.5, density=True, label='barcodes from cancer data')
plt.legend()
plt.title(f'distribution scores of scoring on malignant cells\nthe data for nonoverlapping barcodes.')

### Take the barcodes in the intersection of the two set of barcodes

In [None]:
mal_cells_barcodes = list(set(mal_sig_above_quant_2).intersection(set(mal_sig_above_quant_1)))

In [None]:
if save:
    pd.Series(mal_cells_barcodes).to_csv(os.path.join(storing_path, 'barcodes_cancer_emt.csv'))

### Get dataset statistics

In [None]:
adata.obs['celltype_emt'] = adata.obs.celltype.copy()
adata.obs['celltype_emt'] = adata.obs['celltype_emt'].astype(str)
adata.obs.loc[adata.obs['celltype_emt']=='Epi', 'celltype_emt']='Malignant'
adata.obs['celltype_emt'].value_counts().sort_index()

In [None]:
adata.obs.loc[mal_cells_barcodes, 'celltype_emt'] = 'Malignant with EMT'
adata.obs['celltype_emt'] = adata.obs['celltype_emt'].astype('category')
adata.obs['celltype_emt'].value_counts().sort_index()

In [None]:
adata.obs.columns

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

with plt.rc_context({'figure.figsize':(10,8)}):
    umap_fig = sc.pl.umap(adata[adata.obs.celltype.isin(['Malignant', 'Fibroblasts'])], color=['sample_id', 'celltype', 'mes_sig', 'celltype_emt'],return_fig=True,cmap='viridis')
    if save:
        curr_path = os.path.join(storing_path, 'cancer_emt_barcode_selection')
        if not os.path.exists(curr_path):
            os.mkdir(curr_path)
        umap_fig.savefig(os.path.join(curr_path, f'mal_n_caf_cells_umap_emt_sigs.png'), dpi=600)

In [None]:
tmp = adata.obs[['sample_id','celltype_emt']]

In [None]:
cross_tab_prop = pd.crosstab(index=tmp['sample_id'],
                             columns=tmp['celltype_emt'],
                             normalize="index")

In [None]:
cross_tab_prop = cross_tab_prop[['Malignant with EMT','Malignant', 'Fibroblasts', 'Bcells', 'Endothelial', 'FRC',
                                 'Myeloid', 'Pericytes', 'Tcells']]
round(cross_tab_prop*100,2)

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

cross_tab_prop.plot(kind='bar', 
                    stacked=True, 
                    colormap='tab20', 
                    figsize=(15, 8))
plt.legend(loc='center left',bbox_to_anchor=(1.0, 0.5),ncol=1)
plt.xlabel("Sample ID")
plt.ylabel("Proportions")
plt.show()

if save:
    curr_path = os.path.join(storing_path, 'proportions')
    if not os.path.exists(curr_path):
        os.mkdir(curr_path)
    plt.savefig(os.path.join(curr_path, f'distribution_celltypes.png'), dpi=600)
    cross_tab_prop.to_csv(os.path.join(curr_path, f'proportions_celltype.csv'))

In [None]:
prop_counts = pd.crosstab(index=tmp['sample_id'],
            columns=tmp['celltype_emt'],)
prop_counts = prop_counts[['Malignant with EMT','Malignant', 'Fibroblasts', 'Bcells', 'Endothelial', 'FRC',
                                 'Myeloid', 'Pericytes', 'Tcells']]
if save:
    curr_path = os.path.join(storing_path, 'proportions')
    if not os.path.exists(curr_path):
        os.mkdir(curr_path)
    prop_counts.to_csv(os.path.join(curr_path, f'counts_celltype.csv'))
prop_counts

### Take barcodes from fibroblasts and concatenate with them from cancer EMT cells

In [None]:
# caf_barcodes = adata.obs[adata.obs.celltype=='Fibroblasts'].index.to_list()

In [None]:
#  mal_cells_barcodes + caf_barcodes