## Task ba7c

In [50]:
import numpy as np
from collections import defaultdict

In [222]:
def read_data(fname='rosalind_ba7c.txt'):
    with open(fname, 'r') as f:
        n = int(f.readline().strip())
        D = [0]*n
        for i in range(n):
            row = [int(elem) for elem in f.readline().strip().split()]
            D[i] = np.array(row)
        D = np.array(D)
    return D
   

def get_limb_len(D, i):
    n = D.shape[0]
    min_val = np.inf
    for j in range(n):
        for k in range(n):
            if i != j and j != k and i != k:
                cur_val = (D[i][j] + D[i][k] - D[j][k]) / 2
                if cur_val < min_val:
                    min_val = cur_val
    return int(min_val)


def get_leaves_cond(D):
    n = D.shape[0]
    # satisfy  Di,j = Di,n + Dj,n, return i, j and x
    for i in range(n):
        for j in range(i+1, n):
#             print(i, j, D[i, j], D[i, n-1], D[j, n-1])
#             print(D[i, j] == D[i, n-1] + D[j, n-1])
            if D[i, j] == D[i, n-1] + D[j, n-1]:
                return i, j, D[i, n-1]
    return np.nan


def get_path(tree, i, j):
    q = [[i]]
    visited = set([i])
    path = []
    while len(q) > 0:
        cur_path = q.pop()
        node = cur_path[-1]
        visited.add(node)
        if node == j:
            path = cur_path
            break
#         print(tree, node)
        for next_node in tree['adj'][node]:
            if next_node not in visited:
                q.append(cur_path + [next_node])

    return path


def get_position_to_add(tree, path, x):
    dist = 0
    for k in range(len(path)-1):
        m, l = path[k], path[k+1]
#         print('--------', path, tree)
#         print('--------', m, l, k, dist, tree['w'][tuple(sorted([m, l]))], x)

        if dist + tree['w'][tuple(sorted([m, l]))] > x:
            src = m
            dest = l
            weight_src = x - dist
            weight_dest = tree['w'][tuple(sorted([m, l]))] + dist - x 
            return src, dest, weight_src, weight_dest
        dist += tree['w'][tuple(sorted([m, l]))]
        
def add_new_node(tree, x, i, j, new_node_name):
    path = get_path(tree, i, j)
    dist = 0
    src, dst, weight_src, weight_dest = get_position_to_add(tree, path, x)
    
    if weight_src > 0:
        cur_new_node = new_node_name
        future_new_node = new_node_name + 1
        tree['adj'][src].remove(dst)
        tree['adj'][dst].remove(src)
        tree['adj'][src].append(cur_new_node)
        tree['adj'][dst].append(cur_new_node)
        tree['adj'][cur_new_node] = [src, dst]
        tree['w'][(src, cur_new_node)] = weight_src
        tree['w'][(dst, cur_new_node)] = weight_dest
        tree['w'].pop(tuple(sorted([src, dst])))
    return tree, future_new_node

def additive_phylogeny_recursive(D, n, new_node_name):
    # terminal rule
    if n == 2:
        # tree: adjacency list and weights
        tree = {'adj': defaultdict(list), 'w': defaultdict(int)}
        tree['adj'][0].append(1)
        tree['adj'][1].append(0)
        tree['w'][(0, 1)] = D[0, 1]
        return tree, new_node_name
    limb_len = get_limb_len(D, n-1)
    
    for i in range(n-1):
        D[i, n-1] -= limb_len
        D[n-1, i] = D[i, n-1]
        
    i, j, x = get_leaves_cond(D)
#     print(i, j, x, D)
    D = D[:-1, :-1]
    tree, cur_new_node = additive_phylogeny_recursive(D, n-1, new_node_name)
    tree, future_new_node = add_new_node(tree, x, i, j, cur_new_node)
    
    tree['adj'][cur_new_node].append(n-1)
    tree['adj'][n-1].append(cur_new_node)
    tree['w'][(n-1, cur_new_node)] = limb_len
    return tree, future_new_node
    

