# This notebook creates the plots for the comparability evaluation of the different scoring methods.
This notebook is structured as follows:
1. Load the results of the comparability evaluation. **NOTE**: The results are not provided in the repository. Need to run the evaluation first: `experiments/run_comp_range_exp.sh`
2. Create confusion matrix plots (Suppl. Figure S7 & S10)
3. Create performance overview tables (Suppl. Table S2 & S3)
4. Show the difference in performance for cansig pp and non-canisg pp (Suppl. Figure S12)
5. Compute performance plots for main text Figure 2 (Scatter plot & Bar plot)

In [1]:
from io import StringIO
from pathlib import Path

import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sympy import print_tree

from old.score_with_all_methods import save_close_or_show

sns.set_style('ticks')

plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 10})

In [2]:
# from ANS_supplementary_information.data.constants import BASE_PATH_RESULTS
# from ANS_supplementary_information.data.load_data import load_signatures
from data.constants import BASE_PATH_RESULTS
from data.load_data import load_signatures

In [3]:
exp_path = Path(BASE_PATH_RESULTS) / 'comparable_score_ranges'

storing_path = exp_path / 'plots'
SAVE = True
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)


In [4]:
dfs = []
for path in exp_path.glob('**/metrics.csv'):
    tt, gene_pool, overlapping_genes = path.parts[-4:-1]
    df = pd.read_csv(path, index_col=0)
    df = df.reset_index(names='metric')
    df['cell_type'] = tt
    df['tissue_type'] = 'PBMC' if 'pbmc' in tt else 'Cancer'
    df['use_gene_pool'] = False if 'without' in gene_pool else True
    df[
        'signatures'] = 'Signatures without overlapping genes' if 'without' in overlapping_genes else 'Original signatures'
    dfs.append(df)
all_ds = pd.concat(dfs)

In [5]:
all_ds = pd.melt(
    all_ds,
    id_vars=['metric', 'cell_type', 'tissue_type', 'use_gene_pool', 'signatures'],
    var_name='scoring_method',
    value_name='value'
)

In [6]:
all_ds = pd.pivot(all_ds,
                  index=['cell_type', 'tissue_type', 'use_gene_pool', 'signatures', 'scoring_method'],
                  columns='metric',
                  values='value').reset_index()

In [7]:
float_cols = ['balanced_accuracy', 'f1_score', 'gmm_balanced_accuracy',
              'gmm_f1_score', 'gmm_jaccard_score', 'jaccard_score',
              'logreg_balanced_accuracy_10cv_mean',
              'logreg_balanced_accuracy_10cv_std', 'logreg_f1_weighted_10cv_mean',
              'logreg_f1_weighted_10cv_std', 'logreg_jaccard_weighted_10cv_mean',
              'logreg_jaccard_weighted_10cv_std']
all_ds[float_cols] = all_ds[float_cols].astype(float)

In [8]:
style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy',
    'Jasmine_LH',
    'Jasmine_OR',
    'UCell'
]

In [9]:
all_ds['cell_type_pp'] = all_ds.cell_type.map({"breast_malignant": 'BRCA (6 states)',
                                               'luad_kim_malignant': 'LUAD (3 states)',
                                               'luad_kim_malignant_2': 'LUAD (3 states, 3ca)',
                                               'ovarian_malignant': 'HGSOC (8 states)',
                                               'ovarian_malignant_bak': 'HGSOC (8 states, bak)',
                                               'ovarian_malignant_2': 'HGSOC (8 states, cellxgene)',
                                               'skin_malignant': 'cSCC (4 states, self-pp)',
                                               'skin_malignant_2': 'cSCC (4 states)',
                                               'pbmc_b_mono_nk': 'B, Monocytes, NK cells',
                                               'pbmc_b_subtypes': 'B-cell subtypes',
                                               'pbmc_cd4_subtypes': 'CD4 T-cell subtypes',
                                               'pbmc_cd8_subtypes': 'CD8 T-cell subtypes',
                                               })

## Create confusion matrix plots

