In [None]:
from Bio import Phylo
from plotly import graph_objects as go

def get_x_coordinates(tree):
    """Associates to  each clade an x-coord.
       returns dict {clade: x-coord}
    """
    xcoords = tree.depths()
    # tree.depth() maps tree clades to depths (by branch length).
    # returns a dict {clade: depth} where clade runs over all Clade instances of the tree, and depth
    # is the distance from root to clade

    #  If there are no branch lengths, assign unit branch lengths
    if not max(xcoords.values()):
        xcoords = tree.depths(unit_branch_lengths=True)
    return xcoords


def get_y_coordinates(tree, dist=1.3):
    """
       returns  dict {clade: y-coord}
       The y-coordinates are  (float) multiple of integers (i*dist below)
       dist depends on the number of tree leafs
    """
    maxheight = tree.count_terminals()  # Counts the number of tree leafs.
    # Rows are defined by the tips/leafs
    ycoords = dict((leaf, maxheight - i * dist) for i, leaf in enumerate(reversed(tree.get_terminals())))

    def calc_row(clade):
        for subclade in clade:
            if subclade not in ycoords:
                calc_row(subclade)
        ycoords[clade] = (ycoords[clade.clades[0]] +
                          ycoords[clade.clades[-1]]) / 2

    if tree.root.clades:
        calc_row(tree.root)
    return ycoords


def get_clade_lines(orientation='horizontal', y_curr=0, x_start=0, x_curr=0, y_bot=0, y_top=0,
                    line_color='rgb(25,25,25)', line_width=0.5):
    """define a shape of type 'line', for branch
    """
    branch_line = dict(type='line',
                       layer='below',
                       line=dict(color=line_color,
                                 width=line_width)
                       )
    if orientation == 'horizontal':
        branch_line.update(x0=x_start,
                           y0=y_curr,
                           x1=x_curr,
                           y1=y_curr)
    elif orientation == 'vertical':
        branch_line.update(x0=x_curr,
                           y0=y_bot,
                           x1=x_curr,
                           y1=y_top)
    else:
        raise ValueError("Line type can be 'horizontal' or 'vertical'")

    return branch_line


def draw_clade(clade, x_start, line_shapes, line_color='rgb(15,15,15)', line_width=1, x_coords=0, y_coords=0):
    """Recursively draw the tree branches, down from the given clade"""

    x_curr = x_coords[clade]
    y_curr = y_coords[clade]

    # Draw a horizontal line from start to here
    branch_line = get_clade_lines(orientation='horizontal', y_curr=y_curr, x_start=x_start, x_curr=x_curr,
                                  line_color=line_color, line_width=line_width)

    line_shapes.append(branch_line)

    if clade.clades:
        # Draw a vertical line connecting all children
        y_top = y_coords[clade.clades[0]]
        y_bot = y_coords[clade.clades[-1]]

        line_shapes.append(get_clade_lines(orientation='vertical', x_curr=x_curr, y_bot=y_bot, y_top=y_top,
                                           line_color=line_color, line_width=line_width))

        # Draw descendants
        for child in clade:
            draw_clade(child, x_curr, line_shapes, x_coords=x_coords, y_coords=y_coords)


def read_treefile(filename):
    'create tree object from newick format using Bio.Phylo'
    tree = Phylo.read(filename, "newick")
    return tree