def additive_phylogeny(D):
    n = D.shape[0]
    return additive_phylogeny_recursive(D, n, n)


def print_tree(tree):
    for i in sorted(tree['adj']):
        for j in sorted(tree['adj'][i]):
            print (str(i) + "->" + str(j) + ":" + str(tree['w'][tuple(sorted([i, j]))]))

def main(fname='rosalind_ba7c.txt'):
    D = read_data(fname)
    tree, _ = additive_phylogeny(D)
    print_tree(tree)

In [223]:
main('sample_data/ba7c/sample.txt')

0->4:11
1->4:2
2->5:6
3->5:7
4->0:11
4->1:2
4->5:4
5->2:6
5->3:7
5->4:4


In [224]:
main(fname='sample_data/ba7c/input.txt')

0->29:745
1->36:156
2->32:788
3->30:409
4->31:280
5->35:125
6->33:492
7->34:657
8->53:311
9->43:820
10->37:280
11->38:723
12->39:417
13->40:864
14->41:236
15->42:89
16->51:713
17->44:445
18->45:87
19->46:441
20->47:783
21->48:348
22->49:922
23->50:662
24->51:375
25->52:718
26->55:868
27->54:841
28->55:890
29->0:745
29->39:323
29->45:355
30->3:409
30->39:64
30->41:698
31->4:280
31->41:656
31->47:110
32->2:788
32->34:527
32->50:276
33->6:492
33->38:884
33->45:809
34->7:657
34->32:527
34->47:866
35->5:125
35->44:230
35->46:481
36->1:156
36->38:87
36->42:247
37->10:280
37->40:464
37->54:112
38->11:723
38->33:884
38->36:87
39->12:417
39->29:323
39->30:64
40->13:864
40->37:464
40->46:522
41->14:236
41->30:698
41->31:656
42->15:89
42->36:247
42->43:982
43->9:820
43->42:982
43->52:170
44->17:445
44->35:230
44->53:416
45->18:87
45->29:355
45->33:809
46->19:441
46->35:481
46->40:522
47->20:783
47->31:110
47->34:866
48->21:348
48->49:683
48->51:381
49->22:922
49->48:683
49->52:952
50->23:662
50->

In [226]:
# results are different for extra dataset but task is still completed for test task
with open('sample_data/ba7c/output.txt', 'r') as f:
    out = [elem.rstrip() for elem in f.readlines()]
for elem in out:
    print(elem)

0->55:745
1->48:156
2->52:788
3->54:409
4->53:280
5->49:125
6->51:492
7->50:657
8->31:311
9->41:820
10->47:280
11->46:723
12->45:417
13->44:864
14->43:236
15->42:89
16->33:713
17->40:445
18->39:87
19->38:441
20->37:783
21->36:348
22->35:922
23->34:662
24->33:375
25->32:718
26->29:868
27->30:841
28->29:890
29->28:890
29->31:965
29->26:868
30->27:841
30->34:687
30->47:112
31->8:311
31->40:416
31->29:965
32->41:170
32->25:718
32->35:952
33->24:375
33->16:713
33->36:381
34->52:276
34->23:662
34->30:687
35->32:952
35->22:922
35->36:683
36->33:381
36->21:348
36->35:683
37->20:783
37->53:110
37->50:866
38->19:441
38->44:522
38->49:481
39->18:87
39->51:809
39->55:355
40->17:445
40->31:416
40->49:230
41->32:170
41->42:982
41->9:820
42->41:982
42->15:89
42->48:247
43->53:656
43->54:698
43->14:236
44->38:522
44->13:864
44->47:464
45->12:417
45->54:64
45->55:323
46->11:723
46->51:884
46->48:87
47->10:280
47->30:112
47->44:464
48->42:247
48->46:87
48->1:156
49->5:125
49->38:481
49->40:230
50->52:52

In [212]:
# driver code
# main()