In [None]:
con_mat_cols = set(all_ds.columns) - {'gmm_balanced_accuracy', 'gmm_f1_score', 'gmm_jaccard_score', 'jaccard_score',
                                      'logreg_balanced_accuracy_10cv_mean', 'logreg_balanced_accuracy_10cv_std',
                                      'logreg_f1_weighted_10cv_mean', 'logreg_f1_weighted_10cv_std',
                                      'logreg_jaccard_weighted_10cv_mean', 'logreg_jaccard_weighted_10cv_std'}
conf_mat_data = all_ds[sorted(list(con_mat_cols))].copy()
conf_mat_data.columns

In [None]:
# conf_mat_data.groupby(['cell_type','cell_type_pp', 'tissue_type', 'use_gene_pool', 'signatures']).size()

In [None]:
def parse_matrix_str(matrix_str):
    cleaned = matrix_str.replace('\n  ', ' ').replace('[', '').replace(']', '')
    return np.loadtxt(StringIO(cleaned))

In [None]:
conf_mat_data['conf_mat'] = conf_mat_data['conf_mat'].apply(parse_matrix_str)

In [None]:
conf_mat_data['cell_type'].unique()

In [None]:
sigs = {}
for ds in conf_mat_data.cell_type.unique():
    sigs[ds] = list(sorted(load_signatures(ds).keys()))

In [None]:
import textwrap

width_text_wrap = 7


def pp_lbl(ds, lbl):
    if 'ovarian' in ds:
        return lbl.replace('cell.', '').replace('cancer.', 'C.').replace('Cancer.', 'C.').replace('Cycling.', 'Cyc.')
    elif 'skin' in ds:
        return lbl.replace('Cycling', 'Cyc.').replace('Tumor', 'Tum.')
    elif 'b_subtypes' in ds:
        return lbl.replace('intermediate', 'interm.').replace('memory', 'mem.').replace('Cancer.', 'C.')
    elif 'cd4_subtypes' in ds:
        return lbl.replace('Proliferating', 'Prolif.')
    elif 'cd8_subtypes' in ds:
        return lbl.replace('Proliferating', 'Prolif.')
    return lbl


for k, v in sigs.items():
    sigs[k] = [textwrap.fill(pp_lbl(k, curr_sig), width=width_text_wrap) for curr_sig in v]


In [None]:
subfig_size = 4.2
for key, data in conf_mat_data.groupby(['tissue_type', 'use_gene_pool', 'signatures']):
    if key[0] == 'Cancer':
        continue
    # filter out the different preprocessing strategies
    if key[0] == 'Cancer':
        data = data[~ data['cell_type_pp'].str.contains(', ')].copy()

    # get the number of different scoring methods and datasets
    n_sc_methods = data['scoring_method'].nunique()
    n_ds = data['cell_type_pp'].nunique()

    fig, axes = plt.subplots(nrows=n_sc_methods, ncols=n_ds, figsize=(n_ds * subfig_size, n_sc_methods * subfig_size),
                             sharey=False, sharex='col')

    axes = axes.flatten()
    for i, (sc_method, sc_data) in enumerate(data.groupby(['scoring_method', 'cell_type', 'cell_type_pp'])):
        assert len(sc_data) == 1
        ax = axes[i]
        conf_map = sc_data['conf_mat'].iloc[0]
        bal_acc = sc_data['balanced_accuracy'].iloc[0]

        sns.heatmap(conf_map, ax=ax, annot=True, fmt='.2f', cmap='coolwarm', cbar=False, vmin=0, vmax=1)
        ax.set_title(f'{sc_method[2]} – {bal_acc:.2f} bal. acc.', fontsize=11, weight='bold')
        if i % n_ds == 0:
            ax.set_ylabel(sc_method[0], fontsize=11, weight='bold')


        # set tick labels
        def pp_lbl(lbl):
            if 'HGSOC' in sc_method[2]:
                return lbl.replace('cell.', '').replace('cancer.', 'C.').replace('Cancer.', 'C.')
            return lbl


        tick_lbls = sigs[sc_data['cell_type'].iloc[0]]
        ax.set_xticks(np.arange(len(tick_lbls)) + 0.5, tick_lbls)
        ax.set_yticks(np.arange(len(tick_lbls)) + 0.5, tick_lbls, rotation=0)
    fig.tight_layout()

    wspace = 0.21 if key[0] == 'Cancer' else 0.25
    fig.subplots_adjust(wspace=wspace, hspace=0.15)

    suffix = f"{key[0].lower()}_{'with' if key[1] else 'without'}_gene_pool_{key[2].replace(' ', '_').lower()}"
    save_close_or_show(fig, SAVE, save_path=storing_path / f'conf_mats_{suffix}.pdf')
    save_close_or_show(fig, SAVE, save_path=storing_path / f'conf_mats_{suffix}.svg')

