In [183]:
from ete3 import Tree, TreeStyle
from sklearn.datasets import make_blobs
import pandas as pd
import hdbscan
from scipy.cluster import hierarchy
from itertools import chain
import collections as co
import seaborn as sns
import re
import numpy as np
from matplotlib import colors 
import random

In [184]:
def mean(array):
    return(sum(array)/float(len(array)))

In [185]:
def cache_distances(tree):
    ''' precalculate distances of all nodes to the root''' 
    node2rootdist = {tree:0}
    for node in tree.iter_descendants('preorder'):
        node2rootdist[node] = node.dist + node2rootdist[node.up]
    return(node2rootdist)

In [186]:
def collapse(tree, min_dist):
    # cache the tip content of each node to reduce the number of times the tree is traversed
    node2tips = tree.get_cached_content()
    root_distance = cache_distances(tree)

    i = 0
    for node in tree.get_descendants('preorder'):
        if not node.is_leaf():
            
            avg_distance_to_tips = mean([root_distance[tip]-root_distance[node]
                                         for tip in node2tips[node]])
            #print(node)
            if avg_distance_to_tips < min_dist:
                
                cluster = []
                for leaf in node.get_leaves():
                    cluster.append(leaf.cluster)
                
                # do whatever, ete support node annotation, deletion, labeling, etc.
                
                # rename
                node.name = 'Cluster ' + str(cluster[0])
                # label
                node.add_features(collapsed=True)

                # set drawing attribute so they look collapsed when displayed with tree.show()
                node.img_style['draw_descendants'] = False

                # etc...
                i += 1

In [187]:
def getNewick(node, newick, parentdist, leaf_names):
    if node.is_leaf():
        return "%s:%.2f%s" % (leaf_names[node.id], parentdist - node.dist, newick)
    else:
        if len(newick) > 0:
            newick = "):%.2f%s" % (parentdist - node.dist, newick)
        else:
            newick = ");"
        newick = getNewick(node.get_left(), newick, node.dist, leaf_names)
        newick = getNewick(node.get_right(), ",%s" % (newick), node.dist, leaf_names)
        newick = "(%s" % (newick)
        return newick

In [190]:
protein = 'H'

cluster = pd.read_csv('output/cluster.csv', sep = ',', na_filter = False, header = 0, index_col = 0)
linkage = pd.read_csv('output/linkage.csv', sep = ',', na_filter = False, header = 0)

tree_hex = []
if protein == 'H':
    tree_col = sns.color_palette('husl', n_colors=18)
elif protein == 'N':
    tree_col = sns.color_palette('husl', n_colors=11)

for col in tree_col:
    tree_hex.append(colors.to_hex(col))

#tree_cluster = cluster.query('segment == 4')[protein]
cl = cluster.query('segment == 4')[['cluster', 'H', 'N']].copy()
cl['color'] = ''

tree_label = cl.index
tree_Z = linkage.query('segment == 4').drop(columns=['segment', 'parent'])
link = tree_Z.to_numpy()

for i in range(0, cl.cluster.max() +1):

    prot = set(cl.query('cluster == @i')['H'].replace('', np.nan).dropna().tolist())
    
    if len(prot) == 1:
        cl.loc[cl.cluster == i, 'color'] = tree_hex[int(list(prot)[0][1:])]

        
cl_dict = cl[['cluster', 'color']].to_dict()

tree = hierarchy.to_tree(link,False)
y = getNewick(tree, "", tree.dist, tree_label)
t = Tree(y)

ts = TreeStyle()
ts.show_leaf_name = True
ts.show_branch_length = True
ts.show_branch_support = True
ts.mode = "c"
#ts.arc_start = -180 # 0 degrees = 3 o'clock
#ts.arc_span = 180

for leaf in t:
    
    leaf.add_features(color=cl_dict['color'].get(leaf.name, "none"))
    leaf.add_features(cluster = cl_dict['cluster'].get(leaf.name, "none"))
    
for col in set(list(cl_dict['color'].values())):
    for node in t.get_monophyletic(values=[col], target_attr="color"):
        for n in node.traverse():
            #n.set_style(nstyle)
            n.img_style['hz_line_color'] = col
            n.img_style['vt_line_color'] = col
            n.img_style['fgcolor'] = col

collapse(t, 0.06)

for n in t.search_nodes(collapsed=True):
    for ch in n.get_children():
        ch.detach()
        
#t.render('test.pdf', tree_style=ts)
t.show(tree_style=ts)