In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import glob
from Bio import Phylo

import numpy as np
from matplotlib import pyplot as plt

from scipy import stats

# Testing first on the basic eukaryotic set

In [3]:
def is_monophyletic_all(hypothetical_root, tree, test_set, is_mono):
    if tree.is_monophyletic(test_set):
        is_mono = True
    if len(hypothetical_root.clades) == 2:
        l_clade, r_clade = hypothetical_root.clades
        if l_clade.branch_length > 0:
            tree.root_with_outgroup(l_clade, outgroup_branch_length=10e-10)
            is_mono = is_monophyletic_all(l_clade, tree, test_set, is_mono)
        if r_clade.branch_length > 0:
            tree.root_with_outgroup(r_clade, outgroup_branch_length=10e-10)
            is_mono = is_monophyletic_all(r_clade, tree, test_set, is_mono)
    elif len(hypothetical_root.clades) == 1:
        l_clade = hypothetical_root.clades[0]
        if l_clade.branch_length > 0:
            tree.root_with_outgroup(l_clade, outgroup_branch_length=10e-10)
            is_mono = is_monophyletic_all(l_clade, tree, test_set, is_mono)
    elif len(hypothetical_root.clades) == 0:
        return is_mono
    return is_mono

# tree.get_terminals()
# tree.is_monophyletic(metazoa)
# tree = Phylo.read('../test.ete3.newick', 'newick', rooted=False)
# tree = mp_root_adhock(tree)
# tree.is_monophyletic([term for term in tree.get_terminals() if\
#                       term.name in ['7165', '7425', '7460', '121225', '7227', '6239']])

In [None]:
trees_dir = '../Data/raw_OMA_trees/*Rooted.MPAJH'
methods = ['.MPAJH', '.MLAJH', '.MADAJH']
accuracy_dict = {}
for method in methods:
    accuracy_dict[method] = []

trees_tested = []

for tree_loc in glob.glob(trees_dir)[:500]:
    print(tree_loc)
    test_tree = Phylo.read(tree_loc, 'newick', rooted=True)
    if len(test_tree.get_terminals()) != ideal_species_n:
        continue
    testy = [term for term in test_tree.get_terminals() if\
                      term.name in metazoa]
    valid = is_monophyletic_all(test_tree.root, test_tree, testy, False)
    if valid == False:
        continue
        
    trees_tested.append(tree_loc)
    for method in methods:
        my_tree = Phylo.read(tree_loc.replace('.MPAJH', method), 'newick', rooted=True)
        metazoa_clades = [term for term in my_tree.get_terminals() if\
                          term.name in metazoa]
        if set(metazoa_clades) == set(my_tree.root.clades[0].get_terminals()) or \
            set(metazoa_clades) == set(my_tree.root.clades[1].get_terminals()):
            accuracy_dict[method].append(1)
        else:
            accuracy_dict[method].append(0)
        
    ###ML
#     ml_tree = Phylo.read(tree_loc.replace('.MPAJH', '.MLAJH'), 'newick', rooted=True)
#     metazoa_clades = [term for term in ml_tree.get_terminals() if\
#                       term.name in metazoa]
#     if set(metazoa_clades) == set(ml_tree.root.clades[0].get_terminals()) or \
#         set(metazoa_clades) == set(ml_tree.root.clades[1].get_terminals()):
#         print('ML correct!')
            
#     ###ML        
#     ml_tree = max_likelihood_root(tree)
#     if set(testy) == set(ml_tree.root.clades[0].get_terminals()) or \
#         set(testy) == set(ml_tree.root.clades[1].get_terminals()):
#             ml_success_rate += 1
#     ###MAD
#     testy = [term for term in mad_tree.get_terminals() if\
#                   term.name in metazoa]
#     if set(testy) == set(mad_tree.root.clades[0].get_terminals()) or \
#         set(testy) == set(mad_tree.root.clades[1].get_terminals()):
#             mad_success_rate += 1
    

In [None]:
np.sum(accuracy_dict['.MPAJH']), np.sum(accuracy_dict['.MLAJH']), np.sum(accuracy_dict['.MADAJH'])

In [None]:
len(trees_tested)

In [None]:
np.array([np.sum(accuracy_dict['.MPAJH']),\
         np.sum(accuracy_dict['.MLAJH']),\
         np.sum(accuracy_dict['.MADAJH'])])/len(trees_tested)

In [None]:
success = []
failure = []
for i, tree_loc in enumerate(trees_tested):  
    if accuracy_dict['.MADAJH'][i] == 1 and accuracy_dict['.MPAJH'][i] == 0:
        mad_tree = Phylo.read(tree_loc.replace('.MPAJH', '.MADAJH'), 'newick', rooted=True)
        temp_bl = mad_tree.root.clades[0].branch_length + mad_tree.root.clades[1].branch_length
        smallest = min(mad_tree.root.clades[0].total_branch_length()-mad_tree.root.clades[0].branch_length,\
                       mad_tree.root.clades[1].total_branch_length()-mad_tree.root.clades[1].branch_length)
        success.append(smallest/(mad_tree.total_branch_length()-temp_bl))
    
    elif accuracy_dict['.MADAJH'][i] == 0 and accuracy_dict['.MPAJH'][i] == 1:
        mp_tree = Phylo.read(tree_loc.replace('.MPAJH', '.MPAJH'), 'newick', rooted=True)
        temp_bl = mp_tree.root.clades[0].branch_length + mp_tree.root.clades[1].branch_length
        smallest = min(mp_tree.root.clades[0].total_branch_length()-mp_tree.root.clades[0].branch_length,\
                       mp_tree.root.clades[1].total_branch_length()-mp_tree.root.clades[1].branch_length)
        failure.append(smallest/(mp_tree.total_branch_length()-temp_bl))
        


