In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import numpy as np
import anndata
from anndata import AnnData
from matplotlib import gridspec
import matplotlib as mpl
from scipy import stats
%matplotlib inline

In [None]:
#functions
def silheatmap(adata,clust,marker_list,sil_key):
    cluster_list = [str(item) for item in adata.uns[f'dendrogram_{clust}']['categories_ordered']]
    #dataframe
    df = adata.to_df()
    df[clust] = adata.obs[clust]
    #sort by sil
    df[sil_key] = adata.obs[sil_key]
    df = df.sort_values(by=sil_key)
    #sort by cluster, markers
    df['old_index'] = df.index
    obs_tidy = df.set_index(clust)
    obs_tidy.index = obs_tidy.index.astype('str')
    obs_tidy = obs_tidy.loc[cluster_list,:]
    df = df.loc[obs_tidy.old_index]
    obs_tidy = obs_tidy.loc[:,marker_list]

    # define a layout of 3 rows x 3 columns
    # The first row is for the dendrogram (if not dendrogram height is zero)
    # second row is for main content. This col is divided into three axes:
    #   first ax is for the heatmap
    #   second ax is for 'brackets' if any (othwerise width is zero)
    #   third ax is for colorbar
    colorbar_width = 0.2
    var_names = marker_list
    width = 10
    dendro_height = 0.8 #if dendrogram else 0
    groupby_height = 0.13 #if categorical else 0
    heatmap_height = len(var_names) * 0.18
    height = heatmap_height + dendro_height + groupby_height + groupby_height
    height_ratios = [dendro_height, heatmap_height, groupby_height,groupby_height]
    width_ratios = [width, 0, colorbar_width, colorbar_width]
    fig = plt.figure(figsize=(width, height),dpi=200)
    axs = gridspec.GridSpec(
        nrows=4,
        ncols=4,
        wspace=1 / width,
        hspace=0.3 / height,
        width_ratios=width_ratios,
        height_ratios=height_ratios,
    )
    norm = mpl.colors.Normalize(vmin=0, vmax=1, clip=False)
    norm2 = mpl.colors.Normalize(vmin=-1, vmax=1, clip=False)

    # plot heatmap
    heatmap_ax = fig.add_subplot(axs[1, 0])
    im = heatmap_ax.imshow(obs_tidy.T.values, aspect='auto',norm=norm,interpolation='nearest') # ,interpolation='nearest'
    heatmap_ax.set_xlim(0 - 0.5, obs_tidy.shape[0] - 0.5)
    heatmap_ax.set_ylim(obs_tidy.shape[1] - 0.5, -0.5)
    heatmap_ax.tick_params(axis='x', bottom=False, labelbottom=False)
    heatmap_ax.set_xlabel('')
    heatmap_ax.grid(False)
    heatmap_ax.tick_params(axis='y', labelsize='small', length=1)
    heatmap_ax.set_yticks(np.arange(len(var_names)))
    heatmap_ax.set_yticklabels(var_names, rotation=0)

    #colors
    value_sum = 0
    ticks = []  # list of centered position of the labels
    labels = []
    label2code = {}  # dictionary of numerical values asigned to each label
    for code, (label, value) in enumerate(
            obs_tidy.index.value_counts().loc[cluster_list].iteritems()
        ):
            ticks.append(value_sum + (value / 2))
            labels.append(label)
            value_sum += value
            label2code[label] = code

    groupby_cmap = mpl.colors.ListedColormap(adata.uns[f'{clust}_colors'])
    groupby_ax = fig.add_subplot(axs[3, 0])
    groupby_ax.imshow(
                np.array([[label2code[lab] for lab in obs_tidy.index]]),
                aspect='auto',
                cmap=groupby_cmap,
            )
    groupby_ax.grid(False)
    groupby_ax.yaxis.set_ticks([])
    groupby_ax.set_xticks(ticks,labels,fontsize='xx-small',rotation=90)
    groupby_ax.set_ylabel('Cluster',fontsize='x-small',rotation=0,ha='right',va='center')


    #sil
    sil_ax = fig.add_subplot(axs[2, 0])
    #max_index = df[sil_key].idxmax()    #df.loc[max_index,sil_key] = 1    #min_index = df[sil_key].idxmin()    #df.loc[min_index,sil_key] = -1 #not needed
    a=np.array([df[sil_key]]) #f'{clust}_silhuette'
    a_tile = np.tile(a,(int(len(df)/80),1))
    sil_ax.imshow(a_tile,cmap='bwr',norm=norm2)
    sil_ax.xaxis.set_ticks([])
    sil_ax.yaxis.set_ticks([])
    sil_ax.set_ylabel('Silhouette',fontsize='x-small',rotation=0,ha='right',va='center')
    sil_ax.grid(False)

    #dendrogram
    dendro_ax = fig.add_subplot(axs[0, 0], sharex=heatmap_ax)
    #_plot_dendrogram(dendro_ax, adata, groupby, dendrogram_key=dendrogram,ticks=ticks, orientation='top', )
    dendro_info = adata.uns[f'dendrogram_{clust}']['dendrogram_info']
    leaves = dendro_info["ivl"]
    icoord = np.array(dendro_info['icoord'])
    dcoord = np.array(dendro_info['dcoord'])
    orig_ticks = np.arange(5, len(leaves) * 10 + 5, 10).astype(float)
    for xs, ys in zip(icoord, dcoord):
        if ticks is not None:
            xs = translate_pos(xs, ticks, orig_ticks)
        dendro_ax.plot(xs, ys, color='#555555')
    dendro_ax.tick_params(bottom=False, top=False, left=False, right=False)
    ticks = ticks if ticks is not None else orig_ticks
    dendro_ax.set_xticks(ticks)
    #dendro_ax.set_xticklabels(leaves, fontsize='small', rotation=90)
    dendro_ax.set_xticklabels([])
    dendro_ax.tick_params(labelleft=False, labelright=False)
    dendro_ax.grid(False)
    dendro_ax.spines['right'].set_visible(False)
    dendro_ax.spines['top'].set_visible(False)
    dendro_ax.spines['left'].set_visible(False)
    dendro_ax.spines['bottom'].set_visible(False)

    # plot colorbar
    cbar_ax = fig.add_subplot(axs[1, 2])
    mappable = mpl.cm.ScalarMappable(norm=norm, cmap='viridis')
    cbar = plt.colorbar(mappable=mappable, cax=cbar_ax)
    cbar_ax.tick_params(axis='both', which='major', labelsize='xx-small',rotation=90,length=.1)
    cbar_ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(locs=[0,1]))
    cbar.set_label('Expression', fontsize='xx-small',labelpad=-5)

    # plot colorbar2
    cbar_ax = fig.add_subplot(axs[1, 3])
    mappable = mpl.cm.ScalarMappable(norm=norm2, cmap='bwr')
    cbar = plt.colorbar(mappable=mappable, cax=cbar_ax)
    cbar_ax.tick_params(axis='both', which='major', labelsize='xx-small',rotation=90,length=.1)
    cbar_ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(locs=[-1,0,1]))
    cbar.set_label('Silhouette Score', fontsize='xx-small',labelpad=0)

    #return dict
    return_ax_dict = {'heatmap_ax': heatmap_ax}
    return_ax_dict['groupby_ax'] = groupby_ax
    return_ax_dict['dendrogram_ax'] = dendro_ax
    return(fig)

