In [3]:
import os
import pandas as pd
import numpy as np
import scipy
from scipy.spatial import distance
from scipy.cluster import hierarchy
# from emtdecode.utility import print_df, log2_transform, center_value

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

In [4]:
def center_value(df):
    """
    exp - mean(exp) for each genes
    :param df: expression dataframe, gene x sample
    :return:
    """
    df_mean = df.mean(axis=1)
    return df - np.vstack(df_mean)

In [5]:
def plot_hcluster(exp_df, sample2subtype, method='sample_corr', color_threshold=4.0, result_dir='', sample_type=''):
    """
    plot heatmap and dendrogram after hierarchy clustering
    :param exp_df: expression dataframe, gene x samples, non-log and centered values
        - can be centered by genes after log2 normalized for calculating sample correlation
        - or clustering by gene expression directly
    :param sample2subtype: dataframe
        sample annotation, must contain 'subtype' column
    :param method: str
        sample_corr (correlation between each sample) or distance (between genes <row-wise> and samples <col-wise>)
    :param color_threshold: color threshold for dendrogram graph
    :param result_dir:
    :param sample_type: only for naming result file
    :return: reordered samples in each cluster and cluster id
    """
    subtype = np.unique(list(sample2subtype['subtype']))
    color = 'rbgyk'[:len(subtype)]
    lut = dict(zip(subtype, color))
    print('subtype2color: {}'.format(lut))
    row_colors = sample2subtype['subtype'].map(lut)
    # exp_df = log2_transform(exp_df)
    exp_df = center_value(exp_df)
    if method == 'sample_corr':
        sample_corr = exp_df.corr()
        row_linkage = hierarchy.linkage(
            distance.pdist(sample_corr), method='centroid')

        col_linkage = hierarchy.linkage(
            distance.pdist(sample_corr.T), method='centroid')
        g1 = sns.clustermap(sample_corr,
                            row_linkage=row_linkage, col_linkage=col_linkage,
                            row_colors=row_colors, col_colors=row_colors,
                            method="centroid", figsize=(15, 15), cmap='vlag')
    else:
        # distance between gene pairs
        row_linkage = hierarchy.linkage(
            distance.pdist(exp_df), metric='correlation', method='centroid')

        # distance between sample pairs
        col_linkage = hierarchy.linkage(
            distance.pdist(exp_df.T), metric='correlation', method='centroid')
        g1 = sns.clustermap(exp_df, row_linkage=row_linkage, col_linkage=col_linkage,
                            col_colors=row_colors, z_score=0,
                            method="centroid", figsize=(15, 15), cmap='vlag')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    g1.savefig(os.path.join(result_dir, '{}_hierarchy_clustering.png'.format(sample_type)), dpi=200)

    plt.figure(figsize=(15, 5))
    hierarchy.dendrogram(row_linkage, show_leaf_counts=True, p=8, truncate_mode='level', leaf_font_size=6,
                         leaf_rotation=50, color_threshold=color_threshold)
    plt.savefig(os.path.join(result_dir, '{}_dendrogram.png'.format(sample_type)), dpi=200)
    # return {'row_linkage': row_linkage, 'col_linkage': col_linkage, 'g1': g1}

    # get sub cluster info: reordered samples in each cluster and cluster id
    node_id2sub_cluster = cut_cluster_by_distance(row_linkage, threshold=color_threshold)
    reordered_ind2sample_name = {}
    # g1 = linkage['g1']
    # sample_corr = bulk_tpm_log_c.corr()
    for ind in g1.dendrogram_col.reordered_ind:
        reordered_ind2sample_name[ind] = exp_df.columns[ind]
    reordered_ind2sample_name_df = pd.DataFrame.from_dict(data=reordered_ind2sample_name, orient='index',
                                                          columns=['sample_name'])
    # reordered_ind2sample_name_df.head(2)
    reordered_ind2sample_name_df['cluster_id'] = reordered_ind2sample_name_df.index.map(node_id2sub_cluster)
    reordered_ind2sample_name_df2 = reordered_ind2sample_name_df.merge(sample2subtype, left_on='sample_name',
                                                                       right_index=True)
    reordered_ind2sample_name_df2.index.name = 'reordered_inx'
    print_df(reordered_ind2sample_name_df2)
    reordered_ind2sample_name_df2['color'] = reordered_ind2sample_name_df2['subtype'].map(lut)

    subtype_pivot = reordered_ind2sample_name_df2.pivot_table(index=['cluster_id'], columns='subtype',
                                                              aggfunc='size', fill_value=0)
    non_new = subtype_pivot.loc[:, [i for i in subtype_pivot.columns if i != 'New']]
    max_inx = np.where(non_new == np.vstack(non_new.max(axis=1)))
    cluster_id2max_col_inx = dict(zip(max_inx[0], max_inx[1]))  # prevent multiple values equal to max
    subtype_pivot['predicted_subtype'] = [subtype_pivot.columns[cluster_id2max_col_inx[x]]
                                          for x in range(subtype_pivot.shape[0])]
    # subtype_pivot['predicted_subtype'] = np.vstack(
    #     non_new.columns[np.where(non_new == np.vstack(non_new.max(axis=1)))[1]].to_list())
    if 'New' in subtype_pivot.columns:
        subtype_pivot['main_subtype_with_new%'] = (non_new.max(axis=1) + subtype_pivot['New']) / subtype_pivot.sum(axis=1)
        subtype_pivot['keep'] = (non_new.max(axis=1) >= 5) & (subtype_pivot['main_subtype_with_new%'] >= 0.8)
    else:
        subtype_pivot['keep'] = non_new.max(axis=1) >= 5
    subtype_pivot['keep'] = subtype_pivot['keep'].map({True: 1, False: 0})
    if 'New' in subtype_pivot.columns:
        reordered_ind2sample_name_df2 = reordered_ind2sample_name_df2.merge(subtype_pivot[['predicted_subtype',
                                                                                           'main_subtype_with_new%',
                                                                                           'keep']],
                                                                            left_on='cluster_id', right_index=True)
    else:
        reordered_ind2sample_name_df2 = reordered_ind2sample_name_df2.merge(subtype_pivot[['predicted_subtype',
                                                                                           'keep']],
                                                                            left_on='cluster_id', right_index=True)
    subtype_pivot.to_csv(os.path.join(result_dir, 'subtype_pivot.csv'))
    reordered_ind2sample_name_df2.to_csv(os.path.join(result_dir, 'reordered_ind2sample_name.csv'),
                                         index=False, float_format='%.3f')
    return reordered_ind2sample_name_df2


