# Malignant cells signature creation for ESCC, CRC and LUAD
The following notebook rhe signatures for malignant cells when distinguishing between malignant and non-malignant cells. We consider two approeaches:
1. DGEX on all preprocessed samples together, with thresholds log2FC>2 and adjusted p-val<0.01
2. DGEX on each samples individually, with thresholds log2FC>{2, 1} and adjusted p-val<{0.05, 0.01, 0.001} and selecting genes that appear in X% of the samples. (X in {75, 80, 85, 90, 95, 100}%)
We make the distinction as all datasets contain imbalances in sample contributions (i.e., cell numbers contributing to the dataset). Selecting the signature as suggested in step 2 ensures an large enough overlap in signature genes

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
import os
import sys
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
import pandas as pd
import itertools

sys.path.append('..')
from load_data import load_datasets
from constants import BASE_PATH_DGEX_CANCER

In [None]:
def get_dgex_genes(dataset, adata, logfc_min=2, pval_max=0.01, show_plot=True):
    curr_adata = adata.copy()
    sc.tl.rank_genes_groups(curr_adata, 'malignant_key', method='wilcoxon', key_added='wilcoxon', tie_correct=True)
    # get all the genes and only select genes afterwards
    wc = sc.get.rank_genes_groups_df(curr_adata, group='malignant', key='wilcoxon')
    if show_plot:
        plt.figure(figsize=(8,8))
        g = sns.scatterplot(wc, x='scores', y='logfoldchanges')
        g.set_title(f'{dataset.upper()} DGEX scores vs log2FC.')
        g.axhline(y=2, c='r', ls=':')
        plt.show()
    gex_genes = wc[(wc.logfoldchanges>logfc_min) & (wc.pvals_adj<pval_max)]
    gex_genes = gex_genes.sort_values(by='logfoldchanges', ascending=False).reset_index(drop=True)
    return gex_genes


def get_per_sample_dgex_genes(adata, dataset, logfc_min=1, pval_max=0.05,col_sid='sample_id'):
    adatas = {}
    for group in adata.obs.groupby(col_sid):
        tmp = adata[group[1].index,].copy()
        tmp_mal = tmp[tmp.obs.malignant_key == 'malignant'].copy()
        tmp_non_mal = tmp[tmp.obs.malignant_key == 'non-malignant'].copy()

        sc.pp.filter_genes(tmp_mal, min_cells=1)
        sc.pp.filter_genes(tmp_non_mal, min_cells=1)
        
        adatas[group[0]] = sc.concat([tmp_mal, tmp_non_mal], join='inner', merge='same')
        #adatas[group[0]] = adata[group[1].index,].copy()

    list_dges = []
    for sid, curr_adata in adatas.items():
        curr_genes = get_dgex_genes(dataset, curr_adata, logfc_min, pval_max, show_plot=False)[['names', 'logfoldchanges']].copy()
        curr_genes = curr_genes.set_index('names')
        list_dges.append(curr_genes)
        
    return list_dges


def get_genes_dgex_genes_in_pct_samples(list_dges, pct=0.90):
    nr_samples = len(list_dges)
    
    logfc_per_sample_and_gene = pd.concat(list_dges, axis=1, join='outer')
    
    gene_occurence = (~logfc_per_sample_and_gene.isna()).sum(axis=1)/nr_samples >= pct
    
    logfc_per_sample_and_gene = logfc_per_sample_and_gene[gene_occurence]
    
    logfc_per_gene = pd.concat([logfc_per_sample_and_gene.mean(axis=1, skipna=True), logfc_per_sample_and_gene.median(axis=1, skipna=True)], axis=1 )
    logfc_per_gene.columns=['mean_log2FC', 'median_log2FC']
    logfc_per_gene = logfc_per_gene.reset_index()
    return logfc_per_gene   

### Global variables 

In [None]:
## Dataset normalized with which shift logarithm method. 
norm_method='mean' # mean, median, CP10k

## min log2FC and max adjustes p-val
## > DGEX on all samples 
min_logfc_onall = 2 
min_pval_onall = 0.01
## > DGEX on each sample individually
min_logfc_sep = 1 # 1, 2
min_pval_sep = 0.005 # 0.05, 0.01, 0.005

pctgs = [0.75, 0.8, 0.85, 0.9, 0.925,0.95, 0.975, 0.99, 1]

In [None]:
base_storing_path = BASE_PATH_DGEX_CANCER

crc_storing_path_all = os.path.join(base_storing_path, 'crc', f'{norm_method}_norm', f'dgex_on_all_sid', f'min_log2fc_{min_logfc_onall}_pval_{min_pval_onall}', f'dgex_genes.csv')
crc_storing_path_per_merged = os.path.join(base_storing_path, 'crc', f'{norm_method}_norm', f'dgex_on_each_sid', f'min_log2fc_{min_logfc_sep}_pval_{min_pval_sep}')

