In [None]:
%load_ext autoreload
%autoreload 2
import sys
import warnings
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sys.path.append('/Users/lciernik/Documents/TUB/projects/ans_scoring/ANS_supplementary_information')
from data.load_data import load_datasets

sys.path.append('..')

from score_with_all_methods import (
    score_signatures_with_all_methods,
    label_assignment_from_scores,
    get_lbl_assignment_performance,
    get_information_from_scores,
    remove_overlapping_signature_genes,
    get_violin_all_methods,
    prepare_data_for_violin_plot,
    save_close_or_show,
    plot_confusion_matrix
)

warnings.simplefilter(action='ignore', category=FutureWarning)
plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 10})

# Wu et al. breast cancer dataset
Wu, S.Z., Al-Eryani, G., Roden, D.L. et al. A single-cell and spatially resolved atlas of human breast cancers. Nat Genet 53, 1334–1347 (2021). https://doi.org/10.1038/s41588-021-00911-1

Storing information and global variables

In [None]:
base_data_path = Path('/Users/lciernik/Documents/TUB/projects/ans_scoring/data/data_from_florian/')

remove_overlapping_genes = True

SAVE = False
storing_path = Path('/Users/lciernik/Documents/TUB/projects/ans_scoring/results/cancer_datasets/breast')

if remove_overlapping_genes:
    storing_path = storing_path / 'signatures_without_overlapping'
else:
    storing_path = storing_path / 'signatures_with_overlapping'

if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

Loading data

In [None]:
adata = load_datasets('breast_malignant')
adata

In [None]:
y_true_col = 'gene_module'
sample_col = 'Patient'

In [None]:
adata.obs[y_true_col] = adata.obs[y_true_col].map({str(i): f'GM{i}' for i in range(1, 8)})

In [None]:
adata.obs[y_true_col].value_counts().sort_index()

Computing dimensionality reduction

In [None]:
# sc.tl.pca(adata)
# sce.pp.harmony_integrate(adata, sample_col)
# sc.pp.neighbors(adata, use_rep='X_pca_harmony')
# sc.tl.umap(adata)

In [None]:
# fig = sc.pl.umap(adata, color=['celltype_minor', 'subtype','gene_module',sample_col], ncols=1, return_fig=True)
# if SAVE:
#     fig.savefig(storing_path / 'umap.png', bbox_inches='tight')
#     fig.savefig(storing_path / 'umap.pdf', bbox_inches='tight')

Load signatures

In [None]:
include_undefined = False

In [None]:
signatures = pd.read_csv(base_data_path / 'annotations' / 'wu_6.csv')
signatures = signatures.to_dict('series')
signatures = {k: sorted(v.dropna().tolist()) for k, v in signatures.items()}
for cell_type, genes in signatures.items():
    print(cell_type, ' with signature length: ', len(genes))

In [None]:
if remove_overlapping_genes:
    signatures = remove_overlapping_signature_genes(signatures)

In [None]:
order_signatures = list(signatures.keys())

Scoring signatures

In [None]:
score_cols, adata = score_signatures_with_all_methods(adata, signatures)

Label assignment

In [None]:
all_cols = []
label_cols = {}
for method_name, method_scores in score_cols.items():
    adata, new_lbl_col = label_assignment_from_scores(adata, method_name, method_scores, include_undefined=False)
    label_cols[method_name] = new_lbl_col
    all_cols += method_scores + [new_lbl_col]

In [None]:
score_cols, label_cols

Visualizing results

In [None]:
# fig = sc.pl.umap(adata, color=all_cols + [sample_col, y_true_col], ncols=len(signatures) + 1, return_fig=True)
# if SAVE:
#     fig.savefig(storing_path / 'umap.png', bbox_inches='tight')
#     fig.savefig(storing_path / 'umap.pdf', bbox_inches='tight')
#     plt.close(fig)
#     print(f"Saved UMAP.")
# else:
#     plt.show(fig)

In [None]:
df_melted = prepare_data_for_violin_plot(adata, y_true_col, score_cols)

In [None]:
df_melted['Signature'].unique()

In [None]:
### Combined violin plots
fig = get_violin_all_methods(df_melted, 
                             y_true_col, 
                             hue_order=order_signatures, 
                             textwrap_width=7,
                             sharey=False, 
                             height=1.95, 
                             aspect=2.5, 
                             wspace=0.075,
                             col_wrap=2,
                             legend_bbox_anchor=(1.125, 1,))
plt.show(fig)
# save_close_or_show(fig, SAVE, storing_path / "violin_all_methods.pdf")

In [None]:
# for method_name, method_scores in score_cols.items():
#     df = adata.obs.loc[:, method_scores + [y_true_col]]
#     fig = get_violin(df, method_scores, y_true_col)
#     plt.title(f"{method_name}")
#     if SAVE:
#         fig.savefig(storing_path / f'violin_{method_name}.png', bbox_inches='tight')
#         fig.savefig(storing_path / f'violin_{method_name}.pdf', bbox_inches='tight')
#         plt.close(fig)
#         print(f"Saved violin plot for {method_name}.")
#     else:
#         plt.show(fig)

Computing label assignment performance

In [None]:
metrics = defaultdict(dict)
nfold = 10

for method_name, method_scores in score_cols.items():
    lbl_col = label_cols[method_name]
    conf_mat, bal_acc, f1_val = get_lbl_assignment_performance(adata,
                                                               y_true_col=y_true_col,
                                                               y_pred_col=lbl_col,
                                                               label_names=order_signatures)
    print(f"{method_name} - balanced accuracy: {bal_acc:.3f}, f1 score: {f1_val:.3f}")
    scores = get_information_from_scores(adata, y_true_col=y_true_col, scores=method_scores, nfold=nfold)

    metrics[method_name] = {
        'conf_mat': conf_mat,
        'balanced_accuracy': bal_acc,
        'f1_score': f1_val,
        f'logreg_balanced_accuracy_{nfold}cv': np.mean(scores),
        f'logreg_balanced_accuracy_{nfold}cv_std': np.std(scores)
    }

    ## Confusion matrix plot

    fig = plot_confusion_matrix(conf_mat, order_signatures, method_name)
    save_close_or_show(fig, SAVE, storing_path / f'conf_mat_{method_name}.pdf')

In [None]:
metrics_df = pd.DataFrame(metrics)

Saving performance metrics

In [None]:
if SAVE:
    metrics_df.to_csv(storing_path / 'metrics.csv')
    print(f"Saved metrics to {storing_path / 'metrics.csv'}.")