def translate_pos(pos_list, new_ticks, old_ticks):
    """
    transforms the dendrogram coordinates to a given new position.
    """
    # of given coordinates.

    if not isinstance(old_ticks, list):
        # assume that the list is a numpy array
        old_ticks = old_ticks.tolist()
    new_xs = []
    for x_val in pos_list:
        if x_val in old_ticks:
            new_x_val = new_ticks[old_ticks.index(x_val)]
        else:
            # find smaller and bigger indices
            idx_next = np.searchsorted(old_ticks, x_val, side="left")
            idx_prev = idx_next - 1
            old_min = old_ticks[idx_prev]
            old_max = old_ticks[idx_next]
            new_min = new_ticks[idx_prev]
            new_max = new_ticks[idx_next]
            new_x_val = ((x_val - old_min) / (old_max - old_min)) * (
                new_max - new_min
            ) + new_min
        new_xs.append(new_x_val)
    return new_xs

## load data

In [None]:
os.listdir('data')

In [None]:
data_df = pd.read_csv('./data/raw_data.csv',index_col=0)
data_df.index = data_df.index.astype('str')

In [None]:
print(len(data_df))
data_df.head()

In [None]:
clust_df = pd.read_csv('./data/clusterings.csv',index_col=0)
clust_df.index = clust_df.index.astype('str')

In [None]:
print(len(clust_df))
clust_df.head()

In [None]:
sil_df = pd.read_csv('./data/silhouette_coefficients.csv',index_col='CellID')
sil_df.drop('Unnamed: 0',axis=1,inplace=True)
sil_df.index = sil_df.index.astype('str')