escc_storing_path_all = os.path.join(base_storing_path, 'escc', f'{norm_method}_norm', f'dgex_on_all_sid', f'min_log2fc_{min_logfc_onall}_pval_{min_pval_onall}',f'dgex_genes.csv')
escc_storing_path_per_merged = os.path.join(base_storing_path, 'escc', f'{norm_method}_norm', f'dgex_on_each_sid', f'min_log2fc_{min_logfc_sep}_pval_{min_pval_sep}')

luad_storing_path_all = os.path.join(base_storing_path, 'luad', f'{norm_method}_norm', f'dgex_on_all_sid', f'min_log2fc_{min_logfc_onall}_pval_{min_pval_onall}',f'dgex_genes.csv')
luad_storing_path_per_merged = os.path.join(base_storing_path, 'luad', f'{norm_method}_norm', f'dgex_on_each_sid', f'min_log2fc_{min_logfc_sep}_pval_{min_pval_sep}')

In [None]:
if not os.path.isdir(os.path.dirname(crc_storing_path_all)):
    os.makedirs(os.path.dirname(crc_storing_path_all))
    print(f'Created directory {os.path.dirname(crc_storing_path_all)}')
if not os.path.isdir(crc_storing_path_per_merged):
    os.makedirs(crc_storing_path_per_merged)
    print(f'Created directory {crc_storing_path_per_merged}\n')
    
if not os.path.isdir(os.path.dirname(escc_storing_path_all)):
    os.makedirs(os.path.dirname(escc_storing_path_all))
    print(f'Created directory {os.path.dirname(escc_storing_path_all)}')
if not os.path.isdir(escc_storing_path_per_merged):
    os.makedirs(escc_storing_path_per_merged)
    print(f'Created directory {escc_storing_path_per_merged}\n')
    
if not os.path.isdir(os.path.dirname(luad_storing_path_all)):
    os.makedirs(os.path.dirname(luad_storing_path_all))
    print(f'Created directory {os.path.dirname(luad_storing_path_all)}')
if not os.path.isdir(luad_storing_path_per_merged):
    os.makedirs(luad_storing_path_per_merged)
    print(f'Created directory {luad_storing_path_per_merged}')

### ESCC 

In [None]:
adata = load_datasets('escc', preprocessed=True, norm_method=norm_method)
#adata = sc.read_h5ad(escc_path)
adata.uns['log1p']['base'] = None

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True, layer='counts')

In [None]:
%%time
gex_genes = get_dgex_genes('escc', adata, logfc_min=min_logfc_onall, pval_max=min_pval_onall)

In [None]:
print(f'Storing DGEX genes  on all samples simultaneously to {escc_storing_path_all}')
gex_genes.to_csv(escc_storing_path_all)

In [None]:
## Do DGEX per sample and see if the sets are overlapping
overall_dge = set(gex_genes.names.tolist())
len(overall_dge)

In [None]:
list_dges = get_per_sample_dgex_genes(adata, 'escc', logfc_min=min_logfc_sep, pval_max=min_pval_sep)

In [None]:
subsets = []
outputs = []
for i in pctgs:
    tmp = get_genes_dgex_genes_in_pct_samples(list_dges, pct=i)
    overlap_with_overall = len(overall_dge.intersection(tmp.names.tolist()))
    to_print = f'For {round(i*100)}% of DGEX gene overlap over the samples we get {len(tmp)} genes.\n{overlap_with_overall} ({round(overlap_with_overall/len(tmp)*100)}%) genes have also been found when doing DGEX over all samples.\n'
    print(to_print)
    outputs.append(to_print)
    curr_path = os.path.join(escc_storing_path_per_merged, f'dgex_genes_intersec_{int(round(i*100))}_psid.csv')
    to_print = f'> Storing at {curr_path}\n'
    print(to_print)
    outputs.append(to_print)
    tmp.to_csv(curr_path)
    subsets.append(tmp)
with open(os.path.join(escc_storing_path_per_merged, 'percentages_overlap.txt'), 'w') as f:
    f.writelines(outputs)

### CRC

In [None]:
adata = load_datasets('crc', preprocessed=True, norm_method=norm_method)
adata.uns['log1p']['base'] = None

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True, layer='counts')

In [None]:
%%time
gex_genes = get_dgex_genes('crc', adata, logfc_min=min_logfc_onall, pval_max=min_pval_onall)

In [None]:
print(f'Storing DGEX genes  on all samples simultaneously to {crc_storing_path_all}')
gex_genes.to_csv(crc_storing_path_all)

In [None]:
## Do DGEX per sample and see if the sets are overlapping
overall_dge = set(gex_genes.names.tolist())
len(overall_dge)

In [None]:
list_dges = get_per_sample_dgex_genes(adata, 'crc', logfc_min=min_logfc_sep, pval_max=min_pval_sep)