def create_plotly_tree_old(tree, title=None, t_nodes_color_dict=None, in_node_size=2, t_node_size=10,
                            in_node_color='rgb(100,100,100)', i_node_color_dict=None, height=550):
    """make plotly figure from newick tree

    parameters:
    tree_filepath (str or Phylo.Newick.Tree object): path to newick tree or tree object
    title (str): title of plotly graph
    t_nodes_color_dict (dict): differentiate groups of leaves on tree: 'rgb()':['leafname1', 'leafname2','leafname3'...]
    node_size_dict (dict): set the size of the different colors entered as the t_nodes_color_dict dict keys
                            this is in the form 'rgb()':['leafname1', 'leafname2','leafname3'...]

    return (Plotly.graph_objects.Figure)
    """
    if type(tree) != Phylo.Newick.Tree:
        tree = Phylo.read(tree, "newick")
    x_coords = get_x_coordinates(tree)
    y_coords = get_y_coordinates(tree)
    line_shapes = []
    draw_clade(tree.root, 0, line_shapes, line_color='rgb(25,25,25)', line_width=1, x_coords=x_coords,
               y_coords=y_coords)
    
    X = []
    Y = []
    text = []
    node_sizes = []
    color_dict = {}
    t_node_names = [clade.name for clade in tree.get_terminals()]
    loops=0
    if type(t_nodes_color_dict)!= dict:
        print('this was triggered')
        t_nodes_color_dict={}
        t_nodes_color_dict['rgb(100,100,100)'] = t_node_names
    
    if tree.get_nonterminals()[0].name: 
        my_tree_clades = tree.depths().keys()
    else:
        my_tree_clades = tree.get_terminals()
    for cl in my_tree_clades:
        X.append(x_coords[cl])
        Y.append(y_coords[cl])
        #generate hover text and node colors
        i=0
        loops+=1
        for color in t_nodes_color_dict.keys():#check all groups to see if clade is included
            if cl.name in t_nodes_color_dict[color]:
                text.append(cl.name)
                color_dict[cl.name]=color
                i+=1
        if i==0:
            if cl.name:
                #print(cl.name, loops)
                text.append(cl.name)
                color_dict[loops]= in_node_color
                i+=1
    #node sizes
    for cl in my_tree_clades:
        if cl.name in t_node_names:
            node_sizes.append(t_node_size)
        else:
            node_sizes.append(in_node_size)


    axis = dict(showline=False,
                zeroline=False,
                showgrid=False,
                showticklabels=False,
                title=''  # y title
                )

    data = dict(type='scatter',
                x=X,
                y=Y,
                mode='markers',
                marker=dict(color=list(color_dict.values()),
                            size=node_sizes
                ),
                text=text,  # vignet information of each node
                hoverinfo='text',
                )
    if title:
        title=title
    layout = dict(title=title,
                  paper_bgcolor='rgb(248,248,248)',
                  dragmode="lasso",
                  font=dict(family='Balto', size=14),
                  #width=750,
                  height=height,
                  autosize=True,
                  showlegend=False,
                  xaxis=dict(showline=False,
                             zeroline=False,
                             showgrid=False,  # To visualize the vertical lines
                             ticklen=4,
                             showticklabels=False,
                             title=''),
                  yaxis=axis,
                  hovermode='closest',
                  shapes=line_shapes,
                  plot_bgcolor='rgb(248,248,248)',
                  legend={'x': 0, 'y': 1},
                  margin={'b': 0, 'l': 0, 'r': 0, 't': 0}
                  )
    fig = dict(data=[data], layout=layout)
    return fig




