In [239]:
from math import inf

In [240]:
purines = ['A', 'G']
pyrimidines = ['C', 'T']
nucleotides = purines + pyrimidines

c = dict()
for nt1 in nucleotides:
    c[nt1] = dict()
    for nt2 in nucleotides:
        # no change
        if nt1 == nt2:
            c[nt1][nt2] = 0.0
        # transition
        elif (nt1 in purines and nt2 in purines) or (nt1 in pyrimidines and nt2 in pyrimidines):
            c[nt1][nt2] = 1.0
        # transversion
        else:
            c[nt1][nt2] = 2.5

In [213]:
c['G']['A']

1.0

In [241]:
def init_tree(num_tips):
    """
    Create an empty tree dictionary with the correct number of nodes.
    This function does not link the nodes together, so there is no tree structure yet.
    """
    tree = dict()
    
    # create all the nodes
    tree['root'] = dict()
    for x in range(1, num_tips + 1):
        tree['tip{}'.format(x)] = dict()
    for x in range(1, num_tips + 1 - 2):
        tree['int{}'.format(x)] = dict()
    
    # initialize node data
    for node in tree:
        for child in ['left', 'right']:
            tree[node][child] = None
        for nt in nucleotides:
            tree[node][nt] = None
    
    return tree

In [242]:
def init_tip(tree, tip_name, observed_nt):
    """
    Intialise a tip of a tree with the correct costs.
    At this tip node, only the observed nucleotides are possible:
    the cost of the observed nucleotide is zero, and
    the cost of any other nucleotide is infinity.
    """
    for nt in nucleotides:
        if nt is observed_nt:
            tree[tip_name][nt] = 0
        else:
            tree[tip_name][nt] = inf

In [243]:
tip_seq = 'CACAG'
my_tree = init_tree(len(tip_seq))

In [191]:
for k, v in my_tree['tip1'].items():
    print(k, v)

left None
right None
A inf
G inf
C 0
T inf


In [244]:
my_tree['root']['left'] = 'int1'
my_tree['root']['right'] = 'int3'
my_tree['int1']['left'] = 'tip1'
my_tree['int1']['right'] = 'tip2'
my_tree['int3']['left'] = 'tip3'
my_tree['int3']['right'] = 'int2'
my_tree['int2']['left'] = 'tip4'
my_tree['int2']['right'] = 'tip5'

In [245]:
print(my_tree[my_tree['int1']['left']])

{'left': None, 'right': None, 'A': None, 'G': None, 'C': None, 'T': None}


In [246]:
for i, nt in enumerate(tip_seq):
    tip_name = 'tip{}'.format(i + 1)
    init_tip(my_tree, tip_name, nt)
    print(tip_name, my_tree[tip_name])


tip1 {'left': None, 'right': None, 'A': inf, 'G': inf, 'C': 0, 'T': inf}
tip2 {'left': None, 'right': None, 'A': 0, 'G': inf, 'C': inf, 'T': inf}
tip3 {'left': None, 'right': None, 'A': inf, 'G': inf, 'C': 0, 'T': inf}
tip4 {'left': None, 'right': None, 'A': 0, 'G': inf, 'C': inf, 'T': inf}
tip5 {'left': None, 'right': None, 'A': inf, 'G': 0, 'C': inf, 'T': inf}


In [225]:
for nt1 in nucleotides:
    left_costs = list()
    right_costs = list()
    for nt2 in nucleotides:
        left_costs.append(c[nt2][nt1])
        right_costs.append(c[nt2][nt1])
print(left_costs," ", right_costs)

[2.5, 2.5, 1.0, 0.0]   [2.5, 2.5, 1.0, 0.0]


In [247]:
def sankoff_calculate(c_matrix, tree, node_name):
    """
    For the specified node of the tree, calculate the minimum possible cost 
    for each nucleotide. 
    """
    for nt1 in nucleotides:
        left_costs = list()
        right_costs = list()
        left_node = tree[tree[node_name]['left']]
        right_node = tree[tree[node_name]['right']]
           
        for nt2 in nucleotides:
            left_costs.append(c[nt2][nt1] + left_node[nt2])
            right_costs.append(c[nt2][nt1] + right_node[nt2])
                
        tree[node_name][nt1] = min(left_costs) + min(right_costs)



In [248]:
def sankoff_traverse(c_matrix, tree, node_name):
    """
    Traverse the tree recursively, calculating costs for each node.
    Calculate a node's children before the node itself.
    If a child node already has its costs calculated, don't try to calculate it.
    This function depends on correctly initialising the costs in the tips
    of the tree before traversing the tree.
    """
    print("Examining node {}".format(node_name))
    if tree[tree[node_name]['left']]['A'] is None:
        sankoff_traverse(c_matrix, tree, tree[node_name]['left'])
    else:
        print("Node {} values already known".format(tree[node_name]['left']))
    
    if tree[tree[node_name]['right']]['A'] is None:
        sankoff_traverse(c_matrix, tree, tree[node_name]['right'])
    else:
        print("Node {} values already known".format(tree[node_name]['right']))   
    
    print("Calculating node {}".format(node_name))
    sankoff_calculate(c_matrix, tree, node_name)

In [249]:
sankoff_traverse(c, my_tree, 'root')

Examining node root
Examining node int1
Node tip1 values already known
Node tip2 values already known
Calculating node int1
Examining node int3
Node tip3 values already known
Examining node int2
Node tip4 values already known
Node tip5 values already known
Calculating node int2
Calculating node int3
Calculating node root


In [349]:
min_cost = inf
min_nts = list()
min_value = min(min_cost_dict.values())
min_nts = [my_tree['root'][nt] for nt in nucleotides]
min_cost_dict = {i:x for i, x in (min_cost_dict.items()) if x == min_value}
print("Minimum cost is {} for {}".format(list(min_cost_dict.values()), " and ".join(min_cost_dict.keys())))        

Minimum cost is [6.0, 6.0] for A and C
