In [143]:
from heapq import heappush, heappop, heapify
from collections import defaultdict
import copy

class UnionFind:
    '''A Union find data structure'''
    def __init__(self, nodes):
        self.parent = {}
        self.size = {}
        for i in nodes:
            self.parent[i] = i # Initilize root for subsets
            self.size[self.parent[i]] = 1 # Initialze size of subsets
            
    
    def find(self, node):
        if self.parent[node] == node:
            return node
        else:
            return self.find(self.parent[node])
    

    def union(self, node1, node2):
        '''Merge small subset to large subset'''
        root1 = self.find(node1)
        root2 = self.find(node2)
        if root1 == root2:
            return None
        if self.size[root1] > self.size[root2]:
            self.parent[root2] = root1
            self.size[root1] += self.size[root2]
        else:
            self.parent[root1] = root2
            self.size[root2] += self.size[root1]    

class UnionFindPathCompression:
    '''A Union find data structure, apply path compression'''
    def __init__(self, nodes):
        self.parent = {}
        self.size = {}
        for i in nodes:
            self.parent[i] = i # Initilize root for subsets
            self.size[self.parent[i]] = 1 # Initialze size of subsets
              
    def find(self, node):
        node_cpy = node
        while self.parent[node] != node:
            node = self.parent[node]
        self.parent[node_cpy] = node
        return node
        
    
    def union(self, node1, node2):
        '''Merge small subset to large subset'''
        root1 = self.find(node1)
        root2 = self.find(node2)
        if root1 == root2:
            return None
        if self.size[root1] > self.size[root2]:
            self.parent[root2] = root1
            self.size[root1] += self.size[root2]
        else:
            self.parent[root1] = root2
            self.size[root2] += self.size[root1]    

            
class MSTsolver:
    def __init__(self, path):
        self.path = path
        self.graph = defaultdict(list)
        self.MST = []
        
    def make_graph(self):
        file = open(path, 'r')
        for line in file.readlines()[1:]:
            node1, node2, edgeCost = line.split(" ")
            self.graph[int(node1)].append((int(node2), int(edgeCost)))
            self.graph[int(node2)].append((int(node1), int(edgeCost)))
            
    def solver_prim(self):
        '''Prim's algorithm straightforward implementation'''
        self.MST = []
        totalCost = 0
        oset = list(self.graph.keys())[1:] # Unprocessed nodes
        vset = [list(self.graph.keys())[0]] # Processed nodes
        while oset:
            cost_min = float('inf')
            for n_v in vset:
                for n_g, e in self.graph[n_v]:
                    if n_g not in vset:
                        if e < cost_min:
                            cost_min = e
                            e_selected = (n_v, n_g)
                            
            self.MST.append(e_selected)
            totalCost += cost_min
            vset.append(e_selected[1])            
            oset.remove(e_selected[1])
            
        return totalCost
            
    def solver_prim_heap(self):
        '''
        Prim's algorithm implemented based on heap
        '''
        oset = list(self.graph.keys())
        st_node = oset[0] #starting node
        heap = [(0, st_node,st_node)] # [(edgeCost, start, end)]
        totalCost = 0
        while oset:
            edgeCost, node1, node2 = heappop(heap)
            if node2 in oset: # 这个判断至关重要，防止当前节点链接回已经处理过的节点
                self.MST.append((node1, node2)) # edge (tail --> head)
                totalCost += edgeCost
                for v, l in self.graph[node2]:
                    if v in oset: # 防止当前节点链接回已处理过的节点
                        heappush(heap, (l, node2, v))
                oset.remove(node2)
        return totalCost
    
    def solver_kruskal(self):
        '''Kruskal's algorithm straightforward implementation'''
        edges = []
        # Heap sort on edges
        for key in self.graph.keys():
            for ele in self.graph[key]:
                if ele not in edges:
                    edges.append((ele[1], min(key, ele[0]), max(key, ele[0]))) #(edges, node_small, node_large(index))
        edges = list(set(edges))
        heapify(edges)
        edges_sort = [heappop(edges) for i in range(len(edges))]
        
        # Start merge process
        self.MST = defaultdict(list)
        totalCost = 0
        # Apply DFS to check if there is cycle in results so far
        for e in edges_sort:
            cost, node1, node2 = e
            stack = [node1]
            visited = []
            FLAG = 0  # Indicating cycle happens
            while stack:
                if node1 not in self.MST or node2 not in self.MST:
                    break
                node_pop = stack.pop()
                if node_pop == node2: # Indicating cycle happens
                    FLAG = 1
                    break
                visited.append(node_pop)
                for e_dfs in self.MST[node_pop]:
                    if e_dfs not in visited:
                        stack.append(e_dfs)
            if FLAG == 0:         
                self.MST[node1].append(node2)
                self.MST[node2].append(node1)
                totalCost += cost
        
        return totalCost
    
    def solver_kruskal_UF(self):
        '''Kruskal's algorithm implementation based on union-find'''
        edges = []
        nodes = [key for key in self.graph.keys()]
        # Heap sort on edges
        for key in nodes:
            for ele in self.graph[key]:
                if ele not in edges:
                    edges.append((ele[1], min(key, ele[0]), max(key, ele[0]))) #(edges, node_small, node_large(index))
        edges = list(set(edges))
        heapify(edges)
        edges_sort = [heappop(edges) for i in range(len(edges))]
        
        # Start merge process
        uf = UnionFindPathCompression(nodes)
        self.MST = []
        totalCost = 0
        for e in edges_sort:
            cost, node1, node2 = e
            if uf.find(node1) != uf.find(node2):
                uf.union(node1, node2)
                self.MST.append((node1, node2))
                totalCost += cost
        return totalCost

