# Eager Prim's Minimum Spanning Tree
Instead of blindly inserting edges into a PQ which could later become stale, the eager version of Prim's tracks <font color="darkblue" size="2"><b>(node, edge) key-value pairs</b></font> that can easily be **updated** and **polled** to determine the next best edge to add to the MST

One possible solution is to use and <font color="slate" size="2"><b>Indexed Priority Queue</b></font> which can efficiently update and poll key-value pairs.  

This reduces the overall time complexity from <font color="orange" size="2"><b>O(E * log(E)) of Lazy version</b></font> to <font color="green" size="2"><b>O( log(V) )</b></font> thus making the update and poll operations <font color="green" size="2"><b>O( log(V) ).</b></font>

In [19]:
# Undirected weighted graph as an example
graph = {
    0: [(1, 9), (2, 0), (3, 5), (5, 7)],
    1: [(0, 9), (3, -2), (4, 3), (6, 4)],
    2: [(0, 0), (5, 6)],
    3: [(0, 5), (1, -2), (5, 2), (6, 3)],
    4: [(1, 3), (6, 6),],
    5: [(0, 7), (2, 6), (3, 2), (6, 1)],
    6: [(1, 4), (3, 3), (4, 6), (5, 1)]
}

In [20]:
def eager_prims(graph, start=0):
    '''
    An Eager version of Prim's MST algorithm for finding the MST of a graph if any
    
    Args:
    - graph: A dictinary of adjacency lists, where each key represents a node
             with list of edges, where each tuple represents an edge direction 
             and associated weight (if any)
             e.g.: [(0, 2), (1, 5), (3, 11), (2, 8)]
    - start: Optional, starting node for Prim's algorithm.
             Default: 0
             
    Returns:
    - mst_cost:  Integer, total cost of the MST's edges
    - mst_edges: List of tuples, where each tuple represents an edge
                 in form: (node_from, node_to, weight)
                 e.g.: [(0, 2, 0), (0, 3, 5), (3, 1, -2)]
    '''
    
    # Initialization
    total_edges = len(graph) - 1               # MST's total length
    ipq = IPQ()                                # Indexed Priority Queue Class object
    ipq.min_heap(2, total_edges ** 2)          # Instantiating Binary Heap
    edge_count, mst_cost = 0, 0                # Variables to hold MST's solution
    visited = [False] * len(graph)             # Visited nodes array
    mst_edges = [None] * total_edges           # MST's nodes array
    
    def relax_edges_at_node(current_node_index):
        '''
        Loop over node's edges and attempt to update them with lower
        weight values
        
        Args:
        - current_node_index: Current node we are at
        
        Returns:
        - None
        '''
        idx = current_node_index
        visited[current_node_index] = True
        
        # Get the list of edges for the current node
        # from the graph
        edges = graph[current_node_index]
        
        # Loop over all edges of the current node
        for edge in edges:
            destination = edge[0]
            
            # Skip already visited nodes
            if visited[destination]: continue
            
            # If PQ does not contain current node: add it to PQ
            if not ipq.contains(destination):
                ipq.insert((edge[1], edge, idx), destination)
            # If it does: attemt to update it with the current value,
            # If current edge weights are lower
            else:
                ipq.decrease(destination, (edge[1], edge, idx))
    
    # Start relaxing edges from the 'start' node
    relax_edges_at_node(start)
    
    # While PQ is not empty and edge count is not equal to total edges:
    # Keep polling from PQ.
    while not ipq.is_empty() and edge_count != total_edges:
        
        # Poll next best edge from the PQ based on min. weight
        nxt_edge = ipq.poll_min_value()
        # Next node 
        dest_node_index = nxt_edge[1][0]
        # Updating solution variables
        mst_edges[edge_count] = nxt_edge[2], nxt_edge[1][0], nxt_edge[1][1]
        edge_count += 1
        mst_cost += nxt_edge[0]
        
        # Relax edges at the destination node
        relax_edges_at_node(dest_node_index)
    
    # No MST exists
    if edge_count != total_edges:
        return (None, None) 
    
    return (mst_cost, mst_edges)

