# Robustness tests

For this notebook **you need to run the 4M and 4H notbeooks previously!!**.

In this notebook we are going to analyze with different tests the robustness and replicability of the results obtained by the algorithm of cluster annotation.

## imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import scanpy as sc
import scanpy.external as sce
import pandas as pd
import numpy as np
import os
import triku as tk
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm.notebook import tqdm
import scipy.sparse as spr
import matplotlib.cm as cm
import matplotlib.pylab as pylab
import networkx as nx
import functools
import math

In [None]:
!pip install munkres
from munkres import Munkres

In [None]:
# local imports and imports from other notebooks
from cellassign import assign_cats
from fb_functions import make_gene_scoring_with_expr, plot_score_graph, plot_UMAPS_gene, plot_adata_cluster_properties
%store -r dict_colors_human
%store -r dict_colors_mouse

dict_colors_human_mouse = {**dict_colors_human , **dict_colors_mouse}

%store -r seed
%store -r magma
%store -r data_dir

In [None]:
%store -r plot_params

pylab.rcParams.update(plot_params)
pd.set_option('display.max_columns', None)
pd.options.display.float_format = "{:,.2f}".format

In [None]:
# we are going to use dicts from 3H/3M and not 4H/4M because we want to replicate the results 
# from that notebook (in 4H/4M the clusters are not assigned again, only the genes are shown)

%store -r dict_cats_clusters_robust_3H 
%store -r dict_cats_clusters_robust_3M

In [None]:
%store -r list_all_datasets_human
%store -r list_all_datasets_mouse

%store -r list_names_human
%store -r list_names_mouse

In [None]:
def change_texts_format(ax):
    for text in ax.texts:
            if text.get_text() == '1.00':
                text.set_text('1')
                
            elif '0.' in text.get_text():
                text.set_text(text.get_text().replace('0.', '.'))
                
def round_df(df, N=2):
    return df.round(N).div(df.round(N).sum(axis=1), axis=0).round(N)

# Stratified bootstrap

In [None]:
dir_save_strat = 'results/bootstrap_cell_types'
os.makedirs(dir_save_strat, exist_ok=True)

In [None]:
def stratified_subsampling(adata, obs_clusters, frac, seed):
    idx_sub = adata.obs.groupby(obs_clusters, group_keys=False).apply(lambda x: x.sample(frac=frac, random_state=seed)).index.values
    adata_sub = adata[idx_sub].copy()
    return adata_sub

def bootstrap_cluster_assign(adata, obs_clusters, frac, N_iters, dict_cats_cluster):
    df_bootstrap = pd.DataFrame(columns = [-1] + list(range(N_iters)), index=adata.obs_names)
    df_bootstrap.loc[adata.obs_names, -1] = adata.obs['cluster_robust']

    for it in range(N_iters):
        try:
            print(adata.obs['Author'].values[0], int(adata.obs['Year'].values[0]), it)
            adata_sub = stratified_subsampling(adata=adata.copy(), obs_clusters=obs_clusters, frac=frac, seed=it)

    #         We preprocess the adata again
            sc.pp.filter_genes(adata_sub, min_counts=1)
            sc.pp.pca(adata_sub, random_state=seed, n_comps=50)
            if 'X_pca_harmony' in adata_sub.obsm:
                use_rep = 'X_pca_harmony_sub'
                sce.pp.harmony_integrate(adata_sub, key='Internal sample identifier', 
                                         max_iter_harmony=50, adjusted_basis=use_rep)

            else:
                use_rep = 'X_pca'

            n_neighbors = int(adata_sub.uns['neighbors']['params']['n_neighbors'] * frac)

            sc.pp.neighbors(adata_sub, use_rep=use_rep,  n_neighbors=n_neighbors, metric='cosine')
            tk.tl.triku(adata_sub, use_raw=False)

            sc.pp.pca(adata_sub, random_state=seed, n_comps=50)
            if 'X_pca_harmony' in adata_sub.obsm:
                sce.pp.harmony_integrate(adata_sub, key='Internal sample identifier', 
                                         max_iter_harmony=50, adjusted_basis=use_rep)

            sc.pp.neighbors(adata_sub, use_rep=use_rep,  n_neighbors=n_neighbors, metric='cosine')

            sc.tl.leiden(adata_sub, resolution=frac * adata.uns['leiden']['params']['resolution'], 
                        key_added='leiden_sub')  # I know that this is not linear, but it is an approach


            min_score = adata_sub.uns['cell_assign'][obs_clusters]['min_score']
            quantile_gene_sel = adata_sub.uns['cell_assign'][obs_clusters]['quantile_gene_sel']
            assign_cats(adata_sub, column_groupby='leiden_sub', dict_cats=dict_cats_cluster, min_score=min_score, 
                        quantile_gene_sel=quantile_gene_sel, 
                        key_added=obs_clusters + '_sub', others_name='U', verbose=False)
            df_bootstrap.loc[adata_sub.obs_names, it] = adata_sub.obs[obs_clusters + '_sub']
        except:
            continue
    return df_bootstrap

