# Minimum Spanning Tree

## Kruskal's Algorithm:
1. At first, we will sort the edges in ascending order of their weights. 
2. After this, select the edge having the minimum weight and add it to the MST. If an edge creates a cycle, we reject it. 
3. Repeat the above steps till we cover all the vertices

In [2]:
class Graph:
    def __init__(self, vertex):
        self.V = vertex
        self.graph = []
 
    def add_edge(self, u, v, w):
        self.graph.append([u, v, w])
    
    # to find set of an element i
    def find(self,parent,i):
        if parent[i] == i:
            return i
        return self.find(parent,parent[i])
    
    # A function that does union of two sets
    def union(self,parent,rank,x,y):
        xroot = self.find(parent,x)
        yroot = self.find(parent,y)
        
        # Attach the smaller rank tree under root of higher rank tree (union by rank)
        if rank[xroot] < rank[yroot]:
            parent[xroot] = yroot
        elif rank[xroot] > rank[yroot]:
            parent[yroot] = xroot
            
        # if the ranks are the same, then make one as root and increament its rank by one
        else:
            parent[yroot] = xroot
            rank[xroot] += 1
            
    # For debugging purpose
    def print_graph(self):
        print(self.graph)
            
    
    def MST_Kruskal(self):
        # This will store the resultant MST
        result = []  
         
        # An index variable, used for sorted edges
        i = 0
         
        # An index variable, used for result[]
        e = 0
        
        # STEP 1 : Sort all the edges in non-decreasing order of their weight. 
        self.graph = sorted(self.graph,key = lambda item: item[2])
        
        parent = []
        rank = []
        
        # Create V suset with single elements
        for node in range(self.V):
            # Create parent as [0,1,2,3,4 ... n]
            parent.append(node)
            # Create rank as [0,0,.....0] upto n elements
            rank.append(0)
        
        # Number of edges to be taken = |V| - 1
        while e < self.V - 1:
            # STEP 2: Pick the smallest edge and increament the index for next iteration
            u,v,w = self.graph[i]
            i = i + 1
            x = self.find(parent,u)
            y = self.find(parent,v)
            
            # If including this edge doesn't cause cycle, include it in result and increment the index of result for next edge
            if x != y:
                e = e + 1
                result.append([u,v,w])
                self.union(parent,rank,x,y)
                
            # else discard the edge
        
        minimumCost = 0
        print("Edges in the constructed MST")
        for u,v,weight in result:
            minimumCost += weight
            print("%d -- %d == %d" %(u,v,weight))
        print("Minimum spanning tree:",minimumCost)
        print("Rank = ",rank) 
        print("Parent = ", parent)
        self.print_graph()


In [3]:
    
g = Graph(5)
g.add_edge(0, 1, 8)
g.add_edge(0, 2, 5)
g.add_edge(1, 2, 9)
g.add_edge(1, 3, 11)
g.add_edge(2, 3, 15)
g.add_edge(2, 4, 10)
g.add_edge(3, 4, 7)

g.MST_Kruskal()

Edges in the constructed MST
0 -- 2 == 5
3 -- 4 == 7
0 -- 1 == 8
2 -- 4 == 10
Minimum spanning tree: 30
Rank =  [2, 0, 0, 1, 0]
Parent =  [0, 0, 0, 0, 3]
[[0, 2, 5], [3, 4, 7], [0, 1, 8], [1, 2, 9], [2, 4, 10], [1, 3, 11], [2, 3, 15]]


## Prim's Algorithm

- Start by picking any vertex r to be the root of the tree
- While the tree does not contain all vertices in the graph, find the shortest edge leaving the tree and add it to the tree

STEP1 : Choose any element r; set S = {r} and A = NULL. (Take r as the root of our spanning tree) 


STEP2 : Find a lightest edge such that one end point is in S and other is in (V-S). Add this edge to A and its (other) endpoint to S.


STEP3: If V-S = NULL, then stop and output(minimum) spanning tree (S,A). Otherwise go to Step 1. 

**Question**: How does the algorithm find the lightest edge efficiently?
- We use priority queue to find the lightest edge

In [13]:
import sys

from collections import defaultdict