In [None]:
subsets = []
outputs = []
for i in pctgs:
    tmp = get_genes_dgex_genes_in_pct_samples(list_dges, pct=i)
    overlap_with_overall = len(overall_dge.intersection(tmp.names.tolist()))
    to_print = f'For {round(i*100)}% of DGEX gene overlap over the samples we get {len(tmp)} genes.\n{overlap_with_overall} ({round(overlap_with_overall/len(tmp)*100)}%) genes have also been found when doing DGEX over all samples.\n'
    print(to_print)
    outputs.append(to_print)
    curr_path = os.path.join(crc_storing_path_per_merged, f'dgex_genes_intersec_{int(round(i*100))}_psid.csv')
    to_print = f'> Storing at {curr_path}\n'
    print(to_print)
    outputs.append(to_print)
    tmp.to_csv(curr_path)
    subsets.append(tmp)
with open(os.path.join(crc_storing_path_per_merged, 'percentages_overlap.txt'), 'w') as f:
    f.writelines(outputs)

### LUAD 

In [None]:
adata = load_datasets('luad', preprocessed=True, norm_method=norm_method)
adata.uns['log1p']['base'] = None

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True, layer='counts')

In [None]:
%%time
gex_genes = get_dgex_genes('luad', adata, logfc_min=min_logfc_onall, pval_max=min_pval_onall)

In [None]:
print(f'Storing DGEX genes  on all samples simultaneously to {luad_storing_path_all}')
gex_genes.to_csv(luad_storing_path_all)

In [None]:
## Do DGEX per sample and see if the sets are overlapping
overall_dge = set(gex_genes.names.tolist())
len(overall_dge)

In [None]:
list_dges = get_per_sample_dgex_genes(adata, 'luad', logfc_min=min_logfc_sep, pval_max=min_pval_sep)

In [None]:
subsets = []
outputs = []
for i in pctgs:
    tmp = get_genes_dgex_genes_in_pct_samples(list_dges, pct=i)
    overlap_with_overall = len(overall_dge.intersection(tmp.names.tolist()))
    to_print = f'For {round(i*100)}% of DGEX gene overlap over the samples we get {len(tmp)} genes.\n{overlap_with_overall} ({round(overlap_with_overall/len(tmp)*100)}%) genes have also been found when doing DGEX over all samples.\n'
    print(to_print)
    outputs.append(to_print)
    curr_path = os.path.join(luad_storing_path_per_merged, f'dgex_genes_intersec_{int(round(i*100))}_psid.csv')
    to_print = f'> Storing at {curr_path}\n'
    print(to_print)
    outputs.append(to_print)
    tmp.to_csv(curr_path)
    subsets.append(tmp)
with open(os.path.join(luad_storing_path_per_merged, 'percentages_overlap.txt'), 'w') as f:
    f.writelines(outputs)

In [None]:
raise ValueError()

## Compare number of found genes for different configuration when applying DGEX on each sample and requireing X% of sample overlap

In [None]:
## Get counts per configuration
norm_method='mean' # mean, median, CP10k

## > DGEX on each sample individually
min_logfc_sep = [1,2] # 1, 2
min_pval_sep = [0.05, 0.01, 0.005]
datasets = ['crc', 'escc', 'luad']

pctgs = [0.75, 0.8, 0.85, 0.9, 0.925,0.95, 0.975, 0.99, 1]

In [None]:
values = []
for ds in datasets:
    for min_log2fc in min_logfc_sep:
        for min_apval in min_pval_sep:
            file_path = os.path.join(base_storing_path, ds, f'{norm_method}_norm', f'dgex_on_each_sid', f'min_log2fc_{min_log2fc}_pval_{min_apval}')
            file_name = 'percentages_overlap.txt'
            fn = os.path.join(file_path, file_name)
            
            with open(fn, 'r') as f:
                for line1,line2 in itertools.zip_longest(*[f]*2):
                    if line1.startswith('For '):
                        sp_line1 = line1.split()
                        sp_line2 = line2.split()

                        overlap_pct = float(sp_line1[1][0:-1])/100
                        found_genes = int(sp_line1[-2])

                        overlap_onall = int(sp_line2[0])
                        overlap_onall_pct = float(sp_line2[1][1:-2])/100
                        values.append({
                            'dataset': ds,
                            'min_log2fc': min_log2fc,
                            'max_adj_pval': min_apval,
                            'pct_overlap_in_sid': overlap_pct,
                            'nr_found_genes_DGEX': found_genes,
                            'overlap_with_DGEX_onall':overlap_onall,
                            'overlap_with_DGEX_onall_pct':overlap_onall_pct
                        })

In [None]:
result = pd.DataFrame(values)

In [None]:
pd.pivot_table(result, values='nr_found_genes_DGEX', index=['dataset', 'min_log2fc', 'max_adj_pval'],
                       columns=['pct_overlap_in_sid'])

In [None]:
plt.figure(figsize=(15,10))
sns.heatmap(pd.pivot_table(result, values='overlap_with_DGEX_onall_pct', index=['dataset', 'min_log2fc', 'max_adj_pval'],
                       columns=['pct_overlap_in_sid']), annot=True)