In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Summary

The point of this notebook is to develop better tree rooting algorithms. Develop, because this will be messy for the time being and will eventually split into hopefully comprehensible code when the time arises.

In [2]:
from Bio import Phylo
import numpy as np
from scipy import stats

import ete3
from io import StringIO
import random
from scipy.optimize import minimize, minimize_scalar
from matplotlib import pyplot as plt

import glob
import pandas as pd

import copy

# Test my MP algorithm

I'm confident that it works, at least on the trees I've tested. But I really need to demonstrate that here.

Note that I've uncovered weird problems with the `root_at_midpoint()` method from `Bio.Phylo`, so I won't be trusting it as a positive control. Rather, should demonstrate that they are usually equivalent and when there are discrepancies my method produces the expected behavior (conserving branch length, for instance).

In [None]:
import rooting_methods

In [None]:
for input_tree in glob.glob('../Tria_et_al_data/cyanobacteria/ingroup/phyml/*.nwk')[:]:
    my_tree = Phylo.read(input_tree, 'newick', rooted=False)
    initial_terminals = my_tree.get_terminals()
    my_tree = rooting_methods.mp_root_adhock(my_tree)
    assert my_tree.is_bifurcating()
    my_bls = (my_tree.root.clades[0].branch_length, my_tree.root.clades[1].branch_length)
    my_bls = sorted(my_bls)
    phylo_tree = Phylo.read(input_tree, 'newick', rooted=False)
    phylo_tree.root_at_midpoint()
    phylo_bls = (phylo_tree.root.clades[0].branch_length, phylo_tree.root.clades[1].branch_length)
    phylo_bls = sorted(phylo_bls)

**Investigate example errors**

**The ete3 method is just plain wrong**

In [None]:
# tree_loc = '../test.ete3.newick'
# # tree_loc = '/Users/adamhockenberry/Projects/Phylogenetic_couplings/scratch/current/1AOE_A_rp75.newick'
# tree = ete3.Tree(tree_loc)
# outgroup = tree.get_midpoint_outgroup()
# tree.set_outgroup(outgroup)
# tree.render('%%inline')

** Add in a comparison to DendroPy to be comprehensive **

# Testing method based off of minimizing the standard deviation in the distribution of root-to-tip distances

...this is surprisingly easy and seems like it should work great...I guess it's just the Min-var or MCCV method as it's called in the MAD paper? Not sure if maximizing the likelihood of a gaussian is the sme as minimizing the coefficient of variation but I'm kind of guessing yes. And in any event this code is easily adapted

In [3]:
import rooting_methods
import rooting_methods_v2

In [25]:
# tree = Phylo.read('../../Tree_rooting/Data/raw_OMA_trees/OMAGroup_479938.mafft.afa.treefile.Rooted.MPAJH', 'newick')
# tree = Phylo.read('/Users/adamhockenberry/Downloads/BM_Folder/paper_tree.txt', 'newick')
# rooted_tree.root_with_outgroup(['ELI', 'MAL'], outgroup_branch_length=10e-6)
# tree = Phylo.read(StringIO('(((A:20, B:20):30,C:50):30, D:80)'), 'newick', rooted=False)
# tree = Phylo.read('../../Tree_rooting/Data/euk_trees/KOG0001.faa.aln.nwk.Rooted.MADAJH', 'newick')
tree = Phylo.read('../../Phylogenetic_couplings/Data/psicov150_aln_pdb/raw_trees/1a3aA.newick', 'newick')
rooted_tree = rooting_methods.mp_root_adhock(tree)
# Phylo.draw(rooted_tree)

In [24]:
%%timeit
a1, b1, c1 = rooting_methods.ml_root_adhock(rooted_tree)

1 loops, best of 3: 12.6 s per loop


In [26]:
%%timeit
a2, b2, c2 = rooting_methods_v2.mlfit_root_adhock(rooted_tree)

1 loops, best of 3: 2.09 s per loop


# Testing systematically

In [None]:
id_species_dict = {}
with open('../Tria_et_al_data/eukaryotes/ID_to_Species.txt', 'r') as infile:
    texty = infile.readlines()
    for line in texty[1:]:
        sl = line.split('\t')
        id_species_dict[sl[0]] = sl[1]
print(len(id_species_dict.keys()))

species_seqid_dict = {}
with open('../Tria_et_al_data/eukaryotes/cluster_to_seqid.txt', 'r') as infile:
    texty = infile.readlines()
    for line in texty:
        sl = line.split('\t')
        if sl[0] == 'KOG0725':
            species_seqid_dict[sl[1]] = sl[2].strip()
