In [178]:
from pathlib import Path

import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from score_with_all_methods import save_close_or_show

sns.set_style('ticks')

plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 10})

In [179]:
res_root = Path('/Users/lciernik/Documents/TUB/projects/ans_scoring/results')
SAVE = True

In [180]:
def load_results(res_root):
    res = []
    for f in res_root.glob('**/metrics.csv'):
        df = pd.read_csv(f)

        if 'signatures_with_overlapping' in str(f):
            df['signatures'] = 'Original signatures'
        else:
            df['signatures'] = 'Signatures without overlapping genes'

        if 'scored_on_all' in str(f):
            df['scored_on'] = 'Scoring on all cells'
        else:
            df['scored_on'] = 'Scoring on subtypes'

        if 'cancer' in str(f):
            df['tissue_type'] = 'Cancer'
        else:
            df['tissue_type'] = 'PBMC'

        if 'cd4' in str(f):
            df['Cell type'] = 'CD4 T-cell subtypes'
        elif 'cd8' in str(f):
            df['Cell type'] = 'CD8 T-cell subtypes'
        elif 'b_subtypes' in str(f):
            df['Cell type'] = 'B-cell subtypes'
        elif 'b_mono_nk' in str(f):
            df['Cell type'] = 'B, Monocytes, NK cells'
        elif 'breast' in str(f):
            df['Cell type'] = 'BRCA (6 states)'
        elif 'skin' in str(f):
            df['Cell type'] = 'cSCC (4 states)'
        elif 'lung' in str(f):
            df['Cell type'] = 'LUAD (3 states)'
        elif 'ovarian' in str(f):
            df['Cell type'] = 'HGSOC (8 states)'
        else:
            df['Cell type'] = 'Unknown'

        df.columns = ['metric'] + list(df.columns[1:])
        df = pd.melt(
            df,
            id_vars=[df.columns[0]] + list(df.columns[-4:]),
            value_vars=df.columns[1:-4],
            var_name='Scoring method',
            value_name='metric_value'

        )
        df = df[df['metric'] != 'conf_mat']
        res.append(df)
    return pd.concat(res)

In [181]:
res_cancer = load_results(res_root / "cancer_datasets_bak")
res_pbmc = load_results(res_root / "citeseq")

res = pd.concat([res_cancer, res_pbmc])

In [183]:
# res.columns = ['metric', 'ANS', 'Scanpy', 'signatures', 'scored_on', 'Cell type']
# res = res[['Cell type', 'scored_on', 'signatures', 'metric', 'ANS', 'Scanpy']]

In [184]:
res['metric_value'] = res['metric_value'].astype(float)
res = res[~res.metric.str.contains('std')]
res.shape

(384, 7)

In [185]:
res = pd.pivot(res,
               index=['tissue_type', 'Cell type', 'scored_on', 'signatures', 'Scoring method'],
               columns='metric',
               values='metric_value').reset_index()
res.index.name = None

In [187]:
# allowed_scoring_methods = ['ANS', 'Scanpy', 'UCell']
allowed_scoring_methods = list(res['Scoring method'].unique())
res_plt = res[res['Scoring method'].isin(allowed_scoring_methods)]

res_plt['Score imbalance'] = res_plt['logreg_balanced_accuracy_10cv'] - res_plt['balanced_accuracy']

In [188]:
style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy',
    'Jasmine_LH',
    'Jasmine_OR',
    'UCell'
]

## Plotting

