# Experiment testing (in-)dependence of scores with total read counts in data

## Part 1
In the following experiment we score several preprocessed datasets (CRC, ESCC, and LUAD) for genes separating malignant versus non-malignant cells. We want to see if the scores correlate with the total read counts given by the datasets. 

Expectation: As malignant cells generally have higher total read counts and we are scoring for signatures separating malignant from non-malignant counts, we except correlation between the scores and the total read counts. 

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
import os
import sys

sys.path.append('../..')

import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from scipy.stats import kstest
import matplotlib.pyplot as plt

from signaturescoring.utils.utils import get_gene_list_real_data, get_mean_and_variance_gene_expression
from signaturescoring import score_signature

from data.constants import BASE_PATH_PREPROCESSED, BASE_PATH_DGEX_CANCER, BASE_PATH_EXPERIMENTS


### Global variables

In [None]:
## Paths to datasets (define path to data)
base_path = BASE_PATH_PREPROCESSED
crc_path  = os.path.join(base_path, 'pp_crc.h5ad')
escc_path = os.path.join(base_path, 'pp_escc.h5ad')
luad_path = os.path.join(base_path, 'pp_luad.h5ad')

base_path = BASE_PATH_DGEX_CANCER
crc_dgex_path = os.path.join(base_path, 'crc',  'dgex_min_log2fc_2_pval_0.01.csv')
escc_dgex_path = os.path.join(base_path, 'escc',  'dgex_min_log2fc_2_pval_0.01.csv')
luad_dgex_path = os.path.join(base_path, 'luad',  'dgex_min_log2fc_2_pval_0.01.csv')

In [None]:
## global variable for experiment
nr_of_sig_genes = 100

In [None]:
## Path to store data or images (define storing path)
storing_path = os.path.join(BASE_PATH_EXPERIMENTS, 'correlation_scores_with_TRC_and_MTP_experiments/')

In [None]:
only_malignant = True

In [None]:
scoring_methods = [
    {
        "scoring_method": "adjusted_neighborhood_scoring",
        "sc_params": {
            "ctrl_size": 100,
            "score_name": "ANS",
        },
    },
    {
        "scoring_method": "seurat_scoring",
        "sc_params": {
            "ctrl_size": 100,
            "n_bins": 25,
            "score_name": "Seurat",
        },
    },
    {
        "scoring_method": "seurat_ag_scoring",
        "sc_params": {
            "n_bins": 25,
            "score_name": "Seurat_AG",
        },
    },
    {
        "scoring_method": "seurat_lvg_scoring",
        "sc_params": {
            "ctrl_size": 100,
            "n_bins": 25,
            "lvg_computation_version": "v1",
            "lvg_computation_method": "seurat",
            "score_name": "Seurat_LVG",
        },
    },
    {
        "scoring_method": "scanpy_scoring",
        "sc_params": {
            "ctrl_size": 100,
            "n_bins": 25,
            "score_name": "Scanpy",
        },
    },
    
    {
        "scoring_method": "jasmine_scoring",
        "sc_params": {
            "score_method": 'likelihood',
            "score_name": "Jasmine_LH",
        },
    },
    {
        "scoring_method": "jasmine_scoring",
        "sc_params": {
            "score_method": 'oddsratio',
            "score_name": "Jasmine_OR",
        },
    },
    {
        "scoring_method": "ucell_scoring",
        "sc_params": {
            "score_name": "UCell",
            "maxRank":1500,
        },
    },
]

### Helper function

In [None]:
def get_data_and_gene_list(data_path, dgex_path):
    print(f'Load data with path {data_path}')
    adata = sc.read_h5ad(data_path)
    adata.uns['log1p']['base'] = None
    print('Loaded data')
    
    print(f'Load DGEX genes ..')
    wc = pd.read_csv(dgex_path)
    
    print(f'Total nr. of DGEX genes {len(wc)}. We will use {nr_of_sig_genes} with highest logfoldchanges as signature.')
    diffexp_genes = wc.nlargest(nr_of_sig_genes, columns="logfoldchanges")
    gene_list = diffexp_genes["names"].tolist()
    
    print(f'Finished loading data and malignant signature.')

    return adata, gene_list

In [None]:
def create_and_save_plots(adata, y_var, title, filename, show=True,
                         cols=['malignant_key','total_counts', 'total_counts_mt','pct_counts_mt', 'n_genes_by_counts']):
    g = sns.pairplot(data = adata.obs[[cols[0],y_var]+cols[1:]],
                     hue = 'malignant_key',
                     y_vars = [y_var])
    g.fig.subplots_adjust(top=0.85)
    g.fig.suptitle(title, fontsize = 14)
    g.fig.savefig(os.path.join(storing_path, filename), format='png', dpi=300)
    if show:
        plt.show(g.fig)
    else:
        plt.close(g.fig)

### CRC

