In [None]:
#numerics and vis
import math
import numpy as np
import pandas as pd
from scipy import stats
import altair as alt
import altair_saver
import sys
import subprocess

import ete3
from ete3 import Tree, TreeStyle, TextFace


root = "/data/tobiassonva/data/eukgen/"
sys.path.insert(0, root)
%cd {root}

from core_functions.microcosm_functions import color_tree, tree_analysis
from core_functions.altair_plots import plot_alignment, plot_cumsum_counts
from core_functions.tree_functions import get_outlier_nodes_by_lognorm

#disable altair max rows
alt.data_transformers.disable_max_rows()
#get default altair style'
%run ~/scripts/altair_style_config_default.py

#dont wrap text output from cells
from IPython.core.display import display, HTML
display(HTML("<style>div.output_area pre {white-space: pre;}</style>"))






In [None]:
#load phylogeny information

import pickle 
#small helper for pkl parsing
def load_pkl(pkl_file):
    import parseHHsuite as HH
    with open(pkl_file, 'rb') as infile:
        item = pickle.load(infile)
    return item

#lighter parsed version
prok_tax = load_pkl('analysis/core_data/prok2111_protein_taxonomy_trimmed.pkl')
print('Loading taxonomy info')
euk_tax = load_pkl('euk72/euk72_protein_taxonomy.pkl')
euk_tax.drop(['orgid', 'species'], axis=1, inplace=True)
tax_merge = pd.concat([euk_tax, prok_tax])

euk_header = load_pkl('analysis/core_data/euk72_header_mapping.pkl')

In [None]:
#tools and functions from from TreeClust.py 
#https://github.com/niemasd/TreeCluster

from niemads import DisjointSet
from queue import PriorityQueue,Queue
from treeswift import read_tree_newick

# initialize properties of input tree and return set containing taxa of leaves
def prep(tree, support, resolve_polytomies=True, suppress_unifurcations=True):
    if resolve_polytomies:
        tree.resolve_polytomies()
    if suppress_unifurcations:
        tree.suppress_unifurcations()
    leaves = set()
    for node in tree.traverse_postorder():
        if node.edge_length is None:
            node.edge_length = 0
        node.DELETED = False
        if node.is_leaf():
            leaves.add(str(node))
        else:
            try:
                node.confidence = float(str(node))
            except:
                node.confidence = 100. # give edges without support values support 100
            if node.confidence < support: # don't allow low-support edges
                node.edge_length = float('inf')
    return leaves

# cut out the current node's subtree (by setting all nodes' DELETED to True) and return list of leaves
def cut(node):
    cluster = list()
    descendants = Queue(); descendants.put(node)
    while not descendants.empty():
        descendant = descendants.get()
        if descendant.DELETED:
            continue
        descendant.DELETED = True
        descendant.left_dist = 0; descendant.right_dist = 0; descendant.edge_length = 0
        if descendant.is_leaf():
            cluster.append(str(descendant))
        else:
            for c in descendant.children:
                descendants.put(c)
    return cluster



# split leaves into minimum number of clusters such that the maximum leaf pairwise distance is below some threshold
def min_clusters_threshold_max(tree,threshold,support):
    leaves = prep(tree,support)
    clusters = list()
    for node in tree.traverse_postorder():
        # if I've already been handled, ignore me
        if node.DELETED:
            continue

        # find my undeleted max distances to leaf
        if node.is_leaf():
            node.left_dist = 0; node.right_dist = 0
        else:
            children = list(node.children)
            if children[0].DELETED and children[1].DELETED:
                cut(node); continue
            if children[0].DELETED:
                node.left_dist = 0
            else:
                node.left_dist = max(children[0].left_dist,children[0].right_dist) + children[0].edge_length
            if children[1].DELETED:
                node.right_dist = 0
            else:
                node.right_dist = max(children[1].left_dist,children[1].right_dist) + children[1].edge_length

            # if my kids are screwing things up, cut out the longer one
            if node.left_dist + node.right_dist > threshold:
                if node.left_dist > node.right_dist:
                    cluster = cut(children[0])
                    node.left_dist = 0
                else:
                    cluster = cut(children[1])
                    node.right_dist = 0

                # add cluster
                if len(cluster) != 0:
                    clusters.append(cluster)
                    for leaf in cluster:
                        leaves.remove(leaf)

    # add all remaining leaves to a single cluster
    if len(leaves) != 0:
        clusters.append(list(leaves))
    return clusters

# min_clusters_threshold_max, but all clusters must define a clade
def min_clusters_threshold_max_clade(tree,threshold,support):
    leaves = prep(tree, support, resolve_polytomies=False)

    # compute leaf distances and max pairwise distances
    for node in tree.traverse_postorder():
        if node.is_leaf():
            node.leaf_dist = 0; node.max_pair_dist = 0
        else:
            node.leaf_dist = float('-inf'); second_max_leaf_dist = float('-inf')
            for c in node.children: # at least 2 children because of suppressing unifurcations
                curr_dist = c.leaf_dist + c.edge_length
                if curr_dist > node.leaf_dist:
                    second_max_leaf_dist = node.leaf_dist; node.leaf_dist = curr_dist
                elif curr_dist > second_max_leaf_dist:
                    second_max_leaf_dist = curr_dist
            node.max_pair_dist = max([c.max_pair_dist for c in node.children] + [node.leaf_dist + second_max_leaf_dist])

    # perform clustering
    q = Queue(); q.put(tree.root); roots = list()
    while not q.empty():
        node = q.get()
        if node.max_pair_dist <= threshold:
            roots.append(node)
        else:
            for c in node.children:
                q.put(c)

    return [[str(l) for l in root.traverse_leaves()] for root in roots]


# average leaf pairwise distance cannot exceed threshold, and clusters must define clades
def min_clusters_threshold_avg_clade(tree,threshold,support):
    leaves = prep(tree,support)

    # bottom-up traversal to compute average pairwise distances
    for node in tree.traverse_postorder():
        node.total_pair_dist = 0; node.total_leaf_dist = 0
        if node.is_leaf():
            node.num_leaves = 1
            node.avg_pair_dist = 0
        else:
            children = list(node.children)
            node.num_leaves = sum(c.num_leaves for c in children)
            node.total_pair_dist = children[0].total_pair_dist + children[1].total_pair_dist + (children[0].total_leaf_dist*children[1].num_leaves + children[1].total_leaf_dist*children[0].num_leaves)
            node.total_leaf_dist = (children[0].total_leaf_dist + children[0].edge_length*children[0].num_leaves) + (children[1].total_leaf_dist + children[1].edge_length*children[1].num_leaves)
            node.avg_pair_dist = node.total_pair_dist/((node.num_leaves*(node.num_leaves-1))/2)

    # perform clustering
    q = Queue(); q.put(tree.root); roots = list()
    while not q.empty():
        node = q.get()
        if node.avg_pair_dist <= threshold:
            roots.append(node)
        else:
            for c in node.children:
                q.put(c)

    return [[str(l) for l in root.traverse_leaves()] for root in roots]

# cut tree at threshold distance from root (clusters will be clades by definition) (ignores support threshold if branch is below cutting point)
def root_dist(tree,threshold,support):
    leaves = prep(tree,support)
    clusters = list()
    for node in tree.traverse_preorder():
        # if I've already been handled, ignore me
        if node.DELETED:
            continue
        if node.is_root():
            node.root_dist = 0
        else:
            node.root_dist = node.parent.root_dist + node.edge_length
        if node.root_dist > threshold:
            cluster = cut(node)
            if len(cluster) != 0:
                clusters.append(cluster)
                for leaf in cluster:
                    leaves.remove(leaf)

    # add all remaining leaves to a single cluster
    if len(leaves) != 0:
        clusters.append(list(leaves))
    return clusters






In [None]:
#input helper to read tsv from file, merge singletons
def read_cluster_tsv(cluster_file, split_large=False, max_size=500, batch_single=False, single_cutoff=1):   

    #read TSV and group clusters based on first tsv column
    with open(cluster_file, 'r') as infile:
        clusters = {}

        for l in infile.readlines():
            cluster_acc, acc = l.strip().split('\t')

            if cluster_acc not in clusters.keys():
                clusters[cluster_acc] = [acc]

            else:
                clusters[cluster_acc].append(acc)
    
    #merge all clusters smaller than cutoff into one
    if batch_single:
        filter_dict = {}
        singles = []
        for key, accs in clusters.items():  
            #gather singletons
            if len(accs) <= single_cutoff:
                singles.extend(accs)
            
            #keep larger clusters
            else:
                filter_dict[key] = accs
                
        if singles:
            filter_dict[singles[0]] = singles
        
        clusters = filter_dict
    
    #split clusters larger than x into smaller pieces
    if split_large:
        filter_dict = {}
        for key, accs in clusters.items():  
            if len(accs) > max_size:
                #partition large clusters into batches of max_size
                for split in range(0, len(accs), max_size):
                    batch = accs[split:split + max_size]
                    filter_dict[batch[0]] = batch

            #keep smaller clusters
            else:
                filter_dict[key] = accs
                
        clusters = filter_dict         
        
    return clusters


