## Comparable score ranges experiment - *hard* discrimination task
The following notebook explores the comparability of score ranges for the *hard* task when using **non-overlapping** signatures for **two of the three** available B-cell subtypes (B-memory, B-naive). This setting explores if wheather scores and probabilities can be used for hard-labeling if a cell does not belong to any of the signatures we are scoring for.

After selecting the cell type specific singatures we score with each scoring method the signatures and apply hard labeling on the scores as well as on the probabilities returned by the GMM postprocessing.

This jupyter notebook uses the data and differentially expressed genes found [here](https://atlas.fredhutch.org/nygc/multimodal-pbmc/).

In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from statannotations.Annotator import Annotator

sys.path.append('../../..')
from data.load_data import load_datasets
from data.constants import BASE_PATH_EXPERIMENTS, BASE_PATH_DATA

from signaturescoring import score_signature
from signaturescoring.scoring_methods.gmm_postprocessing import GMMPostprocessor
from signaturescoring.utils.utils import check_signature_genes, get_mean_and_variance_gene_expression

sc.settings.verbosity = 2

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

## Global variables

In [None]:
## define the path where the PBMC data is stored
dataset = 'pbmc_b_subtypes'
norm_method = 'mean'
DE_of_celltypes_fn = os.path.join(BASE_PATH_DATA, 'annotations/citeseq_pbmc/DE_by_celltype.csv')

In [None]:
## define the path where data should be stored.
storing_path = os.path.join(BASE_PATH_EXPERIMENTS, 'comparable_score_ranges/B_cell_subtypes/scoring_two_of_three_b_cell_subtypes_nonoverlapping_signautres')
if not os.path.exists(storing_path):
    os.makedirs(storing_path)
    sc.logging.info(f'Created new directory with path {storing_path}')

In [None]:
save = True

## Load preprocessed data

In [None]:
## define the path where the PBMC data is stored
adata = load_datasets(dataset, norm_method=norm_method)

In [None]:
if 'log1p' in adata.uns_keys():
    adata.uns['log1p']['base'] = None
else:
    adata.uns['log1p'] = {'base': None}

In [None]:
adata.obs['celltype.l2'].value_counts()

### Look at the differentially expressed genes given by the paper
The differential gene expression is done on level 3 celltypes. The logfoldchanges for the genes of different cell types are not comparable.

In [None]:
## define path to table with DGEX genes
DE_of_celltypes = pd.read_csv(DE_of_celltypes_fn)

In this part we  want to get signature for a specific celltype (B-cells) of level 2. 

In [None]:
subtypes_B = np.unique(DE_of_celltypes[DE_of_celltypes['Cell Type'].str.contains('B ')]['Cell Type'])
subtypes_B

In [None]:
SG_subtypes_B = {}
for subtype in subtypes_B:
    SG_subtypes_B[subtype] = list(DE_of_celltypes[DE_of_celltypes['Cell Type']==subtype]['Gene'])

In [None]:
# SG_subtypes_B['B intermediate'] = set(SG_subtypes_B['B intermediate kappa']).union(set(SG_subtypes_B['B intermediate lambda']))
SG_subtypes_B['B memory'] = set(SG_subtypes_B['B memory kappa']).union(set(SG_subtypes_B['B memory lambda']))
SG_subtypes_B['B naive'] = set(SG_subtypes_B['B naive kappa']).union(set(SG_subtypes_B['B naive lambda']))

In [None]:
for subtype in subtypes_B:
    SG_subtypes_B.pop(subtype, None)

In [None]:
for key, val in SG_subtypes_B.items():
    print(f'signature for B-cell subtype {key} has length {len(val)}')

Remove all overlapping genes

In [None]:
intersection_memory_naive = SG_subtypes_B['B memory'].intersection(SG_subtypes_B['B naive'])
print('nr. sig. genes intersecting naive and memory ',len(intersection_memory_naive))

In [None]:
SG_subtypes_B['B memory'].difference_update(intersection_memory_naive)
SG_subtypes_B['B naive'].difference_update(intersection_memory_naive)

In [None]:
for key, val in SG_subtypes_B.items():
    print(f'signature for B-cell subtype {key} has length {len(val)}')

Check signature genes expressed in the dataset

In [None]:
for key, val in SG_subtypes_B.items():
    print(f'signature for B-cell subtype {key} has length {len(val)}')
    SG_subtypes_B[key]  = check_signature_genes(adata.var_names, val)

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
df_mean_var = get_mean_and_variance_gene_expression(
    adata,
    estim_var=True,
    show_plots=True,
    store_path=None,
#     store_path=storing_path,
    store_data_prefix='all'
)

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

In [None]:
for k,v in SG_subtypes_B.items():
    print(f'Signature for subtype {k} contains {len(v)} genes.')
    SG_subtypes_B[k] = list(v)
    plt.figure(figsize=(10,10))
    allowed_v = []
    plt.plot(df_mean_var['mean'].values)
    for sig_gene in v:
        sig_gene_idx = np.argwhere(df_mean_var['mean'].index ==sig_gene)[0]
        
        if sig_gene_idx<= (df_mean_var.shape[0]-50):
            plt.axvline(sig_gene_idx,c='g')
            allowed_v.append(sig_gene)
        else:
            plt.axvline(sig_gene_idx,c='r')
    SG_subtypes_B[k] = allowed_v  
    plt.xlim([df_mean_var.shape[0]-100,df_mean_var.shape[0]+50])
    plt.title(f'avg. expression signature genes for {k}')
#     plt.savefig(os.path.join(storing_path, 'mean_expr_genes',f'{k}.png'), format = 'png')
    if save:
        path = os.path.join(storing_path, 'mean_expr_genes')
        if not os.path.exists(path):
            os.makedirs(path)
            sc.logging.info(f'Created new directory with path {path}')
        plt.savefig(os.path.join(path, f'{k}.png'), format = 'png')
    else:
        print('not storing image')
    plt.show()

In [None]:
for k,v in SG_subtypes_B.items():
    print(f'Signature for subtype {k} contains {len(v)} genes.')

### Use for all signatures the same gene pool create 

In [None]:
all_sig_genes = set() 
for key, val in SG_subtypes_B.items():
    all_sig_genes.update(val)

In [None]:
gene_pool = list(set(adata.var_names).difference(all_sig_genes))

### Score  marker genes (differentially expressed genes) for specifc celltypes of level2 given by the paper

In [None]:
n_bins = 25
n_ctrl_genes = 100

In [None]:
scoring_methods = [
    {
        "scoring_method": "scanpy_scoring",
        "sc_params": {
            "ctrl_size": n_ctrl_genes,
            "n_bins": n_bins,
            "score_name": "Scanpy",
        },
    },
    {
        "scoring_method": "seurat_scoring",
        "sc_params": {
            "ctrl_size": n_ctrl_genes,
            "n_bins": n_bins,
            "score_name": "Seurat",
            "gene_pool":gene_pool
        },
    },
    {
        "scoring_method": "adjusted_neighborhood_scoring",
        "sc_params": {
            "ctrl_size": n_ctrl_genes,
            "score_name": "ANS",
            "gene_pool":gene_pool
        },
    },
    {
        "scoring_method": "seurat_ag_scoring",
        "sc_params": {
            "n_bins": n_bins,
            "score_name": "Seurat_AG",
            "gene_pool":gene_pool
        },
    },
    {
        "scoring_method": "seurat_lvg_scoring",
        "sc_params": {
            "ctrl_size": n_ctrl_genes,
            "n_bins": n_bins,
            "lvg_computation_version": "v1",
            "lvg_computation_method": "seurat",
            "score_name": "Seurat_LVG",
            "gene_pool":gene_pool
        },
    },
    {
        "scoring_method": "ucell_scoring",
        "sc_params": {
            "score_name": "UCell",
        },
    },
    {
        "scoring_method": "jasmine_scoring",
        "sc_params": {
            "score_method": 'likelihood',
            "score_name": "Jasmine_LH",
        },
    },
    {
        "scoring_method": "jasmine_scoring",
        "sc_params": {
            "score_method": 'oddsratio',
            "score_name": "Jasmine_OR",
        },
    },
]

In [None]:
method_wo_mean = ['scanpy_scoring', 'corrected_scanpy_scoring','ucell_scoring','jasmine_scoring']

In [None]:
sc_names = ['ANS', 'Seurat', 'Seurat_AG', 'Seurat_LVG', 'Scanpy', 'Jasmine_LH', 'Jasmine_OR', 'UCell']

#### B-cells and subtypes
Here we only score B-cells and signatures that separate subtypes of B-cells. 

In [None]:
len(set(df_mean_var.index).difference(set(gene_pool)))

In [None]:
scoring_names = []
for sc_method in scoring_methods:
    
    scoring_method = sc_method['scoring_method']
    sc_params = sc_method['sc_params']
    
    print(f'Running scoring with scoring method {scoring_method}')
    
    for k1, v1 in SG_subtypes_B.items():
        
        print(f'   > Running scoring for signatures of celltyple-l2 {k1}')
        
        curr_sc_params = sc_params.copy()
        curr_sc_params['score_name'] = curr_sc_params['score_name'] +'_'+k1

        if scoring_method in method_wo_mean:
            score_signature(method=scoring_method,
                            adata=adata,
                            gene_list=v1,
                            **curr_sc_params)
        else:
            score_signature(method=scoring_method,
                        adata=adata,
                        gene_list=v1,
                        df_mean_var=df_mean_var,
                        **curr_sc_params)
        scoring_names.append(curr_sc_params['score_name'])
            

In [None]:
scoring_names = [x for x in adata.obs.columns if any([y in sc_names or y == 'Jasmine' for y in x.split('_')])]
scoring_names

In [None]:
for i in range(0, len(scoring_names), 2):
    gmm_post = GMMPostprocessor(
        n_components=3
    )
    
    store_name_pred, store_names_proba, _ = gmm_post.fit_and_predict(adata, scoring_names[i:(i+2)])
    assignments = gmm_post.assign_clusters_to_signatures(adata, scoring_names[i:(i+2)], store_names_proba, plot=False)
    
    print(assignments)
    for key, val in assignments.items():
        if key =='rest':
            continue
        adata.obs[key+'_gmm_3K'] = adata.obs[val].copy()
    
    curr_name = '_'.join(scoring_names[i].split('_')[0:-1])
    adata.obs[curr_name +'_B intermediate_gmm_3K'] = adata.obs[next(iter(assignments['rest']))].copy()

In [None]:
adata.obs = adata.obs.drop(columns = [x for x in adata.obs.columns if ('_GMM_proba' in x) or ('_GMM_pred' in x)])

In [None]:
scoring_names = [x for x in adata.obs.columns if any([y in sc_names or y == 'Jasmine' for y in x.split('_')])]
scoring_names.sort()

In [None]:
scoring_names

In [None]:
score_name_wo_gmm = [x for x in scoring_names if 'gmm' not in x]
score_name_w_gmm = [x for x in scoring_names if 'gmm_3K' in x]

### evaluate scores

In [None]:
tmp = adata.obs[['celltype.l2']+score_name_wo_gmm]

In [None]:
tmp = tmp.melt(id_vars=['celltype.l2'],
        var_name='scoring_method',
        value_name='scores')
tmp

In [None]:
# tmp = tmp.groupby(by=['celltype.l2', 'scoring_method']).mean().reset_index()

In [None]:
tmp['scoring_method_short'] = tmp.scoring_method.apply(lambda x: '_'.join(x.split('_')[0:-1]))

In [None]:
tmp['Scoring for signature'] = tmp.scoring_method.apply(lambda x: x.split('_')[-1])

In [None]:
tmp['scoring_method_short'].value_counts()

In [None]:
if save:
    tmp.to_csv(os.path.join(storing_path, 'data_for_violin_plot_normal_scores.csv'))

In [None]:
order = ['B naive', 'B intermediate', 'B memory']

sc_names = ['ANS', 'Seurat', 'Seurat_AG', 'Seurat_LVG', 'Scanpy', 'Jasmine_LH', 'Jasmine_OR', 'UCell']

In [None]:
yticks = [-1, -0.5, 0, 0.5, 1.0, 1.5]

yticks = [round(x,2) for x in yticks]

In [None]:
import textwrap
def wrap_labels(ax, width, break_long_words=False):
    labels = []
    for label in ax.get_xticklabels():
        text = label.get_text()
        labels.append(textwrap.fill(text, width=width,
                      break_long_words=break_long_words))
    ax.set_xticklabels(labels, rotation=0)

In [None]:
g = sns.catplot(data=tmp,
                x='celltype.l2', 
                y='scores', 
                hue='Scoring for signature', 
                hue_order=order,
                col_order=sc_names,
                col= 'scoring_method_short',  
                kind='violin', 
                col_wrap=4,
                order=order,
                legend=False
               )
g.set_ylabels('Scores', size=22)
g.set_titles("{col_name}", size=24)
g.set_xticklabels(order, size=22)
g.set(xlabel=None)
g.fig.subplots_adjust(top=0.88)
g.fig.suptitle('$\it{Hard}$ task', fontsize=26)
g.add_legend(fontsize=22, title='Signature')
g.legend.get_title().set_fontsize(22)
g.set(yticks=yticks)
g.set_yticklabels(yticks, size=20)

for ax in g.axes[4:]:
    wrap_labels(ax, 7, break_long_words=True)

if save:
    plt.savefig(os.path.join(storing_path, 'violin_plots_not_comparable_ranges.svg'), format='svg')

In [None]:
gt = adata.obs['celltype.l2'].copy()
gt

In [None]:
from sklearn.preprocessing import OneHotEncoder


enc = OneHotEncoder(handle_unknown='ignore')
enc_df = pd.DataFrame(enc.fit_transform(adata.obs[['celltype.l2']]).toarray())
enc_df

##### Hard -labeling
Hard labeling on scores needs a further step. We need to find for signature a threshold indicateing the activity of the cell. After this threshold a cell is considered expressing the cell type associateed with the signature. If a cell has scores below all thresholds the cell is called 'undefined'. to select these thresholds we compute for the scores of each signature a histogram and select the local minimum best separating high and low scoring cells.

In [None]:
from sklearn.metrics import f1_score, jaccard_score, balanced_accuracy_score

In [None]:
from scipy.signal import argrelmin

In [None]:
def make_prediction(scores, rest_label='B intermediate', selected_thresholds=None, show =False, save = True):
    plt.figure(figsize=(15,6))
    hist_one = np.histogram(scores.iloc[:,0].values, bins = 100, density=True)
    x_min = argrelmin(hist_one[0], order=2)
    thresh_one = hist_one[1][x_min[0][0]]
    plt.plot(hist_one[0], label=scores.columns[0])
    for e in x_min[0]:
        plt.axvline(e, c='r', alpha=0.5)
        plt.text(e,max(hist_one[0]),e)
    plt.axvline(e, c='r', alpha=0.5, label=f'mins {scores.columns[0]}')
    
    hist_two = np.histogram(scores.iloc[:,1].values, bins = 100, density=True)
    x_min = argrelmin(hist_two[0], order=2)
    thresh_two = hist_two[1][x_min[0][0]]
    plt.plot(hist_two[0], label=scores.columns[1])
    for e in x_min[0]:
        plt.axvline(e, c='g', alpha=0.5)
        plt.text(e,max(hist_one[0]),e)
    plt.axvline(e, c='g', alpha=0.5, label=f'mins {scores.columns[1]}')
        
        
    plt.legend(bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    if save:
        path = os.path.join(storing_path, 'score_hardlabeling_thresholds')
        if not os.path.exists(path):
            os.makedirs(path)
            sc.logging.info(f'Created new directory with path {path}')
        plt.savefig(os.path.join(path, f'{"_".join(scores.columns[1].split("_")[0:-1])}'))
    if show:
        plt.show()
    else:
        plt.close()
    
    if selected_thresholds is not None:
        thresh_one = hist_one[1][selected_thresholds[0]]          
        thresh_two = hist_two[1][selected_thresholds[1]]
        tmp = scores.idxmax(axis=1)
        tmp.loc[(scores.iloc[:,0]<thresh_one)&(scores.iloc[:,1]<thresh_two)]=rest_label
    
        return tmp 

In [None]:
for i in range(0,len(score_name_wo_gmm),2):
    prediction = adata.obs[score_name_wo_gmm[i:(i+2)]]
    try:
        tmp = make_prediction(prediction, show=True, save=save)
    except:
        continue
    

In [None]:
#selected_thresh = [54,43, 50,47, 55,40, 59,46, 53,44, 52,52, 49,49,  58,43, 60,46, 54,41, 50,47, 56,44,57,43,58,45,44,42]
selected_thresh = [55,43,
                   42,46,
                   52,27,
                   51,44,
                   58,44,
                   55,44,
                   58,45,
                   42,41,
                  ]

In [None]:
rows = []
for i in range(0,len(score_name_wo_gmm),2):
    prediction = adata.obs[score_name_wo_gmm[i:(i+2)]]

    tmp = make_prediction(prediction,selected_thresholds=selected_thresh[i:(i+2)], save=False)

    tmp[tmp.str.contains('B memory')] = 'B memory'
    tmp[tmp.str.contains('B naive')] = 'B naive'
    tmp[tmp.str.contains('B intermediate')] = 'B intermediate'
 
    curr_f1 = f1_score(gt,tmp, average='weighted')
    curr_j = jaccard_score(gt,tmp, average='weighted')
    curr_ba = balanced_accuracy_score(gt,tmp)
    
    row = {
        'Scoring method': '_'.join(score_name_wo_gmm[i].split('_')[0:-1]),
#         'AUCROC (weighted)':curr_auc,
        'F1-score (weighted)':curr_f1, 
        'Jaccard-score (weighted)':curr_j,
        'Balanced accuracy':curr_ba
    }
    rows.append(row)

In [None]:
performance_hard_labeling_on_scores = pd.DataFrame(rows)

In [None]:
performance_hard_labeling_on_scores

In [None]:
if save:
    performance_hard_labeling_on_scores.to_csv(os.path.join(storing_path, 'performance_hard_labeling_on_scores.csv'))

In [None]:
performance_hard_labeling_on_scores = performance_hard_labeling_on_scores.melt(
    id_vars=['Scoring method', 'F1-score (weighted)'],
    var_name='metric',
    value_name='value'
)

f = plt.figure(figsize=(8, 6))
g = sns.scatterplot(
    x='value',
    y='F1-score (weighted)',
    hue='Scoring method',
    hue_order=sc_names,
    style='metric',
    data=performance_hard_labeling_on_scores,
    s=200
)
lgnd = g.legend(bbox_to_anchor=(1, 1), fontsize=16)
g.set_title('Performance hard labeling using scores ($\it{hard}$ task)', fontsize=18)
g.set_xlabel('Values of metrics', fontsize=16)
g.set_ylabel('F1-score (weighted)', fontsize=16)
if save:
    f.savefig(os.path.join(storing_path, f'scores_hard_labeling.svg'), format='svg')
    f.savefig(os.path.join(storing_path, f'scores_hard_labeling.png'), format='png', dpi=300)

### evaluate GMM outcome

In [None]:
tmp = adata.obs[['celltype.l2']+score_name_w_gmm]

In [None]:
tmp = tmp.melt(id_vars=['celltype.l2'],
        var_name='scoring_method',
        value_name='scores')
tmp

In [None]:
tmp['scoring_method_short'] = tmp.scoring_method.apply(lambda x: '_'.join(x.split('_')[0:-3])+' with GMM 3K')

In [None]:
tmp['Scoring for signature'] = tmp.scoring_method.apply(lambda x: x.split('_')[-3])

In [None]:
tmp['scoring_method_short'].value_counts()

In [None]:
tmp = tmp[tmp.scoring_method_short.str.contains('std_adjust')==False]

In [None]:
tmp['scoring_method_short'].value_counts()

In [None]:
yticks = [0, 0.2,0.4,0.6,0.8,1.0]
sc_names = ['ANS with GMM 3K', 'Seurat with GMM 3K',
            'Seurat_AG with GMM 3K', 'Seurat_LVG with GMM 3K',
            'Scanpy with GMM 3K', 'Jasmine_LH with GMM 3K',
            'Jasmine_OR with GMM 3K','UCell with GMM 3K']

g = sns.catplot(data=tmp[tmp['scoring_method_short'].str.contains('var adjustment')==False],
                x='celltype.l2', 
                y='scores', 
                hue='Scoring for signature', 
                hue_order=order,
                col= 'scoring_method_short', 
                col_order=sc_names, 
                kind='violin',
                order=order,
                #height=10, 
                #aspect=1
               )
g.set_ylabels('Scores', size=22)
g.set_titles("{col_name}", size=24)
g.set_xticklabels(order, size=22)
g.set(xlabel=None)
g.set(yticks=yticks)
g.set_yticklabels(yticks, size=20)
for ax in g.axes[0]:
    wrap_labels(ax, 7, break_long_words=True)
if save:
    plt.savefig(os.path.join(storing_path, 'violin_plots_not_comparable_ranges_GMM.svg'), format='svg')

In [None]:
gt = adata.obs['celltype.l2'].copy()
gt

In [None]:
from sklearn.preprocessing import OneHotEncoder


enc = OneHotEncoder(handle_unknown='ignore')
enc_df = pd.DataFrame(enc.fit_transform(adata.obs[['celltype.l2']]).toarray())
enc_df

In [None]:
from sklearn.metrics import f1_score, jaccard_score, balanced_accuracy_score

In [None]:
score_name_w_gmm = [x for x in score_name_w_gmm if 'std_adjust' not in x]
score_name_w_gmm

In [None]:
rows = []
for i in range(0,len(score_name_w_gmm),3):
    
    prediction = adata.obs[score_name_w_gmm[i:(i+3)]]
    
#     curr_auc = roc_auc_score(enc_df, prediction, average='weighted')
    
    tmp = prediction.idxmax(axis=1)
    tmp[tmp.str.contains('B memory')] = 'B memory'
    tmp[tmp.str.contains('B naive')] = 'B naive'
    tmp[tmp.str.contains('B intermediate')] = 'B intermediate'
 
    curr_f1 = f1_score(gt,tmp, average='weighted')
    curr_j = jaccard_score(gt,tmp, average='weighted')
    curr_ba = balanced_accuracy_score(gt,tmp)
    
    row = {
        'Scoring method': '_'.join(score_name_w_gmm[i].split('_')[0:-3])+' with GMM 3K',
#         'AUCROC (weighted)':curr_auc,
        'F1-score (weighted)':curr_f1, 
        'Jaccard-score (weighted)':curr_j,
        'Balanced accuracy':curr_ba
    }
    rows.append(row)

In [None]:
performance_hard_labeling_on_scores = pd.DataFrame(rows)

In [None]:
performance_hard_labeling_on_scores.sort_values(by='Balanced accuracy', ascending=False)

In [None]:
if save:
    performance_hard_labeling_on_scores.to_csv(os.path.join(storing_path, 'performance_hard_labeling_on_GMM.csv'))

In [None]:
performance_hard_labeling_on_scores = performance_hard_labeling_on_scores.melt(id_vars=['Scoring method', 'F1-score (weighted)'],
            var_name='metric',
            value_name='value'
)

In [None]:
f = plt.figure(figsize=(8, 6))
g = sns.scatterplot(
    x='value',
    y='F1-score (weighted)',
    hue='Scoring method',
    hue_order=sc_names,
    style='metric',
    data=performance_hard_labeling_on_scores,
    s=200
)
lgnd = g.legend(bbox_to_anchor=(1, 1), fontsize=16)
g.set_title('Performance hard labeling using probabilities ($\it{hard}$ task)', fontsize=18)
g.set_xlabel('Values of metrics', fontsize=16)
g.set_ylabel('F1-score (weighted)', fontsize=16)
if save:
    f.savefig(os.path.join(storing_path, f'GMM3_hard_labeling.svg'), format='svg')
    f.savefig(os.path.join(storing_path, f'GMM3_hard_labeling.png'), format='png', dpi=300)