def split_tree_by_distance(node, threshold):
    assert type(node) == scipy.cluster.hierarchy.ClusterNode
    if node.count < 2:
        return [node]
    else:
        if node.dist <= threshold:
            return [node]
        else:
            return split_tree_by_distance(node.left, threshold) + split_tree_by_distance(node.right, threshold)


def cut_cluster_by_distance(linkage_matrix, threshold):
    """
    linkage_matrix: the matrix comes from hierarchy.linkage,
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html#scipy.cluster.hierarchy.linkage
    threshold: a float, max distance of clusters in the tree at the cut point
    """
    node_id2sub_cluster = {}
    root_node = hierarchy.to_tree(linkage_matrix)
    root_node_in_sub_clusters = split_tree_by_distance(root_node, threshold=threshold)
    for i, sub_rootnodes in enumerate(root_node_in_sub_clusters):
        subcluster_nodes = get_node_id(sub_rootnodes)
        for _node_id in subcluster_nodes:
            node_id2sub_cluster[_node_id] = i
    return node_id2sub_cluster


def get_node_id(node):
    assert type(node) == scipy.cluster.hierarchy.ClusterNode
    if node.count == 1:
        return [node.id]
    else:
        left_node = node.left
        right_node = node.right
        return get_node_id(left_node) + get_node_id(right_node)