#quick function for adding phylogenetic annotation to tree labels
def dirty_phyla_add(tree, tax_mapping):
    euk_entries = []
    for leaf in tree.get_leaves():
        try:
            acc = leaf.name
            entry = tax_mapping.loc[acc]
            leaf.add_feature('tax_superkingdom', entry['superkingdom'])
            
            leaf.add_feature('tax_class', entry['class'])
            
            if leaf.tax_superkingdom == 'Eukaryota':
                leaf.add_feature('tax_filter', 'Eukaryota')
            else:
                leaf.add_feature('tax_filter', entry['class'])
            
        except KeyError:
            leaf.add_feature('superkingom', 'ERROR')
            leaf.add_feature('class', 'ERROR')

#write pdf tree under /data/tobiassonva/data/eukgen/tmp_trees
#overwrites previous treee in directory unless specified name
#the file path is always relative to the notbook starting directory due to the "Tornado" viewer restrictions
def view_tree(tree, ts=None, name='test_tree.pdf', backup=None):
    from IPython.display import IFrame
    tree_img_root = '/data/tobiassonva/data/eukgen/analysis/tmp_trees/'
    tree.render(tree_img_root+name, tree_style=ts)
    
    if backup != None:
        import os
        os.system(f'cp {tree_img_root+name} {backup}')
    return IFrame('tmp_trees/'+name, width=950, height=950*1.5)


#return LCA for list of leaf names
def get_LCA(tree, leaf_names):
    leaves = [tree.get_leaves_by_name(name)[0] for name in leaf_names]
    LCA = leaves[0].get_common_ancestor(leaves)
    return LCA

def cluster_TreeClust(tree, threshold):
    fasttree = read_tree_newick(tree.write())
    tree_clusters = min_clusters_threshold_max_clade(fasttree, threshold, -999999999)
    #tree_clusters = min_clusters_threshold_avg_clade(fasttree, threshold, -999999999)    
    return  tree_clusters

#extract all against all dsitance matrix for ete3 trees
def calculate_pairwise_distances(tree):
    leaves = tree.get_leaves()
    pairwise_mat = np.zeros((len(leaves),len(leaves)))
    for i, m in enumerate(leaves):
        for j, k in enumerate(leaves[i+1:]):
            pairwise_mat[i,j] = tree.get_distance(m, k)
    return pairwise_mat


#return closest non self leaf
def get_closest_leaf(leaf):
    
    near_leaves = [near_leaf for near_leaf in leaf.up.get_leaves() if near_leaf!=leaf]
    distances = [leaf.get_distance(near_leaf) for near_leaf in near_leaves]
    min_dist =min(distances)
    closest_leaf = near_leaves[distances.index(min_dist)]
    
    return closest_leaf, min_dist
    
#iteratively merge closest leaf pair until less than N leaves
def reduce_leaves_to_size(tree, max_size):
    current_size = len(tree.get_leaves())
    
    if max_size >= current_size:
        print(f'Tree of length {current_size} smaller than {max_size}')
        return tree    
    
    leaf_partners = {leaf.name:get_closest_leaf(leaf) for leaf in tree.get_leaves()}
    leaf_partner_dist = {key:value[1] for key, value in leaf_partners.items()}
    leaf_partners = {key:value[0].name for key, value in leaf_partners.items()}
    
    #print(leaf_partner_dist)
    #print(leaf_partners)
    
    
    #serial implementetaion, not ideal as it produced uneven pruning if terminal brach length are very even.
    while max_size < current_size:
        
        min_distance = min(leaf_partner_dist.values())
        min_leaf_A = list(leaf_partner_dist.keys())[list(leaf_partner_dist.values()).index(min_distance)]
        
        min_leaf_B = leaf_partners[min_leaf_A]
        
        #print(f'{current_size} checking {min_leaf_A}, deleting closest partner is {min_leaf_B} with distance {min_distance}')
        
        #delete closest leaf
        try:
            tree.get_leaves_by_name(min_leaf_B)[0].delete()
        
        #if current leaf A maps to a deleted leaf update closest leaf and delete
        except IndexError:
            #update the min_leaf with new closest pair
            new_leaf = tree.get_leaves_by_name(min_leaf_A)[0]
            closest_new_leaf = get_closest_leaf(new_leaf)

            leaf_partner_dist[new_leaf.name] = closest_new_leaf[1]
            leaf_partners[new_leaf.name] = closest_new_leaf[0].name
            min_leaf_B = leaf_partners[min_leaf_A]
            
            tree.get_leaves_by_name(min_leaf_B)[0].delete()
            
        
        #delete the removed partner from dictionaries
        leaf_partner_dist.pop(min_leaf_B, None)
        leaf_partners.pop(min_leaf_B, None)
        
        #update the min_leaf with new closest pair
        new_leaf = tree.get_leaves_by_name(min_leaf_A)[0]
        closest_new_leaf = get_closest_leaf(new_leaf)
        
        leaf_partner_dist[new_leaf.name] = closest_new_leaf[1]
        leaf_partners[new_leaf.name] = closest_new_leaf[0].name
        
        current_size -= 1
        
        #print(leaf_partner_dist)
        #print(leaf_partners)

    
    return tree


#fit inverse gamma distribution to all internal node distances and 
#exclude those beyong threshold probability
def get_outlier_nodes_by_invgamma(tree, p_low=0, p_high=0.99, only_leaves=False):
    
    if only_leaves:
        node_dists = [(node, node.dist) for node in tree.get_leaves()]
    
    else:
        node_dists = [(node, node.dist) for node in tree.traverse()]
    
    dist_series = pd.Series([i[1] for i in node_dists])
    
    dist_stats = dist_series.describe()
    
    fit_alpha, fit_loc, fit_beta=stats.invgamma.fit(dist_series.values)

    cutoff_high = stats.invgamma.ppf(p_high, a=fit_alpha, loc=fit_loc, scale=fit_beta)
    cutoff_low = stats.invgamma.ppf(p_low, a=fit_alpha, loc=fit_loc, scale=fit_beta)
    
    outlier_nodes = [node[0] for node in node_dists if node[1] < cutoff_low or node[1] > cutoff_high]
    print(f'Identified {len(outlier_nodes)} outlier nodes outside interval {cutoff_low} > d > {cutoff_high}')
    
    return outlier_nodes

#starting from one leaf with an attribute traverse upwards untill 
#all leaves from the ancestor is no loger monophyletic under the given attribute
#repeat for all remaining leaves
#if any clade would have the global root as ancestor rerooot and retry to avoid 
#false paraphyly created by tree data struture
def get_paraphyletic_groups(tree, attribute='tax_superkingdom', attr_value='Eukaryota', current_root=False):
    
    #tree.set_outgroup(tree.get_farthest_leaf()[0])
    
    if current_root:
        tree.set_outgroup(current_root)
    else:
        current_root = tree.get_tree_root()
    
    #get a list of all leaves with an attribute matching the match value provided
    check_leaves = [leaf for leaf in tree.get_leaves() if getattr(leaf, attribute)==attr_value]
    clade_nodes = []
    
    seed_node = check_leaves[0]    

    while check_leaves:
        #assume monophyly
        mono = True
        
        #check for all parent leaves if attribute matches the value, if not its not monophyletic, break 
        for leaf in seed_node.up.get_leaves():
            if getattr(leaf, attribute)!=attr_value:
                mono = False
                break
        
        #if monophyletic try higher node 
        if mono:
            seed_node = seed_node.up
        
        #else retrun node and exclude all leaves from list of leaves to check
        else:
            clade_nodes.append(seed_node)
            check_leaves = [leaf for leaf in check_leaves if leaf not in seed_node.get_leaves()]
            if check_leaves:
                seed_node = check_leaves[0] 
    
    #if parent has no parent it is the root
    if [node for node in clade_nodes if node.up.up == None]:
        #print('A tree clade has rooted parent nodes, rerooting')
        
        #get the first non-clade daughter from current root 
        non_clade_daughter = [node for node in current_root.children if node not in clade_nodes][0] 
        
        return get_paraphyletic_groups(tree, attribute=attribute, attr_value=attr_value, current_root=non_clade_daughter)

    else:
        return clade_nodes
    
    