## Create performance overview table

In [10]:
all_ds = all_ds.sort_values(by=['signatures', 'use_gene_pool', 'tissue_type', 'cell_type_pp'])

In [11]:
curr_ds = all_ds[~all_ds['cell_type_pp'].str.contains(f'(\d states, .+)')]

  curr_ds = all_ds[~all_ds['cell_type_pp'].str.contains(f'(\d states, .+)')]


In [12]:
grouped_all_ds = curr_ds.groupby(['signatures', 'use_gene_pool'])

In [13]:
def metric_type(x):
    if 'balanced_accuracy' in x:
        return 'Balanced Accuracy'
    if 'f1_score' in x:
        return 'F1 Score'
    if 'f1_weighted' in x:
        return 'F1 Score'
    return x


def lbl_method(x):
    if 'gmm' in x:
        return 'Rediscovery score'
    elif 'logreg' in x:
        return 'Information quantity (logistic regression 10-fold CV)'
    else:
        return 'Hard labeling score'

def h_max(s):
    if s.apply(type).unique().item() == str:
        s_float = s.apply(lambda x: float(x.split('\n')[0]))
    else:
        s_float = s
    is_max = s == s.max()
    return ['font-weight: bold' if cell else '' for cell in is_max]

In [16]:
from pandas import ExcelWriter

for group, data in grouped_all_ds:
    print(group)
    data['logreg_balanced_accuracy_10cv_mean_n_std'] = data['logreg_balanced_accuracy_10cv_mean'].round(3).astype(
        str) + '\n(' + data['logreg_balanced_accuracy_10cv_std'].round(3).astype(str) + ')'
    data['logreg_f1_weighted_10cv_mean_n_std'] = data['logreg_f1_weighted_10cv_mean'].round(3).astype(
        str) + '\n(' + data['logreg_f1_weighted_10cv_std'].round(3).astype(str) + ')'

    data = data.drop(columns=['conf_mat',
                              'gmm_balanced_accuracy',
                              'gmm_f1_score',
                              'gmm_jaccard_score',
                              'jaccard_score',
                              'logreg_balanced_accuracy_10cv_mean',
                              'logreg_balanced_accuracy_10cv_std',
                              'logreg_f1_weighted_10cv_mean',
                              'logreg_f1_weighted_10cv_std',
                              'logreg_jaccard_weighted_10cv_mean',
                              'logreg_jaccard_weighted_10cv_std'])

    melted_data = pd.melt(
        data,
        id_vars=['cell_type', 'tissue_type', 'use_gene_pool', 'signatures',
                 'scoring_method', 'cell_type_pp'],
        var_name='metric',
        value_name='value'
    )
    melted_data['metric_type'] = melted_data['metric'].apply(metric_type)
    melted_data['labeling_method'] = melted_data['metric'].apply(lbl_method)

    perf_table = pd.pivot(
        melted_data,
        index=['metric_type', 'scoring_method'],
        columns=['labeling_method', 'tissue_type', 'cell_type_pp'],
        values='value'
    ).T
    
    table_bal_acc = perf_table[[col for col in perf_table.columns if 'Balanced Accuracy' in col[0]]].style.apply(h_max, axis=1)
    table_f1 = perf_table[[col for col in perf_table.columns if 'F1 Score' in col[0]]].style.apply(h_max, axis=1)
    
    fn = storing_path / f'overview_table_{group[0].replace(" ", "_").lower()}_{"with" if group[1] else "without"}_gene_pool.xlsx'
    fn1 = storing_path / f'overview_table_{group[0].replace(" ", "_").lower()}_{"with" if group[1] else "without"}_gene_pool_bal_acc.xlsx'
    fn2 = storing_path / f'overview_table_{group[0].replace(" ", "_").lower()}_{"with" if group[1] else "without"}_gene_pool_f1.xlsx'

    
    if SAVE:
        with ExcelWriter(fn, engine='xlsxwriter') as writer:
            table_bal_acc.to_excel(writer, sheet_name='Balanced Accuracy')
            table_f1.to_excel(writer, sheet_name='F1 Score')
            workbook = writer.book
            format_font = workbook.add_format({'font_size': 9.5, 'text_wrap': True})
            for worksheet in writer.sheets.values():
                worksheet.set_column('A:Z', None, format_font)  # Apply to all columns A through Z
        print(f"Stored Excel file at {fn=}")

    print()
    print()

('Original signatures', False)


('Original signatures', True)


('Signatures without overlapping genes', False)


('Signatures without overlapping genes', True)




## Show difference in performance for cansig pp and non-canisg pp

In [None]:
def compare_two_scoring_methods(res_df):
    g = sns.catplot(
        res_df,
        x='scoring_method',
        y='balanced_accuracy',
        hue='cell_type_pp',
        row='signatures',
        col='use_gene_pool',
        kind='bar',
        height=3,
        aspect=1.5
    )
    # Rotate labels before setting titles
    for ax in g.axes.flat:
        plt.sca(ax)
        plt.xticks(rotation=30, ha='right')

    title_template = "{row_name}\n Use gene_pol={col_name}"
    g.set_titles(title_template, fontsize=10)
    g.set_axis_labels('', 'Balanced Accuracy', fontsize=10)

    g.set(ylim=(res_df['balanced_accuracy'].min() - 0.01, res_df['balanced_accuracy'].max() + 0.01))

    g.fig.subplots_adjust(hspace=0.25)
    g._legend.set_title('Cell Type')

    return g.fig


In [None]:
luad_results = all_ds[all_ds.cell_type.str.contains('luad_kim')].copy()
fig = compare_two_scoring_methods(luad_results)
save_close_or_show(fig, SAVE, save_path=storing_path / 'luad_ds_pp_diff.pdf')
save_close_or_show(fig, SAVE, save_path=storing_path / 'luad_ds_pp_diff.svg')

In [None]:
skin_results = all_ds[all_ds.cell_type.str.contains('skin')].copy()
fig = compare_two_scoring_methods(skin_results)
save_close_or_show(fig, SAVE, save_path=storing_path / 'skin_ds_pp_diff.pdf')
save_close_or_show(fig, SAVE, save_path=storing_path / 'skin_ds_pp_diff.svg')

In [None]:
ovarian_results = all_ds[all_ds.cell_type.str.contains('ovarian')].copy()
ovarian_results = ovarian_results[~ovarian_results.cell_type.str.contains('_bak')].copy()
fig = compare_two_scoring_methods(ovarian_results)
save_close_or_show(fig, SAVE, save_path=storing_path / 'ovarian_ds_pp_diff.pdf')
save_close_or_show(fig, SAVE, save_path=storing_path / 'ovarian_ds_pp_diff.svg')

## Compute performance plots

### Scatter plot

In [None]:
all_ds = all_ds[~ all_ds['cell_type_pp'].str.contains(r', .+\)')].copy()

In [None]:
all_ds['Score imbalance'] = all_ds['logreg_balanced_accuracy_10cv_mean'] - all_ds['balanced_accuracy']

In [None]:
style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy',
    'Jasmine_LH',
    'Jasmine_OR',
    'UCell'
]


In [None]:
from matplotlib.colors import LinearSegmentedColormap