In [189]:
def plot_scatter(df, scale_imbalance):
    g = sns.relplot(
        data=df,
        x='balanced_accuracy',
        y='logreg_balanced_accuracy_10cv',
        style='Scoring method',
        style_order=style_order,
        hue='Cell type',
        # col='tissue_type',
        # row='signatures',
        col='signatures',
        row='tissue_type',
        height=2.5,
        aspect=1.2,
        s=100,
        alpha=0.7,
        facet_kws={'sharey': False, 'sharex': False},
    )

    # title_template = "{col_name}"
    title_template = "{row_name}"
    g.set_titles(title_template, fontsize=10)
    g.set_axis_labels('Balanced Accuracy', 'Information quantity (logreg)', fontsize=10)

    for ax in g.axes.flatten():
        # Define the range for the line
        vmin = min(ax.get_xlim()[0], ax.get_ylim()[0])
        vmax = max(ax.get_xlim()[1], ax.get_ylim()[1])

        x_values = np.linspace(vmin, vmax, 100)
        y_values = x_values  # f(x) = x

        # Plot the line
        ax.plot(x_values, y_values, color='grey', linestyle='--', alpha=0.5)

    # plt.tight_layout()
    g.fig.subplots_adjust(hspace=0.25)

    rectangle = mpath.Path([
        (-1.5, -0.5), (1.5, -0.5), (1.5, 0.5), (-1.5, 0.5), (-1.5, -0.5)
    ], [mpath.Path.MOVETO, mpath.Path.LINETO, mpath.Path.LINETO, mpath.Path.LINETO, mpath.Path.CLOSEPOLY])

    for i in range(1, 9):
        g.legend.get_lines()[i].set_marker(rectangle)
        g.legend.get_lines()[i].set_markersize(17)

    edge_colors = ['black'] + ['lightgrey'] * 7
    for ax in g.axes.flatten():
        nr_colors = ax.collections[0].get_facecolors().shape[0] // len(style_order)
        ax.collections[0].set_edgecolor(edge_colors * nr_colors)
        
    for ax in g.axes.flat:
        for spine in ax.spines.values():
            spine.set_linewidth(0.85)  # Set axis line width
    
    for ax in g.axes.flat:
        ax.tick_params(axis='y', labelsize=8, length=2.5, width=0.85)
        ax.tick_params(axis='x', labelsize=8, length=2.5, width=0.85)
    
    sns.move_legend(g, "upper left", bbox_to_anchor=(0.65, 0.99), frameon=False, fontsize=10, ncols=2)
    
    table = g.axes.flat[1].table(
        cellText=scale_imbalance.loc[style_order, :].values, 
        rowLabels=scale_imbalance.loc[style_order, :].index, 
        colLabels=scale_imbalance.loc[style_order, :].columns, 
        fontsize=10,
        cellLoc='center',
        bbox=[1.7,0, 1.1, 1.1],  # Adjust [x, y, width, height] of the table
        # bbox=[1.7,0, 3, 3],  # Adjust [x, y, width, height] of the table

    )
    for key, cell in table.get_celld().items():
        cell.set_linewidth(0.85)  # Set the line width for all cells
        
    return g.fig


def get_scale_imbalance(df):
    scale_imbalance = df.groupby(['tissue_type', 'Scoring method'])['Score imbalance'].describe().reset_index()
    scale_imbalance = scale_imbalance[['tissue_type', 'Scoring method', 'mean', 'std']]
    scale_imbalance['mean_std'] = scale_imbalance['mean'].round(3).astype(str) + ' ± ' + scale_imbalance['std'].round(3).astype(str)
    scale_imbalance_pivot = pd.pivot(scale_imbalance, columns='tissue_type', index='Scoring method', values='mean_std')
    return scale_imbalance_pivot


In [190]:
sc_methods = '_'.join(allowed_scoring_methods) if len(allowed_scoring_methods) != res['Scoring method'].nunique() else 'all'

### Overlapping gene signatures

In [191]:
res_w_overlapping = res_plt[res_plt['signatures'] == 'Original signatures'].copy()
scale_imbalance_pivot = get_scale_imbalance(res_w_overlapping)

fig = plot_scatter(res_w_overlapping, scale_imbalance_pivot)
save_close_or_show(fig, SAVE, res_root / f'metrics_{sc_methods}_orig_sigs.pdf')


Saved figure at /Users/lciernik/Documents/TUB/projects/ans_scoring/results/metrics_all_orig_sigs.pdf.


### Non-overlapping gene signatures

In [192]:
res_wo_overlapping = res_plt[res_plt['signatures'] == 'Signatures without overlapping genes'].copy()
scale_imbalance_pivot = get_scale_imbalance(res_wo_overlapping)
fig = plot_scatter(res_wo_overlapping, scale_imbalance_pivot)

save_close_or_show(fig, SAVE, res_root / f'metrics_{sc_methods}_wo_overlap_sigs.pdf')

Saved figure at /Users/lciernik/Documents/TUB/projects/ans_scoring/results/metrics_all_wo_overlap_sigs.pdf.


In [193]:
# fig = plt.figure(
#     figsize=(8, 6)
# )
# g= sns.boxplot(
#     res_wo_overlapping,
#     x='Scoring method',
#     y='Score imbalance',
#     hue='tissue_type',
#     order=style_order,
# )
# g.set_ylabel('Score imbalance (logreg - balanced accuracy)');
# plt.xticks(rotation=45);
# g.get_legend().set_title("")
# plt.tight_layout()