#calculate the entropy of a list of labels
def calculate_label_entropy(l):
    H = 0
    size = len(l)
    for i in set(l):
        n = l.count(i)
        hl = (n/size)*np.log2(1/(n/size))
        H += hl
    return H




#return the set of all leaf pairs
def get_sister_leaf_sets(tree):
    leaves = set(tree.get_leaves())
    sister_sets = []
    while leaves:
        leaf = leaves.pop()
        sister = leaf.get_sisters()[0]
        if sister in leaves:
            sister_sets.append(set((leaf, sister)))
        
    return sister_sets

#give each node in tree a feature numerical id
def add_node_ids(tree, strategy='postorder'):
    for i, node in enumerate(tree.traverse(strategy=strategy)):
        node.add_feature('post_i', i)
    return tree

#count all combintions of nodes such that no node is a decendant of any other 
#infeasible for more than 30 nodes as combinations scale fast!
def enumerate_all_clades(tree):
    stack = {}
    clades = []

    #stack a copy of the input tree as first tree
    #use ID to prevent duplicates
    subtree = tree.copy('deepcopy')
    stack[subtree.get_topology_id()] = subtree

    while stack:
        print(f'Stack contains {len(stack)}')
        
        #take the last element out of the stack by id
        subtree_id = list(stack.keys()).pop()
        subtree = stack.pop(subtree_id)
        
        #save the leaf configuration
        clades.append(tuple(leaf.post_i for leaf in subtree.get_leaves()))

        sister_sets = get_sister_leaf_sets(subtree)
        for pair in sister_sets:

            ids = [i.post_i for i in pair]

            recursetree = subtree.copy()
            
            #collapse sister nodes
            for node in recursetree.traverse():
                if node.post_i in ids:
                    node.detach()
            
            #add cropped tree to stack if the root has children
            if recursetree.children:
                stack[recursetree.get_topology_id()] = recursetree

    return clades

In [None]:
#each leaf has a weight of its last node depth, 
#sum weights of daughters postorder then normalize by total weight
#---most consistent one so far---
def weight_tree_nodes_bottom_up(tree, add_residual=True):
    root = tree.get_tree_root()
    
    #add residual to correct for branch lengths of 0
    residual = 0 
    if add_residual:
        residual = tree.get_farthest_node()[1]/len(tree)
        
    for node in tree.traverse(strategy='postorder'):
        if node.is_leaf():
            node.add_feature('weight', node.get_distance(root))
            #print(f'{node.name} has weight {node.weight}')
        else:
            weight = sum([child.weight+residual for child in node.children])
            node.add_feature('weight', weight)
            #print(f'Internal {node.name} has weight {node.weight}')
    
    total_weight = sum([child.weight for child in tree.children])
    
    for node in tree.traverse(strategy='postorder'):
        node.weight = node.weight/total_weight
    
    return total_weight

#leaves have weight of distance to root
#nodes have summed weight of children minus own distance to root
def weight_tree_nodes_bottom_up_corrected(tree, add_residual=True):
    root = tree.get_tree_root()
    
    #add residual to correct for branch lengths of 0
    residual = 0 
    if add_residual:
        residual = tree.get_farthest_node()[1]/len(tree)
        
    for node in tree.traverse(strategy='postorder'):
        if node.is_leaf():
            node.add_feature('weight', node.get_distance(root)+residual)
            #print(f'{node.name} has weight {node.weight}')
        else:
            weight = node.get_distance(root) + sum([child.dist for child in node.children])
            node.add_feature('weight', weight+residual)
            #print(f'Internal {node.name} has weight {node.weight}')
    
    total_weight = sum([child.weight for child in tree.children])
    
    for node in tree.traverse(strategy='postorder'):
        node.weight = node.weight/total_weight
    
    return total_weight

#each node distributes weight among its decendants dependent on distance 
def weight_tree_nodes(tree, add_residual=True):
    root = tree.get_tree_root()
    
    #add residual to correct for branch lengths of 0
    residual = 0 
    if add_residual:
        residual = 1/tree.get_farthest_node()[1]/len(tree)
    
    for node in tree.traverse(strategy='levelorder'):
        if node == root:
            #print('node is root weight is 1')
            tree.add_feature('weight', 1)
        
        #print(f'node has {len(node.children)} children')
        total_dist = sum([child.dist+residual for child in node.children])
        
        #print(total_dist)
        for child in node.children:
            weight = node.weight*child.dist/total_dist
            child.add_feature('weight', weight)

#leaves have weights of distance to root
#nodes have weights of sum of children
#then all nodes children get probabilities assigned as node.prob*(node.weight-child.weight)/node.weight
def weight_tree_nodes_up_down(tree, add_residual=True):
    root = tree.get_tree_root()
    total_weight = 0
    #add residual to correct for branch lengths of 0
    
    residual = 0 
    if add_residual:
        residual = tree.get_farthest_node()[1]/len(tree)

    for node in tree.traverse(strategy='postorder'):
        if node.is_leaf():
            node.add_feature('weight', node.get_distance(root))
        else:
            weight = sum([child.weight+residual for child in node.children])
            node.add_feature('weight', weight)
            
    root.prob = 1
    
    for node in tree.traverse(strategy='levelorder'):
        #print(node.name, node.weight, node.prob)
        if len(node.children) == 1:
            #for linked internal nodes without branches 
            child.add_feature('prob', node.prob)
        else:
            for child in node.children:
                prob = node.prob*(node.weight-child.weight)/node.weight
                child.add_feature('prob', prob)
                #print(child.name, child.prob)

    return total_weight


#leaves have weight equal to the distance to last node
#nodes have weights equal to sum of children
#normalised by total weight
def weight_tree_nodes_local(tree, add_residual=True):
    root = tree.get_tree_root()
    
    residual = 0 
    if add_residual:
        residual = tree.get_farthest_node()[1]/len(tree)
        
    for node in tree.traverse(strategy='postorder'):
        if node.is_leaf():
            node.add_feature('weight', node.dist+residual)
        else:
            weight = sum([child.weight+residual for child in node.children])
            node.add_feature('weight', weight)
        
    for node in tree.traverse(strategy='levelorder'):
        sum_weight = sum([child.dist for child in node.children])+node.dist
        for child in node.children: 
            prob = child.dist/sum_weight
            child.add_feature('prob', prob) 
            
#leaves have weights equal to the root distance normalised of all leaves
#nodes have weights equal to the sum of their children
def weight_tree_nodes_top_down_prob(tree, add_residual=True):
    root = tree.get_tree_root()
    
    residual = 0 
    if add_residual:
        residual = tree.get_farthest_node()[1]/len(tree)

    root.add_feature('weight', 1)
    
    for node in tree.traverse(strategy='levelorder'):
        sum_dists = sum([child.dist+residual for child in node.children])
        
        #for nested unbranched nodes in pathological cases
        if len(node.children) == 1:
            node.children[0].add_feature('weight', node.weight)
            
        else:    
            for child in node.children:
                #farther nodes are lesss probable
                weight = node.weight*(sum_dists-child.dist+residual)/sum_dists
                #farther nodes are more probable
                #weight = node.weight*(child.dist+residual)/sum_dists
                child.add_feature('weight', weight)

                
#using precomputed weights
def calculate_mututal_I_and_distortion(tree, nodes, uniform=False):
    #initialise
    leaves = tree.get_leaves()
    depth = tree.get_farthest_leaf()[1]
    Icx = 0
    dis = 0
    tiny = np.finfo(np.float64).eps


    Px = np.array([leaf.weight for leaf in leaves])

    for C in nodes:
        
        Pc = C.weight
        #mututal information is done over all x
        Plca2 = np.array([C.get_common_ancestor(x).weight**2 for x in leaves])
        Icx += (((Pc**2)/Plca2)*Px*np.log2(Pc/Plca2)).sum()
        
        #distortion calculations are done over all x
        #difference in probabilities
        d = np.abs(Px-Pc)
        dis += (((Pc**2)/Plca2)*Px*d).sum()

    return dis, Icx

In [None]:
#evalutaing automatic clustering scoring usign MI and compression

tree_file = 'microcosm2/OAE21175.1/test/WP_172991781_1.cluster.clu.euk.tree'
seq_file = 'microcosm2/OAE21175.1/test/WP_172991781_1.cluster.clu'
cluster_file = 'microcosm2/KAA6409619.1/KAA6409619.1.members.cluster.tsv'

mmseqs_clusters = [value for key, value in read_cluster_tsv(cluster_file).items()]

tree = Tree(tree_file)
tree_depth = tree.get_farthest_leaf()[1]
tree_shallow = tree.get_closest_leaf()[1]
#add weights
weight_tree_nodes_bottom_up(tree, add_residual=True)
print(tree_depth, tree_shallow, tree_file, seq_file)

