In [1]:
from collections import defaultdict


def find_closest_cluster(matrix):
    cur_min = 10e8
    cur_cluster = (-1, -1)
    for i in matrix:
        for j in matrix[i]:
            if matrix[i][j] < cur_min:
                cur_min = matrix[i][j]
                cur_cluster = (i, j)
    return cur_cluster


def get_star_matrix(matrix, n):
    total_dist = defaultdict(int)
    for i in matrix:
        for j in matrix[i]:
            total_dist[i] += matrix[i][j]
    star_matrix = {}
    for i in matrix:
        star_matrix[i] = {}
        for j in matrix[i]:
            if i == j:
                star_matrix[i][j] = 0
            else:
                star_matrix[i][j] = (n - 2) * matrix[i][j] - total_dist[i] - total_dist[j]
    return star_matrix, total_dist


def neighbour_joining(matrix, n, inserting):
    if n == 2:
        edges = {}
        weights = {}
        for i in matrix:
            for j in matrix[i]:
                if j != i:
                    edges[i] = [j]
                    edges[j] = [i]
                    weights[(i, j)], weights[(j, i)] = matrix[i][j], matrix[i][j]
        return edges, weights
    star_matrix, total_distances = get_star_matrix(matrix, n)
    i, j = find_closest_cluster(star_matrix)
    delta = (total_distances[i] - total_distances[j]) / (n - 2)
    limb_i = (matrix[i][j] + delta) / 2
    limb_j = (matrix[i][j] - delta) / 2
    #   Dk, m = Dm, k = (1 / 2)(Dk, i + Dk, j - Di, j) for any k
    matrix[inserting] = {inserting: 0}
    for element in matrix:
        if element == inserting:
            continue
        matrix[element][inserting] = (matrix[element][i] + matrix[element][j] - matrix[i][j]) / 2
        matrix[inserting][element] = matrix[element][inserting]
    #   remove rows/columns i, j
    del matrix[i], matrix[j]
    for element in matrix:
        del matrix[element][i], matrix[element][j]

    edges, weights = neighbour_joining(matrix, n - 1, inserting + 1)

    edges[i] = [inserting]
    edges[inserting].append(i)
    edges[j] = [inserting]
    edges[inserting].append(j)
    weights[(i, inserting)], weights[(inserting, i)] = limb_i, limb_i
    weights[(j, inserting)], weights[(inserting, j)] = limb_j, limb_j
    return edges, weights

## Newick format

In [2]:
def newick_dfs(edges, visited, names, current_node, node_from):
    visited[current_node] = True
    for key, value in edges.items():
        if key == current_node:
            for i in range(1, len(value), 2):
                if value[i] == node_from or (value[i] == 0 and key != len(visited) - 1):
                    value = value[:i - 1] + value[i+1:]
                    break
            if len(value) == 0:
                return names[current_node]
            left_part = newick_dfs(edges, visited, names, value[1], current_node)
            right_part = newick_dfs(edges, visited, names, value[3], current_node)
            return '({}:{:.2f},{}:{:.2f})'.format(left_part, value[0], right_part, value[2])
    return names[current_node]

## Testing

In [3]:
from collections import defaultdict

### NJ. Test 1

In [6]:
matrix = [[0, 16, 16, 10], [16, 0, 8, 8], [16, 8, 0, 4], [10, 8, 4, 0]]
names = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
n = 4

dict_matrix = {}
for v in range(n):
    dict_matrix[v] = {}
    for u, w in enumerate(matrix[v]):
        dict_matrix[v][u] = w

edges, weights = neighbour_joining(dict_matrix, n, n)
graph = defaultdict(list)

vertices_count = max(edges.keys()) + 1
for v in edges:
    for w in edges[v]:
        if v == 0:
            graph[vertices_count].append(weights[(v, w)] / 2)
            graph[vertices_count].append(v)
            graph[vertices_count].append(weights[(v, w)] / 2)
            graph[vertices_count].append(w)
        else:
            graph[v].append(weights[(v, w)])
            graph[v].append(w)

print(newick_dfs(graph, [False for _ in range(vertices_count + 1)], names, vertices_count, -1))

(A:5.00,((B:5.00,C:3.00):2.00,D:0.00):5.00)


### NJ. Test 2

In [7]:
matrix = [[0, 5, 4, 7, 6, 8], [5, 0, 7, 10, 9, 11], [4, 7, 0, 7, 6, 8], 
          [7, 10, 7, 0, 5, 9], [6, 9, 6, 5, 0, 8], [8, 11, 8, 9, 8, 0]]
names = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
n = 6


dict_matrix = {}
for v in range(n):
    dict_matrix[v] = {}
    for u, w in enumerate(matrix[v]):
        dict_matrix[v][u] = w

edges, weights = neighbour_joining(dict_matrix, n, n)
graph = defaultdict(list)

vertices_count = max(edges.keys()) + 1
for v in edges:
    for w in edges[v]:
        if v == 0:
            graph[vertices_count].append(weights[(v, w)] / 2)
            graph[vertices_count].append(v)
            graph[vertices_count].append(weights[(v, w)] / 2)
            graph[vertices_count].append(w)
        else:
            graph[v].append(weights[(v, w)])
            graph[v].append(w)

print(newick_dfs(graph, [False for _ in range(vertices_count + 1)], names, vertices_count, -1))

(A:0.50,((((D:3.00,E:2.00):1.00,F:5.00):1.00,C:2.00):1.00,B:4.00):0.50)