In [None]:
all_frac_vals = [0.5, 0.7, 0.8, 0.95, 0.99]
frac = 0.8
N_iters = 30

for frac in all_frac_vals:
    for adata in list_all_datasets_human:
        org = 'human'
        name = f"{adata.obs['Author'].values[0]}_{int(adata.obs['Year'].values[0])}_{org}_frac-{frac}_Niters-{N_iters}.csv"
        if name not in os.listdir(dir_save_strat):
            df_bootstrap = bootstrap_cluster_assign(adata, obs_clusters='cluster_robust', 
                                                        frac=frac, N_iters=N_iters, 
                                                        dict_cats_cluster=dict_cats_clusters_robust_3H)

            df_bootstrap = df_bootstrap.sort_values(by=-1)
            df_bootstrap.to_csv(f"{dir_save_strat}/{name}", sep='\t')



    for adata in list_all_datasets_mouse:
        org = 'mouse'
        name = f"{adata.obs['Author'].values[0]}_{int(adata.obs['Year'].values[0])}_{org}_frac-{frac}_Niters-{N_iters}.csv"
        if name not in os.listdir(dir_save_strat):
            df_bootstrap = bootstrap_cluster_assign(adata, obs_clusters='cluster_robust', 
                                                        frac=frac, N_iters=N_iters, 
                                                        dict_cats_cluster=dict_cats_clusters_robust_3M)

            df_bootstrap = df_bootstrap.sort_values(by=-1)
            df_bootstrap.to_csv(f"{dir_save_strat}/{name}", sep='\t')


## Load dict_bootstraps (with all options)

In [None]:
def read_dict_bootstrap(dir_save='results/bootstrap_cell_types', not_include=[]):
    list_files = sorted([i for i in os.listdir(dir_save) if i[-3:] == 'csv'])
    
    dict_bootstraps = {}

    for file in list_files:
        include = True
        for not_in in not_include:
            if not_in in file:
                include = False
                
        if include:        
            df = pd.read_csv(f"{dir_save}/{file}", sep='\t', index_col=0)
            repls = ('.csv', ''), ('frac-', ''), ('Niters-', '')
            file_red = functools.reduce(lambda a, kv: a.replace(*kv), repls, file)
            author, year, org, frac, N_iters = file_red.split('_')

            dict_bootstraps[f"{author}_{year}_{org}_{frac}_{N_iters}"] = df
    
    return dict_bootstraps

In [None]:
dict_bootstraps = read_dict_bootstrap()

## Calculating and plotting the scores for a cell being assigned the same cluster 

In [None]:
def assign_cluster_score(df_bootstrap, N_iters):
    """
    Creates a column with the proportion of iterations that have the same label assigned
    as the original labelling. Since there are NaN values, the denominator is the
    number of iterations that have non NaN values.
    """

    cols = [str(i) for i in range(N_iters)]
    
    
    is_equal_to_val = np.zeros((len(df_bootstrap), N_iters), np.bool)
    is_na = np.zeros((len(df_bootstrap), N_iters), np.bool)
    
    for it in range(N_iters):
        is_equal_to_val[:, it] = df_bootstrap.loc[:, str(it)].values == df_bootstrap.loc[:, '-1'].values
        is_na[:, it] = pd.isna(df_bootstrap.loc[:, str(it)].values)
        
    df_bootstrap[f'same_cluster_prop'] = is_equal_to_val.sum(1)/(N_iters - is_na.sum(1))

In [None]:
frac_vals = ['0.5', '0.7', '0.8', '0.95', '0.99']
N_iters = 30

for df_bootstrap in dict_bootstraps.values():
    assign_cluster_score(df_bootstrap, N_iters)

In [None]:
# Create a dataframe, per dataset, that contains all the different fraction scores.
# This will allow us to compare values and to see tendencies

