In [2]:
import sys
import numpy as np
from copy import deepcopy

In [3]:
class UPGMA:
    def __init__(self):
        #n, disMatrix = self._input()
        n, disMatrix = self.readFromFile()
        adj = self.runUPGMA(disMatrix, n)
        self.printGraph(adj)
        self.saveResult(adj)        

    def _input(self):
        data = sys.stdin.read().strip().split('\n')
        n = int(data[0])
        distMatrix = [[0]*n for _ in range(n)]
        for i in range(n):
            d = data[i+1].split()
            for k in range(n):
                distMatrix[i][k] = int(d[k])
        return n, distMatrix
    
    def readFromFile(self):
        f = open('/content/drive/My Drive/Colab Notebooks/Rosalind/rosalind_ba7d.txt', 'r')
        data = []
        for line in f:
            data.append(line.strip())
        n = int(data[0])
        distMatrix = [[0]*n for _ in range(n)]
        for i in range(n):
            d = data[i+1].split()
            for k in range(n):
                distMatrix[i][k] = int(d[k])
        return n, distMatrix

    def saveResult(self, adj):
        f = open('result.txt', 'w')
        for i, nodes in enumerate(adj):
            for d, w in nodes:
                f.write(str(i)+'->'+str(d)+':'+'%0.3f' % w+'\n')

    def printDistMatrix(self, distMatrix):
        for d in distMatrix:
            print(' '.join([str(i) for i in d]))

    def printGraph(self, adj):
        for i, nodes in enumerate(adj):
            for d, w in nodes:
                print(str(i)+'->'+str(d)+':'+'%0.3f' % w)

    def runUPGMA(self, disMatrix, n):
        D = np.array(disMatrix, dtype = float)
        np.fill_diagonal(D, np.inf)        
        clusters = [[i, 1] for i in range(n)]
        adj = [[] for i in range(n)]
        age = [0. for i in range(n)]
        if len(D) <= 1:
            return adj
        while True:
            index = np.argmin(D)
            i = index // len(D)
            j = index % len(D)
            i_new = len(adj)
            adj.append([])
            C_new = [i_new, clusters[i][1] + clusters[j][1]]
            adj[i_new].append(clusters[i][0])
            adj[i_new].append(clusters[j][0])
            adj[clusters[i][0]].append(i_new)
            adj[clusters[j][0]].append(i_new)
            age.append(D[i, j] / 2)

            if 2 == len(D):
                break

            d_new = (D[i,:]*clusters[i][1] + D[j,:]*clusters[j][1]) / (clusters[i][1]+clusters[j][1])
            d_new = np.delete(d_new, [i, j], 0)
            D = np.delete(D, [i, j], 0)
            D = np.delete(D, [i, j], 1)
            D = np.insert(D, len(D), d_new, axis = 0)
            d_new = np.insert(d_new, len(d_new), np.inf, axis = 0)
            D = np.insert(D, len(D)-1, d_new, axis = 1)

            if i < j:
                del clusters[j]
                del clusters[i]
            else:
                del clusters[i]
                del clusters[j]
            
            clusters.append(C_new)

        adjL = deepcopy(adj)
        for i, nodes in enumerate(adj):
            for j, v in enumerate(nodes):
                adjL[i][j] = (v, abs(age[i]-age[v]))
        return adjL

if __name__ == "__main__":
    UPGMA()

0->33:399.500
1->38:427.000
2->34:405.000
3->41:469.500
4->29:394.500
5->36:416.000
6->33:399.500
7->32:399.000
8->37:421.000
9->35:412.500
10->41:469.500
11->30:396.500
12->32:399.000
13->40:459.250
14->39:437.500
15->31:397.500
16->35:412.500
17->28:392.500
18->48:571.500
19->39:437.500
20->29:394.500
21->36:416.000
22->34:405.000
23->30:396.500
24->37:421.000
25->31:397.500
26->38:427.000
27->28:392.500
28->17:392.500
28->27:392.500
28->40:66.750
29->4:394.500
29->20:394.500
29->48:177.000
30->11:396.500
30->23:396.500
30->43:103.375
31->15:397.500
31->25:397.500
31->53:213.913
32->7:399.000
32->12:399.000
32->45:120.250
33->0:399.500
33->6:399.500
33->42:88.125
34->2:405.000
34->22:405.000
34->43:94.875
35->9:412.500
35->16:412.500
35->42:75.125
36->5:416.000
36->21:416.000
36->46:129.625
37->8:421.000
37->24:421.000
37->46:124.625
38->1:427.000
38->26:427.000
38->44:78.375
39->14:437.500
39->19:437.500
39->45:81.750
40->13:459.250
40->28:66.750
40->50:119.604
41->3:469.500
41->10: