In [1]:
from pathlib import Path

import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from old.score_with_all_methods import save_close_or_show

sns.set_style('ticks')

plt.rcParams.update({'pdf.fonttype': 42, 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.size': 10})

In [2]:
from ANS_supplementary_information.data.constants import BASE_PATH_RESULTS

In [3]:
exp_path = Path(BASE_PATH_RESULTS) / 'comparable_score_ranges'

storing_path = exp_path / 'plots'
SAVE = False
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)


In [4]:
dfs = []
for path in exp_path.glob('**/metrics.csv'):
    tt, gene_pool, overlapping_genes = path.parts[-4:-1]
    df = pd.read_csv(path, index_col=0)
    df = df.reset_index(names='metric')
    df['cell_type'] = tt
    df['tissue_type'] = 'PBMC' if 'pbmc' in tt else 'Cancer'
    df['use_gene_pool'] = False if 'without' in gene_pool else True
    df[
        'signatures'] = 'Signatures without overlapping genes' if 'without' in overlapping_genes else 'Original signatures'
    dfs.append(df)
all_ds = pd.concat(dfs)

In [5]:
all_ds.columns

Index(['metric', 'ANS', 'Seurat', 'Seurat_AG', 'Seurat_LVG', 'Scanpy',
       'Jasmine_LH', 'Jasmine_OR', 'UCell', 'cell_type', 'tissue_type',
       'use_gene_pool', 'signatures'],
      dtype='object')

In [6]:
all_ds = pd.melt(
    all_ds,
    id_vars=['metric', 'cell_type', 'tissue_type', 'use_gene_pool', 'signatures'],
    var_name='scoring_method',
    value_name='value'
)

In [7]:
all_ds = pd.pivot(all_ds,
                  index=['cell_type', 'tissue_type', 'use_gene_pool', 'signatures', 'scoring_method'],
                  columns='metric',
                  values='value').reset_index()

In [8]:
float_cols = ['balanced_accuracy', 'f1_score', 'gmm_balanced_accuracy',
              'gmm_f1_score', 'gmm_jaccard_score', 'jaccard_score',
              'logreg_balanced_accuracy_10cv_mean',
              'logreg_balanced_accuracy_10cv_std', 'logreg_f1_weighted_10cv_mean',
              'logreg_f1_weighted_10cv_std', 'logreg_jaccard_weighted_10cv_mean',
              'logreg_jaccard_weighted_10cv_std']
all_ds[float_cols] = all_ds[float_cols].astype(float)

In [9]:
style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy',
    'Jasmine_LH',
    'Jasmine_OR',
    'UCell'
]

In [10]:
all_ds['cell_type_pp'] = all_ds.cell_type.map({"breast_malignant": 'BRCA (6 states)',
                                               'luad_kim_malignant': 'LUAD (3 states)',
                                               'luad_kim_malignant_2': 'LUAD (3 states, 3ca)',
                                               'ovarian_malignant': 'HGSOC (8 states)',
                                               'skin_malignant': 'cSCC (4 states, self-pp)',
                                               'skin_malignant_2': 'cSCC (4 states)',
                                               'pbmc_b_mono_nk': 'B, Monocytes, NK cells',
                                               'pbmc_b_subtypes': 'B-cell subtypes',
                                               'pbmc_cd4_subtypes': 'CD4 T-cell subtypes',
                                               'pbmc_cd8_subtypes': 'CD8 T-cell subtypes',
                                               })

#### Create performance overview table

In [96]:
all_ds = all_ds.sort_values(by=['signatures', 'use_gene_pool', 'tissue_type', 'cell_type_pp'])

In [98]:
grouped_all_ds = all_ds.groupby(['signatures', 'use_gene_pool'])

In [None]:
def metric_type(x):
    if 'balanced_accuracy' in x:
        return 'Balanced Accuracy'
    if 'f1_score' in x:
        return 'F1 Score'
    if 'f1_weighted' in x:
        return 'F1 Score'
    return x


def lbl_method(x):
    if 'gmm' in x:
        return 'Rediscovery score'
    elif 'logreg' in x:
        return 'Information quantity'
    else:
        return 'Hard labeling score'

