## Prerequisite- Understand MST
## Prim's Algorithm for MST

In [5]:
# Prim using adjacency list and a min-heap (priority queue).
# Graph is a dict or list: adjacency list where graph[u] = [(v, weight), ...].

import heapq

def prim_algo(adj, start=0):
    n = len(adj)
    visited = set()
    min_heap = [(0, start, -1)] # weight , node, parent
    mst = []
    total_weight = 0

    while min_heap:
        weight, node, parent = heapq.heappop(min_heap)

        if node in visited:
            continue
        visited.add(node)

        if parent != -1: # at first don't count it
            mst.append((weight, node, parent))
            total_weight += weight

        for neigh, w in adj[node]:
            if neigh not in visited:
                heapq.heappush(min_heap, (w, neigh, node))
    return mst, total_weight

# Time complexity: O(E log V)
# Space complexity: O(V + E)
adj = {
        0: [(1, 2), (3, 6)],
        1: [(0, 2), (2, 3), (3, 8), (4, 5)],
        2: [(1, 3), (4, 7)],
        3: [(0, 6), (1, 8), (4, 9)],
        4: [(1, 5), (2, 7), (3, 9)]
    }

mst_edges, total_weight = prim_algo(adj, start=0)

print("Edges in MST (weight, node, parent):")
for e in mst_edges:
    print(e)
print("Total weight:", total_weight)

Edges in MST (weight, node, parent):
(2, 1, 0)
(3, 2, 1)
(5, 4, 1)
(6, 3, 0)
Total weight: 16


## Union Find

In [9]:
#  # O(constructor) = O(V)
# # O(find) = amortized O(𝜶(V))* ≅ amortized O(1)
# # O(union) = amortized O(𝜶(V))* ≅ amortized O(1)

# * 𝜶, known as the Inverse Ackermann function, grows so
# slowly that 𝜶(All the observable particles in the universe)
# is approximately equal to 4.
# https://codeforces.com/blog/entry/98275

## Union by Rank
class UnionFind_Rank:
    def __init__(self, size):
        self.parent = [i for i in range(size)]
        self.rank = [0] * size
    def find(self, x):
        if x == self.parent[x]:
            return self.parent[x]

        self.parent[x] = self.find(self.parent[x]) # path compression
        return self.parent[x]
    def union(self, x , y):
        parentx, parenty = self.find(x), self.find(y)
        # connect smaller rank to larger rank

        # if you connect the smaller to larger the rank won't increase
        if parentx != parenty:
            rankx = self.rank[parentx]
            ranky = self.rank[parenty]

            if rankx > ranky:
                self.parent[parenty] = parentx
            elif rankx < ranky:
                self.parent[parentx] = parenty
            else:
                self.parent[parenty] = parentx
                self.rank[parentx] += 1

In [16]:
# Union by Size
class UnionFind_Size:
    def __init__(self, size):
        self.parent = [i for i in range(size)]
        self.size = [0] * size
    def find(self, x):
        if x == self.parent[x]:
            return self.parent[x]

        # do path compression
        self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    def union(self, x , y):
        parentx, parenty = self.find(x), self.find(y)

        if parentx != parenty:
            if self.size[parentx] > self.size[parenty]:
                self.parent[parenty] = parentx
                self.size[parentx] += self.size[parenty]
            else:
                self.parent[parentx] = parenty
                self.size[parenty] += self.size[parentx]

In [20]:
class UnionFind:
    def __init__(self, size):
        self.parent = [i for i in range(size)]
        self.rank = [0] * size
        self.size = [0] * size

    def find(self, x):
        if x == self.parent[x]:
            return self.parent[x]
        # path compression
        self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def unionByRank(self, x, y):
        parentx , parenty = self.find(x), self.find(y)

        if parentx != parenty:
            rankx = self.rank[parentx]
            ranky = self.rank[parenty]

            if rankx > ranky:
                self.parent[parenty] = parentx
            elif rankx < ranky:
                self.parent[parentx] = parenty
            else:
                self.parent[parenty] = parentx
                self.rank[parentx] += 1
    def unionBySize(self, x, y):
        parentx , parenty = self.find(x), self.find(y)

        if parentx != parenty:
            if self.size[parentx] > self.size[parenty]:
                self.parent[parenty] = parentx
                self.size[parentx] += self.size[parenty]
            else:
                self.parent[parentx] = parenty
                self.size[parenty] += self.size[parentx]

## Kruskal's Algorithm

In [23]:
def kruskal_algorithm(n, edges):
    edges.sort(key= lambda edge : edge[2])

    uf = UnionFind(n)
    total_weight = 0
    mst = []

    for u, v, weight in edges:
        if uf.find(u) != uf.find(v):
            uf.unionByRank(u, v)
            total_weight += weight
            mst.append((u, v , weight))

            # if we already have n - 1 edges the MST is complete
            if len(mst) == n - 1:
                break
    return total_weight, mst

# Time Complexity:
    # Sorting edges: O(ElogE), where E is the number of edges.
    # Union-Find operations: O(α(n)), where α(n) is the inverse Ackermann function (almost constant).
# Thus, the overall time complexity is dominated by the edge sorting: O(ElogE).
# Space Complexity: O(V+E), where V is the number of vertices and E is the number of edges.
    
edges = [
    (0, 1, 10),  # Edge between node 0 and 1 with weight 10
    (0, 2, 6),   # Edge between node 0 and 2 with weight 6
    (0, 3, 5),   # Edge between node 0 and 3 with weight 5
    (1, 3, 15),  # Edge between node 1 and 3 with weight 15
    (2, 3, 4)    # Edge between node 2 and 3 with weight 4
]

n = 4  # Number of nodes

mst_cost, mst_edges = kruskal_algorithm(n, edges)
print("MST Cost:", mst_cost)  # Output: MST Cost: 19
print("MST Edges:", mst_edges)  # Output: MST Edges: [(2, 3, 4), (0, 3, 5), (0, 1, 10)]

MST Cost: 19
MST Edges: [(2, 3, 4), (0, 3, 5), (0, 1, 10)]


In [25]:
# trick to find the number of components after implementing unionFind
# components = sum(1 for i in range(n) if uf.find(i) == i)
# sum(1 for i in range(n) if uf.find(i) == i):
# This is a generator expression that iterates over all nodes (from 0 to n-1), and for each node, it checks if that node is a root (uf.find(i) == i).
# If uf.find(i) == i is true (meaning node i is a root), the expression yields 1 (i.e., it counts that node as a separate component).
# The sum() function then sums up all the 1s, giving the total number of connected components in the graph.