#lower T less clusters
#set the distortion cator to average of extremes
T=(tree_depth-tree_shallow)/2
#T=1

cluster_data = pd.DataFrame(columns=['clusters', 'dis', 'Icx', 'dis+Icx'])
partitions = {}

#
#DEAL WITH THE SINGLE CLASS EVENT

best_score = 9999999
current_clusters = 0
break_count = 0
break_limit = 5

for f in np.linspace(1.5,0,2).round(2):
    print(f)
    #cluster the tree using a TreeClust method
    tree_clusters = cluster_TreeClust(tree, tree_depth*f)
    #get LCA ancestors and calculate MI and distortion
    tree_cluster_LCA = [get_LCA(tree, cluster) for cluster in tree_clusters]


    distortion, Icx  = calculate_mututal_I_and_distortion(tree, tree_cluster_LCA, T)
    
    #reweight distortion
    distortion = distortion*T
    
    #save each partition for analysis
    partitions[f] = (tree_clusters, tree_cluster_LCA)
    cluster_data.loc[f] = [len(tree_clusters), distortion, Icx, (distortion+Icx)]
    
    

    #add break counter if clsuters increased and optimal value did not decrease
    if (distortion+Icx) < best_score:
        best_score = distortion+Icx
        current_clusters = len(tree_clusters)
        break_count = 0
        print('Score improoved, reset break counter')
        print(f, len(tree_clusters), distortion, Icx, (distortion+Icx), sep='\t')
    
    #break after repeated failures with increasing clusters
    elif len(tree_clusters) != current_clusters:
        current_clusters = len(tree_clusters)
        
        print(f'Score did not improove, despite different clusters try: {break_count}')
        print(f, len(tree_clusters), distortion, Icx, (distortion+Icx), sep='\t')
        break_count += 1
        
        if break_count > break_limit:
            print('Failed to improve after 5 attempts breaking!')
            break
        
#parse data for plotting
cluster_data.reset_index(inplace=True, names='f')
cluster_data.drop_duplicates('clusters', inplace=True)
#cluster_data.sort_values(by='f', inplace=True)
cluster_data_melt = cluster_data.melt(id_vars=['clusters','f'])

best_run = cluster_data[cluster_data['dis+Icx'] == cluster_data['dis+Icx'].min()].index.values
print(cluster_data.loc[best_run])
optimal_f = cluster_data.loc[best_run,'f'].values[0]

f_distplot = alt.Chart(cluster_data_melt).mark_line(interpolate='step').encode(
    alt.X('f:O'),
    alt.Y('value', axis=alt.Axis(labelAlign='left')),
    alt.Color('variable')
)
c_distplot = alt.Chart(cluster_data_melt).mark_bar(color=warmgrays[2],interpolate='step-before',
                                            fillOpacity=0.05, line=True).encode(
    alt.X('f:O'),
    alt.Y('clusters')
)
#f_distplot.resolve_scale('independent')
(c_distplot + f_distplot).resolve_scale(y='independent')

tree_clusters, tree_cluster_LCA = partitions[optimal_f]
print([len(cluster) for cluster in tree_clusters])

#replace with mmseqs clusters
#tree_clusters = mmseqs_clusters

colors = colorlib['twilight_shifted_r_perm']
class_colors = [colors[i%len(colors)] for i, _ in enumerate(tree_clusters)]
classes = len(tree_clusters)

#define overall stree styling
ts = TreeStyle()
ts.mode = 'c'
ts.show_leaf_name = False
ts.show_branch_length = False
ts.show_branch_support = False
#ts.optimal_scale_level = 'full'
ts.allow_face_overlap = True
ts.scale = 50

default_node_style = ete3.NodeStyle()
default_node_style['size'] = 0
default_node_style['fgcolor'] = 'Black'

#set default colors
for _, node in tree.iter_prepostorder():
    node.set_style(default_node_style)

#color LCA nodes
for i, node in enumerate(tree_cluster_LCA):
    LCA_node_style = ete3.NodeStyle()
    LCA_node_style['fgcolor'] = class_colors[i]
    LCA_node_style['size'] = 2
    node.set_style(LCA_node_style)
    

#color leaves
for i, cluster in enumerate(tree_clusters):
    class_style = ete3.NodeStyle()
    class_style['fgcolor'] = class_colors[i]
    class_style['size'] = 1
    for leaf in cluster:
        tree.get_leaves_by_name(leaf)[0].set_style(class_style)
        
tree.ladderize()

view_tree(tree, ts)

In [None]:
#weighting scheme testing

t1 = Tree('((((A,B)D,E)F,G)H,(I,J)K)root;', format=1)
print(t1.get_ascii(attributes=['name']))


In [None]:
#remove isolated leaves by checking all monophyletic nodes for singletons
def trim_singleton_leaves(tree, attribute='tax_superkingdom', attr_value='Eukaryota', min_size=1, detach=True):
    
    LCA_groups = get_paraphyletic_groups(tree, attribute=attribute, attr_value=attr_value)
    
    pruned_leaves = []
    
    for node in LCA_groups:
        #remove all LCA leaves which are not in the majority among its neighbors
        if len(node.get_leaves()) <= min_size:
            neigbour_attr = [getattr(leaf, attribute) for leaf in node.up.get_leaves()]
            
            if neigbour_attr.count(attr_value)/len(neigbour_attr) < 0.5 and detach:
                pruned_leaves.append(node.detach())
            
            else: 
                pruned_leaves.append(node)
                
    return pruned_leaves


In [None]:
#return the entropy of decendant and non decendant leaf labels 
def get_entropy_for_partition(tree, node, attribute='tax_filter', attr_value='Eukaryota'):
    
    all_labels = [getattr(leaf, attribute) for leaf in tree.get_leaves()]
    all_label_count = all_labels.count(attr_value)
    tree_width = len(all_labels)
    
#     base_label_Px = (all_label_count/tree_width)
#     base_label_H = base_label_Px*np.log2(1/base_label_Px)
    
    clade_labels = [getattr(leaf, attribute) for leaf in node.get_leaves()]
    clade_label_count = clade_labels.count(attr_value)
    clade_width = len(clade_labels)
    
    #calculate label entropy
    #print('AAA', clade_labels, clade_label_count, clade_width)
    label_Px = (clade_label_count)/(clade_width)
    label_H = label_Px*np.log2(1/label_Px)
    
    #calculate external entropy change
    
    # if all labels in the clade the external entropy is 0
    if all_label_count-clade_label_count == 0:
        external_label_H = 0

    else:
        #calculate external entropy change
        external_label_Px = (all_label_count-clade_label_count)/(tree_width-clade_width)
        external_label_H = external_label_Px*np.log2(1/external_label_Px)


    return label_H, external_label_H
    
#assign soft LCA node based on minimizing entropy between given label outside and inside clade
#more pessimissive than voting ratio, qualitatively underestimates
def get_soft_LCA_by_relative_entropy(tree, attribute='tax_superkingdom', attr_value='Eukaryota', save_loss=False):
    #count vote ratio for each node for one taxa
    tree_width = len(tree.get_leaves())
    root = tree.get_tree_root()
    
    lowest_total_H = float("inf")
    best_node = root
    vote_label = attr_value
    
    all_labels = [getattr(leaf, attribute) for leaf in tree.get_leaves()]
    all_label_count = all_labels.count(attr_value)
    tree_width = len(all_labels)
    
    all_label_Px = (all_label_count/tree_width)
    all_labels_H = all_label_Px*np.log2(1/all_label_Px)
    
    #for debugging
    if save_loss:
        for node in tree.traverse():
            node.add_feature('vote_loss', 'None')
    
    #check for monophyly
    LCA_groups = get_paraphyletic_groups(tree, attribute=attribute, attr_value=attr_value)
    if len(LCA_groups) == 1:
        print(f'The attribute {attribute} is monophyletic for {attr_value}. Returning LCA node.')
        return (LCA_groups[0], 0)


    #the best partition will be on the path from an LCA node to the root
    tested_nodes = []
    for node in LCA_groups:
        
        node_label_count = len([leaf for leaf in node.get_leaves() if getattr(leaf, attribute) == attr_value])

        #ascend until all labeled leaves are decendants of node 
        while node != root:
           
            #skip known nodes
            if node not in tested_nodes:
                
                #calculate internal and external entropy
                label_H, external_label_H =  get_entropy_for_partition(tree, node, attribute=attribute, attr_value=attr_value)
                total_H = label_H + external_label_H
                    
                #penalize leaf LCAs to avoid laddered LCAs when having repeated outgroups
                #not neccesarily waned as spread singleons get their global LCA as soft_LCA
                if node.is_leaf():
                    total_H += 0.5
                
                #update best guess
                if total_H < lowest_total_H:
                    lowest_total_H = total_H
                    best_node = node
                
                if save_loss:
                    node.add_feature('vote_loss', total_H)
            
            #break after calculations if all nodes are decendants
            node_label_count = len([leaf for leaf in node.get_leaves() if getattr(leaf, attribute) == attr_value])
            if node_label_count == all_label_count:
                break
            #print(node_label_count, all_label_count)
            
            #ascend
            tested_nodes.append(node)
            node = node.up
            
            
    #print(f'Best node for {attr_value} has a total H of {lowest_total_H}')
    return (best_node, lowest_total_H)