In [None]:
for group, data in grouped_all_ds:
    print(group)
    print(data.head())
    fn = storing_path / f'overview_table_{group[0].replace(" ", "_").lower()}_{"with" if group[1] else "without"}_gene_pool.xlsx'

    data = data.drop(columns=['conf_mat',
                              'logreg_balanced_accuracy_10cv_std', 'logreg_f1_weighted_10cv_std',
                              'jaccard_score', 'gmm_jaccard_score', 'logreg_jaccard_weighted_10cv_mean',
                              'logreg_jaccard_weighted_10cv_std']
                     )

    melted_data = pd.melt(
        data,
        id_vars=['cell_type', 'tissue_type', 'use_gene_pool', 'signatures',
                 'scoring_method', 'cell_type_pp'],
        var_name='metric',
        value_name='value'
    )
    print(melted_data.metric.unique())

    melted_data['metric_type'] = melted_data['metric'].apply(metric_type)
    melted_data['labeling_method'] = melted_data['metric'].apply(lbl_method)
    perf_table = pd.pivot(
        melted_data,
        index=['metric_type', 'scoring_method'],
        columns=['labeling_method', 'tissue_type', 'cell_type_pp'],
        values='value'
    ).T

    if SAVE:
        perf_table.to_excel(storing_path / fn)
        print(f'Stored {storing_path / fn}')

#### Show difference in performance for cansig pp and non-canisg pp

In [None]:
def compare_two_scoring_methods(res_df):
    g = sns.catplot(
        res_df,
        x='scoring_method',
        y='balanced_accuracy',
        hue='cell_type_pp',
        row='signatures',
        col='use_gene_pool',
        kind='bar',
        height=3,
        aspect=1.5
    )
    # Rotate labels before setting titles
    for ax in g.axes.flat:
        plt.sca(ax)
        plt.xticks(rotation=30, ha='right')

    title_template = "{row_name}\n Use gene_pol={col_name}"
    g.set_titles(title_template, fontsize=10)
    g.set_axis_labels('', 'Balanced Accuracy', fontsize=10)

    g.set(ylim=(res_df['balanced_accuracy'].min() - 0.01, res_df['balanced_accuracy'].max() + 0.01))

    g.fig.subplots_adjust(hspace=0.25)
    g._legend.set_title('Cell Type')

    return g.fig


In [None]:
luad_results = all_ds[all_ds.cell_type.str.contains('luad_kim')].copy()
fig = compare_two_scoring_methods(luad_results)
save_close_or_show(fig, SAVE, save_path=storing_path / 'luad_ds_pp_diff.svg')

In [None]:
skin_results = all_ds[all_ds.cell_type.str.contains('skin')].copy()
fig = compare_two_scoring_methods(skin_results)
save_close_or_show(fig, SAVE, save_path=storing_path / 'skin_ds_pp_diff.svg')

### Compute performance plots

In [None]:
all_ds = all_ds[~ all_ds['cell_type_pp'].str.contains(r', .+\)')].copy()

In [None]:
all_ds['Score imbalance'] = all_ds['logreg_balanced_accuracy_10cv_mean'] - all_ds['balanced_accuracy']

In [None]:
style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy',
    'Jasmine_LH',
    'Jasmine_OR',
    'UCell'
]

In [None]:
from matplotlib.colors import LinearSegmentedColormap


