# Disjoint Set (Union-Find) Data Structure

This notebook implements the Disjoint Set Abstract Data Type (ADT) with the following operations:
- **MAKE-SET(x)** in Θ(1): Creates a set containing only element x
- **FIND-SET(x)** in O(α(n)): Finds the representative (root) of the set containing x
- **UNION(x,y)** in O(α(n)): Merges the sets containing x and y

where α(n) is the inverse Ackermann function, which grows extremely slowly (α(n) < 5 for all practical values of n).

## Implementation Strategy

To achieve O(α(n)) time complexity, we use two key optimizations:

1. **Union by Rank**: When merging two sets, attach the tree with lower rank under the root of the tree with higher rank. This keeps trees shallow.

2. **Path Compression**: During FIND-SET, make all nodes on the path from x to the root point directly to the root. This flattens the tree structure.

The combination of these techniques gives us nearly constant time operations!

In [32]:
from typing import Dict, Hashable, Optional

class DisjointSet:
    """
    Disjoint Set (Union-Find) data structure with path compression and union by rank.
    
    Attributes:
        parent: Dictionary mapping each element to its parent
        rank: Dictionary mapping each element to its rank (approximate tree height)
    """
    
    def __init__(self):
        """Initialize empty disjoint set structure."""
        self.parent: Dict[Hashable, Hashable] = {}
        self.rank: Dict[Hashable, int] = {}
    
    def make_set(self, x: Hashable) -> None:
        """
        Create a new set containing only element x.
        Time Complexity: Θ(1)
        
        Args:
            x: The element to create a set for
        """
        if x not in self.parent:
            self.parent[x] = x  # x is its own parent (root)
            self.rank[x] = 0    # Initial rank is 0
    
    def find_set(self, x: Hashable) -> Hashable:
        """
        Find the representative (root) of the set containing x.
        Uses path compression for efficiency.
        Time Complexity: O(α(n)) amortized
        
        Args:
            x: The element to find the set representative for
            
        Returns:
            The representative element of the set containing x
        """
        if x not in self.parent:
            raise ValueError(f"Element {x} is not in any set")
        
        # Path compression: make x point directly to the root
        if self.parent[x] != x:
            self.parent[x] = self.find_set(self.parent[x])
        
        return self.parent[x]
    
    def union(self, x: Hashable, y: Hashable) -> None:
        """
        Merge the sets containing x and y.
        Uses union by rank to keep trees balanced.
        Time Complexity: O(α(n)) amortized
        
        Args:
            x: Element in the first set
            y: Element in the second set
        """
        # Find the roots of both sets
        root_x = self.find_set(x)
        root_y = self.find_set(y)
        
        # If already in the same set, nothing to do
        if root_x == root_y:
            return
        
        # Union by rank: attach smaller rank tree under root of higher rank tree
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            # Same rank: make one root the parent and increment its rank
            self.parent[root_y] = root_x
            self.rank[root_x] += 1
    
    def connected(self, x: Hashable, y: Hashable) -> bool:
        """
        Check if x and y are in the same set.
        
        Args:
            x: First element
            y: Second element
            
        Returns:
            True if x and y are in the same set, False otherwise
        """
        try:
            return self.find_set(x) == self.find_set(y)
        except ValueError:
            return False
    
    def __repr__(self) -> str:
        """String representation of the disjoint set structure."""
        sets: Dict[Hashable, list] = {}
        for element in self.parent:
            root = self.find_set(element)
            if root not in sets:
                sets[root] = []
            sets[root].append(element)
        return f"DisjointSet({list(sets.values())})"

## Basic Usage Examples

In [33]:
# Create a new disjoint set structure
ds = DisjointSet()

# MAKE-SET: Create individual sets
elements = [1, 2, 3, 4, 5, 6, 7, 8]
for elem in elements:
    ds.make_set(elem)

print("Initial state - each element in its own set:")
print(ds)
print()

Initial state - each element in its own set:
DisjointSet([[1], [2], [3], [4], [5], [6], [7], [8]])



In [34]:
# UNION: Merge some sets
ds.union(1, 2)
ds.union(3, 4)
ds.union(5, 6)
print("After union(1,2), union(3,4), union(5,6):")
print(ds)
print()

# More unions
ds.union(1, 3)  # Merge {1,2} with {3,4}
print("After union(1,3) - merges {1,2} with {3,4}:")
print(ds)
print()

After union(1,2), union(3,4), union(5,6):
DisjointSet([[1, 2], [3, 4], [5, 6], [7], [8]])

After union(1,3) - merges {1,2} with {3,4}:
DisjointSet([[1, 2, 3, 4], [5, 6], [7], [8]])



In [35]:
# FIND-SET: Find the representative of each element
print("Finding representatives:")
for elem in [1, 2, 3, 4, 5, 6, 7, 8]:
    print(f"FIND-SET({elem}) = {ds.find_set(elem)}")
print()

# Check connectivity
print("Connectivity checks:")
print(f"Are 1 and 4 connected? {ds.connected(1, 4)}")
print(f"Are 1 and 5 connected? {ds.connected(1, 5)}")
print(f"Are 7 and 8 connected? {ds.connected(7, 8)}")

Finding representatives:
FIND-SET(1) = 1
FIND-SET(2) = 1
FIND-SET(3) = 1
FIND-SET(4) = 1
FIND-SET(5) = 5
FIND-SET(6) = 5
FIND-SET(7) = 7
FIND-SET(8) = 8

Connectivity checks:
Are 1 and 4 connected? True
Are 1 and 5 connected? False
Are 7 and 8 connected? False


## Time Complexity Verification

In [37]:
import time
import random

def benchmark_disjoint_set(n):
    """Benchmark disjoint set operations."""
    ds = DisjointSet()
    
    # MAKE-SET operations
    start = time.time()
    for i in range(n):
        ds.make_set(i)
    make_set_time = time.time() - start
    
    # UNION operations
    start = time.time()
    for i in range(n - 1):
        ds.union(i, i + 1)
    union_time = time.time() - start
    
    # FIND-SET operations (on a fully connected set)
    start = time.time()
    for _ in range(n):
        ds.find_set(random.randint(0, n - 1))
    find_set_time = time.time() - start
    
    return make_set_time, union_time, find_set_time

# Test with different sizes
sizes = [1000, 5000, 10000, 50000]
print(f"{'n':>10} | {'MAKE-SET (ms)':>15} | {'UNION (ms)':>12} | {'FIND-SET (ms)':>15}")
print("-" * 65)

for n in sizes:
    ms_time, u_time, fs_time = benchmark_disjoint_set(n)
    print(f"{n:>10} | {ms_time*1000:>15.2f} | {u_time*1000:>12.2f} | {fs_time*1000:>15.2f}")

print("\nNote: All operations scale nearly linearly, showing O(α(n)) ≈ O(1) behavior!")

         n |   MAKE-SET (ms) |   UNION (ms) |   FIND-SET (ms)
-----------------------------------------------------------------
      1000 |            0.08 |         0.33 |            0.31
      5000 |            0.35 |         1.67 |            1.73
     10000 |            0.73 |         3.33 |            3.45
     50000 |            3.89 |        16.90 |           17.28

Note: All operations scale nearly linearly, showing O(α(n)) ≈ O(1) behavior!
