# Minimum Spanning Tree (MST)

### Learning Objective
By the end of this notebook, you should be able to:
1.  Implement **Prim's Algorithm** (Greedy using Min-Heap).
2.  Implement **Kruskal's Algorithm** (Greedy using DSU/Disjoint Set).
3.  Solve MST problems like **Min Cost to Connect All Points**.

---

### Conceptual Notes

**1. What is an MST?**
A subset of edges in a connected, weighted graph that connects all the vertices together, without any cycles and with the minimum possible total edge weight.

**2. Prim's Algorithm (Vertex-Centric)**
*   **Idea:** Grow the MST from a start node ONE edge at a time.
*   **Data Structure:** Priority Queue stores `(weight, node)`.
*   **Logic:** Always pick the cheapest edge connected to the *currently visited set of nodes*.

**3. Kruskal's Algorithm (Edge-Centric)**
*   **Idea:** Sort ALL edges by weight. Pick the smallest edge. If it doesn't form a cycle, add it.
*   **Data Structure:** DSU (Disjoint Set Union) to check for cycles.
*   **Logic:** `if find(u) != find(v): union(u, v), cost += w`.

**4. Comparison**
*   **Prim's:** Good for Dense Graphs O(E log V).
*   **Kruskal's:** Good for Sparse Graphs O(E log E).

### Core Task 1: Prim's Algorithm
Given adjacency list, return the MST sum.

In [None]:
import heapq

def prims_algorithm(n, adj):
    """
    Return sum of MST weights.
    Indeterminate if graph is disconnected (usually assume connected).
    """
    visited = [False] * n
    mst_sum = 0
    
    # Priority Queue: (weight, node). Start with (0, 0).
    pq = [(0, 0)]
    
    # TODO: While pq is not empty:
    #   Pop (wt, node).
    #   If node is visited, continue.
    #   Mark node visited. Add wt to mst_sum.
    #   For (neighbor, edge_wt) in adj[node]:
    #      If not visited[neighbor]: 
    #          Push (edge_wt, neighbor)
            
    return mst_sum

### Core Task 2: Kruskal's Algorithm
Given List of Edges `[u, v, w]`, return MST sum.

In [None]:
class DisjointSet:
    # ... (Copy logic from prev notebook or re-implement simply)
    def __init__(self, n):
        self.parent = list(range(n + 1))
    
    def find(self, i):
        if self.parent[i] == i:
            return i
        self.parent[i] = self.find(self.parent[i])
        return self.parent[i]
    
    def union(self, i, j):
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            self.parent[root_i] = root_j
            return True
        return False

def kruskals_algorithm(n, edges):
    """
    edges: List of [u, v, w]
    """
    mst_sum = 0
    dsu = DisjointSet(n)
    
    # TODO: Sort edges by weight! (Critical step)
    edges.sort(key=lambda x: x[2])
    
    # TODO: Iterate through sorted edges.
    # If dsu.union(u, v) is True:
    #    mst_sum += w
    
    return mst_sum

### Core Task 3: Min Cost to Connect All Points (LeetCode 1584)
Given coordinates `points = [[x1,y1], [x2,y2]...]`.
Cost to connect is Manhattan Distance: `|xi - xj| + |yi - yj|`.
*   **Approach:** Build a dense graph where every pair of points is an edge with weight = distance.
*   **Optimization:** Prim's is better here because Edges = V^2 (Dense).

In [None]:
def minCostConnectPoints(points):
    """
    Return min cost.
    """
    n = len(points)
    
    # TODO: Step 1. Build Adjacency List? Or run Prim's implicitly?
    # Implicit is better to save memory, but list is fine for N=1000.
    adj = [[] for _ in range(n)]
    for i in range(n):
        for j in range(i + 1, n):
            dist = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
            adj[i].append((j, dist))
            adj[j].append((i, dist))
            
    # TODO: Step 2. Run Prim's on 'adj'.
    
    return 0

In [None]:
# --- TEST CELL ---
print("Testing Prim's...")
# 0-1(1), 1-2(1), 0-2(5). MST should take 0-1 and 1-2. Sum = 2.
adj_mst = [
    [(1, 1), (2, 5)],
    [(0, 1), (2, 1)],
    [(0, 5), (1, 1)]
]
res_prim = prims_algorithm(3, adj_mst)
# assert res_prim == 2, f"Prim's Failed: {res_prim}"

print("Testing Kruskal's...")
edges_mst = [[0, 1, 1], [1, 2, 1], [0, 2, 5]]
res_kruskal = kruskals_algorithm(3, edges_mst)
# assert res_kruskal == 2, f"Kruskal's Failed: {res_kruskal}"

print("âœ… Tests Ready")