print(len(species_seqid_dict.keys()))

In [None]:

# fungi = ['13684', ]
# id_species_dict

** test monophyly **

In [None]:
def recursive_tree_monophyly(hypothetical_root, tree, test_set, is_mono):
    if tree.is_monophyletic(test_set):
        is_mono = True
    if len(hypothetical_root.clades) == 2:
        l_clade, r_clade = hypothetical_root.clades
        if l_clade.branch_length > 0:
            tree.root_with_outgroup(l_clade, outgroup_branch_length=10e-10)
            is_mono = recursive_tree_monophyly(l_clade, tree, test_set, is_mono)
            is_mono = recursive_tree_monophyly(r_clade, tree, test_set, is_mono)
        elif r_clade.branch_length > 0:
            tree.root_with_outgroup(r_clade, outgroup_branch_length=10e-10)
            is_mono = recursive_tree_monophyly(l_clade, tree, test_set, is_mono)
            is_mono = recursive_tree_monophyly(r_clade, tree, test_set, is_mono)
    elif len(hypothetical_root.clades) == 1:
        l_clade = hypothetical_root.clades[0]
        if l_clade.branch_length > 0:
            tree.root_with_outgroup(l_clade, outgroup_branch_length=10e-10)
            is_mono = recursive_tree_monophyly(l_clade, tree, test_set, is_mono)
    elif len(hypothetical_root.clades) == 0:
        return is_mono
    return is_mono

# tree.get_terminals()
# tree.is_monophyletic(metazoa)
tree = Phylo.read('../test.ete3.newick', 'newick', rooted=False)
tree = mp_root_adhock(tree)
# tree.is_monophyletic([term for term in tree.get_terminals() if\
#                       term.name in ['7165', '7425', '7460', '121225', '7227', '6239']])

In [None]:
testy = [term for term in tree.get_terminals() if\
                      term.name in metazoa]
recursive_tree_monophyly(tree.root, tree, testy, False)

In [None]:
Phylo.draw(tree)

In [None]:
metazoa = ['10090', '121225', '9606', '30611', '8364', '7955', '8128', '8090',\
          '7668', '7460', '7425', '7227', '7165', '6239']
problematic = ['../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG3467.faa.aln.nwk',\
              '../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG2866.faa.aln.nwk']

trees_dir = '../Tria_et_al_data/eukaryotes/ingroup/phyml/*.nwk'
ideal_species_number = 31

# n_pruned = 15
# trees_dir = '../Tria_et_al_data/eukaryotes/ingroup/phyml/*.{}.pruned'.format(n_pruned)
# problematic = [i+'.{}.pruned'.format(n_pruned) for i in problematic]
# # problematic += ['../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG2688.faa.aln.nwk.6.pruned']
# problematic += ['../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG3887.faa.aln.nwk.15.pruned']
# problematic += ['../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG1374.faa.aln.nwk.15.pruned']
# problematic += ['../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG1558.faa.aln.nwk.15.pruned']
# problematic += ['../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG0284.faa.aln.nwk.15.pruned']
# problematic += ['../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG0594.faa.aln.nwk.15.pruned']
# ideal_species_number = 31-n_pruned


# trees_dir = ['../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG0725.faa.aln.nwk']
mp_success_rate = 0
ml_success_rate = 0
mad_success_rate = 0
attempts = 0
for tree_loc in glob.glob(trees_dir)[:50]:
    if tree_loc in problematic:
        continue
    print('######## {}'.format(tree_loc))
    tree = Phylo.read(tree_loc, 'newick')
    if len(tree.get_terminals()) != ideal_species_number:
        continue
        
    try:
        mad_tree = Phylo.read(tree_loc+'.rooted', 'newick', rooted=True)
    except ValueError:
        print('MAD did not work here')
        continue

    testy = [term for term in tree.get_terminals() if\
                      term.name in metazoa]
    rooted_tree = mp_root_adhock(tree)
    valid = recursive_tree_monophyly(rooted_tree.root, rooted_tree, testy, False)
    if valid:
        attempts += 1
        ###Mid point
        mp_tree = mp_root_adhock(tree)
        if set(testy) == set(mp_tree.root.clades[0].get_terminals()) or \
            set(testy) == set(mp_tree.root.clades[1].get_terminals()):
                mp_success_rate += 1
        ###ML        
        ml_tree = max_likelihood_root(tree)
        if set(testy) == set(ml_tree.root.clades[0].get_terminals()) or \
            set(testy) == set(ml_tree.root.clades[1].get_terminals()):
                ml_success_rate += 1
        ###MAD
        testy = [term for term in mad_tree.get_terminals() if\
                      term.name in metazoa]
        if set(testy) == set(mad_tree.root.clades[0].get_terminals()) or \
            set(testy) == set(mad_tree.root.clades[1].get_terminals()):
                mad_success_rate += 1