In [None]:
%%time
adata, gene_list = get_data_and_gene_list(crc_path, crc_dgex_path)

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True, layer='counts')

In [None]:
%%time
for sc_method in scoring_methods:
    score_signature(
        method=sc_method['scoring_method'],
        adata=adata,
        gene_list=gene_list,
        **sc_method['sc_params']
    )

In [None]:
sc_names = [sc_method['sc_params']['score_name'] for sc_method in scoring_methods]

In [None]:
if only_malignant:
    adata = adata[adata.obs.malignant_key=='malignant',:].copy()

In [None]:
for curr_name in sc_names:
    title = f"CRC scatterplots with {'_'.join(curr_name.split('_')[0:-1])} signature scores vs. counts and mt"
    filename = f'CRC/scatter_{curr_name}_only_mal.png' if only_malignant else f'CRC/scatter_{curr_name}.png'
    create_and_save_plots(adata, curr_name, title, filename, show=False)

In [None]:
corr_adata = adata.obs[sc_names+['total_counts', 'total_counts_mt','pct_counts_mt', 'n_genes_by_counts']].corr()
corr_adata = corr_adata[['total_counts', 'total_counts_mt','pct_counts_mt', 'n_genes_by_counts']][0:-4]

In [None]:
g = sns.heatmap(corr_adata, annot=True, )
g.set_title('Correlations')
g.figure.tight_layout()
g.figure.savefig(os.path.join(storing_path, 'CRC', 'correlation_heatmap_only_mal.png' if only_malignant else 'correlation_heatmap.png'), dpi=300)

### ESCC

In [None]:
%%time
adata, gene_list = get_data_and_gene_list(escc_path, escc_dgex_path)

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True, layer='counts')

In [None]:
%%time
for sc_method in scoring_methods:
    score_signature(
        method=sc_method['scoring_method'],
        adata=adata,
        gene_list=gene_list,
        **sc_method['sc_params']
    )

In [None]:
if only_malignant:
    adata = adata[adata.obs.malignant_key=='malignant',:].copy()

In [None]:
sc_names = [sc_method['sc_params']['score_name'] for sc_method in scoring_methods]

In [None]:
for curr_name in sc_names:
    title = f"ESCC scatterplots with {'_'.join(curr_name.split('_')[0:-1])} signature scores vs. counts and mt"
    filename = f'ESCC/scatter_{curr_name}_only_mal.png' if only_malignant else f'ESCC/scatter_{curr_name}.png'
    create_and_save_plots(adata, curr_name, title, filename, show=False)

In [None]:
corr_adata = adata.obs[sc_names+['total_counts', 'total_counts_mt','pct_counts_mt', 'n_genes_by_counts']].corr()
corr_adata = corr_adata[['total_counts', 'total_counts_mt','pct_counts_mt', 'n_genes_by_counts']][0:-4]

In [None]:
g = sns.heatmap(corr_adata, annot=True, )
g.set_title('Correlations')
g.figure.tight_layout()
g.figure.savefig(os.path.join(storing_path, 'ESCC', 'correlation_heatmap_only_mal.png' if only_malignant else 'correlation_heatmap.png'), dpi=300)

### LUAD

In [None]:
%%time
adata, gene_list = get_data_and_gene_list(luad_path, luad_dgex_path)

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True, layer='counts')

In [None]:
%%time
for sc_method in scoring_methods:
    score_signature(
        method=sc_method['scoring_method'],
        adata=adata,
        gene_list=gene_list,
        **sc_method['sc_params']
    )

In [None]:
if only_malignant:
    adata = adata[adata.obs.malignant_key=='malignant',:].copy()

In [None]:
sc_names = [sc_method['sc_params']['score_name'] for sc_method in scoring_methods]

In [None]:
for curr_name in sc_names:
    title = f"LUAD scatterplots with {'_'.join(curr_name.split('_')[0:-1])} signature scores vs. counts and mt"
    filename = f'LUAD/scatter_{curr_name}_only_mal.png' if only_malignant else f'LUAD/scatter_{curr_name}.png'
    create_and_save_plots(adata, curr_name, title, filename, show=False,
                         cols=['malignant_key','total_counts_mt','pct_counts_mt', 'n_genes_by_counts'])

In [None]:
corr_adata = adata.obs[sc_names+['total_counts', 'total_counts_mt','pct_counts_mt', 'n_genes_by_counts']].corr()
corr_adata = corr_adata[['total_counts', 'total_counts_mt','pct_counts_mt', 'n_genes_by_counts']][0:-4]

In [None]:
g = sns.heatmap(corr_adata, annot=True, )
g.set_title('Correlations')
g.figure.tight_layout()
g.figure.savefig(os.path.join(storing_path, 'LUAD', 'correlation_heatmap_only_mal.png' if only_malignant else 'correlation_heatmap.png'), dpi=300)