## data summary 

In [None]:
df_all = clust_df.merge(sil_df,left_index=True,right_index=True)

In [None]:
clustering_list = ['hdbscan', 'kmeans', 'gmm', 'flowsom', 'LeidenCluster','Consensus_Cluster']
df_mean_all = pd.DataFrame()
for clust in clustering_list:
    df_mean = df_all.loc[:,[clust,f'{clust}_silhuette']].groupby(clust).mean()
    df_mean = df_mean.rename({f'{clust}_silhuette':clust.replace('Cluster','')},axis=1)
    df_mean_all = pd.concat([df_mean_all,df_mean],axis=1)
ls_order = df_mean_all.median().sort_values(ascending=False).index
df_mean_all = df_mean_all.loc[:,ls_order]
fig,ax = plt.subplots(figsize=(8,4),dpi=200)
sns.boxplot(data=df_mean_all,showfliers=False,ax=ax,palette='muted')
sns.stripplot(data=df_mean_all,ax=ax,palette='dark')
ax.set_title('Silhouette Values: Mean per Cluster')
ax.set_ylabel('Mean Silhouette')
ax.set_xlabel('Type')
plt.tight_layout()
fig.savefig('./figures/all_boxplot.png')

In [None]:
clustering_list = ['hdbscan', 'kmeans', 'gmm', 'flowsom', 'LeidenCluster','Consensus_Cluster']
df_std_all = pd.DataFrame()
for clust in clustering_list:
    df_mean = df_all.loc[:,[clust,f'{clust}_silhuette']].groupby(clust).std()
    df_mean = df_mean.rename({f'{clust}_silhuette':clust.replace('Cluster','')},axis=1)
    df_std_all = pd.concat([df_std_all,df_mean],axis=1)
df_std_all = df_std_all.loc[:,ls_order]
fig,ax = plt.subplots(figsize=(8,4),dpi=200)
sns.boxplot(data=df_std_all,showfliers=False,ax=ax,palette='muted')
sns.stripplot(data=df_std_all,ax=ax,palette='dark')
ax.set_title('Silhouette Values Std: Mean per Cluster')
ax.set_ylabel('Std Silhouette')
ax.set_xlabel('Type')
plt.tight_layout()
fig.savefig('./figures/all_boxplot_std.png')

In [None]:
'''
a_hmean = stats.hmean(np.dstack((df_mean_all.fillna(-1).values+1,1-df_std_all.fillna(0).values)),axis=2)
df_hmean = pd.DataFrame(a_hmean,index=df_mean_all.index,columns=df_mean_all.columns)
df_hmean = df_hmean.replace({0.000000:np.nan})
df_hmean = df_hmean.loc[:,ls_order]
fig,ax = plt.subplots(figsize=(8,4),dpi=200)
sns.boxplot(data=df_hmean,showfliers=False,ax=ax,palette='muted')
sns.stripplot(data=df_hmean,ax=ax,palette='dark')
ax.set_title('Silhouette Values Mean and Std: Harmonic Mean')
ax.set_ylabel('h-mean')
ax.set_xlabel('Type')
plt.tight_layout()
fig.savefig('./figures/all_boxplot_hmean.png')
'''

## visualize umap

full data set 

In [None]:
sc.settings.set_figure_params(dpi=300, facecolor='white')

In [None]:
adata_big = sc.AnnData(X=data_df.loc[clust_df.index])

In [None]:
adata_big.obsm['X_umap'] = np.array(clust_df.loc[:,['emb1', 'emb2']])

In [None]:
sc.pl.umap(adata_big,color=adata_big.var.index,ncols=5,save='fulldata.png')

In [None]:
#sc.pl.highest_expr_genes(adata)

## visualize subsample

In [None]:
#make anndata
adata = sc.AnnData(data_df.loc[clust_df.index].sample(frac=0.1, replace=False, random_state=123))
adata.raw=adata

In [None]:
#add embedding
adata.obsm['X_umap'] = np.array(clust_df.loc[adata.obs.index,['emb1', 'emb2']])
#sc.pl.umap(adata,color=adata.var.index)

In [None]:
#add clustering and silhouette
clustering_list = ['hdbscan', 'kmeans', 'gmm', 'flowsom', 'LeidenCluster','Consensus_Cluster']
#adata.obs = clust_df.loc[adata.obs.index,cluster_list].astype('category')
adata.obs = pd.concat([clust_df.loc[adata.obs.index,clustering_list].astype('category'),sil_df.loc[adata.obs.index,:]],axis=1)