#     tree.root_at_midpoint()
#     print(min([term.branch_length for term in tree.get_terminals()]))
#     print(recursive_tree_monophyly(tree.root, tree, testy, False))
    print(mp_success_rate, mad_success_rate)


    

In [None]:
mp_success_rate / attempts, ml_success_rate / attempts, mad_success_rate / attempts

In [None]:
# Phylo.draw(tree)
# Phylo.draw(mp_tree)
# Phylo.draw(ml_tree)
# Phylo.draw(mad_tree)

In [None]:
mp_success_rate, ml_success_rate, mad_success_rate, attempts

In [None]:
tree_loc = '../Tria_et_al_data/eukaryotes/ingroup/phyml/KOG3467.faa.aln.nwk'
tree = Phylo.read(tree_loc, 'newick')

# tree.root_at_midpoint()
tree = mp_root_adhock(tree)
Phylo.draw(tree)

In [None]:
# testy = [term for term in tree.get_terminals() if\
#                       term.name in metazoa]
recursive_tree_monophyly(tree.root, tree, testy, False)

In [None]:
tree.get_terminals()

# A weighted MaxLik implementation

In [None]:
from Bio import Phylo
import rooting_methods

import pandas as pd
import numpy as np

from scipy.optimize import minimize
from scipy import stats

In [None]:
import sys
sys.path.append('../../Tree_weighting/Code/')
import weighting_methods

In [None]:
# def ml_root_weighted(tree):
#     ###Depths are important! This is what I am trying to optimize in terms
#     ###of making these look as close to normal as possible. So this gets the
#     ###starting depths as a DataFrame and subsequent tree crawling adds/subtracts
#     ###to these values
#     initial_depths = tree.root.depths()
#     terminal_depths_df = pd.DataFrame()
#     terminal_depths_df['depth'] = np.nan
#     for term in tree.get_terminals():
#         terminal_depths_df.set_value(term.name, 'depth', initial_depths[term])
#     depths_dict = {}
#     depths_dict[tree.root] = terminal_depths_df
    
#     ###Getting starting weights
#     weights_dict_single = weighting_methods.GSC_adhock_extended(tree)
#     weights_dict_all = {}
#     weights_dict_all[tree.root] = weights_dict_single

#     explored, function_optima, depths_dict, weights_dict =\
#             recursive_crawl_ml(tree.root, [], [], depths_dict, weights_dict_all, tree)
    
#     ###Getting the best function eval and rooting there
#     function_optima = sorted(function_optima, key=lambda x: x[1].fun)
#     tree.root_with_outgroup(function_optima[0][0], outgroup_branch_length=0.)
#     assert tree.root.clades[1].branch_length == 0.
#     assert tree.root.clades[1] == function_optima[0][0]
#     tree.root.clades[0].branch_length -= function_optima[0][1].x[0]
#     tree.root.clades[1].branch_length += function_optima[0][1].x[0]
#     return tree, function_optima, depths_dict, weights_dict

# def recursive_crawl_ml(hypothetical_root, explored, function_optima, depths_dict, weights_dict, tree):
#     if len(hypothetical_root.clades) == 2:
#         l_clade, r_clade = hypothetical_root.clades
#         l_bl = l_clade.branch_length
#         r_bl = r_clade.branch_length
#         #L clade first
#         if l_bl > 0:
#             depths_dict, downstream_terms, upstream_terms =\
#                     update_depth_df_dict(depths_dict, l_clade, hypothetical_root)
#             weights_dict =\
#                     update_weights_dict(weights_dict, l_clade, hypothetical_root, downstream_terms, upstream_terms)
#             res = optimize_root_loc_on_branch(l_clade, depths_dict[l_clade], weights_dict[l_clade], downstream_terms, upstream_terms)
#             function_optima.append((l_clade, res))
#             explored, function_optima, depths_dict, weights_dict =\
#                     recursive_crawl_ml(l_clade, explored, function_optima, depths_dict, weights_dict, tree)
#         #R clade second
#         if r_bl > 0:
#             depths_dict, downstream_terms, upstream_terms =\
#                     update_depth_df_dict(depths_dict, r_clade, hypothetical_root)
#             weights_dict =\
#                     update_weights_dict(weights_dict, r_clade, hypothetical_root, downstream_terms, upstream_terms)
#             res = optimize_root_loc_on_branch(r_clade, depths_dict[r_clade], weights_dict[r_clade], downstream_terms, upstream_terms)
#             function_optima.append((r_clade, res))
#             explored, function_optima, depths_dict, weights_dict =\
#                     recursive_crawl_ml(r_clade, explored, function_optima, depths_dict, weights_dict, tree)
#     elif len(hypothetical_root.clades) == 0:
#         explored.append(hypothetical_root)
#         return explored, function_optima, depths_dict, weights_dict
    
