# Parse ASTRAL-pro2 trees from taxonomic imbalance

In [1]:
import sys
import re
import pandas as pd
import numpy as np
from skbio.tree import TreeNode

In [2]:
def assign_supports(tree):
    for node in tree.traverse():
        if node.is_tip() or node.is_root():
            node.support = None
        else:
            node.support = node.name
            node.name = None

In [3]:
def add_length_to_tips(tree):
    '''
    Add a length of 0 to tips. 
    Useful for trees produced by Astral-pro2, in which
    tips do not have a branch lengths
    '''
    for node in tree.tips():
        node.length = 0.0

    return tree

In [4]:
def order_nodes(tree, increase=True):
    res = tree.copy()
    for node in res.postorder():
        if node.is_tip():
            node.n = 1
        else:
            node.n = sum(x.n for x in node.children)
    for node in res.postorder():
        if not node.is_tip():
            child2n = {x: x.n for x in node.children}
            node.children = []
            for child in sorted(child2n, key=child2n.get, reverse=increase):
                node.append(child)
    for node in res.postorder():
        delattr(node, 'n')
    return res

In [5]:
def parse_tree_astralpro(dataPathIn, dataPathOut, k, p):
    # Read newick file
    with open(f'{dataPathIn}/select_k_{k}_p_{p}_speciestree_astral-pro2.nwk', 'r') as f:
        nwk = f.read().strip()
    # Replace complex nodes with simple labels
    tmplabs = []
    def replace(match):
        # Get the first subgroup from re.match
        tmplabs.append(match.group(1))
        return f'X{len(tmplabs)}'
    nwk = re.sub(r'\'\[([^\[\]]+)\]\'', replace, nwk)
    # Convert string into TreeNode object
    tree = TreeNode.read([nwk])
    # Convert indices into branch supports
    assign_supports(tree)
    # Trifurcate the arbitrary root -- root has two children: 1 (big clade) and 0 (single tip) 
    if not tree.children[1].is_tip():
        raise ValueError('Unexpected input tree format.')
    t = tree.children[0]
    tree.remove(t)
    tree.extend(t.children)
    # # Root tree by outgroup
    # tree = root_by_outgroup(tree, outgroup)

    # Add branch lengths to tips for ASTRAL-pro2 trees
    t = add_length_to_tips(tree)
    # Root at midpoint
    tree = t.root_at_midpoint(branch_attrs = ['support'])
    
    # Order nodes by number of childs
    tree = order_nodes(tree, increase = False)
    # Assign incremental nodes IDs
    i = 1
    for node in tree.levelorder(include_self = True):
        if not node.is_tip():
            node.name = f'N{i}'
            i += 1
    # Extract node metadata
    metadata = []
    for node in tree.levelorder(include_self = True):
        if node.is_tip():
            continue
        if node.support is None:
            continue
        label = tmplabs[int(node.support[1:]) - 1]
        attrs = dict(x.split('=') for x in label.split(';'))
        attrs['node'] = node.name
        metadata.append(attrs)
    # Generate metadata table
    df = pd.DataFrame(metadata).set_index('node')
    # Save metadata
    df.to_csv(f'{dataPathOut}/metadata_k_{k}_p_{p}_astral-pro2.tsv', sep = '\t')
    # Save tree with node ids
    tree.write(f'{dataPathOut}/nid_k_{k}_p_{p}_astral-pro2.nwk')
    # Save tree with local posterior probabilities
    lpps = df['pp1'].to_dict()
    t = tree.copy()
    for node in t.non_tips(include_self=True):
        if node.name in lpps:
            lpp = lpps[node.name]
            node.name = '1.0' if lpp == '1.0' else f'{float(lpp):.3f}'
        else:
            node.name = None
    t.write(f'{dataPathOut}/lpp_k_{k}_p_{p}_astral-pro2.nwk')
    # Save tree without node labels
    for node in t.non_tips(include_self = True):
        node.name = None
    t.write(f'{dataPathOut}/nlabels_k_{k}_p_{p}_astral-pro2.nwk')

