In [1]:
# The following code is from Zhaoyi Zhang.
# I tested it to see if it works for trees with 4 taxa
from math import comb
from itertools import combinations, islice
import time

from Bio import Phylo
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams['font.size'] = 12

In [2]:
# In the data provided by Marianne, the names of the taxa are replaced with integers
trees = Phylo.parse('data/4taxa_trees.tre', 'newick')

In [3]:
def get_direct_children(clade, res=None):
    '''
    get direct children of each clade
    '''
    if res == None:
        res = {}
    res[clade] = []
    if len(clade.clades) == 0:
        return
    for c in clade.clades:
        res[clade].append(c)
        get_direct_children(c, res)
    return res

def expand_direct_children(res):
    '''
    expand direct children of each clade into descendants
    '''
    res_expand = {}
    for k, v in reversed(res.items()):
        res_expand[k] = []
        for c in v:
            if c.name and str.isdigit(c.name):
                res_expand[k].append(c)
            else:
                res_expand[k].extend(res_expand[c])
    return res_expand

def get_descendants(clade):
    '''
    get descendants of each clade in the `clade`
    '''
    res = get_direct_children(clade)
    return expand_direct_children(res)

def print_children(res):
    def get_name_or_length(c):
        return c.name or (c.branch_length and f'{c.branch_length:.2}')
    for k, v in res.items():
        print(get_name_or_length(k), '\t',
              list(map(lambda c: get_name_or_length(c), v)))
        
def desc_to_length(descendants):
    res = []
    for parent, desc in descendants.items():
        if desc:
            res.append(([d.name for d in desc], parent.branch_length))
        else:
            res.append((parent.name, parent.branch_length))
    return res

$$
\begin{align}
2^{n} &= \sum_{i=0}^{n} \binom{n}{i} \\
bipartitions &= \frac{1}{2} \sum_{i=1}^{n-1} \binom{n}{i} \\
&= \frac{1}{2} \left(\sum_{i=0}^{n}{\binom{n}{i}} - \binom{n}{0} - \binom{n}{n}\right) \\
&= \frac{1}{2} \left(2^{n} - 2\right) \\
&= 2^{n-1} - 1
\end{align}
$$

In [4]:
for n in range(1, 100):
    assert sum(comb(n, i) for i in range(1, n)) // 2 == 2**(n-1) - 1

In [5]:
def num_bipartitions(n):
    '''
    summing from C(n, 1) to C(n, n-1) and divided by 2
    C(n, 1) is a|bc..n, b|ac..n, n|ab..m
    C(n, n-1) is bc..n|a, ac..n|b, ab..m|n
    C(n, 0) and C(n, n) are not possible
    '''
    return 2**(n-1) - 1