def get_multiple_soft_LCAs(tree, attribute='tax_filter', attr_value='Eukaryota', min_size=1, min_purity=0, max_entropy=9999):
    
    print(f'Searching for LCA_nodes by checking where {attribute} is {attr_value}')
        
    soft_LCA_nodes = []
    total_nodes = len([getattr(node, attribute) for node in tree.get_leaves() if getattr(node, attribute) == attr_value])    
    
    while total_nodes > 0:
        
        soft_LCA_node, lowest_H_loss = get_soft_LCA_by_relative_entropy(tree, attribute=attribute, attr_value=attr_value)
        

        soft_LCA_node_leaves = soft_LCA_node.get_leaves()
        soft_LCA_node_size = len(soft_LCA_node_leaves)
        
        #mark labeled included nodes and decrement total nodes left to check
        for node in soft_LCA_node_leaves:
            if getattr(node, attribute) == attr_value:
                setattr(node, attribute, 'SAMPLED')
                total_nodes -= 1
        
            
        soft_LCA_nodes.append(soft_LCA_node)
        
    #reset leaf node attributes
    for node in tree.get_leaves():
        if getattr(node, attribute) == 'SAMPLED':
            setattr(node, attribute, attr_value)        
    
    #recalculate entropies
    print(f'\tRecalculating sizes, purities and entropies for all LCA nodes')
    filtered_soft_LCA_nodes = []
    
    for i, node in enumerate(soft_LCA_nodes):
        label_H, external_label_H = get_entropy_for_partition(tree, node, attribute=attribute, attr_value=attr_value)
        
        soft_LCA_members = [getattr(leaf, attribute) for leaf in node.get_leaves()]
        soft_LCA_size = len(soft_LCA_members)
        soft_LCA_purity = soft_LCA_members.count(attr_value) / soft_LCA_size
        soft_LCA_entropy = label_H + external_label_H
        
        if soft_LCA_size >= min_size and soft_LCA_purity >= min_purity and soft_LCA_entropy <= max_entropy:
            print(f'\tFound node of size {soft_LCA_size} with purity of {soft_LCA_purity} and entropy {soft_LCA_entropy} as LCA for {attr_value}')
            filtered_soft_LCA_nodes.append([node, soft_LCA_size, soft_LCA_purity, soft_LCA_entropy])
        else:
            print(f'\tRejected node of size {soft_LCA_size} with purity of {soft_LCA_purity} and entropy {soft_LCA_entropy} as LCA for {attr_value}')
    
    if len(filtered_soft_LCA_nodes) == 0:
        print(f'WARNING: No valid LCA nodes present for {attribute} = {attr_value} under conditions min_size={min_size}, min_purity={min_purity}, max_entropy={max_entropy}') 
    
    print()
    
    return filtered_soft_LCA_nodes
    

In [None]:
tree_files = !find testing/IQtree_reverse/ -name '*.contree'

tree_files = !find microcosm/ -name '*fil.contree'

file = tree_files[1]

#file = 'testing/trees/EPZ30938.1.members.fasta.haln.euk.fil.fasttree.2'
#for file in tree_files:
print(file)
tree = Tree(file)
tree_name = file.split('/')[1]
tree_header = euk_header[euk_header.acc==tree_name].header.values
print(tree_header)
dirty_phyla_add(tree, tax_merge)

#merge leaf pairs until total amount of leaves is smaller than x
tree = reduce_leaves_to_size(tree, 300)

print(len(tree.get_leaves()))

#calculate devaiting branch distances
outlier_nodes =  get_outlier_nodes_by_invgamma(tree, p_low=0, p_high=0.999, only_leaves=False)
#cut_nodes = [node.detach() for node in outlier_nodes]


#calculate soft LCA nodes for prok and euk using partition entropy 
filter_taxa = set([leaf.tax_filter for leaf in tree.get_leaves()])
hard_LCA_dict = {}
soft_LCA_dict = {}

for tax in filter_taxa:
    hard_LCA_dict[tax] = get_paraphyletic_groups(tree, attribute='tax_filter', attr_value=tax)
    soft_LCA_dict[tax] = get_multiple_soft_LCAs(tree, attribute='tax_filter', attr_value=tax,
                                               min_size=2, min_purity=0.5)
    
soft_LCA_dict['Eukaryota'] = get_multiple_soft_LCAs(tree, attribute='tax_filter', attr_value='Eukaryota',
                                                    min_size=3, min_purity=0.90)

#add LCA nodes to list for NODE visualisation
hard_LCA_nodes = [node for LCA_nodes in hard_LCA_dict.values() for node in LCA_nodes]
soft_LCA_nodes = [node[0][0] for node in soft_LCA_dict.values() if node != []]

#add LCA labels for LABEL visualisation
for node in tree.traverse():
    node.add_feature('soft_LCA', '')
    node.add_feature('soft_LCA_H', '')
    
#add the first valid soft_LCA as LCA for prok and all valid soft_LCAs for euk
for tax, nodes in soft_LCA_dict.items():
    if nodes != []:
        #add all euk nodes
        if tax == 'Eukaryota':
            for node in nodes:
                node[0].soft_LCA = tax
                node[0].soft_LCA_H = node[3]                
        
        
        #add first prok node
        else:
            nodes[0][0].soft_LCA = tax
            nodes[0][0].soft_LCA_H = nodes[0][3]
    
    #skip empty
    else:
        pass
    
# #track one soft LCA label
# tax = 'Flavobacteriia'
# soft_LCA_dict[tax] = get_soft_LCA_by_relative_entropy(tree, attribute='tax_class', attr_value=tax, save_loss=True)
# for node in tree.traverse():
#     node.add_face(TextFace(node.vote_loss, fsize=4), column=0)

#define overall tree styling
ts = TreeStyle()
ts.title.add_face(TextFace(tree_header, fsize=8), column=0)
ts.mode = 'r'
ts.show_leaf_name = False
ts.show_branch_length = False
ts.show_branch_support = False
#ts.optimal_scale_level = 'full'
ts.allow_face_overlap = True
ts.scale = 50

default_node_style = ete3.NodeStyle()
default_node_style['size'] = 0
default_node_style['fgcolor'] = 'Black'

default_leaf_style = ete3.NodeStyle()
default_leaf_style['size'] = 0
default_leaf_style['fgcolor'] = 'Black'

arc_node_style = ete3.NodeStyle()
arc_node_style['size'] = 2
arc_node_style['fgcolor'] = 'Cyan'

euk_node_style = ete3.NodeStyle()
euk_node_style['size'] = 2
euk_node_style['fgcolor'] = 'Green'

outlier_node_style = ete3.NodeStyle()
outlier_node_style['size'] = 3
outlier_node_style['fgcolor'] = 'Red'

hard_LCA_node_style = ete3.NodeStyle()
hard_LCA_node_style['size'] = 2
hard_LCA_node_style['fgcolor'] = 'Gray'

soft_LCA_node_style = ete3.NodeStyle()
soft_LCA_node_style['size'] = 4
soft_LCA_node_style['fgcolor'] = 'Black'

soft_euk_LCA_node_style = ete3.NodeStyle()
soft_euk_LCA_node_style['size'] = 4
soft_euk_LCA_node_style['fgcolor'] = 'Gray'

#set styling for all leaves and internal nodes
for node in tree.traverse():
    node.set_style(default_node_style)
    node.add_face(TextFace(node.soft_LCA, fsize=4), column=0)
    node.add_face(TextFace(node.soft_LCA_H, fsize=4), column=0)

    if node.is_leaf():
        node.add_face(TextFace(node.name, fsize=4), column=1)
        node.add_face(TextFace(' '+node.tax_class,  fsize=4), column=2)
        node.set_style(default_leaf_style)

        if node.tax_superkingdom == 'Eukaryota':
            node.set_style(euk_node_style)
            
        elif node.tax_superkingdom == 'Archaea':
            node.set_style(arc_node_style)

    if node in hard_LCA_nodes:
        node.set_style(hard_LCA_node_style)
    
    if node in soft_LCA_nodes:
        node.set_style(soft_LCA_node_style)

    if node in outlier_nodes or node in euk_singletons:
        node.set_style(outlier_node_style)

