# Malignant cells signature creation for LUAD XING
The following notebook rhe signatures for malignant cells when distinguishing between malignant and non-malignant cells. We consider:
1. DGEX on all preprocessed samples together, with thresholds log2FC>2 and adjusted p-val<0.01

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
import os
import sys
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
import pandas as pd
import itertools
import warnings
warnings.filterwarnings("ignore")


from signaturescoring import score_signature
from signaturescoring.utils.utils import get_mean_and_variance_gene_expression
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

sys.path.append('..')
sys.path.append('../..')
from load_data import load_datasets
from constants import BASE_PATH_DGEX_CANCER, CANCER_DATASETS, METHOD_WO_MEAN

from experiments.experiment_utils import AttributeDict, get_malignant_signature, get_scoring_method_params

sc.settings.verbosity = 2

In [None]:
def get_dgex_genes(dataset, adata, logfc_min=2, pval_max=0.01, show_plot=True):
    curr_adata = adata.copy()
    sc.tl.rank_genes_groups(curr_adata, 'malignant_key', method='wilcoxon', key_added='wilcoxon', tie_correct=True)
    # get all the genes and only select genes afterwards
    wc = sc.get.rank_genes_groups_df(curr_adata, group='malignant', key='wilcoxon')
    if show_plot:
        plt.figure(figsize=(8,8))
        g = sns.scatterplot(wc, x='scores', y='logfoldchanges')
        g.set_title(f'{dataset.upper()} DGEX scores vs log2FC.')
        g.axhline(y=2, c='r', ls=':')
        plt.show()
        
        plt.figure(figsize=(8,8))
        g = sns.scatterplot(wc, x='logfoldchanges', y='pvals_adj')
        g.set_title(f'{dataset.upper()} log2FC vs pvals_adj.')
        g.axhline(y=2, c='r', ls=':')
        plt.show()
    gex_genes = wc[(wc.logfoldchanges>logfc_min) & (wc.pvals_adj<pval_max)]
    gex_genes = gex_genes.sort_values(by=['pvals_adj', 'logfoldchanges'], ascending=[True, False]).reset_index(drop=True)
    return gex_genes


def get_per_sample_dgex_genes(adata, dataset, logfc_min=1, pval_max=0.05,col_sid='sample_id'):
    adatas = {}
    for group in adata.obs.groupby(col_sid):
        tmp = adata[group[1].index,].copy()
        tmp_mal = tmp[tmp.obs.malignant_key == 'malignant'].copy()
        tmp_non_mal = tmp[tmp.obs.malignant_key == 'non-malignant'].copy()

        sc.pp.filter_genes(tmp_mal, min_cells=1)
        sc.pp.filter_genes(tmp_non_mal, min_cells=1)
        
        adatas[group[0]] = sc.concat([tmp_mal, tmp_non_mal], join='inner', merge='same')
        #adatas[group[0]] = adata[group[1].index,].copy()

    list_dges = []
    for sid, curr_adata in adatas.items():
        curr_genes = get_dgex_genes(dataset, curr_adata, logfc_min, pval_max, show_plot=False)[['names', 'logfoldchanges']].copy()
        curr_genes = curr_genes.set_index('names')
        list_dges.append(curr_genes)
        
    return list_dges


def get_genes_dgex_genes_in_pct_samples(list_dges, pct=0.90):
    nr_samples = len(list_dges)
    
    logfc_per_sample_and_gene = pd.concat(list_dges, axis=1, join='outer')
    
    gene_occurence = (~logfc_per_sample_and_gene.isna()).sum(axis=1)/nr_samples >= pct
    
    logfc_per_sample_and_gene = logfc_per_sample_and_gene[gene_occurence]
    
    logfc_per_gene = pd.concat([logfc_per_sample_and_gene.mean(axis=1, skipna=True), logfc_per_sample_and_gene.median(axis=1, skipna=True)], axis=1 )
    logfc_per_gene.columns=['mean_log2FC', 'median_log2FC']
    logfc_per_gene = logfc_per_gene.reset_index()
    return logfc_per_gene   

### Global variables 