def plot_scatter(df, scale_imbalance):
    df = df.copy()
    df['Scoring Method'] = df['scoring_method']
    df['Cell types'] = df['cell_type_pp']
    g = sns.relplot(
        data=df,
        x='balanced_accuracy',
        y='logreg_balanced_accuracy_10cv_mean',
        style='Scoring Method',
        style_order=style_order,
        hue='Cell types',
        # col='tissue_type',
        # row='signatures',
        col='signatures',
        row='tissue_type',
        height=2.5,
        aspect=1.2,
        s=100,
        alpha=0.7,
        facet_kws={'sharey': False, 'sharex': False},
    )

    # title_template = "{col_name}"
    title_template = "{row_name}"
    g.set_titles(title_template, fontsize=10)
    g.set_axis_labels('Balanced Accuracy', 'Information quantity (logreg)', fontsize=10)

    for ax in g.axes.flatten():
        # Define the range for the line
        vmin = min(ax.get_xlim()[0], ax.get_ylim()[0])
        vmax = max(ax.get_xlim()[1], ax.get_ylim()[1])

        x_values = np.linspace(vmin, vmax, 100)
        y_values = x_values  # f(x) = x

        # Plot the line
        ax.plot(x_values, y_values, color='grey', linestyle='--', alpha=0.5)

    # plt.tight_layout()
    g.fig.subplots_adjust(hspace=0.25)

    rectangle = mpath.Path([
        (-1.5, -0.5), (1.5, -0.5), (1.5, 0.5), (-1.5, 0.5), (-1.5, -0.5)
    ], [mpath.Path.MOVETO, mpath.Path.LINETO, mpath.Path.LINETO, mpath.Path.LINETO, mpath.Path.CLOSEPOLY])

    for i in range(1, 9):
        g.legend.get_lines()[i].set_marker(rectangle)
        g.legend.get_lines()[i].set_markersize(17)

    edge_colors = ['black'] + ['lightgrey'] * 7
    for ax in g.axes.flatten():
        nr_colors = ax.collections[0].get_facecolors().shape[0] // len(style_order)
        ax.collections[0].set_edgecolor(edge_colors * nr_colors)

    for ax in g.axes.flat:
        for spine in ax.spines.values():
            spine.set_linewidth(0.85)  # Set axis line width

    for ax in g.axes.flat:
        ax.tick_params(axis='y', labelsize=9, length=2.5, width=0.85)
        ax.tick_params(axis='x', labelsize=9, length=2.5, width=0.85)

    sns.move_legend(g, "upper left", bbox_to_anchor=(0.65, 0.99), frameon=False, fontsize=10, ncols=2)

    table = g.axes.flat[1].table(
        cellText=scale_imbalance.loc[style_order, :].values,
        rowLabels=scale_imbalance.loc[style_order, :].index,
        colLabels=scale_imbalance.loc[style_order, :].columns,
        fontsize=10,
        cellLoc='center',
        bbox=[1.7, 0, 1.1, 1.1],  # Adjust [x, y, width, height] of the table

    )
    for key, cell in table.get_celld().items():
        cell.set_linewidth(0.85)  # Set the line width for all cells

    colors = ['#00CED1', 'white']  # White to turquoise
    custom_cmap = LinearSegmentedColormap.from_list('custom', colors)

    for ax in g.axes.flatten():
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        x = np.linspace(xlim[0], xlim[1], 100)
        y = np.linspace(ylim[0], ylim[1], 100)
        xx, yy = np.meshgrid(x, y)
        gradient = (xx + yy) / np.sqrt(2)
        mask = np.flipud(yy > xx)
        gradient[~mask] = np.nan
        ax.imshow(gradient,
                  extent=[ax.get_xlim()[0], ax.get_xlim()[1],
                          ax.get_ylim()[0], ax.get_ylim()[1]],
                  cmap=custom_cmap, alpha=0.2, aspect='auto',
                  zorder=0)

        tx_st = xlim[0] + (xlim[1] - xlim[0]) * 0.35
        tx_end = ylim[1] - (ylim[1] - ylim[0]) * 0.1
        ax.text(tx_st,
                tx_end,
                'Scale imbalance',
                ha='center',
                color='black',
                fontsize=9)

        # Then add arrow below
        ax.arrow((tx_st + (tx_end - tx_st) * 0.25),  # x start 
                 ylim[1] - (ylim[1] - ylim[0]) * 0.15,  # y start
                 -(xlim[1] - xlim[0]) * 0.3,  # dx (negative for leftward)
                 0,  # dy
                 head_width=(xlim[1] - xlim[0]) * 0.04,
                 head_length=(ylim[1] - ylim[0]) * 0.04,
                 fc='black',
                 ec='black',
                 alpha=0.5)
    return g.fig


def get_scale_imbalance(df):
    scale_imbalance = df.groupby(['tissue_type', 'scoring_method'])['Score imbalance'].describe().reset_index()
    scale_imbalance = scale_imbalance[['tissue_type', 'scoring_method', 'mean', 'std']]
    scale_imbalance['mean_std'] = scale_imbalance['mean'].round(3).astype(str) + ' ± ' + scale_imbalance['std'].round(
        3).astype(str)
    scale_imbalance_pivot = pd.pivot(scale_imbalance, columns='tissue_type', index='scoring_method', values='mean_std')
    return scale_imbalance_pivot

In [None]:
for key, data in all_ds.groupby(['use_gene_pool', 'signatures']):
    print(key)
    scale_imbalance_pivot = get_scale_imbalance(data)
    fig = plot_scatter(data, scale_imbalance_pivot)
    use_gp, sig_pp = key
    suffix = f"{'with' if use_gp else 'without'}_gene_pool_{sig_pp.replace(' ', '_').lower()}"
    save_close_or_show(fig, SAVE, save_path=storing_path / f'perf_overview_{suffix}.svg')