#     else:
#         print('Some big error here with the number of clades stemming from this root')
#     explored.append(hypothetical_root)
#     return explored, function_optima, depths_dict, weights_dict

# def update_depth_df_dict(depths_dict, my_clade, parent_clade):
#     downstream_terms = [i.name for i in my_clade.get_terminals()]
#     upstream_terms = list(set(list(depths_dict[parent_clade].index)) - set(downstream_terms))
#     depths_dict[my_clade] = depths_dict[parent_clade].copy(deep=True)
#     depths_dict[my_clade].loc[downstream_terms, 'depth'] -= my_clade.branch_length
#     depths_dict[my_clade].loc[upstream_terms, 'depth'] += my_clade.branch_length
#     return depths_dict, downstream_terms, upstream_terms

# def update_weights_dict(weights_dict, my_clade, parent_clade, downstream_terms, upstream_terms):
#     '''
#     Some convoluted copy things happening here that should be double checked
#     '''
#     weights_dict[my_clade] = copy.copy(weights_dict[parent_clade])
#     trashy = 0
#     temp_ds = [next(tree.find_elements(term)) for term in downstream_terms]
#     for term in temp_ds:
#         trashy += weights_dict[my_clade][term][-1]
#         weights_dict[my_clade][term] = weights_dict[my_clade][term][:-1]
#     bl_to_disperse = my_clade.branch_length
#     temp_us = [next(tree.find_elements(i)) for i in upstream_terms]
#     to_divide = np.sum([weights_dict[my_clade][term][-1] for term in temp_us])
#     for term in temp_us:
#         weights_dict[my_clade][term] = weights_dict[my_clade][term] + [weights_dict[my_clade][term][-1] +\
#                                             weights_dict[my_clade][term][-1]/to_divide*bl_to_disperse]       
#     return weights_dict

# def optimize_root_loc_on_branch(my_clade, depths_df, weights_dict, downstream_terms, upstream_terms):
#     '''
#     '''    
# #     print('####')


#     downstream_dists = np.array(depths_df.loc[downstream_terms, 'depth'])
#     downstream_weights = np.array([weights_dict[next(tree.find_elements(i))][-1] for i in downstream_terms])

#     upstream_dists = np.array(depths_df.loc[upstream_terms, 'depth'])
#     upstream_weights = np.array([weights_dict[next(tree.find_elements(i))][-1] for i in upstream_terms])
#     old_upstream_weights = np.array([weights_dict[next(tree.find_elements(i))][-2] for i in upstream_terms])
#     bl_bounds = np.array([[0., my_clade.branch_length]])
#     ###Valid options for method are L-BFGS-B, SLSQP and TNC
#     res = minimize(branch_scan_ml, np.array(np.mean(bl_bounds)),\
#                           args=(downstream_dists, upstream_dists,\
#                                 downstream_weights, upstream_weights, old_upstream_weights),\
#                           bounds=bl_bounds, method='SLSQP')
# #     print(res)
#     return res

# def branch_scan_ml(modifier, ds_dists, us_dists, ds_weights, us_weights, old_us_weights):
#     temp_ds_dists = ds_dists + modifier
#     temp_us_dists = us_dists - modifier
#     all_dists = np.concatenate((temp_ds_dists, temp_us_dists))
    
#     total_ds = np.sum(ds_weights)
#     if total_ds != 0:
#         temp_ds_weights = ds_weights + (ds_weights/total_ds*modifier)
#     else:
#         temp_ds_weights = ds_weights + modifier