dict_df_score_fracs = {}

for idx, key in enumerate(list_names_human + list_names_mouse):
    org = 'human' if idx < len(list_names_human) else 'mouse'
    
    
    name, year = ' '.join(key.split(' ')[:-1]), key.split(' ')[-1]
    df_scores = pd.DataFrame(index=(list_all_datasets_human + list_all_datasets_mouse)[idx].obs_names, 
                             columns=['cluster'] + frac_vals)
    
    for frac in frac_vals:
        df_vals = dict_bootstraps[f"{name}_{year}_{org}_{frac}_{N_iters}"]
        
        
        if frac == frac_vals[0]:
            df_scores['cluster'] = df_vals.loc[df_scores.index, "-1"].values
        
        df_scores[frac] = df_vals.loc[df_scores.index, "same_cluster_prop"].values
    
    dict_df_score_fracs[f"{name}_{year}_{org}_{N_iters}"] = df_scores      
        

In [None]:
for adata, df_score in zip(list_all_datasets_human+list_all_datasets_mouse, dict_df_score_fracs.values()):
    for frac in frac_vals:
        adata.obs[f"bootstrap_{frac}"] = df_score[frac]

### Plotting the score (UMAP, humans)

In [None]:
for frac in ['cluster_robust'] + ['bootstrap_' + i for i in frac_vals]:
    print('FRAC', frac)
    plot_UMAPS_gene(frac, list_datasets=list_all_datasets_human, list_names=list_names_human, n_cols=5, cmap='magma')
    plt.show()

### Plotting the score (UMAP, mouse)

In [None]:
for frac in ['cluster_robust'] + ['bootstrap_' + i for i in frac_vals]:
    print('FRAC', frac)
    plot_UMAPS_gene(frac, list_datasets=list_all_datasets_mouse, list_names=list_names_mouse, n_cols=3)
    plt.show()

## Distribution plots of score

In [None]:
dict_df_score_fracs_human = {key: val for (key, val) in dict_df_score_fracs.items() if 'human' in key}
dict_df_score_fracs_mouse = {key: val for (key, val) in dict_df_score_fracs.items() if 'mouse' in key}

In [None]:
def plot_distributions_bootstrap(dict_df_score_fracs, frac=0.99, n_cols=5, n_rows=None, plot_type='violin'):
    if n_rows is None:
        n_rows = int(math.ceil(len(dict_df_score_fracs) / n_cols))
        
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 2.85))


    for idx, (key, val) in enumerate(dict_df_score_fracs.items()):
        val = val.sort_values(by='cluster')
        val = val[~ val['cluster'].isin(['T1', 'U'])]
        if plot_type == 'violin':
            v = sns.violinplot(data=val, x='cluster', y=str(frac), scale='width', cut=0, inner=None,
                           ax = axs.ravel()[idx], linewidth=1,
                   palette=[dict_colors_human_mouse[h] for  h in sorted(set(val['cluster'].values))]
                          ).set(xlabel=None, ylabel=None, yticks=[], title=' '.join(key.split('_')[:2]))
            vb = sns.boxplot(data=val,  x='cluster', y=str(frac), width=0.2, fliersize=0, ax = axs.ravel()[idx], 
            linewidth=0, showcaps=False, boxprops={'zorder': 2}, color="#232323", 
            whiskerprops=dict(color="#232323", linewidth=1.8),
            medianprops=dict(color="#ffffff", linewidth=6, linestyle=':'))
        elif plot_type == 'boxplot':
            b = sns.boxplot(data=val, x='cluster', y=str(frac), fliersize=0,
            ax = axs.ravel()[idx], linewidth=1,
                   palette=[dict_colors_human_mouse[h] for  h in sorted(set(val['cluster'].values))]
                          ).set(xlabel=None, ylabel=None, yticks=[], title=' '.join(key.split('_')[:2]))
            
        _ = axs.ravel()[idx].set_xticklabels(sorted(set(val['cluster'].values)), size='large')
        _ = axs.ravel()[idx].set_yticks([0, 1])
        _ = axs.ravel()[idx].set_yticklabels([0, 1], size='medium')
        
        sns.despine(left=True, bottom=True)
        
    for idx in range(len(dict_df_score_fracs), n_cols*n_rows):
        axs.ravel()[idx].set_axis_off()
        
    plt.tight_layout()

