# Union Find

### Weighted quick-union by count and path compression by halving.

- This class represents a union–find data type. It supports the classic union and find operations, along with a count operation that returns the total number of sets.

- The union–find data type models a collection of sets containing n elements, with each element in exactly one set. The elements are named 0 through n–1.

- Initially, there are n sets, with each element in its own set. The canonical element of a set (also known as the root, identifier,leader, or set representative) is one distinguished element in the set. Here is a summary of the operations:

> find(p,q): returns the canonical element of the set containing p. The find operation returns the same value for two elements if and only if they are in the same set.

> union(p,q): merges the set
containing element p with the set containing element q. That is, if p and q are in different sets, replace these two sets with a new set that is the union of the two.

> count(): returns the number of sets.

This implementation uses weighted quick union by rank
with path compression by halving.
The constructor takes &Theta;(<em>n</em>) time, where <em>n</em> is the number of elements.

### Big(O) 
- The <em>union</em> and <em>find</em> operations take &Theta;(log <em>n</em>) time in the worst case.
- The <em>unique sets</em> operation takes &Theta;(1) time.

### Implementation

In [101]:
class UnionFind:
    
    def __init__(self, n: int):
        # Create the sandbox of nodes.
        self.nodes = [i for i in range(0,n)]
        
        # Number of unique sets across all the nodes.
        self.unique_sets = n
        
        # How many nodes are there in a given set.
        self.count = [1 for i in range(0,n)]
        
    
    def print_nodes(self):
        """
        Helper function designed to print all variables.
        """
        print('Ind: ', end=" ")
        for i in range(0, len(self.nodes)):
            print(i, end=" ")
        print()
        print('Val: ', end=" ")
        for i in range(0, len(self.nodes)):
            print(self.nodes[i], end=" ")
        print()
        print('Cnt: ', end=" ")
        for i in range(0, len(self.count)):
            print(self.count[i], end=" ")
        print()
        print('The number of unique sets is: ', self.unique_sets)

        
    def root(self,key):
        assert key >= 0 and key < len(self.nodes), 'Index/Key has to be between [0, len(n)-1]'
        while(key != self.nodes[key]):
            self.nodes[key] = self.nodes[self.nodes[key]] #Performs path compression
            key = self.nodes[key]

        return key
            
        
    def isConnected(self, a, b):
        return self.root(a) == self.root(b)
    
    
    def union(self, a, b):
        root_a = self.root(a)
        root_b = self.root(b)
        
        if root_a == root_b:
            return
        
        if self.count[root_a] <= self.count[root_b]:
            self.nodes[root_a] = root_b
            self.count[root_b] = self.count[root_b] + self.count[root_a]
        else:
            self.nodes[root_b] = root_a
            self.count[root_a] = self.count[root_b] + self.count[root_a]
            
        self.unique_sets = self.unique_sets - 1
            
            
    def find(self, a):
        return self.root(a)

In [102]:
n = 10
uf = UnionFind(n)

In [103]:
uf.print_nodes()

Ind:  0 1 2 3 4 5 6 7 8 9 
Val:  0 1 2 3 4 5 6 7 8 9 
Cnt:  1 1 1 1 1 1 1 1 1 1 
The number of unique sets is:  10


In [104]:
uf.union(2, 4)
uf.print_nodes()

Ind:  0 1 2 3 4 5 6 7 8 9 
Val:  0 1 4 3 4 5 6 7 8 9 
Cnt:  1 1 1 1 2 1 1 1 1 1 
The number of unique sets is:  9


In [105]:
uf.union(3, 4)
uf.print_nodes()

Ind:  0 1 2 3 4 5 6 7 8 9 
Val:  0 1 4 4 4 5 6 7 8 9 
Cnt:  1 1 1 1 3 1 1 1 1 1 
The number of unique sets is:  8


In [107]:
uf.find(2)

4