def plot_scatter(df, scale_imbalance):
    df = df.copy()
    df['Scoring Method'] = df['scoring_method']
    df['Cell types'] = df['cell_type_pp']
    hue_order = df.sort_values(['tissue_type', 'Cell types'])['Cell types'].unique()
    g = sns.relplot(
        data=df,
        x='balanced_accuracy',
        y='logreg_balanced_accuracy_10cv_mean',
        style='Scoring Method',
        style_order=style_order,
        hue='Cell types',
        hue_order=hue_order,
        # col='tissue_type',
        # row='signatures',
        col='signatures',
        row='tissue_type',
        height=2.5,
        aspect=1.2,
        s=100,
        alpha=0.7,
        facet_kws={'sharey': False, 'sharex': False},
    )

    # title_template = "{col_name}"
    title_template = "{row_name}"
    g.set_titles(title_template, fontsize=10)
    g.set_axis_labels('Balanced Accuracy', 'Information quantity (logreg)', fontsize=10)

    for ax in g.axes.flatten():
        # Define the range for the line
        vmin = min(ax.get_xlim()[0], ax.get_ylim()[0])
        vmax = max(ax.get_xlim()[1], ax.get_ylim()[1])

        x_values = np.linspace(vmin, vmax, 100)
        y_values = x_values  # f(x) = x

        # Plot the line
        ax.plot(x_values, y_values, color='grey', linestyle='--', alpha=0.5)

    # plt.tight_layout()
    g.fig.subplots_adjust(hspace=0.25)

    rectangle = mpath.Path([
        (-1.5, -0.5), (1.5, -0.5), (1.5, 0.5), (-1.5, 0.5), (-1.5, -0.5)
    ], [mpath.Path.MOVETO, mpath.Path.LINETO, mpath.Path.LINETO, mpath.Path.LINETO, mpath.Path.CLOSEPOLY])

    for i in range(1, 9):
        g.legend.get_lines()[i].set_marker(rectangle)
        g.legend.get_lines()[i].set_markersize(17)

    edge_colors = ['black'] + ['lightgrey'] * 7
    for ax in g.axes.flatten():
        nr_colors = ax.collections[0].get_facecolors().shape[0] // len(style_order)
        ax.collections[0].set_edgecolor(edge_colors * nr_colors)

    for ax in g.axes.flat:
        for spine in ax.spines.values():
            spine.set_linewidth(0.85)  # Set axis line width

    for ax in g.axes.flat:
        ax.tick_params(axis='y', labelsize=9, length=2.5, width=0.85)
        ax.tick_params(axis='x', labelsize=9, length=2.5, width=0.85)

    sns.move_legend(g, "upper left", bbox_to_anchor=(0.65, 0.99), frameon=False, fontsize=10, ncols=2)

    table = g.axes.flat[1].table(
        cellText=scale_imbalance.loc[style_order, :].values,
        rowLabels=scale_imbalance.loc[style_order, :].index,
        colLabels=scale_imbalance.loc[style_order, :].columns,
        fontsize=10,
        cellLoc='center',
        bbox=[1.7, 0, 1.1, 1.1],  # Adjust [x, y, width, height] of the table

    )
    for key, cell in table.get_celld().items():
        cell.set_linewidth(0.85)  # Set the line width for all cells

    colors = ['#00CED1', 'white']  # White to turquoise
    custom_cmap = LinearSegmentedColormap.from_list('custom', colors)

    for ax in g.axes.flatten():
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        x = np.linspace(xlim[0], xlim[1], 100)
        y = np.linspace(ylim[0], ylim[1], 100)
        xx, yy = np.meshgrid(x, y)
        gradient = (xx + yy) / np.sqrt(2)
        mask = np.flipud(yy > xx)
        gradient[~mask] = np.nan
        ax.imshow(gradient,
                  extent=[ax.get_xlim()[0], ax.get_xlim()[1],
                          ax.get_ylim()[0], ax.get_ylim()[1]],
                  cmap=custom_cmap, alpha=0.2, aspect='auto',
                  zorder=0)

        tx_st = xlim[0] + (xlim[1] - xlim[0]) * 0.35
        tx_end = ylim[1] - (ylim[1] - ylim[0]) * 0.1
        ax.text(tx_st,
                tx_end,
                'Scale imbalance',
                ha='center',
                color='black',
                fontsize=9)

        # Then add arrow below
        ax.arrow((tx_st + (tx_end - tx_st) * 0.25),  # x start 
                 ylim[1] - (ylim[1] - ylim[0]) * 0.15,  # y start
                 -(xlim[1] - xlim[0]) * 0.3,  # dx (negative for leftward)
                 0,  # dy
                 head_width=(xlim[1] - xlim[0]) * 0.04,
                 head_length=(ylim[1] - ylim[0]) * 0.04,
                 fc='black',
                 ec='black',
                 alpha=0.5)
    return g.fig