In [155]:
import time

path = 'MST_test_1.txt'
mst = MSTsolver(path)
mst.make_graph()

st0 = time.time()
mst.solver_prim()

st1 = time.time()
mst.solver_prim_heap()

st2 = time.time()
mst.solver_kruskal()

st3 = time.time()
mst.solver_kruskal_UF()
end = time.time()

print("Prim running time: %fs" %(st1 - st0))
print("Prim_heap runing time: %fs" %(st2 - st1))
print("Kruskal runing time: %fs" %(st3 - st2))
print("Kruskal union find runing time: %fs" %(end - st3))

Prim running time: 0.000000s
Prim_heap runing time: 0.000971s
Kruskal runing time: 0.000000s
Kruskal union find runing time: 0.000000s


## Notes:

这次代码以MST问题为情景研究greedy算法。MST问题即求解联通Graph所有节点且保证Graph中无环的最短路径问题：

### Prim算法
1. Prim算法与Dijkstra算法十分相似。Prim先选取任意起始节点作为辐射起点，然后在每次循环中选取已有节点可以reach到的edge中最短的来加入辐射范围，要求最短且无环，直至所有节点均被辐射

2. 与Dijkstra算法相似，Prim算法可用heap进行加速，因为有重复求最小值（最短路径）的需求

### Kruskal算法
1. Kruskal算法不固定起点进行辐射。它按长度从小到大逐一添加edge直到全图联通

2. 算法实现的关键之处在于每次添加新edge, eg.(v1, v2) 时进行的环路检测，即判断是否添加完edge会形成环， naive的实现方式是利用DFS或BFS在已添加的edge中进行遍历。如果发现已有(v1, v2)路径存在则再添加(v1, v2)就会造成环路，所以应舍弃当前edge。

3. Kruskal判断环路的部分可以用union-find (disjoint set) 进行加速。

## References:

1. Union-find https://blog.csdn.net/dm_vincent/article/details/7655764
2. Union-find path compression https://www.geeksforgeeks.org/union-find-algorithm-set-2-union-by-rank/