#     total_us = np.sum(old_us_weights)
#     if total_us != 0:
#         temp_us_weights = us_weights - (old_us_weights/total_us*modifier)
#     else:
#         temp_us_weights = us_weights - modifier
# #     all_weights = np.array([1 for i in range(all_dists.shape[0])])
#     all_weights = np.concatenate((temp_ds_weights, temp_us_weights))
# #     print(all_weights)
#     dsw = DescrStatsW(all_dists, all_weights)
#     return dsw.std

In [None]:
import rooting_methods_weighted

In [None]:
# tree = Phylo.read('../../Phylogenetic_couplings/Data/psicov150_aln_pdb/raw_trees/1a3aA.newick', 'newick')
tree = Phylo.read('/Users/adamhockenberry/Downloads/BM_Folder/paper_tree.txt', 'newick')
# tree = Phylo.read(StringIO('(((A:20, B:20):30,C:50):30, D:80)'), 'newick', rooted=False)
# tree = Phylo.read('../../Tree_rooting/Tria_et_al_data/eukaryotes/ingroup/phyml/KOG0007.faa.aln.nwk', 'newick')
# Phylo.draw(tree)
# tree.root_with_outgroup('A', outgroup_branch_length=10e-8)
# C = next(tree.find_elements('C'))
# C.branch_length = 40
tree.root_with_outgroup('MAL')
# noi = next(tree.find_elements('PV22'))
# noi.branch_length += 20
# Phylo.draw(tree)

In [None]:
tree = rooting_methods.mp_root_adhock(tree)
Phylo.draw(tree)
print([(i.name, i.branch_length) for i in tree.root.clades])
tree, a, b, c = rooting_methods_weighted.ml_root_weighted(tree)
Phylo.draw(tree)
print([(i.name, i.branch_length) for i in tree.root.clades])

In [None]:
%%timeit
ml_root_weighted(tree)

In [None]:
print(tree.root.clades[1].get_terminals())

In [None]:
for key,val in c.items():
    print('#####')
    print(key)
    print(val)

In [None]:
a

In [None]:
# tree = Phylo.read('../../Tree_rooting/Tria_et_al_data/eukaryotes/ingroup/phyml/KOG0007.faa.aln.nwk', 'newick')
# tree = rooting_methods.mp_root_adhock(tree)
# Phylo.draw(tree)
tree,a,b = rooting_methods.ml_root_adhock(tree)
Phylo.draw(tree)

In [None]:
print([(i.name, i.branch_length) for i in tree.root.clades])

In [None]:
testy = [val for key, val in tree.depths().items() if key in tree.get_terminals()]

In [None]:
print(np.std(testy), np.mean(testy), np.std(testy)/np.mean(testy))
print(-np.sum(stats.norm.logpdf(testy, loc=np.mean(testy), scale=np.std(testy))))

In [None]:
print(np.std(testy), np.mean(testy), np.std(testy)/np.mean(testy))
print(-np.sum(stats.norm.logpdf(testy, loc=np.mean(testy), scale=np.std(testy))))

In [None]:
print(np.std(testy), np.mean(testy), np.std(testy)/np.mean(testy))
print(-np.sum(stats.norm.logpdf(testy, loc=np.mean(testy), scale=np.std(testy))))

In [None]:
from statsmodels.stats.weightstats import DescrStatsW
arr = np.arange(-5, 5)
weights = np.arange(9, -1, -1)  # Same size as arr
print(arr, weights)
dsw = DescrStatsW(arr, weights)
cv = dsw.std / abs(dsw.mean)  # weighted std / abs of weighted mean
print(cv)

In [None]:
from matplotlib import pyplot as plt

In [None]:
import statsmodels.api as sm

In [None]:
sm.stats.DescrStatsW.mean()

from statsmodels.stats.weightstats import DescrStatsW

cv = dsw.std / abs(dsw.mean)  # weighted std / abs of weighted mean

print(cv)
1.6583123951777001



weighted_stats = DescrStatsW(array, weights=weights, ddof=0)
>>> weighted_stats.std       


In [None]:
fig, ax = plt.subplots(1, )
ax.plot(testy.support, testy.density)

# A weighted MAD... ugh

In [None]:
import rooting_methods