In [None]:
#umap
sc.pl.umap(adata,color=clustering_list,ncols=3,wspace=.4,save='all_clusters.png')

In [None]:
#matrix plot
for idx, clust in enumerate(clustering_list):
    #order markers with clustermap
    df = adata.to_df()
    df[clust] = adata.obs[clust]
    g = sns.clustermap(df.groupby(clust).mean())
    plt.close(g.fig)
    marker_list = df.groupby(clust).mean().iloc[:,g.dendrogram_col.reordered_ind].columns.tolist()
    #matrixplot
    sc.pl.matrixplot(adata, var_names=marker_list, groupby=clust,dendrogram=True,save=f'{clust}.png')
    break

In [None]:
#violin
for idx, clust in enumerate(clustering_list):
    #order markers with clustermap
    df = adata.to_df()
    df[clust] = adata.obs[clust]
    g = sns.clustermap(df.groupby(clust).mean())
    plt.close(g.fig)
    marker_list = df.groupby(clust).mean().iloc[:,g.dendrogram_col.reordered_ind].columns.tolist()
    #stacked violin
    sc.pl.stacked_violin(adata, var_names=marker_list, groupby=clust,dendrogram=True,save=f'{clust}.png')
    

In [None]:
#heatmap
for idx, clust in enumerate(clustering_list):
    #order markers with clustermap
    df = adata.to_df()
    df[clust] = adata.obs[clust]
    g = sns.clustermap(df.groupby(clust).mean())
    plt.close(g.fig)
    marker_list = df.groupby(clust).mean().iloc[:,g.dendrogram_col.reordered_ind].columns.tolist()
    #sil heatmaps
    sil_key = f'{clust}_silhuette'
    ax = sc.pl.heatmap(adata, var_names=marker_list, groupby=clust,dendrogram=True, save=f'_{clust}.png',cmap='viridis',swap_axes=True,figsize=(12,5))
    fig = silheatmap(adata,clust,marker_list,sil_key)
    fig.savefig(f'./figures/myheatmap_{clust}.png')
    break

## code development

In [None]:
#develop code
from matplotlib import gridspec
import matplotlib as mpl

#dataframe
df = adata.to_df()
df[clust] = adata.obs[clust]
#sort by sil
df[f'{clust}_silhuette'] = adata.obs[f'{clust}_silhuette']
df = df.sort_values(by=f'{clust}_silhuette')
#sort by cluster, markers
df['old_index'] = df.index
obs_tidy = df.set_index(clust)
obs_tidy.index = obs_tidy.index.astype('str')
obs_tidy = obs_tidy.loc[cluster_list.astype(str),:]
df = df.loc[obs_tidy.old_index]
obs_tidy = obs_tidy.loc[:,marker_list]

# define a layout of 3 rows x 3 columns
# The first row is for the dendrogram (if not dendrogram height is zero)
# second row is for main content. This col is divided into three axes:
#   first ax is for the heatmap
#   second ax is for 'brackets' if any (othwerise width is zero)
#   third ax is for colorbar
colorbar_width = 0.2
var_names = marker_list
width = 10
dendro_height = 0.8 #if dendrogram else 0
groupby_height = 0.13 #if categorical else 0
heatmap_height = len(var_names) * 0.18
height = heatmap_height + dendro_height + groupby_height + groupby_height
height_ratios = [dendro_height, heatmap_height, groupby_height,groupby_height]
width_ratios = [width, 0, colorbar_width, colorbar_width]
fig = plt.figure(figsize=(width, height),dpi=200)
axs = gridspec.GridSpec(
    nrows=4,
    ncols=4,
    wspace=1 / width,
    hspace=0.3 / height,
    width_ratios=width_ratios,
    height_ratios=height_ratios,
)
norm = mpl.colors.Normalize(vmin=0, vmax=1, clip=False)
norm2 = mpl.colors.Normalize(vmin=-1, vmax=1, clip=False)

# plot heatmap
heatmap_ax = fig.add_subplot(axs[1, 0])
im = heatmap_ax.imshow(obs_tidy.T.values, aspect='auto',norm=norm,interpolation='nearest') # ,interpolation='nearest'
heatmap_ax.set_xlim(0 - 0.5, obs_tidy.shape[0] - 0.5)
heatmap_ax.set_ylim(obs_tidy.shape[1] - 0.5, -0.5)
heatmap_ax.tick_params(axis='x', bottom=False, labelbottom=False)
heatmap_ax.set_xlabel('')
heatmap_ax.grid(False)
heatmap_ax.tick_params(axis='y', labelsize='small', length=1)
heatmap_ax.set_yticks(np.arange(len(var_names)))
heatmap_ax.set_yticklabels(var_names, rotation=0)

