In [None]:
import argparse
import json
import os
import random
import sys
import warnings
warnings.filterwarnings("ignore")
from datetime import datetime
import matplotlib.pyplot as plt

import pandas as pd
import numpy as np
import scanpy as sc
from signaturescoring import score_signature
from signaturescoring.utils.utils import get_mean_and_variance_gene_expression
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

sys.path.append('../..')
from data.load_data import load_datasets, load_dgex_genes_for_mal_cells
from data.constants import METHOD_WO_MEAN
from experiments.experiment_utils import AttributeDict, get_malignant_signature, get_scoring_method_params

sc.settings.verbosity = 2

In [None]:
def score_genes_and_evaluate(adata, gene_list, df_mean_var,sc_method_long, sc_method, scm_params, col_sid='sample_id'):
    if sc_method in METHOD_WO_MEAN:
        score_signature(
            method=sc_method,
            adata=adata,
            gene_list=gene_list,
            **scm_params
        )
    else:
        score_signature(
            method=sc_method,
            adata=adata,
            gene_list=gene_list,
            df_mean_var=df_mean_var,
            **scm_params
        )
    curr_scores = adata.obs[scm_params['score_name']].copy()
    aucs = []
    
    
    precision, recall, thresholds = precision_recall_curve(adata.obs.malignant_key, curr_scores, pos_label='malignant')
    # calculate precision-recall AUC
    res_auc = auc(recall, precision)
    
    aucs.append((len(gene_list),
                 1 - roc_auc_score(adata.obs.malignant_key, curr_scores), 
                 res_auc))
        
    return pd.DataFrame(aucs, columns=['signature_length',f'AUCROC_{sc_method_long}', f'AUCPR_{sc_method_long}'])

In [None]:
dataset= 'luad_xing'

norm_method='mean'
sample_based=False
dge_on_all='pseudobulk'
intersect_pct=0.9
min_log2fc=2
pval=0.01
ranked_means=False
sort_values_by='median_log2FC'
sig_length=None
most_dge=True

In [None]:
adata = load_datasets(dataset, preprocessed=True, norm_method=norm_method)

if 'log1p' in adata.uns_keys():
    adata.uns['log1p']['base'] = None
else:
    adata.uns['log1p'] = {'base': None}

gene_list = get_malignant_signature(dataset, 
                                    norm_method, 
                                    sample_based, 
                                    dge_on_all, 
                                    intersect_pct, 
                                    min_log2fc,
                                    pval,
                                    ranked_means,
                                    sort_values_by,
                                    sig_length,
                                    most_dge)

In [None]:
sc_method_long = "ucell_scoring"
sc_method, scm_params = get_scoring_method_params(sc_method_long)
sc_method, scm_params

In [None]:
scoring_methods = get_scoring_method_params("all")
scoring_methods

In [None]:
sc.settings.verbosity = 0

if dataset =='crc':
    gene_lengths = range(1,21, 2)
elif dataset == 'escc':
    #gene_lengths = [100, 200, 300, 400, 500, 600, 670]
    gene_lengths = range(1,21, 2)
elif dataset == 'luad':
    gene_lengths = [100, 150, 200, 250, 300, 350, 388]
elif dataset == 'luad_xing':
    gene_lengths = [5,10,20,30,50, 100,200,300,400,464]
#     gene_lengths = range(50,300, 25)
#     gene_lengths = np.logspace(0,np.log10(len(gene_list)),num=9,base=10,dtype='int')
elif dataset == 'melanoma':
    gene_lengths = [100, 150, 200, 250, 300, 350, 388]
else:
    gene_lengths = range(1,21, 2)

for sc_method_long, (sc_method, scm_params) in scoring_methods.items():
    res = []
    for i in gene_lengths:
        res.append(score_genes_and_evaluate(adata, gene_list[0:i], None, sc_method_long, sc_method, scm_params))
    results = pd.concat(res, axis=0)
    display(round(results,3))

In [None]:
sc.settings.verbosity = 0

if dataset =='crc':
    gene_lengths = range(1,21, 2)
elif dataset == 'escc':
    #gene_lengths = [100, 200, 300, 400, 500, 600, 670]
    gene_lengths = range(1,21, 2)
elif dataset == 'luad':
    gene_lengths = [100, 150, 200, 250, 300, 350, 388]
elif dataset == 'luad_xing':
    gene_lengths = [5,10,20,30,50, 100,200,300,400,464]
#     gene_lengths = range(50,300, 25)
#     gene_lengths = np.logspace(0,np.log10(len(gene_list)),num=9,base=10,dtype='int')
elif dataset == 'melanoma':
    gene_lengths = [100, 150, 200, 250, 300, 350, 388]
else:
    gene_lengths = range(1,21, 2)

for sc_method_long, (sc_method, scm_params) in scoring_methods.items():
    res = []
    for i in gene_lengths:
        res.append(score_genes_and_evaluate(adata, gene_list[0:i], None, sc_method_long, sc_method, scm_params))
    results = pd.concat(res, axis=0)
    display(round(results,3))