In [6]:
# Parameters
replicates =[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
markers = ['kegg', 'eggnog']
ks = [50, 100, 200, 400]
p = 0

In [7]:
for rep in replicates:
    print(f'At replicate {rep}')
    for marker in markers:
        dataPathIn = f'./out/replicate_{rep}/pipeline/{marker}'
        dataPathOut = f'./out/replicate_{rep}/parsed/{marker}'
        for k in ks:
            parse_tree_astralpro(dataPathIn, dataPathOut, k, p)

At replicate 0
At replicate 1
At replicate 2
At replicate 3
At replicate 4
At replicate 5
At replicate 6
At replicate 7
At replicate 8
At replicate 9


In [8]:
def parse_tree_astralpro_previous_markers(dataPathIn, dataPathOut, k):
    # Read newick file
    with open(f'{dataPathIn}/combined_speciestree_astral-pro2.nwk', 'r') as f:
        nwk = f.read().strip()
    # Replace complex nodes with simple labels
    tmplabs = []
    def replace(match):
        # Get the first subgroup from re.match
        tmplabs.append(match.group(1))
        return f'X{len(tmplabs)}'
    nwk = re.sub(r'\'\[([^\[\]]+)\]\'', replace, nwk)
    # Convert string into TreeNode object
    tree = TreeNode.read([nwk])
    # Convert indices into branch supports
    assign_supports(tree)
    # Trifurcate the arbitrary root -- root has two children: 1 (big clade) and 0 (single tip) 
    if not tree.children[1].is_tip():
        raise ValueError('Unexpected input tree format.')
    t = tree.children[0]
    tree.remove(t)
    tree.extend(t.children)
    # # Root tree by outgroup
    # tree = root_by_outgroup(tree, outgroup)

    # Add branch lengths to tips for ASTRAL-pro2 trees
    t = add_length_to_tips(tree)
    # Root at midpoint
    tree = t.root_at_midpoint(branch_attrs = ['support'])
    
    # Order nodes by number of childs
    tree = order_nodes(tree, increase = False)
    # Assign incremental nodes IDs
    i = 1
    for node in tree.levelorder(include_self = True):
        if not node.is_tip():
            node.name = f'N{i}'
            i += 1
    # Extract node metadata
    metadata = []
    for node in tree.levelorder(include_self = True):
        if node.is_tip():
            continue
        if node.support is None:
            continue
        label = tmplabs[int(node.support[1:]) - 1]
        attrs = dict(x.split('=') for x in label.split(';'))
        attrs['node'] = node.name
        metadata.append(attrs)
    # Generate metadata table
    df = pd.DataFrame(metadata).set_index('node')
    # Save metadata
    df.to_csv(f'{dataPathOut}/metadata_k_{k}_astral-pro2.tsv', sep = '\t')
    # Save tree with node ids
    tree.write(f'{dataPathOut}/nid_k_{k}_astral-pro2.nwk')
    # Save tree with local posterior probabilities
    lpps = df['pp1'].to_dict()
    t = tree.copy()
    for node in t.non_tips(include_self=True):
        if node.name in lpps:
            lpp = lpps[node.name]
            node.name = '1.0' if lpp == '1.0' else f'{float(lpp):.3f}'
        else:
            node.name = None
    t.write(f'{dataPathOut}/lpp_k_{k}_p_astral-pro2.nwk')
    # Save tree without node labels
    for node in t.non_tips(include_self = True):
        node.name = None
    t.write(f'{dataPathOut}/nlabels_k_{k}_astral-pro2.nwk')

In [9]:
# Parameters
replicates =[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
markers = ['martinez_gutierrez', 'moody_2024', 'amphora', 'phylophlan']
ks = [41, 57, 136, 400]

In [10]:
for rep in replicates:
    print(f'At replicate {rep}')
    for marker, k in zip(markers, ks):
        dataPathIn = f'./out/replicate_{rep}/pipeline/{marker}'
        dataPathOut = f'./out/replicate_{rep}/parsed/{marker}'
        parse_tree_astralpro_previous_markers(dataPathIn, dataPathOut, k)

At replicate 0
At replicate 1
At replicate 2
At replicate 3
At replicate 4
At replicate 5
At replicate 6
At replicate 7
At replicate 8
At replicate 9