def create_plotly_tree(tree, title=None, t_nodes_color_dict=None, in_node_size=2, t_node_size=10,
                            in_node_color='rgb(100,100,100)', i_node_color_dict=None, height=550):
    """make plotly figure from newick tree

    parameters:
    tree_filepath (str or Phylo.Newick.Tree object): path to newick tree or tree object
    title (str): title of plotly graph
    t_nodes_color_dict (dict): differentiate groups of leaves on tree: 'rgb()':['leafname1', 'leafname2','leafname3'...]
    node_size_dict (dict): set the size of the different colors entered as the t_nodes_color_dict dict keys
                            this is in the form 'rgb()':['leafname1', 'leafname2','leafname3'...]

    return (Plotly.graph_objects.Figure)
    """
    if type(tree) != Phylo.Newick.Tree:
        tree = Phylo.read(tree, "newick")
    x_coords = get_x_coordinates(tree)
    y_coords = get_y_coordinates(tree)
    line_shapes = []
    draw_clade_t(tree.root, 0, line_shapes, line_color='rgb(25,25,25)', line_width=1, x_coords=x_coords,
               y_coords=y_coords, i_node_color_dict=i_node_color_dict)
    
    X = []
    Y = []
    text = []
    node_sizes = []
    color_dict = {}
    t_node_names = [clade.name for clade in tree.get_terminals()]
    loops=0
    if type(t_nodes_color_dict)!= dict:
        print('this was triggered')
        t_nodes_color_dict={}
        t_nodes_color_dict['rgb(100,100,100)'] = t_node_names
    
    if tree.get_nonterminals()[0].name: 
        my_tree_clades = tree.depths().keys()
    else:
        my_tree_clades = tree.get_terminals()
    for cl in my_tree_clades:
        X.append(x_coords[cl])
        Y.append(y_coords[cl])
        #generate hover text and node colors
        i=0
        loops+=1
        for color in t_nodes_color_dict.keys():#check all groups to see if clade is included
            if cl.name in t_nodes_color_dict[color]:
                text.append(cl.name)
                color_dict[cl.name]=color
                i+=1
        if i==0:
            if cl.name:
                #print(cl.name, loops)
                text.append(cl.name)
                color_dict[loops]= in_node_color
                i+=1
    #node sizes
    for cl in my_tree_clades:
        if cl.name in t_node_names:
            node_sizes.append(t_node_size)
        else:
            node_sizes.append(in_node_size)



    axis = dict(showline=False,
                zeroline=False,
                showgrid=False,
                showticklabels=False,
                title=''  # y title
                )

    data = dict(type='scatter',
                x=X,
                y=Y,
                mode='markers',
                marker=dict(color=list(color_dict.values()),
                            size=node_sizes
                ),
                text=text,  # vignet information of each node
                hoverinfo='text',
                )
    if title:
        title=title
    layout = dict(title=title,
                  paper_bgcolor='rgb(248,248,248)',
                  dragmode="lasso",
                  font=dict(family='Balto', size=14),
                  #width=750,
                  height=height,
                  autosize=True,
                  showlegend=False,
                  xaxis=dict(showline=False,
                             zeroline=False,
                             showgrid=False,  # To visualize the vertical lines
                             ticklen=4,
                             showticklabels=False,
                             title=''),
                  yaxis=axis,
                  hovermode='closest',
                  shapes=line_shapes,
                  plot_bgcolor='rgb(248,248,248)',
                  legend={'x': 0, 'y': 1},
                  margin={'b': 0, 'l': 0, 'r': 0, 't': 0}
                  )
    fig = dict(data=[data], layout=layout)
    return fig




def draw_clade_t(clade, x_start, line_shapes, line_color='rgb(15,15,15)', 
line_width=1, x_coords=0, y_coords=0, i_node_color_dict=None):
    """Recursively draw the tree branches, down from the given clade"""

    x_curr = x_coords[clade]
    y_curr = y_coords[clade]
    if i_node_color_dict:
        for key in i_node_color_dict:
            if str(clade.name) in i_node_color_dict[key]:
                line_color = key

    # Draw a horizontal line from start to here
    branch_line = get_clade_lines(orientation='horizontal', y_curr=y_curr, x_start=x_start, x_curr=x_curr,
                                  line_color=line_color, line_width=line_width)

    line_shapes.append(branch_line)

    if clade.clades:
        # Draw a vertical line connecting all children
        y_top = y_coords[clade.clades[0]]
        y_bot = y_coords[clade.clades[-1]]

        line_shapes.append(get_clade_lines(orientation='vertical', x_curr=x_curr, y_bot=y_bot, y_top=y_top,
                                           line_color=line_color, line_width=line_width))

        # Draw descendants
        for child in clade:
            draw_clade_t(child, x_curr, line_shapes,line_color=line_color, x_coords=x_coords, y_coords=y_coords, i_node_color_dict=i_node_color_dict)