In [None]:
def mad_root_weighted(tree):
    for node in tree.get_terminals() + tree.get_nonterminals():
        if node == tree.root:
            continue
        if node.branch_length == 0.:
            node.branch_length = 10e-16
    dist_df = get_lca_dist_df(tree)
    tempy_dict = {}
    tempy_dict[tree.root] = dist_df
    explored, function_optima, lca_dist_df_dict = recursive_crawl_mad(tree.root, [], [], tree, tempy_dict)
    function_optima = sorted(function_optima, key=lambda x: x[1][1])
    tree.root_with_outgroup(function_optima[0][0], outgroup_branch_length=0.)
    tree.root.clades[0].branch_length -= function_optima[0][1][0]
    tree.root.clades[1].branch_length += function_optima[0][1][0]
    RAI = function_optima[0][1][1] / function_optima[1][1][1]
    return tree, RAI, function_optima


def get_lca_dist_df(tree):
    ''' 
    Where distance matrix here is subtle. I'm actually calculating the distance to LCA for an initial 
    hypothetical bifurcating root.
    '''
    assert tree.is_bifurcating()
    initial = np.zeros((len(tree.get_terminals()),len(tree.get_terminals())))
    #Call recursive function
    recurse, finished_list = recursive_clade(initial, tree.root, finished=[])
    final = recurse - recurse.diagonal()
    term_names = [i.name for i in tree.get_terminals()]
    final_df = pd.DataFrame(final, index=term_names, columns=term_names)
    return final_df

def recursive_clade(vcv_matrix, initial_clade, finished=[]):
    ''' 
    This is kind of complicated looking but it should scale linearly with tree size
    '''
    if len(initial_clade) == 2:
        #Add branch length to relevant cells in matrix and move down the left side
        if not set(initial_clade[0].get_terminals()).issubset(set(finished)):
            clade = initial_clade[0]
            clade_term_n = len(clade.get_terminals())
            finished_n = len(finished)
            vcv_matrix[finished_n:finished_n+clade_term_n, finished_n:finished_n+clade_term_n] += clade.branch_length
            vcv_matrix, finished = recursive_clade(vcv_matrix, clade, finished)
        #Add branch length to relevant cells in matrix and move down the right side
        if not set(initial_clade[1].get_terminals()).issubset(set(finished)):
            clade = initial_clade[1]
            clade_term_n = len(clade.get_terminals())
            finished_n = len(finished)
            vcv_matrix[finished_n:finished_n+clade_term_n, finished_n:finished_n+clade_term_n] += clade.branch_length
            vcv_matrix, finished = recursive_clade(vcv_matrix, clade, finished)
    elif len(initial_clade) == 0:
        finished.append(initial_clade)
    else:
        print("ERROR: APPEARS TO BE A NON-BINARY TREE. MATRIX GENERATION WILL PROBABLY FAIL")
    return vcv_matrix, finished

def recursive_crawl_mad(hypothetical_root, explored, function_optima, tree, lca_dist_df_dict):
    if len(hypothetical_root.clades) == 2:
        l_clade, r_clade = hypothetical_root.clades
        ###Recurse on l clade
        lca_dist_df_dict, my_terms, other_terms = update_lca_dist_df_dict(lca_dist_df_dict, l_clade, hypothetical_root, tree)
        res = mad_from_df(l_clade, my_terms, other_terms, lca_dist_df_dict[l_clade])
        function_optima.append((l_clade, res))
        explored, function_optima, lca_dist_df_dict = recursive_crawl_mad(l_clade, explored, function_optima, tree, lca_dist_df_dict)
        ###Recurse on r clade
        lca_dist_df_dict, my_terms, other_terms = update_lca_dist_df_dict(lca_dist_df_dict, r_clade, hypothetical_root, tree)
        res = mad_from_df(r_clade, my_terms, other_terms, lca_dist_df_dict[r_clade])
        function_optima.append((r_clade, res))
        explored, function_optima, lca_dist_df_dict = recursive_crawl_mad(r_clade, explored, function_optima, tree, lca_dist_df_dict)
    elif len(hypothetical_root.clades) == 0:
        explored.append(hypothetical_root)
        return explored, function_optima, lca_dist_df_dict
    else:
        print('non binary tree...?')
    explored.append(hypothetical_root)
    return explored, function_optima, lca_dist_df_dict

def update_lca_dist_df_dict(lca_dist_df_dict, my_clade, parent, my_tree):
    bl = my_clade.branch_length
    downstream_terms = [i.name for i in my_clade.get_terminals()]
    upstream_terms = list(set([i.name for i in my_tree.get_terminals()]) - set(downstream_terms))
    lca_dist_df = lca_dist_df_dict[parent].copy(deep=True)
    lca_dist_df.loc[downstream_terms,upstream_terms] -= bl
    lca_dist_df.loc[upstream_terms,downstream_terms] += bl
    lca_dist_df_dict[my_clade] = lca_dist_df
    return lca_dist_df_dict, downstream_terms, upstream_terms

