#  Lazy Prim's Minimum Spanning Tree

Prim's is a greedy (always selects the next best edge) MST algorithm that works well <font color="gold" size="3"><b>on dense graphs</b></font> . On these graphs, Prim's meets or improves on the time bounds of its popular rivals (Kruskal's & Boruvka's)

However, when it comes to finding the **minimum spanning forest** on a disconnected graph, Prim's cannot do this as easily (the algorithm must be run on each connected component individually). And it is not easily parallelizable.

The <font color="orange" size="3"><b>lazy</b></font> version of Prim's has a time complexity of <font color="orange" size="3"><b>O(E * log( E ))</b></font>

Given an **undirected graph** with weighted edges, a <font color="66C5FF" size="3"><b>Minimum Spanning Tree</b></font> (MST) is a subset of the edges in the graph which connects all vertices together (without creating any cycles) while <u>minimizing the total edge cost</u>.  

**Note**: It is not uncommon for a graph to have multiple MSTs of equal costs

In [1]:
# Undirected weighted graph as an example
graph = {
    0: [(1, 10), (2, 1), (3, 4)],
    1: [(0, 10), (2, 3), (4, 0)],
    2: [(0, 1), (1, 3), (3, 2), (5, 8)],
    3: [(0 ,4), (2, 2), (5, 2), (6, 7)],
    4: [(1, 0), (5, 1), (7, 8),],
    5: [(2, 8), (4, 1), (6, 6), (7, 9)],
    6: [(3, 7), (5, 6), (7, 12)],
    7: [(4, 8), (5, 9), (6, 12)]
}

In [2]:
import heapq

def lazy_prims(graph, start=0):
    '''
    Find Minimum Spanning Tree for a graph, if present.

    Args:
    - graph: a dict of adjacency lists, where each key represents
             a node with list of edges, where each tuple represents
             an edge direction and associated weight (if any)
             e.g., [(0, 2), (1, 5), (3, 11), (2, 8)]
             
    - start: optional; default = 0. Node to start exploration from.

    Returns:
    - mst_cost: Int, Combined cost of Minimum Spanning Tree
    - mst_edges: List of tuples representing Minimum Spanning Tree,
                 of form (node_at,(node_to, weight)).
                 E.g.: (2, (3, 5))
    '''
    # Initialization.
    pq = []                       # MinHeap PQ to store edges as tuples
                                  # of form (node_at,(node_to, weight))
        
    visited = [0] * len(graph)    # Visited array
    max_mst = len(graph) - 1      # MST length
    edge_count = 0                # Edge counter
    mst_cost = 0                  # Cost counter
    mst_edges = [None] * max_mst  # Solution MST array
    
    def add_edges(node_id):
        '''
        Add all outgoing egdes from the node to the priority queue,
        if the destination node is unvisited.

        Args:
        - node_id: Node currently being explored

        Returns:
        - None   
    '''
        # Mark current node as visited.
        visited[node_id] = True
        # Iterate over all edges going outwards from the current node.
        neighbors = graph[node_id]
        for edge in neighbors:
            if not visited[edge[0]]:
                # If destination node is unvisited,
                # push outgoing edge to minheap.
                # Edge[1] is used as priority key
                heapq.heappush(pq, (edge[1], (node_id, edge))) 
                
    # Begin exploration from the 'start' node
    add_edges(start)
    # While PQ is not empty and MST is not complete
    while len(pq) != 0 and edge_count != max_mst:
        # Grab edge with minimal weight
        edge = heapq.heappop(pq)[1]
        node_id = edge[1][0]
        
        # If destination node is unvisited
        if visited[node_id]: continue
        
        # Update counters
        mst_edges[edge_count] = edge
        edge_count += 1
        mst_cost += edge[1][1]
        # Add edges to PQ
        add_edges(node_id)
        
    if edge_count != max_mst:
        return (None, None) # No MST exists!
    
    return (mst_cost, mst_edges)

In [3]:
cost, edges = lazy_prims(graph)

if cost is None:
    print("No MST exists!")
else:
    print("MST edges:")
    for edge in edges:
        print(edge[0], "->", edge[1])#[0], 'weight', edge[1][0])
    print("\nMST cost:", cost)

MST edges:
0 -> (2, 1)
2 -> (3, 2)
3 -> (5, 2)
5 -> (4, 1)
4 -> (1, 0)
5 -> (6, 6)
4 -> (7, 8)

MST cost: 20