In [None]:
np.mean(success)

In [None]:
np.mean(failure)

In [None]:
fig, ax = plt.subplots()
ax.hist(success, alpha=0.5)
ax.hist(failure, alpha=0.5)


In [None]:
stats.ranksums(failure, success)

In [None]:
stats.ttest_ind(failure, success, equal_var=False)

# Test variability/robustness in distance

In [None]:
trees_dir = '../Data/euk_trees/*Rooted.MPAJH'
ideal_species_n = 31

# trees_dir = '../Data/pruned_euk_trees/*_9_meta.nwk.Rooted.MPAJH'
# ideal_species_n = 22

# trees_dir = '../Data/pruned_euk_trees/*_12_meta.nwk.Rooted.MPAJH'
# ideal_species_n = 19


# methods = ['.MPAJH', '.MLAJH']
methods = ['.MPAJH', '.MLAJH', '.MADAJH']
accuracy_dict = {}
for method in methods:
    accuracy_dict[method] = []

trees_tested = []
for tree_loc in glob.glob(trees_dir)[:]:
    print(tree_loc)
    test_tree = Phylo.read(tree_loc, 'newick', rooted=True)
    if len(test_tree.get_terminals()) != ideal_species_n:
        continue
    testy = [term for term in test_tree.get_terminals() if\
                      term.name in metazoa]
    valid = is_monophyletic_all(test_tree.root, test_tree, testy, False)
    if valid == False:
        continue
        
    trees_tested.append(tree_loc)
    for method in methods:
        my_tree = Phylo.read(tree_loc.replace('.MPAJH', method), 'newick', rooted=True)
        metazoa_clades = [term for term in my_tree.get_terminals() if\
                          term.name in metazoa]
        non_metazoa_clades = [term for term in my_tree.get_terminals() if\
                          term.name not in metazoa]

        if set(metazoa_clades) == set(my_tree.root.clades[0].get_terminals()) or \
            set(metazoa_clades) == set(my_tree.root.clades[1].get_terminals()):
            all_ca = my_tree.common_ancestor(non_metazoa_clades)
            accuracy_dict[method].append(my_tree.distance(all_ca, my_tree.root))

In [None]:
fig, ax = plt.subplots()
ax.hist(accuracy_dict['.MPAJH'], alpha=0.2)
# ax.hist(accuracy_dict['.MLAJH'], alpha=0.2)
ax.hist(accuracy_dict['.MADAJH'], alpha=0.2)

In [None]:
trees_dir = '../Data/euk_trees/*Rooted.MPAJH'
ideal_species_n = 31

# trees_dir = '../Data/pruned_euk_trees/*_9_meta.nwk.Rooted.MPAJH'
# ideal_species_n = 22

# trees_dir = '../Data/pruned_euk_trees/*_12_meta.nwk.Rooted.MPAJH'
# ideal_species_n = 19


# methods = ['.MPAJH', '.MLAJH']
methods = ['.MPAJH', '.MLAJH', '.MADAJH']
accuracy_dict = {}
for method in methods:
    accuracy_dict[method] = []

trees_tested = []
for tree_loc in glob.glob(trees_dir)[:50]:
    print(tree_loc)
    test_tree = Phylo.read(tree_loc, 'newick', rooted=True)
    if len(test_tree.get_terminals()) != ideal_species_n:
        continue
    testy = [term for term in test_tree.get_terminals() if\
                      term.name in metazoa]
    valid = is_monophyletic_all(test_tree.root, test_tree, testy, False)
    if valid == False:
        continue
        
    trees_tested.append(tree_loc)
    for method in methods:
        my_tree = Phylo.read(tree_loc.replace('.MPAJH', method), 'newick', rooted=True)
        metazoa_clades = [term for term in my_tree.get_terminals() if\
                          term.name in metazoa]
        non_metazoa_clades = [term for term in my_tree.get_terminals() if\
                          term.name not in metazoa]

        if set(metazoa_clades) == set(my_tree.root.clades[0].get_terminals()) or \
            set(metazoa_clades) == set(my_tree.root.clades[1].get_terminals()):
            all_ca = my_tree.common_ancestor(non_metazoa_clades)
            initial_dist = my_tree.distance(all_ca, my_tree.root)
        

            pruned_tree_loc = tree_loc.replace('/euk_trees/', '/pruned_euk_trees/')
            pruned_tree_loc = pruned_tree_loc.replace('.nwk.Rooted.MPAJH', '.pruned_9_meta.nwk.Rooted.MPAJH')
            pruned_tree_loc = pruned_tree_loc.replace('.MPAJH', method)
            pruned_tree = Phylo.read(pruned_tree_loc, 'newick', rooted=True)
            metazoa_clades = [term for term in pruned_tree.get_terminals() if\
                          term.name in metazoa]
            non_metazoa_clades = [term for term in pruned_tree.get_terminals() if\
                          term.name not in metazoa]

            if set(metazoa_clades) == set(pruned_tree.root.clades[0].get_terminals()) or \
                set(metazoa_clades) == set(pruned_tree.root.clades[1].get_terminals()):
                pruned_ca = pruned_tree.common_ancestor(non_metazoa_clades)
                pruned_dist = pruned_tree.distance(pruned_ca, pruned_tree.root)
#                 print(method, initial_dist, pruned_dist, initial_dist-pruned_dist)
#                 print(method, initial_dist-pruned_dist)
                accuracy_dict[method].append(initial_dist-pruned_dist)

In [None]:
fig, ax = plt.subplots()
# ax.hist(accuracy_dict['.MPAJH'], alpha=0.2)
ax.hist(accuracy_dict['.MLAJH'], alpha=0.2)
ax.hist(accuracy_dict['.MADAJH'], alpha=0.2)