# Implementation of a Union-Find Data Structure with rank-compression.

In [11]:
class DisjointForests:
    def __init__(self, n):
        assert n >= 1, 'Empty disjoint forest is not allowed.'
        self.n = n
        self.parents = [None]*n
        self.rank = [None]*n
        
    # Converts the forest into a dictionary of sets
    def dictionary_of_sets(self):
        d = {}
        for i in range(self.n):
            if self.is_representative(i):
                d[i] = set([i])
        for j in range(self.n):
            if self.parents[j] is not None:
                root = self.find(j)
                assert root in d
                d[root].add(j)
        return d
    
    def make_set(self, j):
        assert 0 <= j < self.n
        assert self.parents[j] is None, 'j is already a member of some set.'
        self.parents[j] = j
        self.rank[j] = 1
        
    def is_representative(self, j):
        return self.parents[j] == j
    
    def get_rank(self, j):
        return self.rank[j]
    
    # Return the representative of the set to which j belongs to.
    # The algorithm employs the strategy of rank compression.
    def find(self, j):
        assert 0 <= j <= self.n
        assert self.parents[j] is not None, 'j is not a member of the forest.'
        if self.parents[j] == j:
            return j
        self.parents[j] = self.find(self.parents[j])
        return self.parents[j]
    
    def link(self, i, j):
        if self.rank[j] > self.rank[i]:
            self.parents[i] = j
            return
        self.parents[j] = i
        if self.rank[i] == self.rank[j]:
            self.rank[i] += 1
            
    # Compute union of j1 and j2.
    def union(self, j1, j2):
        assert 0 <= j1 < self.n
        assert 0 <= j2 < self.n
        assert self.parents[j1] != None
        assert self.parents[j2] != None
        rep1 = self.find(j1)
        rep2 = self.find(j2)
        if rep1 != rep2:
            self.link(rep1, rep2)
        

In [12]:
# Testing the disjoint set data structure.
d = DisjointForests(10)

for i in range(10):
    d.make_set(i)
    
for i in range(10):
    assert d.find(i) == i, f'Failed: Find on {i} must return {i} back'
    
d.union(0,1)
d.union(2,3)
assert(d.find(0) == d.find(1)), '0 and 1 have been union-ed together'
assert(d.find(2) == d.find(3)), '2 and 3 have been union-ed together'
assert(d.find(0) != d.find(3)), '0 and 3 should be in different trees'
assert((d.get_rank(0) == 2 and d.get_rank(1) == 1) or
(d.get_rank(1) == 2 and d.get_rank(0) == 1)), 'one of the nodes 0 or 1 must have rank 2'
assert((d.get_rank(2) == 2 and d.get_rank(3) == 1) or
(d.get_rank(3) == 2 and d.get_rank(2) == 1)), 'one of the nodes 2 or 3 must have rank 2'

d.union(3,4)
assert(d.find(2) == d.find(4)), '2 and 4 must be in the same set in the family.'

d.union(5,7)
d.union(6,8)
d.union(3,7)
d.union(0,6)

assert(d.find(6) == d.find(1)), '1 and 6 must be in the same set in the family'
assert(d.find(7) == d.find(4)), '7 and 4 must be in the same set in the family'
print('-- All tests passed --')

-- All tests passed --


# Implementing an Undirected graph.

In [13]:
class UndirectedGraph:
    def __init__(self, n):
        assert n >= 1, 'Empty graphs not allowed.'
        self.n = n
        self.edges = []
        self.vertex_data = [None]*self.n
        
    def set_vertex_data(self, j, dat):
        assert 0 <= j < self.n
        self.vertex_data[j] = dat
    
    def get_vertex_data(self, j):
        assert 0 <= j < self.n
        return self.vertex_data[j]
    
    def add_edge(self, i, j, wij):
        assert 0 <= i < self.n
        assert 0 <= j < self.n
        assert i != j
        self.edges.append((i, j, wij))
        
    def sort_edges(self):
    # sort edges in ascending order of weights.
        self.edges = sorted(self.edges, key=lambda edg_data: edg_data[2])

# Using the union find data structure for finding strongly connected components in a graph.