In [57]:
class Heap():
 
    def __init__(self):
        self.array = []
        self.size = 0
        self.pos = []
 
    def newMinHeapNode(self, v, dist):
        minHeapNode = [v, dist]
        return minHeapNode
 
    # A utility function to swap two nodes of
    # min heap. Needed for min heapify
    def swapMinHeapNode(self, a, b):
        t = self.array[a]
        self.array[a] = self.array[b]
        self.array[b] = t
 
    # A standard function to heapify at given idx
    # This function also updates position of nodes
    # when they are swapped. Position is needed
    # for decreaseKey()
    def minHeapify(self, idx):
        smallest = idx
        left = 2 * idx + 1
        right = 2 * idx + 2
 
        if left < self.size and self.array[left][1] < \
                                self.array[smallest][1]:
            smallest = left
 
        if right < self.size and self.array[right][1] < \
                                self.array[smallest][1]:
            smallest = right
 
        # The nodes to be swapped in min heap
        # if idx is not smallest
        if smallest != idx:
 
            # Swap positions
            self.pos[ self.array[smallest][0] ] = idx
            self.pos[ self.array[idx][0] ] = smallest
 
            # Swap nodes
            self.swapMinHeapNode(smallest, idx)
 
            self.minHeapify(smallest)
 
    # Standard function to extract minimum node from heap
    def extractMin(self):
 
        # Return NULL wif heap is empty
        if self.isEmpty() == True:
            return
 
        # Store the root node
        root = self.array[0]
 
        # Replace root node with last node
        lastNode = self.array[self.size - 1]
        self.array[0] = lastNode
 
        # Update position of last node
        self.pos[lastNode[0]] = 0
        self.pos[root[0]] = self.size - 1
 
        # Reduce heap size and heapify root
        self.size -= 1
        self.minHeapify(0)
 
        return root
 
    def isEmpty(self):
        return True if self.size == 0 else False
 
    def decreaseKey(self, v, dist):
 
        # Get the index of v in  heap array
 
        i = self.pos[v]
 
        # Get the node and update its dist value
        self.array[i][1] = dist
 
        # Travel up while the complete tree is not
        # hepified. This is a O(Logn) loop
        while i > 0 and self.array[i][1] < self.array[int((i - 1) / 2)][1]:
 
            # Swap this node with its parent
            self.pos[ self.array[i][0] ] = (i-1)/2
            self.pos[ self.array[int((i-1)/2)][0] ] = i
            self.swapMinHeapNode(i, int((i - 1)/2) )
 
            # move to parent index
            i = int((i - 1) / 2);
 
    # A utility function to check if a given vertex
    # 'v' is in min heap or not
    def isInMinHeap(self, v):
 
        if self.pos[v] < self.size:
            return True
        return False
 
 
def printArr(parent, n):
    for i in range(1, n):
        print("% d - % d" % (parent[i], i))

In [74]:
class Graph():
 
    def __init__(self, V):
        self.V = V
        self.graph = defaultdict(list)
 
    # Adds an edge to an undirected graph
    def addEdge(self, src, dest, weight):
 
        # Add an edge from src to dest.  A new node is
        # added to the adjacency list of src. The node
        # is added at the beginning. The first element of
        # the node has the destination and the second
        # elements has the weight
        newNode = [dest, weight]
        self.graph[src].insert(0, newNode)
 
        # Since graph is undirected, add an edge from
        # dest to src also
        newNode = [src, weight]
        self.graph[dest].insert(0, newNode)
 
    def print_graph(self):
        print(self.graph)
    
    # The main function that prints the Minimum
    # Spanning Tree(MST) using the Prim's Algorithm.
    # It is a O(ELogV) function
    

    def MST_prim(self):
        # Get the number of vertices in the graph
        V = self.V
        # key values used to pick minimum weight edge in cut
        key = []  
        # List to store contructed MST
        parent = []
        # minHeap represents set E
        minHeap = Heap()

        # Initialize min heap with all vertices. Key values of all
        # vertices (except the 0th vertex) is is initially infinite
        for v in range(V):
            parent.append(-1)
            key.append(sys.maxsize)
            minHeap.array.append(minHeap.newMinHeapNode(v, key[v]))
            minHeap.pos.append(v)

        # Make key value of 0th vertex as 0 so that it is extracted first
        minHeap.pos[0] = 0
        key[0] = 0
        minHeap.decreaseKey(0,key[0])
        minHeap.size = V
        
#         print(minHeap.pos)
#         print(minHeap.size)
#         print(minHeap.array)
        
        # In the following loop, min heap contains all nodes
        # not yet added in the MST.
        while minHeap.isEmpty() == False:
            newHeapNode = minHeap.extractMin()
            u = newHeapNode[0] # because 0 is the vertex, and 1 is the key(weight)
            # Traverse through all adjacent vertices of u
            # (the extracted vertex) and update their
            # distance values
            for pCrawl in self.graph[u]:
                v = pCrawl[0] 
                # If shortest distance to v is not finalized
                # yet, and distance to v through u is less than
                # its previously calculated distance
                if minHeap.isInMinHeap(v) and pCrawl[1] < key[v]:
                    key[v] = pCrawl[1]
                    parent[v] = u
                    
                    # update the distance value in min heap too
                    minHeap.decreaseKey(v,key[v])
        printArr(parent,V)
                
        


In [75]:
graph = Graph(9)
graph.addEdge(0, 1, 4)
graph.addEdge(0, 7, 8)
graph.addEdge(1, 2, 8)
graph.addEdge(1, 7, 11)
graph.addEdge(2, 3, 7)
graph.addEdge(2, 8, 2)
graph.addEdge(2, 5, 4)
graph.addEdge(3, 4, 9)
graph.addEdge(3, 5, 14)
graph.addEdge(4, 5, 10)
graph.addEdge(5, 6, 2)
graph.addEdge(6, 7, 1)
graph.addEdge(6, 8, 6)
graph.addEdge(7, 8, 7)
graph.MST_prim()

 0 -  1
 5 -  2
 2 -  3
 3 -  4
 6 -  5
 7 -  6
 0 -  7
 2 -  8