#colors
value_sum = 0
ticks = []  # list of centered position of the labels
labels = []
label2code = {}  # dictionary of numerical values asigned to each label
for code, (label, value) in enumerate(
        obs_tidy.index.value_counts().loc[cluster_list.astype('str')].iteritems()
    ):
        ticks.append(value_sum + (value / 2))
        labels.append(label)
        value_sum += value
        label2code[label] = code

groupby_cmap = mpl.colors.ListedColormap(adata.uns[f'{clust}_colors'])
groupby_ax = fig.add_subplot(axs[3, 0])
groupby_ax.imshow(
            np.array([[label2code[lab] for lab in obs_tidy.index]]),
            aspect='auto',
            cmap=groupby_cmap,
        )
groupby_ax.grid(False)
groupby_ax.yaxis.set_ticks([])
groupby_ax.set_xticks(ticks,labels,fontsize='xx-small',rotation=90)
groupby_ax.set_ylabel('Cluster',fontsize='x-small',rotation=0,ha='right',va='center')


#sil
sil_ax = fig.add_subplot(axs[2, 0])
a=np.array([df[f'{clust}_silhuette']])
a_tile = np.tile(a,(int(len(df)/80),1))
sil_ax.imshow(a_tile,cmap='bwr',norm=norm2)
sil_ax.xaxis.set_ticks([])
sil_ax.yaxis.set_ticks([])
sil_ax.set_ylabel('Silhouette',fontsize='x-small',rotation=0,ha='right',va='center')
sil_ax.grid(False)

#dendrogram
dendro_ax = fig.add_subplot(axs[0, 0], sharex=heatmap_ax)
#_plot_dendrogram(dendro_ax, adata, groupby, dendrogram_key=dendrogram,ticks=ticks, orientation='top', )
dendro_info = adata.uns[f'dendrogram_{clust}']['dendrogram_info']
leaves = dendro_info["ivl"]
icoord = np.array(dendro_info['icoord'])  
dcoord = np.array(dendro_info['dcoord'])
orig_ticks = np.arange(5, len(leaves) * 10 + 5, 10).astype(float)
for xs, ys in zip(icoord, dcoord):
    if ticks is not None:
        xs = translate_pos(xs, ticks, orig_ticks)
    dendro_ax.plot(xs, ys, color='#555555')
dendro_ax.tick_params(bottom=False, top=False, left=False, right=False)
ticks = ticks if ticks is not None else orig_ticks
dendro_ax.set_xticks(ticks)
#dendro_ax.set_xticklabels(leaves, fontsize='small', rotation=90)
dendro_ax.set_xticklabels([])
dendro_ax.tick_params(labelleft=False, labelright=False)
dendro_ax.grid(False)
dendro_ax.spines['right'].set_visible(False)
dendro_ax.spines['top'].set_visible(False)
dendro_ax.spines['left'].set_visible(False)
dendro_ax.spines['bottom'].set_visible(False)

# plot colorbar
cbar_ax = fig.add_subplot(axs[1, 2])
mappable = mpl.cm.ScalarMappable(norm=norm, cmap='viridis')
cbar = plt.colorbar(mappable=mappable, cax=cbar_ax)
cbar_ax.tick_params(axis='both', which='major', labelsize='xx-small',rotation=90,length=.1)
cbar_ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(locs=[0,1]))
cbar.set_label('Expression', fontsize='xx-small',labelpad=-5)


# plot colorbar2
cbar_ax = fig.add_subplot(axs[1, 3])
mappable = mpl.cm.ScalarMappable(norm=norm2, cmap='bwr')
cbar = plt.colorbar(mappable=mappable, cax=cbar_ax)
cbar_ax.tick_params(axis='both', which='major', labelsize='xx-small',rotation=90,length=.1)
cbar_ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(locs=[-1,0,1]))
cbar.set_label('Silhouette Score', fontsize='xx-small',labelpad=0)


#return dict
return_ax_dict = {'heatmap_ax': heatmap_ax}
return_ax_dict['groupby_ax'] = groupby_ax
return_ax_dict['dendrogram_ax'] = dendro_ax