In [None]:
for plot_type in ['violin', 'boxplot']:
    plot_distributions_bootstrap(dict_df_score_fracs_human, frac=0.99, n_cols=4, n_rows=None, plot_type=plot_type)

In [None]:
for plot_type in ['violin', 'boxplot']:
    plot_distributions_bootstrap(dict_df_score_fracs_mouse, frac=0.99, n_cols=3, n_rows=None, plot_type=plot_type)

In [None]:
def plot_distributions_bootstrap_dataset(dict_df_score_fracs, frac=0.99, n_cols=5, n_rows=None, plot_type='violin'):
    list_dfs = []
    for key, val in dict_df_score_fracs.items():
        name = ' '.join(key.split('_')[:2])
        val['name'] = name
        list_dfs.append(val[['name', 'cluster', f"{frac}"]])
    
    df_joint = pd.concat(list_dfs, axis=0, ignore_index=True)


    list_clusters = sorted(set(df_joint['cluster'].values))
    list_clusters = [i for i in list_clusters if i not in ['T1', 'U']]
    
    if n_rows is None:
        n_rows = int(math.ceil(len(list_clusters) / n_cols))
        
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 2.85))


    for idx, cluster in enumerate(list_clusters):
        df_joint_cluster = df_joint[df_joint['cluster'] == cluster].sort_values(by='name') 
        
        if plot_type == 'violin':
            v = sns.violinplot(data=df_joint_cluster, x='name', y=str(frac), scale='width', cut=0, inner='box',
                           ax = axs.ravel()[idx], linewidth=1,
                   color=dict_colors_human_mouse[cluster],
                          ).set(xlabel=None, ylabel=None, yticks=[], title=cluster)
            
        elif plot_type == 'boxplot':
            b = sns.boxplot(data=df_joint_cluster, x='name', y=str(frac), fliersize=0,
            ax = axs.ravel()[idx], linewidth=1,
                   color=dict_colors_human_mouse[cluster],
                          ).set(xlabel=None, ylabel=None, yticks=[], title=cluster)
            

        _ = axs.ravel()[idx].set_xticklabels(sorted(set(df_joint_cluster['name'].values)), 
                                                 size='small', rotation=45, ha='right')
        _ = axs.ravel()[idx].set_yticks([0, 1])
        _ = axs.ravel()[idx].set_yticklabels([0, 1], size='medium')
        
        sns.despine(left=True, bottom=True)
        
    for idx in range(len(list_clusters), n_cols*n_rows):
        axs.ravel()[idx].set_axis_off()
        
    plt.tight_layout()

In [None]:
plot_distributions_bootstrap_dataset(dict_df_score_fracs_human, frac=0.99, n_cols=3, n_rows=None, 
                                     plot_type='boxplot')

In [None]:
plot_distributions_bootstrap_dataset(dict_df_score_fracs_mouse, frac=0.99, n_cols=4, n_rows=None, 
                                     plot_type='boxplot')

In [None]:
xxxx = plot_distributions_bootstrap_general(dict_df_score_fracs_human, frac=0.99, n_cols=2, n_rows=None, plot_type=plot_type)

In [None]:
xxxx.groupby(['name', 'cluster']).mean().reset_index()

In [None]:
def plot_distributions_bootstrap_general(dict_df_score_fracs, frac=0.99, n_cols=5, n_rows=None, plot_type='violin'):
    list_dfs = []
    for key, val in dict_df_score_fracs.items():
        name = ' '.join(key.split('_')[:2])
        val['name'] = name
        list_dfs.append(val[['name', 'cluster', f"{frac}"]])
    
    df_joint = pd.concat(list_dfs, axis=0, ignore_index=True)
    df_joint = df_joint[~ df_joint['cluster'].isin(['T1', 'U'])]
    df_joint = df_joint.groupby(['name', 'cluster']).median().reset_index()
    