#tree.set_outgroup(tree.get_midpoint_outgroup())
tree.ladderize()
#tree.render(file+'.pdf', tree_style=ts)
#tree.write(features=["name", "tax_superkingdom"], outfile=file+'_annot')
view_tree(tree, ts, backup=file+'.pdf')

In [None]:
VALIDATE TQtree2 models

In [None]:
import io

#read and split data table from IQtree
model_data_file = 'testing/IQtree_models/merged_output.txt'
with open(model_data_file, 'r') as infile:
    #models = [m.split('\n') for m in infile.read().split(';')]
    models = infile.read().split(';')

#loop over models adding observations of rank as index to the data
#data is already sorted
data = pd.DataFrame()
for model in models:
    new_data = pd.read_csv(io.StringIO(model), delim_whitespace=True)
    new_data.reset_index(inplace=True, names='rank')
    new_data['rank'] = [max(10-x, 0) for x in new_data['rank']]
    data = pd.concat([data,new_data])

data_models = pd.Series({model:data[data.Model == model]['rank'].sum()/len(models) for model in data.Model.unique()})
data_models.sort_values(ascending=False)[0:50]

In [None]:
plot_data = pd.DataFrame(data_models)
plot_data.reset_index(inplace=True)
plot_data.columns = ['model', 'score']
plot_data.sort_values(by='score', ascending=False, inplace=True)
plot_data['name']= [l.split('+')[0] for l in plot_data.model]
plot_data.reset_index(inplace=True, drop=True)
display(plot_data[0:50])

In [None]:
plot = alt.Chart(plot_data[0:30]).mark_bar().encode(
    x= alt.X('model:O', sort=None, axis=alt.Axis(labelAngle=-45)),
    color = alt.Color('name'),
    y= alt.Y('score'),
).interactive()
plot.width = 1200
plot

In [None]:
#fix further discrepancies in eukprot species labelling not reachable from orgid

from ete3 import NCBITaxa


error_ids = [2985, 191814, 35133, 88547, 438412, 419944, 1104430, 63605, 35686, 299832, 36769, 91373, 186019, 37163, 446134, 5748, 265536, 3028, 195968, 5774, 2949, 81532, 5709, 947084, 12967]

error_lineages = NCBITaxa().get_lineage_translator(error_ids)

error_lineages

#flat_list = [ x for xs in xss for x in xs ]
error_dict = [x for xs in error_lineages.values() for x in xs]
error_dict = NCBITaxa().get_taxid_translator(error_dict)


errorDF = pd.DataFrame()

for error_id, lineage in error_lineages.items():

    try:
        error_class = [error_dict[c] for c in lineage if NCBITaxa().get_rank([c])[c] == 'class']
    except ValueError:
        error_class = ['none']
        
    if error_class == []:
        error_class = ['none']
        
    error_name = NCBITaxa().get_taxid_translator([error_id])[error_id]
    print(error_id, error_class, error_name)        
    classDF = pd.DataFrame({'species': error_name, 'superkingdom':'Eukaryota', 'class':error_class},index=[error_id])
    errorDF = pd.concat([errorDF, classDF])

errorDF.loc[63605]['class'] = 'Heterolobosea'
errorDF.loc[419944]['class'] = 'Picozoa'
errorDF.loc[438412]['class'] = 'Amoebozoa'


---- Temporary ----

In [None]:
#check wellnes of fit of inverse gamma distribution for tree subset
treefiles = subprocess.run(f"find processing/microcosm2/ -name '*.treefile'",shell=True, text=True, capture_output=True)
treefiles = [f for f in treefiles.stdout.split('\n')][:-1]

In [None]:
#check wellnes of fit of inverse gamma distribution for tree subset
for file in treefiles:

    #test invgamma
    treefile = file
    print(file)
    tree = Tree(treefile)
    p_high = 0.99
    p_low = 0

    outlier_nodes, dist_series, fitting_params, cutoffs = get_outlier_nodes_by_invgamma(tree, p_low, p_high)
    print(fitting_params)

    x_max = cutoffs[0] + 0.2

    #tree data
    dist_data = pd.DataFrame(dist_series)
    dist_data.columns = ['stem_length']

    #probability density values from fit
    x = np.linspace(stats.invgamma.pdf(0.01, fitting_params[0], fitting_params[1], fitting_params[2]),
                    stats.invgamma.pdf(0.99, fitting_params[0], fitting_params[1], fitting_params[2]), 1000)
    x = np.linspace(0.01, x_max, 1000)

    y = stats.invgamma.pdf(x, fitting_params[0], fitting_params[1], fitting_params[2])

    fit_data = pd.DataFrame({'stem_length':x, 'prob_density':y})
    fit_data = fit_data[fit_data.stem_length < x_max]

    #simulated data
    sim_data = stats.invgamma.rvs(fitting_params[0], fitting_params[1], fitting_params[2], size=len(tree))
    sim_data = pd.DataFrame(pd.Series(sim_data))
    sim_data.columns = ['stem_length']


    #plot cumsum hist
    plot, data = plot_cumsum_counts(dist_data.stem_length, title='experimental')
    plot, fit = plot_cumsum_counts(sim_data.stem_length, title='simulated from fit')

    merged = pd.concat([data, fit])

    cum_plot, merged_data = plot_cumsum_counts(merged, plot_type='default', formatted_data=True,
                                          x_min = 0, x_max=5, x_scale_type='linear', y_scale_type='linear')

    ks_test = stats.ks_2samp(dist_data.stem_length, fit_data.stem_length, alternative='two-sided', mode='auto')
    cum_plot.title = alt.TitleParams(f'p={ks_test.pvalue}, statistic={ks_test.statistic} ', fontSize=12)


    #data_hist with line fit
    title = alt.TitleParams(f"{treefile.split('/')[-1]}, {len(tree)} leaves, 99% CDF {cutoff_high}", fontSize=12)
    data_hist = alt.Chart(dist_data, title = title).mark_bar().encode(
        x= alt.X('stem_length', bin=alt.Bin(extent=[0, x_max], step=0.02)),
        y= alt.Y('count()')

    )


    fit_curve = alt.Chart(fit_data).mark_line(color='red').encode(
        x= alt.X('stem_length', scale=alt.Scale(domain=[0, x_max])),
        y= alt.Y('prob_density')

    )

    plot = alt.layer(data_hist, fit_curve).resolve_scale(y='independent')
    merge_plot = plot | cum_plot
    
    altair_saver.save(merge_plot, treefile+'.html' )

    subprocess.run(f'cp {treefile} /home/tobiassonva/data/eukgen/testing/invgamma_fitting/'.split())
    subprocess.run(f'mv {treefile}.html /home/tobiassonva/data/eukgen/testing/invgamma_fitting/'.split())
    
    print()

In [None]:
#check wellnes of fit of lognorm distribution for tree subset
from core_functions.tree_functions import get_outlier_nodes_by_lognorm

#test lognorm
#treefile = treefiles[27]
tree = Tree(treefile)
p_high = 0.99
p_low = 0.0

outlier_nodes, dist_series, fitting_params = get_outlier_nodes_by_lognorm(tree, p_low, p_high)
print(fitting_params)

x_max = fitting_params[4] + 0.2

#tree data
dist_data = pd.DataFrame(dist_series)
dist_data.columns = ['stem_length']

#probability density values from fit
x = np.linspace(stats.lognorm.pdf(0.01, fitting_params[0], fitting_params[1], fitting_params[2]),
                stats.lognorm.pdf(0.99, fitting_params[0], fitting_params[1], fitting_params[2]), 1000)
x = np.linspace(0.01, x_max, 1000)

y = stats.lognorm.pdf(x, fitting_params[0], fitting_params[1], fitting_params[2])

fit_data = pd.DataFrame({'stem_length':x, 'prob_density':y})
fit_data = fit_data[fit_data.stem_length < x_max]

#simulated data
sim_data = stats.lognorm.rvs(fitting_params[0], fitting_params[1], fitting_params[2], size=len(tree))
sim_data = pd.DataFrame(pd.Series(sim_data))
sim_data.columns = ['stem_length']


#plot cumsum hist
plot, data = plot_cumsum_counts(dist_data.stem_length, title='experimental')
plot, fit = plot_cumsum_counts(sim_data.stem_length, title='simulated from fit')

merged = pd.concat([data, fit])

