In [1]:
%matplotlib inline

In [2]:
import glob
import ete3
import numpy as np

In [265]:
def MP_root_ete3(tree):
    """
    
    """
    init_bl = np.sum([i.dist for i in tree.traverse()])
    init_terms = tree.get_leaves()
    init_term_names = [i.name for i in tree.get_leaves()]
    if len(tree.children) == 3:
        tree.set_outgroup(tree.get_leaves()[0])
    assert set([len(i.children) for i in tree.traverse() if not i.is_leaf()]) == set([2])
    ###Identifying the two leaves that are farthest from one-another
    starting, trash = tree.get_farthest_leaf()
    farthest1, init_dist = starting.get_farthest_node()
    farthest2, max_dist = farthest1.get_farthest_node()
    assert farthest1.is_leaf() and farthest2.is_leaf()
    
    ###Actually performing the mid-point root
    final_root(tree, farthest1, farthest2, max_dist)
    
    ###Making sure the two leave are equi-distant from one-another
    assert np.isclose(tree.get_distance(farthest1), tree.get_distance(farthest2))
    ###Making sure that I didn't lose any branch length along the way
    final_bl = np.sum([i.dist for i in tree.traverse()])
    assert np.isclose(init_bl, final_bl)
    ###Making sure I didn't lose any leaves along the way
    assert set(init_term_names) == set([i.name for i in tree.get_leaves()])
    ###And that I'm still fully bifurcating
    assert set([len(i.children) for i in tree.traverse() if not i.is_leaf()]) == set([2])
    return

def get_ancestor_path(parent, target):
    temp_node = target
    path = []
    while temp_node != parent:
        path.append(temp_node)
        temp_node = temp_node.up
    return path

def get_shortest_path(tree, node1, node2):
    lca = tree.get_common_ancestor(node1, node2)
    path1 = get_ancestor_path(lca, node1)
    path2 = get_ancestor_path(lca, node2)
    path2 = path2[::-1] ###Reverse this list to correctly orient the final path
    path = path1+path2
    return path

def final_root(tree, farthest1, farthest2, max_dist):
    """
    """
    path = get_shortest_path(tree, farthest1, farthest2)
    assert np.isclose(np.sum([i.dist for i in path]), max_dist)
    ###Finding the correct branch for the final root
    mid_point = max_dist/2.
    counter = 0
    for node in path:
        counter += node.dist
        if counter < mid_point:
            last_success = node
        else:
            remainder = counter - mid_point
            break        
    ###Rooting the tree accordingly
    tree.set_outgroup(node)
    ###Checking the orientation (asking whether the other node is indeed a child)
    if tree.children[0] != node:
        tree.swap_children()
    ###If the topology doesn't get deformed by the rooting
    if tree.children == [node, last_success]:
        total = np.sum([i.dist for i in tree.children])
        tree.children[0].dist = remainder
        tree.children[1].dist = total-remainder
    ###And if it does induce a topology deformation
    else:
        assert last_success in tree.children[1].children        
        total = np.sum([i.dist for i in tree.children + [last_success]])
        tree.children[0].dist = remainder
        tree.children[1].dist = 0.
        last_success.dist = total-remainder
    return



In [259]:
def full_describe(tree):
    n_leaves = len(tree.get_leaves())
    n_nodes = len([i for i in tree.traverse()])
    total_bl = np.sum([i.dist for i in tree.traverse()])
    return n_leaves, n_nodes, total_bl

# Ensuring the same results as the (slower) built in ete3 mid-point method

**Also worth noting that the ete3 mid-point method doesn't root at the actual mid-point of the tree but rather the mid-point of the mid-point branch**

In [266]:
accuracies = []
# for tree_file in glob.glob('../Data/OMA_orthologs/5204_4890/*.treefile')[:]:
# for tree_file in glob.glob('../Data/OMA_orthologs/5204_4890_33511_33317/*.treefile')[:]:
for tree_file in glob.glob('../Data/OMA_orthologs/5204_4890_33511_33317_33090/*.treefile')[:]:
#     print(tree_file)
    tree = ete3.Tree(tree_file, format=0)
    R = tree.get_midpoint_outgroup()
    tree.set_outgroup(R)
    a_orig = set([i.name for i in tree.children[0].get_leaves()])
    b_orig = set([i.name for i in tree.children[1].get_leaves()])
    orig_desc = full_describe(tree)
    
    tree = ete3.Tree(tree_file, format=0)
    MP_root_ete3(tree)
    a_new = set([i.name for i in tree.children[0].get_leaves()])
    b_new = set([i.name for i in tree.children[1].get_leaves()])
    new_desc = full_describe(tree)
    
    assert orig_desc[:-1] == new_desc[:-1]
    assert np.isclose(orig_desc[-1], new_desc[-1])
    assert (a_orig==a_new and b_orig==b_new) or (a_orig==b_new and b_orig==a_new)

KeyboardInterrupt: 

# MinVar development

In [100]:
import sys
sys.path.append('../../Tree_weighting/Code')
import weighting_methods_ete3