def get_scale_imbalance(df):
    scale_imbalance = df.groupby(['tissue_type', 'scoring_method'])['Score imbalance'].describe().reset_index()
    scale_imbalance = scale_imbalance[['tissue_type', 'scoring_method', 'mean', 'std']]
    scale_imbalance['mean_std'] = scale_imbalance['mean'].round(3).astype(str) + ' ± ' + scale_imbalance['std'].round(
        3).astype(str)
    scale_imbalance_pivot = pd.pivot(scale_imbalance, columns='tissue_type', index='scoring_method', values='mean_std')
    return scale_imbalance_pivot

In [None]:
for key, data in all_ds.groupby(['use_gene_pool', 'signatures']):
    print(key)
    scale_imbalance_pivot = get_scale_imbalance(data)
    fig = plot_scatter(data, scale_imbalance_pivot)
    use_gp, sig_pp = key
    suffix = f"{'with' if use_gp else 'without'}_gene_pool_{sig_pp.replace(' ', '_').lower()}"
    save_close_or_show(fig, SAVE, save_path=storing_path / f'perf_overview_{suffix}.pdf')
    save_close_or_show(fig, SAVE, save_path=storing_path / f'perf_overview_{suffix}.svg')

### Bar plot

In [None]:
def plot_balanced_accuracy_barplot(dat, style_order, width_cm=18, height_cm=7.5):
    # Convert cm to inches
    width_inches = width_cm * 0.393701
    height_inches = height_cm * 0.393701

    # Create figure
    fig = plt.figure(figsize=(width_inches, height_inches))

    # Create the main barplot
    ax = sns.barplot(
        data=dat,
        y='balanced_accuracy',
        x='cell_type_pp',
        order=dat.sort_values(['tissue_type', 'cell_type_pp'])['cell_type_pp'].unique(),
        hue='scoring_method',
        hue_order=style_order,
        dodge=True,
        linewidth=0.01
    )

    # Customize the plot
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=9)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Adjust labels
    plt.xlabel('', fontsize=10)
    plt.ylabel('Balanced Accuracy', fontsize=10)

    # Customize legend
    plt.legend(
        title='Scoring Method',
        bbox_to_anchor=(1.01, 1),
        loc='upper left',
        borderaxespad=0,
        fontsize=10,
        title_fontsize=10,
        edgecolor='white'
    )

    # Adjust layout
    plt.tight_layout()

    # Set y-axis limits
    spread = dat['balanced_accuracy'].max() - dat['balanced_accuracy'].min()
    plt.ylim(
        dat['balanced_accuracy'].min() - 0.05 * spread,
        dat['balanced_accuracy'].max() + 0.05 * spread
    )

    return fig, ax

In [None]:
for key, data in all_ds.groupby(['use_gene_pool', 'signatures']):
    fig, ax = plot_balanced_accuracy_barplot(data, style_order)
    use_gp, sig_pp = key
    suffix = f"{'with' if use_gp else 'without'}_gene_pool_{sig_pp.replace(' ', '_').lower()}"
    save_close_or_show(fig, SAVE, save_path=storing_path / f'perf_bar_{suffix}.pdf')
    save_close_or_show(fig, SAVE, save_path=storing_path / f'perf_bar_{suffix}.svg')