def show_bipartitions(n, show_section=False, start=0, end=None):
    idx_width = len(str(num_bipartitions(n)))
    node_width = len(str(n))
    idx = 0
    u = set(range(1, n + 1))
    print('idx\tpartition')
    for i in range(1, n//2+1):
        if show_section:
            print(f'{i}/{n-i}')
            
        iterable = combinations(range(1, n + 1), i)
        if i == n - i:
            iterable = islice(iterable, comb(n, i) // 2)
            
        for l in iterable:
            if start <= idx and (end == None or idx <= end):
                print(f'{idx:{idx_width}}', end='\t')
                for e in l:
                    print(f'{e:{node_width}}', end=' ')
                print('|', end=' ')
                for e in u - set(l):
                    print(f'{e:{node_width}}', end=' ')
                print()
            idx += 1

def show_bipartition(n, idx):
    show_bipartitions(n, start=idx, end=idx)

show_bipartitions(4)
show_bipartition(4, idx=5)

idx	partition
0	1 | 2 3 4 
1	2 | 1 3 4 
2	3 | 1 2 4 
3	4 | 1 2 3 
4	1 2 | 3 4 
5	1 3 | 2 4 
6	1 4 | 2 3 
idx	partition
5	1 3 | 2 4 


In [6]:
class BipartitionEnc():
    def __init__(self, n):
        self.n = n
        self.b = num_bipartitions(n)
        
        self.idx_offset = {}
        self.idx_offset[1] = 0
        for i in range(2, n + 1):
            self.idx_offset[i] = comb(n, i - 1) + self.idx_offset[i - 1]
            
        self.u = set(range(1, n + 1))

    def _rank(self, desc):
        k = len(desc)
        rank = comb(self.n, k)
        for i, d in enumerate(sorted(desc)):
            rank -= comb(self.n - d, k - i)
        return rank

    def get_idx_from_desc(self, desc):
        desc = [int(d) for d in desc]
        if len(desc) > self.n//2:
            desc = self.u - set(desc)
        offset = self.idx_offset[len(desc)]
        rk = self._rank(desc) + offset - 1
        if rk >= self.b:
            rk = 2 * self.b - 1 - rk
        return rk
    
    def encode(self, tree, return_desc=False):
        t_nodes = tree.get_terminals()
        if len(t_nodes) != self.n:
            raise Exception(f'Expected {self.n} terminal nodes, but the tree has {len(t_nodes)}')
        encoding = []
        descendants = get_descendants(tree.clade)
        res = desc_to_length(descendants)
        for desc, branch_length in res:
            if len(desc) != self.n:
                idx = self.get_idx_from_desc(desc)
                if return_desc:
                    encoding.append((desc, idx, branch_length))
                else:
                    encoding.append((idx, branch_length))
        return encoding

In [7]:
# We have no use for the following two methods
# a) print pairs in 2D array
def printPairs(trees):
    data = []
    # record the tree number
    treeNum = 1;
    for i in trees:
        tree = be.encode(i)
        data.append(tree)
        treeNum += 1
    return data

# same as b but with title
def printBipartIdxTitle(trees,n): 
    # build the table
    data = []
    N = num_bipartitions(n)
    treeNum = 1;
    # get existing index
    Idx = []
    length = []
    
    for i in trees:
        col = []
        for j in range(N * 2):
            col.append(0)
        col.insert(0,"Tree" + str(treeNum))
        data.append(col)
        treeNum += 1
        tree = be.encode(i)
        # record index and length in a single tree
        treeIdx = []
        treeLen = []
        for j in range(len(tree)):
            treeIdx.append(tree[j][0])
            treeLen.append(tree[j][1])
        Idx.append(treeIdx)
        length.append(treeLen)

    # insert branch length to corresponding idx
    for i in range(len(Idx)):
        for j in range(len(Idx[i])):
            if data[i][Idx[i][j] + 1] == 0:
                data[i][Idx[i][j] + 1] = length[i][j]
            else: # if the index is duplicated, the length is inserted in the second index.
                data[i][Idx[i][j] + 8] = length[i][j]
   
    
    
     # generate the index
    title = ["Name"]
    for i in range(N):
        title.append(i)
    for i in range(N):
        title.append(i)
    data.insert(0,title)
    return data


In [8]:
# print data into a csv file

def printFile(data, filePath = 'data/test.csv'):
    f = open(filePath, 'w')
    for i in data:
        # from https://stackoverflow.com/questions/11178061/print-list-without-brackets-in-a-single-row
        print(', '.join(map(str,i)), file = f)

In [9]:
# b) print the bipartition names and the corresponding branch lengths in each tree
def printBipartIdx(trees,n): 
    be = BipartitionEnc(n)
    # build the table
    data = []
    N = num_bipartitions(n)
    treeNum = 1;
    # get existing index
    Idx = []
    length = []
    
    for i in trees:
        col = []
        for j in range(N):
            col.append(0)
        data.append(col)
        treeNum += 1
        tree = be.encode(i)
        # record index and length in a single tree
        treeIdx = []
        treeLen = []
        for j in range(len(tree)):
            treeIdx.append(tree[j][0])
            treeLen.append(tree[j][1])
        Idx.append(treeIdx)
        length.append(treeLen)

    # insert branch length to corresponding idx
    for i in range(len(Idx)):
        for j in range(len(Idx[i])):
            data[i][Idx[i][j]] += length[i][j]
    
    # generate index
    title = []
    for i in range(N):
        title.append(i)
    data.insert(0,title)
    
    return data


In [10]:
'''print trees into csv file
path: the path of trees; 
n: number of taxa; 
target: the name of csv file
'''

def printTree(path, n, target):
    trees = Phylo.parse(path, 'newick')
    # print use a method
    #result = printPairs(treeList)
    #printFile(result,"data/pair_4taxa_tree.csv")

    # print use b method
    result = printBipartIdx(trees,n)
    printFile(result,target)

In [11]:
# print trees
printTree('data/4taxa_trees.tre', 4, 'data/bipartition_4taxa_tree.csv')

In [12]:
printTree('data/8taxa_trees.tre', 8, 'data/bipartition_8taxa_tree.csv')

In [13]:
printTree('data/10taxa_trees.tre', 10, 'data/bipartition_10taxa_tree.csv')

In [16]:
printTree('data/15taxa_trees_1.tre', 15, 'data/bipartition_15taxa_tree_1.csv')

In [None]:
printTree('data/15taxa_trees_2.tre', 15, 'data/bipartition_15taxa_tree_2.csv')