In [15]:
def compute_scc(g, W): # W is the weight threshold.
    d = DisjointForests(g.n)
    for i in range(len(g.vertex_data)):
        g.vertex_data[i] = i
        d.make_set(i)
    for u, v, w in g.edges:
        if w <= W:
            if d.find(u) != d.find(v):
                d.union(u, v)
    return d.dictionary_of_sets()

In [16]:
# Running test cases.

g3 = UndirectedGraph(8)
g3.add_edge(0,1,0.5)
g3.add_edge(0,2,1.0)
g3.add_edge(0,4,0.5)
g3.add_edge(2,3,1.5)
g3.add_edge(2,4,2.0)
g3.add_edge(3,4,1.5)
g3.add_edge(5,6,2.0)
g3.add_edge(5,7,2.0)
res = compute_scc(g3, 2.0)
print('SCCs with threshold 2.0 computed by your code are:')
assert len(res) == 2, f'Expected 2 SCCs but got {len(res)}'
for (k, s) in res.items():
    print(s)
    
# Let us check that your code returns what we expect.
for (k, s) in res.items():
    if (k in [0,1,2,3,4]):
        assert (s == set([0,1,2,3,4])), '{0,1,2,3,4} should be an SCC'
    if (k in [5,6,7]):
        assert (s == set([5,6,7])), '{5,6,7} should be an SCC'
        
# Let us check that the thresholding works
print('SCCs with threshold 1.5')
res2 = compute_scc(g3, 1.5) # This cutsoff edges 2,4 and 5, 6, 7
for (k, s) in res2.items():
    print(s)
assert len(res2) == 4, f'Expected 4 SCCs but got {len(res2)}'

for (k, s) in res2.items():
    if k in [0,1,2,3,4]:
        assert (s == set([0,1,2,3,4])), '{0,1,2,3,4} should be an SCC'
    if k in [5]:
        assert s == set([5]), '{5} should be an SCC with just a single node.'
    if k in [6]:
        assert s == set([6]), '{6} should be an SCC with just a single node.'
    if k in [7]:
        assert s == set([7]), '{7} should be an SCC with just a single node.'
        
print('-- All tests passed --')

SCCs with threshold 2.0 computed by your code are:
{0, 1, 2, 3, 4}
{5, 6, 7}
SCCs with threshold 1.5
{0, 1, 2, 3, 4}
{5}
{6}
{7}
-- All tests passed --


# Computing the minimum spanning tree using Kruskal's algorithm.

In [17]:
def compute_mst(g):
    d = DisjointForests(g.n)
    mst_edges = []
    g.sort_edges()
    tot_w = 0 # Total weight of the mst.
    for i in range(len(g.vertex_data)):
        d.make_set(i)
    
    for edge in g.edges:
        if d.find(edge[0]) != d.find(edge[1]):
            mst_edges.append(edge)
            d.union(edge[0], edge[1])
            tot_w += edge[2]
    return (mst_edges, tot_w)

In [18]:
# Running tests

g3 = UndirectedGraph(8)
g3.add_edge(0,1,0.5)
g3.add_edge(0,2,1.0)
g3.add_edge(0,4,0.5)
g3.add_edge(2,3,1.5)
g3.add_edge(2,4,2.0)
g3.add_edge(3,4,1.5)
g3.add_edge(5,6,2.0)
g3.add_edge(5,7,2.0)
g3.add_edge(3,5,2.0)


(mst_edges, mst_weight) = compute_mst(g3)
print('Your code computed MST: ')
for (i,j,wij) in mst_edges:
    print(f'\t {(i,j)} weight {wij}')
    
print(f'Total edge weight: {mst_weight}')

assert mst_weight == 9.5, 'Optimal MST weight is expected to be 9.5'
assert (0,1,0.5) in mst_edges
assert (0,2,1.0) in mst_edges
assert (0,4,0.5) in mst_edges
assert (5,6,2.0) in mst_edges
assert (5,7,2.0) in mst_edges
assert (3,5,2.0) in mst_edges
assert (2,3, 1.5) in mst_edges or (3,4, 1.5) in mst_edges
print('-- All tests passed --')

Your code computed MST: 
	 (0, 1) weight 0.5
	 (0, 4) weight 0.5
	 (0, 2) weight 1.0
	 (2, 3) weight 1.5
	 (5, 6) weight 2.0
	 (5, 7) weight 2.0
	 (3, 5) weight 2.0
Total edge weight: 9.5
-- All tests passed --