In [21]:
cost, edges = eager_prims(graph)

if cost is None:
    print("No MST exists!")
else:
    print("MST edges:")
    for edge in edges:
        print(edge[0], "->", edge[1], 'weight:', edge[2])
    print("\nMST total cost:", cost)

MST edges:
0 -> 2 weight: 0
0 -> 3 weight: 5
3 -> 1 weight: -2
3 -> 5 weight: 2
5 -> 6 weight: 1
1 -> 4 weight: 3

MST total cost: 9


In [22]:
class IPQ:
    def __init__(self):
        self.debug_swap = False
        self._size = 0
        self._num_elem = 0
        self._degree = 0
        self._child = []
        self._parent = []
        self.ki_count = 0
        # The Position map (pos_map) to map Key Indexes (ki) to where the position of that
        # key is represented in the priority queue in the domain (0, sz)
        self.pos_map = []
        
        # The Inverse Map (inv_map) stores the indexes of the keys in the range (0, sz)
        # which make up the priority queue. It should be noted that 'im' and 'pm'
        # are inverses of each other, so: pm[im[i]] = im[pm[i]] = i
        self.inv_map = []
        
        # The values associated with the keys. It is very importantt to note
        # that this array is indexed by the key indexes(aka 'ki')
        self.values = []
   
    def min_heap(self, degree, max_size):
        self._degree = max(2, degree)
        self._num_elem = max(self._degree + 1, max_size)
        
        self.ki_arr = [j for j in range(self._size)] 
        self.pos_map = [0] * self._num_elem
        self.inv_map = [0] * self._num_elem
        self._child = [0] * self._num_elem
        self._parent = [0] * self._num_elem
        self.values = [None] * self._num_elem
        
        for i in range(self._num_elem):
            self._parent[i] = (i - 1) // self._degree    
            self._child[i] = i * self._degree + 1
            self.pos_map[i] = self.inv_map[i] = -1

    def size(self):
        return self._size
    
    def is_empty(self):
        return self._size == 0
    
    def contains(self, ki):
        self.key_inbounds_or_raise(ki)
        return self.pos_map[ki] != -1
      
    def peek_min_key_index(self):
        self.is_not_empty_or_raise()
        return self.inv_map[0]
    
    def poll_min_key_index(self):
        minkey = self.peek_min_key_index() 
        self.delete(minkey)
        return minkey
     
    def peek_min_value(self):
        self.is_not_empty_or_raise()
        return self.values[self.inv_map[0]]
          
    def poll_min_value(self):
        min_value = self.peek_min_value()
        self.delete(self.peek_min_key_index())
        return min_value
    
    def insert(self, value, ki=None):
        if ki is None:
            ki = self.ki_count
            self.ki_arr.append(self.ki_count)
            self.ki_count += 1
        '''by default inserts at ki=_size'''
        if ki + 1 > self._num_elem:
            raise IndexError(f'Key index out of bounds; recieved: {ki}')
        if self.contains(ki):
            raise ValueError(f'Index already exists; recieved: {ki}')
        self.value_not_None_or_raise(value)
        self.pos_map[ki] = self._size
        self.inv_map[self._size] = ki
        self.values[ki] = value
        self.swim(self._size)
        self._size += 1  #v1
          
    def value_of(self, ki):
        self.key_exists_or_raise(ki)
        return self.values[ki]
    
    def delete(self, ki):
        self.key_exists_or_raise(ki)
        i = self.pos_map[ki]
        self._size -= 1 
        self.swap(i, self._size)
        self.sink(i)
        self.swim(i)
        value = self.values[ki] 
        self.values[ki] = None
        self.pos_map[ki] = -1
        self.inv_map[self._size] = -1
        return value 
    
    def sink(self, i):
        j = self.min_child(i)
        while j != -1 and self.values[self.inv_map[i]] > self.values[self.inv_map[j]]:
            tmp1 = j
            self.swap(i, j)
            i = tmp1
            j = self.min_child(i)
            
    def update(self, ki, value):
        self.key_exists_and_value_not_None_or_raise(ki, value)
        i = self.pos_map[ki]
        old_value = self.values[ki]
        self.values[ki] = value
        self.sink(i)
        self.swim(i)
        return old_value

    def decrease(self, ki, value):
        self.key_exists_and_value_not_None_or_raise(ki, value)
        if self.less(value, self.values[ki], is_value=True):
            self.values[ki] = value
            self.swim(self.pos_map[ki])

    def increase(self, ki, value):
        self.key_exists_and_value_not_None_or_raise(ki, value)
        if self.less(self.values[ki], value, is_value=True):
            self.values[ki] = value
            self.sink(self.pos_map[ki])
                     
    ''' Helper functions '''
    def swim(self, i):

        while self.less(i, self._parent[i]) and i > 0:
            temp = self._parent[i]
            self.swap(i, self._parent[i])
            i = temp

    def swap(self, heap_i, heap_j):
        '''
        Swaps node i with node j in pos_map and inv_map to maintain
        heap invariant. Used in delete, sink and swim functions.
        
        Args:
        - heap_i: Heap's node index
        - heap_j: Heap's node index
        
        Returns: None
        '''        
        if self.debug_swap:
            print('\n======= DEBUG: swap =======')
            print('---------------------------')
            print('\tBefore swap:')
            print('\tkey_arr:', self.ki_arr)
            print('\tval_arr:', self.values)
            print('\tpos_map:', self.pos_map)
            print('\tinv_map:', self.inv_map)
            
        ki_i = self.inv_map[heap_i]
        ki_j = self.inv_map[heap_j]
        self.inv_map[heap_i], self.inv_map[heap_j] = self.inv_map[heap_j], self.inv_map[heap_i]
        self.pos_map[ki_i], self.pos_map[ki_j] = self.pos_map[ki_j], self.pos_map[ki_i]

        if self.debug_swap:
            print('\n\tAfter swap:')
            print('\tkey_arr:', self.ki_arr)
            print('\tval_arr:', self.values)
            print('\tpos_map:', self.pos_map)
            print('\tinv_map:', self.inv_map)
            print('---------------------------\n')

    
    # From the parent node at index [i] find the minimum child below it                 
    def min_child(self, i):
        self._children = {}
        index = -1 
        frm = self._child[i] 
        to = min(self._size, (frm + self._degree))
        for j in range(frm, to):    
            self._children[j] = self.values[self.inv_map[j]]
        for k, v in self._children.items():
            if v == min(self._children.values()):
                index = k
        return index 
    
    ''' Tests if the value of node [i] < node [j] '''
    def less(self, i, j, is_value=False):
        
        # If values are passed directly for comparison
        if is_value:
            if i < j: return True
            else: return False
        # If  heap nodes are passed for comparison
        else: 
            if self.values[self.inv_map[j]] is None: return False
        if self.values[self.inv_map[i]] < self.values[self.inv_map[j]]:
            return True
        else: return False
        
    ''' Helper functions to make the code more readable '''
    def is_not_empty_or_raise(self):
        if self.is_empty(): raise ValueError("Priority queue underflow")

    def key_exists_and_value_not_None_or_raise(self, ki, value):
        self.key_exists_or_raise(ki)
        self.value_not_None_or_raise(value)

    def key_exists_or_raise(self, ki):
        if not self.contains(ki): raise ValueError(f'Index does not exist; received: {ki}')
     
    def value_not_None_or_raise(self, value):
        if value == None: raise ValueError('Value cannot be None')

    def key_inbounds_or_raise(self, ki):
        if ki < 0 or ki > self._num_elem:
            raise ValueError(f'Key index out of bounds; recieved: {ki}')
   
    ''' Test functions '''
    # Recursively checks if this heap is a min heap. This method is used
    # for testing purposes to validate the heap invariant
    def is_min_heap(self):
        return self._is_min_heap(0)
    
    def _is_min_heap(self, i):
        frm = self._child[i]
        to = min(self._size, frm + self._degree)
        for j in (frm, to):
            if j > self._size: continue
            if not self.less(i, j): return False
            if not self._is_min_heap(j): return False
        
        return True