#     return df_joint
    
    fig, axs = plt.subplots(1, 2, figsize=(2 * 4, 1 * 2.85))
        
    if plot_type == 'violin':
        v1 = sns.violinplot(data=df_joint, x='name', y=str(frac), scale='width', cut=0, inner='box',
            ax = axs[0], linewidth=1,
               color="#bcbcbc").set(xlabel=None, ylabel=None, yticks=[], title='Dataset')
        v2 = sns.violinplot(data=df_joint, x='cluster', y=str(frac), scale='width', cut=0, inner='box',
            ax = axs[1], linewidth=1,
               palette=[dict_colors_human_mouse[h] for  h in sorted(set(df_joint['cluster'].values))],
                      ).set(xlabel=None, ylabel=None, yticks=[], title='Cluster')
        
        
    elif plot_type == 'boxplot':
        b1 = sns.boxplot(data=df_joint, x='name', y=str(frac), fliersize=0,
            ax = axs[0], linewidth=1,
                   color="#bcbcbc",
                          ).set(xlabel=None, ylabel=None, yticks=[], title='Dataset')
        
        b2 = sns.boxplot(data=df_joint, x='cluster', y=str(frac), fliersize=0,
            ax = axs[1], linewidth=1,
                   palette=[dict_colors_human_mouse[h] for  h in sorted(set(df_joint['cluster'].values))],
                          ).set(xlabel=None, ylabel=None, yticks=[], title='Cluster')

    
    _ = axs[0].set_xticklabels(sorted(set(df_joint['name'].values)), 
                                                 size='small', rotation=45, ha='right')
    _ = axs[1].set_xticklabels(sorted(set(df_joint['cluster'].values)), 
                                                 size='medium')
    
    for idx in [0, 1]:
        _ = axs[idx].set_yticks([0, 1])
        _ = axs[idx].set_yticklabels([0, 1], size='medium')
        
    sns.despine(left=True, bottom=True)
                
    plt.tight_layout()

In [None]:
plot_distributions_bootstrap_general(dict_df_score_fracs_human, frac=0.99, n_cols=2, n_rows=None, plot_type=plot_type)

In [None]:
plot_distributions_bootstrap_general(dict_df_score_fracs_mouse, frac=0.99, n_cols=2, n_rows=None, plot_type=plot_type)

## Calculating the probability of a cell type being assigned to the same, or other cell types

In [None]:
def assign_cluster_score(df_bootstrap, N_iters):
    df_flat = df_bootstrap.loc[:, [str(i) for i in range(N_iters)]].values.ravel()
    cell_types_out = sorted(set(df_flat[~ pd.isnull(df_flat)]))    
    cell_types_in = sorted(set(df_bootstrap.loc[:, '-1'].values.ravel()))
    
    df_adjacency = pd.DataFrame(0, index=cell_types_in, columns=cell_types_out)
    
    for cell_type_in in cell_types_in:
        list_vals = []
        df_sub = df_bootstrap[df_bootstrap['-1'] == cell_type_in].loc[:, [str(i) for i in range(N_iters)]]
        vals_Na = df_sub.values.ravel()
        vals = vals_Na[~ pd.isnull(vals_Na)]
        
        for cell_type_out in cell_types_out:
            list_vals.append(np.sum(vals == cell_type_out) / len(vals))
        
        df_adjacency.loc[cell_type_in, :] = list_vals
    
    return df_adjacency

In [None]:
dict_adjacency_dfs = {}
for key, df_bootstrap in dict_bootstraps.items():
    adjacency_df = assign_cluster_score(df_bootstrap, N_iters)
    dict_adjacency_dfs[key] = adjacency_df

In [None]:
def plot_distributions_bootstrap(dict_adjacency_dfs, n_cols=5, n_rows=None):
    if n_rows is None:
        n_rows = int(math.ceil(len(dict_adjacency_dfs) / n_cols))
        
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))


    for idx, (key, val) in enumerate(dict_adjacency_dfs.items()):
        val_nan = val.copy()
        val_nan = val_nan.loc[[i for i in val_nan.index if i not in ['T1', 'U']], [i for i in val_nan.columns if i not in ['T1', 'U']]]
        val_nan[val_nan < 0.01] = np.NaN
        sns.heatmap(round_df(val_nan), annot=True, cmap='Blues', annot_kws={"fontsize": 'x-small'}, fmt='.2f', yticklabels=True, 
                    xticklabels=True, cbar=False, ax=axs.ravel()[idx], square=True)
        [t.set_color(dict_colors_human_mouse[t.get_text()]) for t in axs.ravel()[idx].xaxis.get_ticklabels()] 
        [t.set_color(dict_colors_human_mouse[t.get_text()]) for t in axs.ravel()[idx].yaxis.get_ticklabels()]
        axs.ravel()[idx].set_xticklabels(axs.ravel()[idx].get_xticklabels(), size='medium', weight='bold',)
        axs.ravel()[idx].set_yticklabels(axs.ravel()[idx].get_yticklabels(), va='center', size='medium', weight='bold')
        axs.ravel()[idx].set_title(' '.join(key.split('_')[:2]))
        
        change_texts_format(axs.ravel()[idx])
        
    for idx in range(len(dict_df_score_fracs), n_cols*n_rows):
        axs.ravel()[idx].set_axis_off()
        

        
    plt.tight_layout()

