In [None]:
import sys
import warnings
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from pbmc_helper import load_dex_genes

sys.path.append('..')
sys.path.append('/Users/lciernik/Documents/TUB/projects/ans_scoring/ANS_supplementary_information')
from data.load_data import load_datasets

from score_with_all_methods import (
    score_signatures_with_all_methods,
    label_assignment_from_scores,
    get_lbl_assignment_performance,
    get_information_from_scores,
    remove_overlapping_signature_genes,
    get_violin_all_methods,
    prepare_data_for_violin_plot,
    save_close_or_show,
    plot_confusion_matrix
)

warnings.simplefilter(action='ignore', category=FutureWarning)
plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 10})

Storing information and global variables

In [None]:
remove_overlapping_genes = True

SAVE = False
storing_path = Path('/Users/lciernik/Documents/TUB/projects/ans_scoring/results/citeseq/b_subtypes')

if remove_overlapping_genes:
    storing_path = storing_path / 'signatures_without_overlapping'
else:
    storing_path = storing_path / 'signatures_with_overlapping'

if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

Loading data

In [None]:
adata = load_datasets('pbmc_b_subtypes')
adata

In [None]:
y_true_col = 'celltype.l2'
sample_col = 'orig.ident'

In [None]:
adata.obs[y_true_col].value_counts().sort_index()

Computing dimensionality reduction

In [None]:
# sc.tl.pca(adata)
# sc.pp.neighbors(adata)
# sc.tl.umap(adata)

Load signatures

In [None]:
DE_of_celltypes = load_dex_genes(filter_genes=True, threshold_pval=0.01, threshold_log2fc=0.5)

In [None]:
subtypes_per_cell_type = adata.obs.groupby('celltype.l2')['celltype.l3'].apply(lambda x: list(x.unique()))

In [None]:
subtypes_per_cell_type.to_dict()

In [None]:
signatures = {}
for row in subtypes_per_cell_type.items():
    cell_type, subtypes = row
    signatures[cell_type] = sorted(
        list(set(DE_of_celltypes[DE_of_celltypes['Cell Type'].isin(subtypes)]['Gene'].tolist())))

In [None]:
for k, v in signatures.items():
    print(k, len(v))

In [None]:
if remove_overlapping_genes:
    signatures = remove_overlapping_signature_genes(signatures)

In [None]:
order_signatures = ['B naive', 'B intermediate', 'B memory']

Scoring signatures

In [None]:
score_cols, adata = score_signatures_with_all_methods(adata, signatures)

Label assignment

In [None]:
all_cols = []
label_cols = {}
for method_name, method_scores in score_cols.items():
    adata, new_lbl_col = label_assignment_from_scores(adata, method_name, method_scores, include_undefined=False)
    label_cols[method_name] = new_lbl_col
    all_cols += method_scores + [new_lbl_col]

Visualizing results

In [None]:
### UMAP
# fig = sc.pl.umap(adata, color=all_cols + [sample_col, y_true_col, 'celltype.l1', 'celltype.l3'],
#                  ncols=len(signatures) + 1, return_fig=True)
# save_close_or_show(fig, SAVE, storing_path / 'umap.pdf')

In [None]:
df_melted = prepare_data_for_violin_plot(adata, y_true_col, score_cols)

In [None]:
df_melted['Signature'].unique()

In [None]:
### Combined violin plots
fig = get_violin_all_methods(
    df_melted, 
    y_true_col, 
    hue_order=order_signatures, 
    textwrap_width=7, 
    aspect=1.05, 
    sharey=True,
    legend_bbox_anchor=(1.075, 1),
    fontsizes={'title': 12, 'labels': 11, 'ticks': 10, 'legend': 11}
)
save_close_or_show(fig, SAVE, storing_path / "violin_all_methods.pdf")

In [None]:
### Single violin plots
# for method_name, method_scores in score_cols.items():
#     df = adata.obs.loc[:, method_scores + [y_true_col]]
#     fig = get_violin(df, method_scores, y_true_col)
#     plt.title(f"{method_name}")
#     save_close_or_show(fig, SAVE, storing_path / f'violin_{method_name}.pdf')

Computing label assignment performance

In [None]:
metrics = defaultdict(dict)
nfold = 10
overall_min = np.inf
overall_max = -np.inf

for method_name, method_scores in score_cols.items():
    lbl_col = label_cols[method_name]
    conf_mat, bal_acc, f1_val = get_lbl_assignment_performance(adata,
                                                               y_true_col=y_true_col,
                                                               y_pred_col=lbl_col,
                                                               label_names=order_signatures)

    scores = get_information_from_scores(adata, y_true_col=y_true_col, scores=method_scores, nfold=nfold)

    metrics[method_name] = {
        'conf_mat': conf_mat,
        'balanced_accuracy': bal_acc,
        'f1_score': f1_val,
        f'logreg_balanced_accuracy_{nfold}cv': np.mean(scores),
        f'logreg_balanced_accuracy_{nfold}cv_std': np.std(scores)
    }

    ## Confusion matrix plot
    # fig = plot_confusion_matrix(conf_mat, order_signatures, method_name, figsize=(2.3, 2.3), textwrap_width=7,
    #                             xrotation=45, cbar=False)
    fig = plot_confusion_matrix(conf_mat, order_signatures, method_name)
    save_close_or_show(fig, SAVE, storing_path / f'conf_mat_{method_name}.pdf')

In [None]:
metrics_df = pd.DataFrame(metrics)

Saving performance metrics

In [None]:
if SAVE:
    metrics_df.to_csv(storing_path / 'metrics.csv')
    print(f"Saved metrics to {storing_path / 'metrics.csv'}.")