from scipy.optimize import minimize
from statsmodels.stats.weightstats import DescrStatsW

In [213]:
def MinVar_root_ete3(my_tree, weights_type='None'):
    """ 

    """
    initial_depths = [my_tree.get_distance(i) for i in my_tree.get_leaves()]
    depths_array_dict = {}
    depths_array_dict[my_tree] = np.array(initial_depths)
    weights_array_dict = {}
    if weights_type == 'None':
        weights_array_dict[my_tree] = np.array([1.0 for i in my_tree.get_leaves()])
        weights_update_fxn = no_weights_update
        branchscan_fxn = MinVar_branchscan
    
    elif weights_type == 'GSC':
        weighting_methods_ete3.GSC_ete3(my_tree)
        weights_array_dict[my_tree] = np.array([i.weight for i in my_tree.get_leaves()])
        weights_update_fxn = update_GSC_array_dict
        branchscan_fxn = MinVar_branchscan_GSC
    else:
        print('Need to give me a valid weights_type (currently None or GSC)')
    finished_count = compile_array_dicts(my_tree, weights_array_dict, depths_array_dict, weights_update_fxn)
    results_dict = MinVar_optimize_all(my_tree, weights_array_dict, depths_array_dict, branchscan_fxn)
    return results_dict, depths_array_dict, weights_array_dict

def compile_array_dicts(node, weights_array_dict, depths_array_dict, weights_update_fxn, finished_count=0):
    if not node.is_root():      
        update_depth_array_dict(node, depths_array_dict, finished_count)
        weights_update_fxn(node, weights_array_dict, finished_count)
    if len(node.children) == 0:
        finished_count += 1   
    elif len(node.children) == 2:
        l_child, r_child = node.children
        finished_count = compile_array_dicts(l_child, weights_array_dict, depths_array_dict, weights_update_fxn, finished_count)
        finished_count = compile_array_dicts(r_child, weights_array_dict, depths_array_dict, weights_update_fxn, finished_count)
    else:
        print('Probable f up')
    return finished_count

def update_depth_array_dict(node, depths_array_dict, finished_count):
    ds_count = len(node.get_leaves())
    parent = node.up
    new_array = np.array(depths_array_dict[parent])
    new_array[finished_count:finished_count+ds_count] -= node.dist
    #Add the branch length to all the upstream clades (two sets)
    new_array[:finished_count] += node.dist
    new_array[finished_count+ds_count:] += node.dist
    #Update the dictionary
    depths_array_dict[node] = new_array
    return

def no_weights_update(node, weights_array_dict, finished_count):
    weights_array_dict[node] = weights_array_dict[node.up]
    return

def update_GSC_array_dict(node, weights_array_dict, finished_count):
    ds_count = len(node.get_leaves())
    parent = node.up
    new_array = np.array(weights_array_dict[parent])
    #This is the total "weight" to reclaim from the downstream terms and distribute to the upstreams
    bl_to_disperse = node.dist
    #Recover from downstream terminals
    current_ds_weights = np.sum(new_array[finished_count:finished_count+ds_count])
    if current_ds_weights > 0:
        to_subtract = new_array[finished_count:finished_count+ds_count]/current_ds_weights*-1*bl_to_disperse
    else:
        to_subtract = np.zeros_like(new_array[finished_count:finished_count+ds_count])
    #Disperse to upstream terminals
    current_us_weights = np.sum(new_array[:finished_count]) + np.sum(new_array[finished_count+ds_count:]) 
    if current_us_weights > 0:
        to_add_a = new_array[:finished_count] / current_us_weights * bl_to_disperse  
        to_add_b = new_array[finished_count+ds_count:] / current_us_weights * bl_to_disperse
    else:
        to_add_a = np.zeros_like(new_array[:finished_count])  
        to_add_b = np.zeros_like(new_array[finished_count+ds_count:])      
    assert  np.isclose(np.sum(to_add_a) + np.sum(to_add_b) + np.sum(to_subtract), 0.)
    #et voila
    new_array = new_array + np.concatenate((to_add_a, to_subtract, to_add_b))
    weights_array_dict[node] = new_array 
    return

def MinVar_optimize_all(my_tree, weights_array_dict, depths_array_dict, branchscan_fxn):
    results_dict = {}
    for node in weights_array_dict.keys():
        if node.is_root():
            continue
        print(node.name)
        leaves = node.get_leaves()
        ds_count = len(leaves)
        first = leaves[0]
        finished_count = my_tree.get_leaves().index(first)
        depths_array = depths_array_dict[node]
        weights_array = weights_array_dict[node]
        ###################################################
        #Root-to-tip distances for all downstream terminals
        ###################################################
        downstream_dists = np.array(depths_array[finished_count:finished_count+ds_count])
        #And all upstream terminals
        upstream_dists = np.concatenate((depths_array[:finished_count],\
                                         depths_array[finished_count+ds_count:]))
        ###################################################
        #Weights for all downstream terminals
        ###################################################
        downstream_weights = np.array(weights_array[finished_count:finished_count+ds_count])
        #And all upstream terminals
        upstream_weights = np.concatenate((weights_array[:finished_count],\
                                           weights_array[finished_count+ds_count:]))
        
        ###################################################
        #Set the bounds and optimize the GSC specific function
        ###################################################
        bl_bounds = np.array([[0., node.dist]])
        #Valid options for method are L-BFGS-B, SLSQP and TNC
        res = minimize(branchscan_fxn, np.array(np.mean(bl_bounds)),\
                              args=(downstream_dists, upstream_dists,\
                                    downstream_weights, upstream_weights),\
                              bounds=bl_bounds, method='L-BFGS-B')
        results_dict[node] = res
    return results_dict