def mad_from_df(my_clade, my_terms, other_terms, lca_dist_df):
    '''
    Need to document this
    '''
    print('###########')
    my_df = lca_dist_df.loc[my_terms, my_terms]
    other_df = lca_dist_df.loc[other_terms, other_terms]
    my_df_trans = my_df.T
    other_df_trans = other_df.T
    print(my_df_trans)
    print(other_df_trans)
    #Dealing with same side pairs
    ss_a_dists = np.abs(np.concatenate((my_df.values[np.triu_indices(len(my_terms), k = 1)],\
                                other_df.values[np.triu_indices(len(other_terms), k = 1)])))
    ss_b_dists = np.abs(np.concatenate((my_df_trans.values[np.triu_indices(len(my_terms), k = 1)],\
                                other_df_trans.values[np.triu_indices(len(other_terms), k = 1)])))
    print(ss_a_dists, ss_b_dists)
    ss_total_dists = ss_a_dists + ss_b_dists
    ss_devs = np.abs(((2*ss_a_dists)/ss_total_dists)-1)
    print(ss_devs)
    #Dealing with different side pairs
    ds_a_dists = lca_dist_df.loc[my_terms, other_terms].values.flatten(order='C')
    ds_b_dists = lca_dist_df.loc[other_terms, my_terms].values.flatten(order='F')
    ds_total_dists = ds_a_dists + ds_b_dists

    ###Using the analytical solution to "rho" parameter as outlined in the MAD paper
    total_bl = my_clade.branch_length
    if total_bl > 0.:
        rho = np.sum((ds_total_dists-(2*ds_a_dists))*ds_total_dists**-2)/(2*total_bl*np.sum(ds_total_dists**-2))
        modifier = total_bl*rho
        modifier = min(max(0, modifier), total_bl)
    else:
        modifier = 0.

    ###Rescale the distances with the optimized modifier
    ds_a_dists = ds_a_dists + modifier
    ds_b_dists = ds_b_dists - modifier
    ds_total_dists = ds_a_dists + ds_b_dists
    ###Calculate their deviations
    ds_devs = np.abs(((2*ds_a_dists)/ds_total_dists)-1)

    ###Concatenate them with the pre-computed same side deviations (ss_devs)
    all_devs = np.concatenate((ss_devs, ds_devs))
    ###And compute final MAD score
    all_devs = all_devs**2
    dev_score = np.mean(all_devs)
    dev_score = dev_score**0.5
    return (modifier, dev_score)




In [None]:
# tree = Phylo.read('../../Phylogenetic_couplings/Data/psicov150_aln_pdb/raw_trees/1a3aA.newick', 'newick')
# tree = Phylo.read('/Users/adamhockenberry/Downloads/BM_Folder/paper_tree.txt', 'newick')
tree = Phylo.read(StringIO('(((A:20, B:20):30,C:50):30, D:80)'), 'newick', rooted=False)
# tree = Phylo.read('../../Tree_rooting/Tria_et_al_data/eukaryotes/ingroup/phyml/KOG0007.faa.aln.nwk', 'newick')
# Phylo.draw(tree)
# tree.root_with_outgroup('A', outgroup_branch_length=10e-8)
# C = next(tree.find_elements('C'))
# C.branch_length = 40
# tree.root_with_outgroup('MAL')
# noi = next(tree.find_elements('PV22'))
# noi.branch_length += 20
# Phylo.draw(tree)

In [None]:
tree = rooting_methods.mp_root_adhock(tree)
Phylo.draw(tree)
print([(i.name, i.branch_length) for i in tree.root.clades])
tree, a, b = mad_root_weighted(tree)
# Phylo.draw(tree)
# print([(i.name, i.branch_length) for i in tree.root.clades])

In [None]:
import numpy

In [None]:
tree.depths()

In [None]:
for i, terminal_a in enumerate(tree.get_terminals()):
    for j, terminal_b in enumerate(tree.get_terminals()):
        if j >= i:
            continue
        path = [terminal_a] + tree.trace(terminal_a, terminal_b)
        ca = tree.common_ancestor(terminal_a, terminal_b)
        if ca.branch_length:
            path_len = np.sum([edge.branch_length for edge in path if edge.branch_length]) - ca.branch_length
        else: 
            path_len = np.sum([edge.branch_length for edge in path if edge.branch_length])
        print(terminal_a, terminal_b, path_len)