cum_plot, merged_data = plot_cumsum_counts(merged, plot_type='default', formatted_data=True,
                                      x_min = 0, x_max=5, x_scale_type='linear', y_scale_type='linear')

ks_test = stats.ks_2samp(dist_data.stem_length, fit_data.stem_length, alternative='two-sided', mode='auto')
cum_plot.title = alt.TitleParams(f'p={ks_test.pvalue}, statistic={ks_test.statistic} ', fontSize=12)


#data_hist with line fits
title = alt.TitleParams(f"{treefile.split('/')[-1]}, {len(tree)} leaves, 99% CDF {fitting_params[4]}", fontSize=12)
data_hist = alt.Chart(dist_data, title = title).mark_bar().encode(
    x= alt.X('stem_length', bin=alt.Bin(extent=[0, x_max], step=0.02)),
    y= alt.Y('count()')

)


fit_curve = alt.Chart(fit_data).mark_line(color='red').encode(
    x= alt.X('stem_length', scale=alt.Scale(domain=[0, x_max])),
    y= alt.Y('prob_density')

)

plot = alt.layer(data_hist, fit_curve).resolve_scale(y='independent')
merge_plot = plot | cum_plot
merge_plot



In [None]:
CDD_annot = pd.read_csv('processing/euk72_ep3/euk72_ep_hhm.CDD_annotation.tsv', sep='\t', index_col=0)
CDD_annot['score'] = CDD_annot.Prob*CDD_annot.Pairwise_cov
CDD_annot.sort_values(by='score', inplace=True, ascending=False)

In [None]:
# merge all tree_data.csv files and clade_data.csv

treefiles = subprocess.run(f"find processing/microcosm2/ -name '*.tree_data.tsv'", shell=True, text=True, capture_output=True)
treefiles = [file for file in treefiles.stdout.split('\n')][:-1]

print(len(treefiles))
new_data = []

for file in treefiles:
    data = pd.read_csv(file, sep='\t', index_col=0)
    new_data.append(data)

all_data = pd.concat(new_data, axis=0)
all_data.to_csv('tmp/microcosm_tree_data2.tsv', sep='\t')

tree_data = pd.read_csv('tmp/microcosm_tree_data2.tsv', sep='\t', index_col=0)
tree_data.set_index('tree_name', inplace=True)


In [None]:
i = iter(range(999))

In [None]:
tree_data[(tree_data.prok_taxa == 'Asgard') & (tree_data['c-ELW'] > 0.8)]

In [None]:
tree_data.loc[tree_data.index.unique()[next(i)]]    

In [None]:
# plot overall branch length distribution of most likley sister taxa based on c-ELW

tree_data = tree_data.sort_values(by='c-ELW', ascending=False)

filtered_data = tree_data[(tree_data.stem_length.between(0.00, 2))] #& 
#                           (tree_data.prok_clade_weight >= 0.6)]# & 
                          #~(tree_data.euk_clade_rep.duplicated())]

a = filtered_data#[~(filtered_data.euk_clade_rep.duplicated())]

#a['dist'] = [np.log(stem) for stem in a['dist']]

#taxa = a.prok_taxa.value_counts().index[0:20].values
taxa = ['Cyanophyceae', 'Asgard', 'Alphaproteobacteria', 'Actinomycetes']
#taxa = ['Aquificae']

a = a[a.prok_taxa.isin(taxa)]

title = f'Eukaryotic branch length per taxa, Sample_size:{a.shape[0]}'
#title = f'{a.shape[0]}'
#title = f'{len(a)} normalized stem lengths as per Gabaldon 2016'

KDE = alt.Chart(a, title=alt.TitleParams(text=title, fontSize=12)).mark_area(line=True, opacity=0.2).transform_density(
    'stem_length',
    as_=['stem_length', 'density'],
    bandwidth=0.1,
    groupby = ['prok_taxa']
    
    ).encode(
    x=alt.X('stem_length:Q', scale=alt.Scale(domain=[0,2], clamp=True, )),
    y=alt.Y('density:Q', scale=alt.Scale(domain=[0,3], clamp=False)),
    color=alt.Color('prok_taxa'),
    tooltip = alt.Tooltip(['prok_taxa'])

).interactive()

bar = alt.Chart(a, title=title).transform_joinaggregate(
    total='count(*)',
    groupby=['prok_taxa']
    ).transform_calculate(
    pct='1 / datum.total'
    ).mark_bar().encode(
    x=alt.X('stem_length:Q', bin=alt.Bin(step=0.05), scale=alt.Scale(domain=[0,4], clamp=True)),
    y=alt.Y('sum(pct):Q', scale=alt.Scale(domain=[0,0.4])),
    color=alt.Color('prok_taxa'),
    tooltip = alt.Tooltip(['prok_taxa'])

).interactive()
KDE

In [None]:
#prefilter tree_data
filtered_data = tree_data[(tree_data.stem_length.between(0, 2)) & 
                          (tree_data.euk_clade_weight >= 0.8) & 
                          (tree_data.prok_clade_weight >= 0.8) & 
                          ~(tree_data.index.duplicated())]

data = {}

#prefilter CDD data
CDD_annot_filtered = CDD_annot.loc[filtered_data.index]

#format all taxa specific closest sisters
for taxa in sorted(filtered_data.prok_taxa.unique()):
    print(taxa)
    tree_data_taxa = filtered_data[filtered_data.prok_taxa == taxa]
    
    CDD_hits = CDD_annot_filtered.loc[tree_data_taxa.index]
    
    CDD_good = CDD_hits[CDD_hits['score']>70].sort_values(by='score', ascending=False)
    CDD_good = CDD_good[~(CDD_good.index.duplicated())]
    CDD_good['closest_sister_taxa'] = [taxa]*CDD_good.shape[0]
    
    data[taxa] = CDD_good.round(2)

#write to file
all_data = pd.concat(data.values())

all_data.to_csv('tmp/microcosm_data_CDD.tsv', sep='\t')
    
#write to multi-sheet excel
xlsx_file = 'test2.xlsx'
with pd.ExcelWriter(xlsx_file) as writer:
    for taxa, df in data.items():
        df.to_excel(writer, taxa)

In [None]:
# plot clade analysis confidence distribution for tests 

from core_functions.altair_plots import plot_cumsum_counts
tree_data = tree_data[tree_data['p-AU_accept'] != 'NONE']

tests = ['c-ELW','bp-RELL', 'p-KH', 'p-SH', 'p-AU']
all_data_dict = {}

for test in tests:
    
    tree_data[test] = tree_data[test].astype(float)
    
    print(test)
    tree_data = tree_data.sort_values(by=[test])

    filtered_data = tree_data[(tree_data[test+'_accept'] == '+')]
    best_data = filtered_data.groupby('euk_clade_rep').apply(lambda data: data[test].max())

    
    plot, data1 = plot_cumsum_counts(filtered_data[test], title='accepted '+test)
    plot, data2 = plot_cumsum_counts(best_data, title='best '+test)
    plot, data3 = plot_cumsum_counts(tree_data[test], title='all '+test)
    
    all_data = pd.concat([data3,data1,data2])
    
    plot, data = plot_cumsum_counts(all_data, formatted_data=True, plot_type='default', 
                                x_scale_type='linear', y_scale_type='linear', 
                               x_min=0, x_max=1, title='Confidence distributions '+test, x_label='probability')
    
    display(plot)
    
    all_data_dict['all_'+test] = data1
    all_data_dict['best_'+test] = data2

In [None]:
#2d histogram heatmap of test value vs topology distance
# looks a bit garbage

tests = ['c-ELW','bp-RELL', 'p-KH', 'p-SH', 'p-AU']

dist_metric = 'dist'

for test in tests:

    plot_data = tree_data.sort_values(by=test + '_accept', ascending=True)
    
    
    
    hist = np.histogram2d(plot_data[test], plot_data[dist_metric], bins=50)
    
    data = pd.DataFrame(np.log2(hist[0]))
    data.index = np.round(hist[1][1:],2)
    data.columns = np.round(hist[2][1:],2)
    data2 = data.stack()
    data2.index.set_names([test, dist_metric], inplace=True)

    data2 = data2.reset_index(name='count')

    plot = alt.Chart(data2, ).mark_rect(filled=True, size=100).encode(
        x = alt.X(dist_metric+':O',  axis=alt.Axis(grid=False, labelAngle=90)),
        y = alt.Y(test+':O',  axis=alt.Axis(grid=False), scale=alt.Scale(reverse=True)),
        color = alt.Color('count')
    ).interactive().properties(width=700, height=700)

    display(plot)

In [None]:
# scatter of test values against distance metric

tests = ['c-ELW','bp-RELL', 'p-KH', 'p-SH', 'p-AU']
dist_metric = 'top_dist'