In [1]:
def create_tree_w_bargraphs(tree_obj, HOGs, orthoparser_obj, colors, mode='markers'):
    
    fig=create_plotly_tree(tree_obj, orthoparser_obj=orthoparser_obj,colors=colors, mode=mode)
    fig_sp_tree = go.Figure(fig)
    shift = 0.01
    bar_thickness = 0.95
    scale = 0.02
    new_max_x = max(fig_sp_tree['data'][0]['x']) #initialize max_x
    if type(HOGs) != list:
        HOGs = [HOGs]
    for i, HOG in enumerate(HOGs):
        l_colors = list(colors.values())
        color = l_colors[i]
        fig_sp_tree_full, new_max_x = _make_HOG_bargraph(fig_sp_tree, HOG, 
                                                    bar_thickness, shift, scale, 
                                                    new_max_x, tree_obj, color, orthoparser_obj)
    return fig_sp_tree_full





def _make_HOG_bargraph(fig_sp_tree, HOG, bar_thickness, shift, 
                      scale, max_x, tree_obj, color):
    x_y_coords = get_bar_coords_dict(count_dict, tree_obj)
    new_max_x = max_x
    for acc, x_y in x_y_coords.items():
        copies = int(x_y[0])
        text = f'{acc}<br>{HOG} copies: {copies}'
        bar_trace = get_bar_traces(max_x, copies, x_y[1], bar_thickness = bar_thickness, 
                                   shift=shift, scale=scale, text=text, 
                                   color=color)
        fig_sp_tree.add_trace(bar_trace)
        #find new max_x v_val
        if max(bar_trace['x']) > new_max_x:
            new_max_x = max(bar_trace['x'])
    fig_sp_tree.add_trace(make_baseline_trace(shift, HOG, x_y_coords, max_x, bar_thickness))
    return fig_sp_tree, new_max_x


def get_bar_coords_dict(count_dict, tree_obj):
    ycoords = get_y_coordinates(tree_obj)
    ycoords_terminal = {clade.name:y_value for clade, y_value in ycoords.items() if clade.name}
    x_y_coords = {leaf:[] for leaf in count_dict.keys()}
    for acc, val in ycoords_terminal.items():
        x_y_coords[acc].append(count_dict[acc])
        x_y_coords[acc].append(val)
    return x_y_coords


def get_bar_traces(x_start, x_val, y_val, bar_thickness, shift=0,
                   scale=1, text='', color='rgba(100,100,100,0.5)'):
    """return trace for a bar in a horizontal bargraph
    
    parameters:
    x_start (int): left_most x_coordinate of graph
    x (int or float): lenght of the bar
    y (int or float): y position on bargraph
    bar_thickness (float): thickness of a bar
    shift (float): shifts the left_hand side of graph 
    
    return (trace): trace containing one bar
    """
    x_left = x_start+shift
    x_right = x_left+(x_val*scale)
    y_low = y_val-(0.5*bar_thickness)
    y_high = y_val+(0.5*bar_thickness)
    x = [x_left, x_right, x_right, x_left, x_left]
    y = [y_low,  y_low,   y_high,  y_high, y_low]
    
    trace = go.Scatter(
        x = x,
        y = y,
        mode='lines',
        fill='toself',
        line={'color':color,
             'width':0.01},
        text = text,
        showlegend=False)
    return trace





def make_baseline_trace(shift, HOG, x_y_coords, max_x, bar_thickness):
    x_start = max_x + shift
    xy_coords_df = pd.DataFrame(x_y_coords).transpose()
    xy_coords_df.columns = [HOG+'_x',HOG+'_y' ]
    max_y = xy_coords_df[HOG+'_y'].max() + 0.5*bar_thickness
    min_y = xy_coords_df[HOG+'_y'].min() - 0.5*bar_thickness
    baseline_trace = go.Scatter(x=[x_start, x_start],
                                y=[min_y, max_y],
                                mode='lines',
                                line={'color':'rgba(100,100,100,0.5)',
                                     'width': 1})
    return baseline_trace