In [None]:
testy = get_lca_dist_df(tree)

In [None]:
testy

In [None]:
all_paths = []

In [None]:
df = pd.DataFrame(index=['AB', 'AC', 'AD', 'BC', 'BD', 'CD'],\
                  columns=['AB', 'AC', 'AD', 'BC', 'BD', 'CD'], dtype=float)

In [None]:
for i in df.index:
    df.set_value(i, i, 1.)
pairs = [['AB', 'AC', 0.1],\
         ['AB', 'AD', 0.0625],\
         ['AB', 'BC', 0.1],\
         ['AB', 'BD', 0.0625],\
         ['AB', 'CD', 0.],\
         ['AC', 'AD', 0.15625],\
         ['AC', 'BC', 0.64],\
         ['AC', 'BD', 0.05625],\
         ['AC', 'CD', 0.15625],\
         ['AD', 'BC', 0.05625],\
         ['AD', 'BD', 0.765625],\
         ['AD', 'CD', 0.47265625],\
         ['BC', 'BD', 0.15625],\
         ['BC', 'CD', 0.15625],\
         ['BD', 'CD', 0.47265625]]
for i,j,k in pairs:
    df.set_value(i, j, k)
    df.set_value(j, i, k)

In [None]:
mat = df.values
inv_mat = np.linalg.inv(mat)
rowsums = np.sum(inv_mat, axis=1)
weights = rowsums/inv_mat.sum()
weights


In [None]:
fig, ax = plt.subplots()
ax.matshow(inv_mat)

In [None]:
mat

In [None]:
(np.array([40, 100, 160, 50, 160, 160])/230)

# Toy example

In [None]:
df = pd.DataFrame(index=['AB', 'AC', 'BC'],\
                  columns=['AB', 'AC', 'BC'], dtype=float)

In [None]:
for i in df.index:
    df.set_value(i, i, 1.)
pairs = [['AB', 'AC', (20**2)/(40*100)],\
         ['AB', 'BC', (20**2)/(40*100)],\
         ['AC', 'BC', (80**2)/(100*100)]]
for i,j,k in pairs:
    df.set_value(i, j, k)
    df.set_value(j, i, k)

mat = df.values
inv_mat = np.linalg.inv(mat)
rowsums = np.sum(inv_mat, axis=1)
weights = rowsums/inv_mat.sum()
weights

In [None]:
for i in df.index:
    df.set_value(i, i, 1.)
pairs = [['AB', 'AC', (5**2)/(10*100)],\
         ['AB', 'BC', (5**2)/(10*100)],\
         ['AC', 'BC', (95**2)/(100*100)]]
for i,j,k in pairs:
    df.set_value(i, j, k)
    df.set_value(j, i, k)

mat = df.values
inv_mat = np.linalg.inv(mat)
rowsums = np.sum(inv_mat, axis=1)
weights = rowsums/inv_mat.sum()
weights

In [None]:
20**2 / 4.

In [None]:
1/((1/50.) + (1/((50*130)/180)))

In [None]:
1/((1/130.)+(1/20)+(1/((50*130)/180)))

In [None]:
20.967741935483872 / (20.967741935483872+11.711711711711711+11.711711711711711)

In [None]:
112.5/2

In [None]:
mat = [[50, 30, 0],\
       [30, 50, 0],\
       [0, 0, 50]]
inv_mat = np.linalg.inv(mat)
rowsums = np.sum(inv_mat, axis=1)
weights = rowsums/inv_mat.sum()
weights, np.sum(weights)

In [None]:
(36.111111+50)/(36.111111+130+40)

In [None]:
50/102.6315

In [None]:
mat = [[50, 45, 0],\
       [45, 50, 0],\
       [0, 0, 50]]
inv_mat = np.linalg.inv(mat)
rowsums = np.sum(inv_mat, axis=1)
weights = rowsums/inv_mat.sum()
weights

In [None]:
mat = [[80, 60, 30, 0],\
       [60, 80, 30, 0],\
       [30, 30, 80, 0],\
       [0, 0, 0, 80]]
inv_mat = np.linalg.inv(mat)
rowsums = np.sum(inv_mat, axis=1)
weights = rowsums/inv_mat.sum()
weights

In [None]:
20/(110/3)

In [None]:
40/(200/3)