for test in tests:
    plot_data = tree_data.sort_values(by=test + '_accept', ascending=True)
    
    plot = alt.Chart(plot_data).mark_point(filled=True, size=100, opacity=0.05).encode(
        x = alt.X(dist_metric),
        y = alt.Y(test),
        color = alt.Color(test + '_accept')
    ).interactive()
    
    display(plot)



In [None]:
data.columns

In [None]:
tree_data[(tree_data.prok_taxa == 'Cyanophyceae') & (tree_data['c-ELW'] > 0.8)]

In [None]:
#system = tree_data[tree_data.stem_length < 0.1].index.unique()[28]
#system = small_systems[12]
system = 'EP00669P010048'
#print(system)

data = tree_data.loc[[system]].sort_values(by='c-ELW', ascending=False)
    
clades = data.euk_clade_rep.unique()
print(f'System {system} has Eukaryotic clades {clades}')

clade = clades[0]

plot_data = data[data.euk_clade_rep == clade]

dist_bar = alt.Chart(data, title=f'{system}: {clade}').mark_bar().encode(
    x = alt.X('prok_taxa', sort=None, axis=alt.Axis(labelAngle=-45)),
    y = alt.Y('top_dist'),
    color = alt.Color('prok_taxa:O'),
    tooltip = alt.Tooltip(['top_dist', 'prok_clade_weight', 'prok_clade_size', 'prok_clade_rep', 'prok_leaf_clade'])
)

tree = Tree(f'processing/microcosm2/{system}/{system}.merged.fasta.muscle.treefile.annot')
annot_tree, tree_img = color_tree(tree, view_in_notebook=True)


from core_functions.altair_plots import plot_alignment
aln_file = f'processing/microcosm2/{system}/{system}.merged.fasta.muscle'

#leaf_names = [leaf.name for leaf in trees[system].get_leaves()]
leaf_names = [leaf.name for leaf in tree.get_leaves()]

plot, aln_data = plot_alignment(aln_file, seqlimit=100, plot_range=(0,300), label_order=leaf_names)

display(dist_bar)
display(data.iloc[:,[0,5,10,11,12,16,17,18,19,20,21,22,23,24,25,26,27]])
display(pd.DataFrame(CDD_annot.loc[system]))
display(plot)
display(tree_img)

In [None]:
tree = Tree(f'processing/microcosm2_old/EP00411P016539/EP00411P016539.merged.fasta.muscle.treefile')
annot_tree, tree_img = color_tree(tree, view_in_notebook=True)
display(tree_img)

In [None]:
tree_data.euk_clade_rep.unique().shape

In [None]:
pd.DataFrame([tree_data.prok_taxa.value_counts()[0:15], tree_data.prok_taxa.value_counts(normalize=True)[0:15]]).transpose()

In [None]:
best = []
for tree, data in tree_data[tree_data['c-ELW_accept'].isin(['+', 'NONE'])].sort_values(by='c-ELW', ascending=False).groupby('euk_clade_rep'):
    best.extend(data['prok_taxa'].values)
    
pd.DataFrame([pd.Series(best).value_counts()[0:15],pd.Series(best).value_counts(normalize=True)[0:15]]).transpose()

In [None]:
best = []
for tree, data in tree_data[tree_data['c-ELW_accept'].isin(['+', 'NONE'])].sort_values(by='c-ELW', ascending=False).groupby('euk_clade_rep'):
    best.append(data.iloc[0]['prok_taxa'])
    
pd.DataFrame([pd.Series(best).value_counts()[0:15],pd.Series(best).value_counts(normalize=True)[0:15]]).transpose()

In [None]:
best = []
for tree, data in tree_data[(tree_data['c-ELW_accept'] == '+') &
                            (tree_data['c-ELW'] > 0.8)].sort_values(by='c-ELW', ascending=False).groupby('euk_clade_rep'):
    best.append(data.iloc[0]['prok_taxa'])
    
pd.DataFrame([pd.Series(best).value_counts()[0:15],pd.Series(best).value_counts(normalize=True)[0:15]]).transpose()

In [None]:
tree_data['rank'] = [None]*tree_data.shape[0]
for clade, data in tree_data.groupby('euk_clade_rep'):
    ranks = data['c-ELW'].rank(method='dense', pct=True)
    tree_data.loc[tree_data.euk_clade_rep == clade, 'rank'] = ranks

In [None]:
tree_data[tree_data.euk_clade_rep == 'EP00437P002328'].iloc[:,6:50]

In [None]:
tree_data.loc['EP00746P122355'].sort_values(by='c-ELW', ascending=False).iloc[:,4:25][:50]

In [None]:
tree_data[(tree_data['c-ELW_accept'] == '+') & 
          (tree_data.prok_taxa == 'Alphaproteobacteria') &
          (tree_data['c-ELW'] > 0.5)].sort_values(by='c-ELW', ascending=False)

In [None]:
tree_data.sort_values(by='c-ELW', ascending=False)

In [None]:
tree_data[tree_data.euk_clade_rep == 'EP00741P020146'].sort_values(by='c-ELW', ascending=False)

In [None]:
from core_functions.altair_plots import plot_alignment

clusters = pd.read_csv('processing/prok2111_as/prok2111_as.repseq.cascaded_cluster.tsv', sep='\t', names =['cluster_acc', 'acc'], index_col=0)

In [None]:
clusters

In [None]:
a = iter(clusters.loc[clusters.index.value_counts().between(50,100)].index.unique().values)

In [None]:
aln_file = f'processing/prok2111_as/cluster_fastas/{next(a)}'

#leaf_names = [leaf.name for leaf in trees[system].get_leaves()]
#leaf_names = [leaf.name for leaf in tree.get_leaves()]

plot, aln_data = plot_alignment(aln_file, seqlimit=100, plot_range=(0,300))

plot

In [None]:
tree_data.dropna(inplace=True)

In [None]:
ELW = tree_data[tree_data['c-ELW_accept'] == '+'].groupby('prok_taxa').apply(lambda x: x['c-ELW'].describe())

ELW['weight'] = ELW['count']*ELW['mean']

ELW = ELW[['count', 'weight', 'mean', 'std', 'min', '25%', '50%', '75%', 'max']]

ELW[ELW['count'] > 50].sort_values(by='mean', ascending=False)[0:40]



In [None]:
# plot overall branch length distribution of most likley sister taxa based on c-ELW

tree_data = tree_data.sort_values(by='c-ELW', ascending=False)

filtered_data = tree_data[tree_data['c-ELW_accept'] == '+']
#[(tree_data.stem_length.between(0.00, 2))] #& 
#                           (tree_data.prok_clade_weight >= 0.6)]# & 
                          #~(tree_data.euk_clade_rep.duplicated())]

a = filtered_data#[~(filtered_data.euk_clade_rep.duplicated())]

#a['dist'] = [np.log(stem) for stem in a['dist']]

#taxa = a.prok_taxa.value_counts().index[0:20].values
taxa = ['Cyanophyceae', 'Asgard', 'Alphaproteobacteria', 'Actinomycetes', 'Gammaproteobacteria']
#taxa = ['Aquificae']

a = a[a.prok_taxa.isin(taxa)]

title = f'c-ELW distribution per taxa, Sample_size:{a.shape[0]}'
#title = f'{a.shape[0]}'
#title = f'{len(a)} normalized stem lengths as per Gabaldon 2016'

KDE = alt.Chart(a, title=alt.TitleParams(text=title, fontSize=12)).mark_area(line=True, opacity=0.2).transform_density(
    'c-ELW',
    as_=['c-ELW', 'density'],
    bandwidth=0.05,
    groupby = ['prok_taxa']
    
    ).encode(
    x=alt.X('c-ELW:Q', scale=alt.Scale(domain=[0,1], clamp=True, )),
    y=alt.Y('density:Q', scale=alt.Scale(domain=[0,3.5], clamp=False)),
    color=alt.Color('prok_taxa'),
    tooltip = alt.Tooltip(['prok_taxa'])

).interactive()

bar = alt.Chart(a, title=title).transform_joinaggregate(
    total='count(*)',
    groupby=['prok_taxa']
    ).transform_calculate(
    pct='1 / datum.total'
    ).mark_bar().encode(
    x=alt.X('stem_length:Q', bin=alt.Bin(step=0.05), scale=alt.Scale(domain=[0,4], clamp=True)),
    y=alt.Y('sum(pct):Q', scale=alt.Scale(domain=[0,0.4])),
    color=alt.Color('prok_taxa'),
    tooltip = alt.Tooltip(['prok_taxa'])

).interactive()
KDE

In [None]:






































a