In [None]:
dict_adjacency_dfs_human_099 = {key: dict_adjacency_dfs[key] for key in sorted([key for key in dict_adjacency_dfs.keys() if ('human' in key) & ('0.99' in key)])}
plot_distributions_bootstrap(dict_adjacency_dfs_human_099)

In [None]:
dict_adjacency_dfs_mouse_099 = {key: dict_adjacency_dfs[key] for key in sorted([key for key in dict_adjacency_dfs.keys() if ('mouse' in key) & ('0.99' in key)])}
plot_distributions_bootstrap(dict_adjacency_dfs_mouse_099, n_cols=3, )

In [None]:
# Unified plot
def make_unified_adjacency(list_adjacency_dfs):
    """
    This unified adjacency matrix is calculated by obtaining the median values of all dfs. 
    However, since the individual dataframes do not have all clusters, we need to apply a nanmean of all the values.
    So, we are going to go pair by pair, get the values of all dfs, and add NaNs to the ones that do not exist, and
    calculate the nanmean value.
    Then, we are going to normalize the probabilities rowwise to sum 1.
    """
    
    list_clusters = []
    for df in list_adjacency_dfs:
        list_clusters += df.index.tolist()
    
    list_clusters = sorted(set([i for i in list_clusters if i not in ['U', 'T1']]))
    
    df_unified = pd.DataFrame(index=list_clusters, columns=list_clusters)
    
    for c_i in list_clusters:
        for c_o in list_clusters:
            list_vals = []
            for df in list_adjacency_dfs:
                if (c_i in df.index) & (c_o in df.columns):
                    list_vals.append(df.loc[c_i, c_o])
                
            df_unified.loc[c_i, c_o] = np.median(list_vals)
    
    df_unified = df_unified.astype(float)
    df_unified = df_unified.div(df_unified.sum(axis=1), axis=0)
   
    return df_unified

In [None]:
df_unified_human = make_unified_adjacency([val for (key, val) in dict_adjacency_dfs.items() if ('0.99' in key) & ('human' in key)])
df_unified_human[df_unified_human < 0.01] = np.NaN

fig, ax = plt.subplots(1, 1)
sns.heatmap(round_df(df_unified_human), annot=True, cmap='Blues', annot_kws={"fontsize": 'x-small'}, yticklabels=True, cbar=False)
[t.set_color(dict_colors_human[t.get_text()]) for t in ax.xaxis.get_ticklabels()]; [t.set_color(dict_colors_human[t.get_text()]) for t in ax.yaxis.get_ticklabels()]
ax.set_xticklabels(ax.get_xticklabels(),  weight='bold'); ax.set_yticklabels(ax.get_yticklabels(),  weight='bold', va='center')
change_texts_format(ax)
None

In [None]:
df_unified_mouse = make_unified_adjacency([val for (key, val) in dict_adjacency_dfs.items() if ('0.99' in key) & ('mouse' in key)])
df_unified_mouse[df_unified_mouse < 0.01] = np.NaN

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
sns.heatmap(round_df(df_unified_mouse), annot=True, cmap='Blues', annot_kws={"fontsize": 'x-small'}, yticklabels=True, cbar=False)
[t.set_color(dict_colors_mouse[t.get_text()]) for t in ax.xaxis.get_ticklabels()]; [t.set_color(dict_colors_mouse[t.get_text()]) for t in ax.yaxis.get_ticklabels()]
ax.set_xticklabels(ax.get_xticklabels(),  weight='bold'); ax.set_yticklabels(ax.get_yticklabels(),  weight='bold', va='center')
None

## Calculate the probability of a cell type having a higher cluster score than a certain value

In [None]:
def calculate_proportion_cell_higher_prop(df_bootstrap, prop):   
    a, b = np.unique(df_bootstrap['-1'], return_counts=True)
    dict_all = dict(zip(a.tolist(), b.tolist()))

    a, b = np.unique(df_bootstrap[df_bootstrap['same_cluster_prop'] > prop]['-1'], return_counts=True)
    dict_short = dict(zip(a.tolist(), b.tolist()))

    dict_prop = dict((key, dict_short.get(key, 0)/val) for key, val in dict_all.items())
    
    return dict_prop

