# Prim algorithm with heap data structure 

In [376]:
import math
from collections import defaultdict

## MinHeap data structure implementation

In [377]:
# ArrayHeap object extends list
class ArrayHeap(list):
    def __init__(self, array):
        super().__init__(array)
        self.heapSize = 0

class MinHeap:
    def __init__(self, array):
        self.arrayHeap = ArrayHeap(array)
    
    # All the following methods work with zero based array.
    def parent(self, i):
        return math.floor(i/2)
    
    def left(self, i):
        return 2*i + 1
    
    def right(self, i):
        return 2*i + 2
    
    # Execution time: O(lg n)
    def minHeapify(self, i):
        l = self.left(i)
        r = self.right(i)
        if l <= self.arrayHeap.heapSize-1 and self.arrayHeap[l].key < self.arrayHeap[i].key:
            minimo = l
        else:
            minimo = i
        if r <= self.arrayHeap.heapSize-1 and self.arrayHeap[r].key < self.arrayHeap[minimo].key:
            minimo = r
        if minimo != i:
            self.arrayHeap[i], self.arrayHeap[minimo] = self.arrayHeap[minimo], self.arrayHeap[i]
            self.minHeapify(minimo)
            
    # Execution time: O(n)
    def buildMinHeap(self):
        self.arrayHeap.heapSize = len(self.arrayHeap)
        for i in range(math.floor(len(self.arrayHeap)/2 - 1), -1, -1): #downto
            self.minHeapify(i)
    
    #def insert(self, value):
    #    self.arrayHeap.append(value)
    #    self.arrayHeap.heapSize += 1
    #    
    #    index = self.arrayHeap.heapSize - 1
    #    parent = self.parent(index)
    #    
    #    while parent >= 0 and self.arrayHeap[index] < self.arrayHeap[parent]:
    #        self.arrayHeap[index], self.arrayHeap[parent] = self.arrayHeap[parent], self.arrayHeap[index]
    #        index = parent
    #        parent = self.parent(index)
            
    #def delete(self, index):
    #    if self.arrayHeap.heapSize == 0:
    #        print("Error: underflow")
    #        return
    #    else:
    #        removed = self.arrayHeap[index]
    #        last = self.arrayHeap.heapSize-1
    #        self.arrayHeap[index], self.arrayHeap[last] = self.arrayHeap[last], self.arrayHeap[index]
    #        self.arrayHeap.heapSize -= 1
    #        self.arrayHeap.pop(last)
    #        parent = self.parent(index)
    #        self.minHeapify(parent)
    #        return removed
    
    # Execution time: O(lg n)
    # First we update the heap structure, then we remove the last element.
    def extractMin(self):
        if self.arrayHeap.heapSize == 0:
            print("Error: underflow")
            return
        else:
            minimo = self.arrayHeap[0]
            self.arrayHeap[0] = self.arrayHeap[self.arrayHeap.heapSize - 1]
            self.arrayHeap.heapSize -= 1
            self.minHeapify(0)
            # This action is needed in order to remove the last element of the list. Otherwise, we will obtain
            # a duplicate of the last element.
            self.arrayHeap.pop(self.arrayHeap.heapSize)
            
            return minimo

# Definition of Node and Graph objects for graph manipluation        

In [378]:
class Node:
    def __init__(self, tag):
        self.tag = tag
        self.key = None
        self.parent = None
        self.adjacencyList = []
        
    # For test 
    def print(self):
        print("tag =", self.tag, "adjList=", self.adjacencyList, "key=", self.key)

class Graph:
    def __init__(self):
        self.nodes = defaultdict(Node)
        
    def createNode(self, nums):
        for i in range(1, nums+1): # nums+1 in order to cover the last node
            self.nodes[i] = Node(i)
        
     # addNode it's a little bit complex but it helps to track nodes in Prim algorithm.   
    def addNode(self, tag, adjList):
        adjTag = adjList[0]
        adjCost = adjList[1]
        if tag != adjTag: # Check for self loop
            self.nodes[tag].adjacencyList.append([self.nodes[adjTag], adjCost])

    def buildGraph(self, input):
        lines = input.readlines()
        #graph = Graph()
        self.createNode(int(lines[0].split()[0])) # Extract number of vertexes
        lines.pop(0) 
        for line in lines:
            info = list(map(int, line.split())) # Convert all the strings deriving from split to int
            self.addNode(info[0], [info[1], info[2]])   

# Prim algorithm

In [379]:
def MSTPrim(g, r):
    for node in g.nodes.values():
        node.key = math.inf # key, parent is already set
    r.key = 0 
    q = MinHeap(list(g.nodes.values()))
    q.buildMinHeap()
    while len(q.arrayHeap) is not 0:
        u = q.extractMin()
        for v in u.adjacencyList:
            if v[0] and v[1] < v[0].key:
                v[0].parent = u
                v[0].key = v[1]

In [380]:
result = Graph()
result.buildGraph(open("dataset/input_random_17_100.txt", "r"))
MSTPrim(result, result.nodes.get(1))
#sum = 0
##for node in result.nodes.values():
#    sum +=# node.key
#print(sum)