def MinVar_branchscan(modifier, ds_dists, us_dists, ds_weights, us_weights):
        temp_ds_dists = ds_dists + modifier
        temp_us_dists = us_dists - modifier
        all_dists = np.concatenate((temp_ds_dists, temp_us_dists))
        all_weights = np.concatenate((ds_weights, us_weights))
        dsw = DescrStatsW(all_dists, all_weights)
        return dsw.var

def MinVar_branchscan_GSC(modifier, ds_dists, us_dists, ds_weights, us_weights):
        temp_ds_dists = ds_dists + modifier
        temp_us_dists = us_dists - modifier
        all_dists = np.concatenate((temp_ds_dists, temp_us_dists))

        ###########################################################################
        #Now adjust the downstream and upstream weights
        ###########################################################################
        #First get the total downstream weights
        total_ds = np.sum(ds_weights)
        #Divide up the added branch length (modifier) across the downstream weights
        if total_ds != 0:
            temp_ds_weights = ds_weights + (ds_weights/total_ds*modifier)
        #Special case if nothing is downstream (for terminal branches)
        else:
            temp_ds_weights = ds_weights + modifier
        #Get the total old upstream weights
        total_us = np.sum(us_weights)
        #Reclaim the branch length (modifier) from all the upstream weights
        if total_us != 0:
            temp_us_weights = us_weights - (us_weights/total_us*modifier)
        #Special case for terminal branches
        else:
            print('This condition should perhaps never occur and should be investigated')
            temp_us_weights = us_weights - modifier
        #Put all the weights together
        all_weights = np.concatenate((temp_ds_weights, temp_us_weights))
        #In GSC weighting, the weights can't be less than the distance! Minor numerical rounding errors can
        #cause this to happen
        all_weights = np.minimum(all_weights, all_dists) 
        #Finally putting that boolean I've been passing around to use. Basically this is a re-scaling
        #of the GSC weights that I came up with that expresses each GSC weight for a given terminal
        #as a fraction of its total possible weight (its depth). In practice, it is a less dramatic 
        #weighting scheme than the non-normalized counterpart.
        ###########################################################################
        #Calculate weighted variance and return
        ###########################################################################
        dsw = DescrStatsW(all_dists, all_weights)
        return dsw.var
    



In [319]:
test_tree = ete3.Tree('(((A:20, B:20):30, C:31):30, D:80);')
results_dict, depths_array_dict, weights_array_dict = MinVar_root_ete3(test_tree, weights_type='None')



A
B
C
D


In [320]:
for i in test_tree.get_leaves():
    print(i.name, i.dist, i.weight)

AttributeError: 'TreeNode' object has no attribute 'weight'

In [321]:
print([np.sum(i) for i in depths_array_dict.values()])

[301.0, 241.0, 241.0, 281.0, 281.0, 303.0, 461.0]


In [322]:
print([np.sum(i) for i in weights_array_dict.values()])

[4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0]


In [323]:
depths_array_dict.values()

dict_values([array([80., 80., 61., 80.]), array([ 50.,  50.,  31., 110.]), array([ 20.,  20.,  61., 140.]), array([  0.,  40.,  81., 160.]), array([ 40.,   0.,  81., 160.]), array([ 81.,  81.,   0., 141.]), array([160., 160., 141.,   0.])])

In [324]:
weights_array_dict.values()

dict_values([array([1., 1., 1., 1.]), array([1., 1., 1., 1.]), array([1., 1., 1., 1.]), array([1., 1., 1., 1.]), array([1., 1., 1., 1.]), array([1., 1., 1., 1.]), array([1., 1., 1., 1.])])

In [325]:
for i,j in results_dict.items():
    print(i,j)


      /-A
   /-|
--|   \-B
  |
   \-C       fun: 67.6875
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
      jac: array([-4.74999808])
  message: b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 4
      nit: 1
   status: 0
  success: True
        x: array([30.])

   /-A
--|
   \-B       fun: 885.1875
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
      jac: array([-20.50002195])
  message: b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 4
      nit: 1
   status: 0
  success: True
        x: array([30.])

--A       fun: 2400.1875
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
      jac: array([-40.25005182])
  message: b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 4
      nit: 1
   status: 0
  success: True
        x: array([20.])

--B       fun: 2400.1875
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
      jac: array([-40.25005182])
  message: b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nf