# def calculate_proportion_cell_higher_prop(df_bootstrap, prop):   
#     a, b = np.unique(df_bootstrap['-1'], return_counts=True)
    
    
#     dict_prop = {i: df_bootstrap[df_bootstrap['-1'] == i]['same_cluster_prop'].mean() for i in a}
    
#     return dict_prop

def calculate_df_proportion_cell_higher_prop(dict_bootstraps, prop):
    list_clusters = []
    for df in dict_bootstraps.values():
        list_clusters += df['-1'].values.tolist()
    
    list_clusters = sorted(set([i for i in list_clusters]))
    
    df_return = pd.DataFrame(columns=list_clusters, index=[i for i in dict_bootstraps.keys()])
    
    
    for name, df_bootstrap in dict_bootstraps.items():
        dict_prop = calculate_proportion_cell_higher_prop(df_bootstrap, prop)        
        df_return.loc[name, list(dict_prop.keys())] = list(dict_prop.values())
    
    return df_return

In [None]:
dict_human_099 = {key: dict_bootstraps[key] for key in sorted([key for key in dict_bootstraps.keys() if ('human' in key) & ('0.99' in key)])}
df_prop = calculate_df_proportion_cell_higher_prop(dict_human_099, 0.7)
df_prop = df_prop[[col for col in df_prop.columns if col not in ['T1', 'U']]]
df_prop = df_prop.astype(float)
df_prop = df_prop.loc[df_prop.mean(1).sort_values().index[::-1]]

df_prop['mean (dataset)'] = df_prop.mean(1)
df_prop.loc['mean (cluster)'] = df_prop.mean(0)
df_prop.iloc[-1, -1] = np.NaN

fig, ax = plt.subplots(1, 1, figsize=(6, 5))
g = sns.heatmap(df_prop, annot=True, cmap='Blues', yticklabels=True, annot_kws={"fontsize": 'xx-small'}, 
                fmt='.2f', cbar=False)
g.set_yticklabels([' '.join(i.get_text().split('_')[:2])  for i in g.get_ymajorticklabels()], fontsize = 'x-small')
g.set_xticklabels(g.get_xmajorticklabels(), fontsize = 'large')
[t.set_color(dict_colors_human[t.get_text()]) if t.get_text() in dict_colors_human else t.set_color("#000000") for t in ax.xaxis.get_ticklabels() ]
[t.set_color(dict_colors_human[t.get_text()]) if t.get_text() in dict_colors_human else t.set_color("#000000") for t in ax.yaxis.get_ticklabels()]
ax.set_xticklabels(ax.get_xticklabels(),  weight='bold'); 
change_texts_format(ax)
None

In [None]:
dict_mouse_099 = {key: dict_bootstraps[key] for key in sorted([key for key in dict_bootstraps.keys() if ('mouse' in key) & ('0.99' in key)])}
df_prop = calculate_df_proportion_cell_higher_prop(dict_mouse_099, 0.7)
df_prop = df_prop[[col for col in df_prop.columns if col not in ['T1', 'U']]]
df_prop = df_prop.astype(float)
df_prop = df_prop.loc[df_prop.mean(1).sort_values().index[::-1]]

df_prop['mean (dataset)'] = df_prop.mean(1)
df_prop.loc['mean (cluster)'] = df_prop.mean(0)
df_prop.iloc[-1, -1] = np.NaN

fig, ax = plt.subplots(1, 1, figsize=(6, 2))
g = sns.heatmap(df_prop, annot=True, cmap='Blues', yticklabels=True, annot_kws={"fontsize": 'xx-small'}, 
                fmt='.2f', cbar=False)
g.set_yticklabels([' '.join(i.get_text().split('_')[:2])  for i in g.get_ymajorticklabels()], fontsize = 'x-small')
g.set_xticklabels(g.get_xmajorticklabels(), fontsize = 'large')
[t.set_color(dict_colors_mouse[t.get_text()]) if t.get_text() in dict_colors_mouse else t.set_color("#000000") for t in ax.xaxis.get_ticklabels() ]
[t.set_color(dict_colors_mouse[t.get_text()]) if t.get_text() in dict_colors_mouse else t.set_color("#000000") for t in ax.yaxis.get_ticklabels()]
ax.set_xticklabels(ax.get_xticklabels(),  weight='bold'); 
change_texts_format(ax)
None