In [None]:
## Dataset normalized with which shift logarithm method. 
norm_method='CP10k' # mean, median, CP10k

## min log2FC and max adjustes p-val
## > DGEX on all samples 
min_logfc_onall = 2
min_pval_onall = 0.01

save = False

In [None]:
dataset = 'luad_xing'
assert dataset in CANCER_DATASETS

In [None]:
base_storing_path = BASE_PATH_DGEX_CANCER

storing_path = os.path.join(base_storing_path, dataset, f'{norm_method}_norm', f'dgex_on_all_sid', f'min_log2fc_{min_logfc_onall}_pval_{min_pval_onall}', f'dgex_genes.csv')


In [None]:
if save and not os.path.isdir(os.path.dirname(storing_path)):
    os.makedirs(os.path.dirname(storing_path))
    print(f'Created directory {os.path.dirname(storing_path)}')

In [None]:
adata = load_datasets(dataset, preprocessed=True, norm_method=norm_method)
adata.uns['log1p']['base'] = None

In [None]:
adata.obs.groupby('malignant_key').sample_id.value_counts().sort_index()

In [None]:
# adata = adata[adata.obs.sample_id.str.startswith('SSN')]

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True, layer='counts')

In [None]:
%%time
gex_genes = get_dgex_genes(dataset, adata, logfc_min=min_logfc_onall, pval_max=min_pval_onall)

In [None]:
if save:
    print(f'Storing DGEX genes  on all samples simultaneously to {storing_path}')
    gex_genes.to_csv(storing_path)

In [None]:
## Do DGEX per sample and see if the sets are overlapping
overall_dge = set(gex_genes.names.tolist())
len(overall_dge)

In [None]:
gene_list = list(overall_dge)

In [None]:
def score_genes_and_evaluate(adata, gene_list, df_mean_var,sc_method_long, sc_method, scm_params, col_sid='sample_id'):
    if sc_method in METHOD_WO_MEAN:
        score_signature(
            method=sc_method,
            adata=adata,
            gene_list=gene_list,
            **scm_params
        )
    else:
        score_signature(
            method=sc_method,
            adata=adata,
            gene_list=gene_list,
            df_mean_var=df_mean_var,
            **scm_params
        )
    curr_scores = adata.obs[scm_params['score_name']].copy()
    aucs = []
    
    
    precision, recall, thresholds = precision_recall_curve(adata.obs.malignant_key, curr_scores, pos_label='malignant')
    # calculate precision-recall AUC
    res_auc = auc(recall, precision)
    
    aucs.append((len(gene_list),
                 1 - roc_auc_score(adata.obs.malignant_key, curr_scores), 
                 res_auc))
        
    return pd.DataFrame(aucs, columns=['signature_length',f'AUCROC_{sc_method_long}', f'AUCPR_{sc_method_long}'])

In [None]:
scoring_methods = get_scoring_method_params("all")
scoring_methods

In [None]:
sc.settings.verbosity = 0
if dataset =='crc':
    gene_lengths = range(1,21, 2)
elif dataset == 'escc':
    #gene_lengths = [100, 200, 300, 400, 500, 600, 670]
    gene_lengths = range(1,21, 2)
elif dataset == 'luad':
    gene_lengths = [100, 150, 200, 250, 300, 350, 388]
elif dataset == 'luad_xing':
#     gene_lengths = [10, 20, 50, 100, 300, 464]
    gene_lengths = np.logspace(0,np.log2(len(overall_dge)),num=11,base=2,dtype='int')
elif dataset == 'melanoma':
    gene_lengths = [100, 150, 200, 250, 300, 350, 388]
    gene_lengths = np.logspace(0,np.log2(len(overall_dge)),num=11,base=2,dtype='int')
else:
    gene_lengths = range(1,21, 2)

for sc_method_long, (sc_method, scm_params) in scoring_methods.items():
    res = []
    for i in gene_lengths:
        res.append(score_genes_and_evaluate(adata, gene_list[0:i], None, sc_method_long, sc_method, scm_params))
    results = pd.concat(res, axis=0)
    display(round(results,3))