# Find-Union | Disjoint Set Union | "DSU"

is defined as a data structure that keeps track of a set of elements partitioned into a number of disjoint (non-overlapping) subsets.

## What

1. **Find**: Determine which subset a particular element is in. This can be used for determining if two elements are in the same subset.
2. **Union**: Join two subsets into a single subset. Here first we have to check if the two subsets belong to same set. If no, then we cannot perform union.
3. A DSU will have an operation to combine any two sets. It will be able to tell in which set a sepcific element element exists. It can create a set from a new element.
4. The DSU operations are all (on average) constant time operations.

## Why

Graphs can be abstractly mapped to mathematical sets. A set is a collection of elements all belonging to one unit. In discrete mathematics we perform logical operations on such data structures. A logical AND operation is finding the intersection of two sets. An OR operation is finding the union of two sets. So again, mapping this back to graphs, a connected component is a mathematical set. When we want to count the number of connected components, we're also saying we'd like to count the number of unique sets in collection of disjoint (un connected) sets.
Some useful problems that can be solved using a DSU are:

1. Detect a Cycle in a graph
2. Count # of connected components.
3. Compress Jumps on a Path
4. Paint subarray offline | Determine max events we can attend given schedule

## Interface

1. `make_set(v)`: Creates a new set with new element `v`.
2. `find_set(v)`: Finds the set that `v` belongs to.
3. `union_sets(a, b)`: Combines sets `a` and `b`.
   1. Union by _Rank_ = **Size**
   2. Union by _Rank_ = **Depth**
4. Distance from Child up to Parent


In [None]:
# Skeleton Structure
class DSU:
    def __init__(self):
        self.parents = []

    def make_set(self, v):
        self.parents[v] = v

    def find(self, v):
        if v != self.parents[v]:
            self.parents[v] = self.find(self.parents[v])
        return self.parents[v]

    def union(self, v1, v2):
        p1, p2 = map(self.find, (v1, v2))
        if p1 != p2:
            self.parents[min(p1, p2)] = max(p1, p2)

There's also another interface for the `make_set()` function. This would automatically generate the parent value based on the current length, guaranteeing a unique parent value. Also, this technique returns the parent value to the caller. This method makes sense whenever you're inputs/source values are not numbers themselves that are mappable to a sequentially increasing sequence for example; a string. If the _parent_ is actually some string etc value, then this would be the appropriate technique to represent the data in a DSU data structure.


In [None]:
class DSU:
    def __init__(self) -> None:
        self.parents = []

    # constructor...
    def make_set(self):
        self.parents.append(len(self.parents))
        return len(self.parents) - 1

## How

### `find()` + Path Compression

1. We'll construct the sets as logically _Trees_ describing the tree as a list.

<img src="https://imgur.com/50fAXXD.png" style="max-width:500px">

2. The benefits of this Data structure is that we can express the relationship of parents and children in a hyper-flattened way for optimal lookup: **path compression**. Take the below image for example.

<img src="https://imgur.com/26kskt3.png" style="max-width:500px">

The naive implementation (left side) would be equivalent to the following `find()` code


In [None]:
class DSU:
    def __init__(self):
        self.parents = []

    def find(self, v):
        if v == self.parents[v]:
            return v
        return self.find(self.parents[v])  # Order(n) Time lookup

However, we're able to assign every immediate descendent directly to the parent, creating an n-ary tree. This would greatly reduce the time taken to find a child of a parent. A simple adjustment would show this optimization


In [None]:
class DSU:
    def __init__(self):
        self.parents = []

    def find(self, v):
        if v != self.parents[v]:
            self.parents[v] = self.find(self.parents[v])  # path compression
        return self.parents[v]

### `union()` | _Rank_ by **Size**

We combine sets together based on which of the sets is smaller. The smaller gets merged into the larger as shown below


In [None]:
class DSU:
    def __init__(self):
        self.parents = []
        self.sizes = []  # union by size

    def make_set(self, v):
        self.parents[v] = v
        self.sizes = 1

    def find(self, v):
        if v != self.parents[v]:
            self.parents[v] = self.find(self.parents[v])  # path compression
        return self.parents[v]

    def union(self, v1, v2):
        p1, p2 = map(self.find, (v1, v2))
        if p1 != p2:
            if self.sizes[p1] < self.sizes[p2]:
                p1, p2 = p2, p1  # 💡 swap, so p1 is always larger
            self.parents[p2] = p1  # Merge smaller (p2) into larger (p1)
            self.sizes[p1] += self.sizes[p2]  # Increase p1's size due to merge.

### `union_sets()` | _Rank_ by **Depth**

There's also the technique to create a union based on tree depth of both sets. Depending on the use case/problem trying to solve, either scenario may work. Both are equivalent in time & space complexity terms.


In [None]:
class DSU:
    def __init__(self):
        self.parents = []
        self.depth = []

    def make_set(self, v):
        self.parents[v] = v
        self.depth = 0

    def find(self, v):
        if v != self.parents[v]:
            self.parents[v] = self.find(self.parents[v])  # path compression
        return self.parents[v]

    def union(self, v1, v2):
        p1, p2 = map(self.find, (v1, v2))
        if p1 != p2:
            if self.depth[p1] < self.depth[p2]:
                p1, p2 = p2, p1
            self.parents[p2] = p1
            if (
                self.depth[p1] == self.depth[p2]
            ):  # ⛔️ make sure new parent is not the same depth as child
                self.depth[p1] += 1

## Time Complexity

- If we combine both optimizations - _path compression_ with union by _size / rank_, we will reach nearly constant time queries.
- Also, it's worth mentioning that DSU with union by size / rank, but **without** _path compression_ works in `Order(log(n))` time per query.


# Problem Solving

---

## **Connected Components in an Undirected Graph**

### Difficulty

`Medium`

### Description

Given a boolean 2D matrix, find the number of islands.
A group of connected 1s forms an island. For example, the below matrix contains 5 islands

```python
{1, 1, 0, 0, 0},
{0, 1, 0, 0, 1},
{1, 0, 0, 1, 1},
{0, 0, 0, 0, 0},
{1, 0, 1, 0, 1}
```

### Background

1. > Are the vertices a & b in the same connected component of the graph?
   - This is a rather useful question & anwer: _Kruskals Min-Cost Spanning Tree_ algorithm asks this question, and the Time Complexity `O(m * logn + n^2)` can be reduced to `O(m * log n)` using a DSU. We'll look at _Kruskals_ algorithm below.
2. > Count the connected components in a graph
   - A DSU can run a bit more efficiently than DFS/BFS's `Order(Vertices + Edges)` run time. We'll look at this problem below.

### Solution Approach

1. **BFS & DFS**: It's simple to implement the DFS/BFS + counting the connected components. Given that our graph representation is a matrix and not a list of edges & the total node count, then all solutions are going ot have the same Time Complexity: `Theta(rows * cols)`
2. **DSU**: We'll detect 4 adjacent cells (up, down, left, right), and if those cells do not share the same parent as the current cell, we'll perform a union. Once we finish, we'll count the total number of unique values in the `parents` list which tells us how many unique sets of islands exist.


In [31]:
class DSU:
    def __init__(self, n):
        self.n = n
        self.prev_weights = [0] * n
        self.prev_parents = [i for i in range(n)]
        self.weights = [0] * n
        self.parents = [i for i in range(n)]

    def reset(self):
        self.prev_weights = self.weights
        self.prev_parents = self.parents
        self.weights = [0] * self.n
        self.parents = [i for i in range(self.n)]

    def make_set(self, p):
        self.parents[p] = p
        self.weights[p] = 1

    def find(self, v, prev=False):
        parents = self.parents if prev == False else self.prev_parents
        if v == parents[v]:
            return v
        parents[v] = self.find(parents[v], prev=prev)
        return parents[v]

    def union(self, v1, v2, prev=False):
        parents = self.prev_parents if prev else self.parents
        p1 = self.find(v1)
        p2 = self.find(v2, prev=prev)
        if p1 == p2 and not prev:
            return
        weights = self.prev_weights if prev else self.weights
        parents[p2] = p1
        self.weights[p1] += weights[p2]
        weights[p2] = 0

    def count_islands(self):
        islands = sum([1 for w in self.prev_weights if w])
        return islands


def count_components(matrix):
    n, m, total = len(matrix), len(matrix[0]), 0
    dsu = DSU(n)
    for i in range(n):
        for j in range(m):
            if matrix[i][j]:
                dsu.make_set(j)
                if j - 1 >= 0 and matrix[i][j - 1]:
                    dsu.union(j, j - 1)
                if i - 1 >= 0 and matrix[i - 1][j]:
                    dsu.union(j, j, prev=True)
        total += dsu.count_islands()
        dsu.reset()
    total += dsu.count_islands()
    return total


count_components(
    [
        [1, 1, 1, 1, 0],
        [0, 1, 0, 0, 1],
        [1, 0, 0, 1, 1],
        [0, 0, 0, 0, 0],
        [1, 1, 1, 0, 1],
    ]
)

5

---

#### **Size of Islands | Maximum/Min Size component**

1. The non DSU solution is well known: We simply perform BFS/DFS and compare the result of each components individual run-size. Using a DSU, we call `union_sets` whenever we locate a cell/value that we want to combine with. By the end of the entire graph traversal, our `sizes` or `depths` list will have all connected components. We would simply return the `max/min` value from this array.
2. We can use the same technique in the counting the number of components. We simply update the `.count_islands()` method to keep a globalized max of the size that if finds before reseting the rows.

#### **Compress Jumps along a Segment/Path**

One common application of the DSU is the following: There is a set of vertices, and each vertex has an outgoing edge to another vertex. With DSU you can find the end point, to which we get after following all edges from a given starting point, in almost constant time.

A good example of this application is the problem of painting subarrays. We have a segment of length , each element initially has the color 0. We have to repaint the subarray with the color for each query . At the end we want to find the final color of each cell. We assume that we know all the queries in advance, i.e. the task is offline.

For the solution we can make a DSU, which for **each cell stores a link to the next unpainted cell**. Thus initially each cell points to itself. After painting one requested repaint of a segment, all cells from that segment will point to **the cell after the segment**.

Now to solve this problem, we consider the queries in the reverse order: from last to first. This way when we execute a query, we only have to paint exactly the unpainted cells in the subarray . All other cells already contain their final color. To quickly iterate over all unpainted cells, we use the DSU. We find the left-most unpainted cell inside of a segment, repaint it, and with the pointer we move to the next empty cell to the right.

Here we can use the DSU with path compression, but we cannot use union by rank / size (because it is important who becomes the leader after the merge). Therefore the complexity will be per union (which is also quite fast).

Below is a problem demonstrating the technique...


---

## **LC: 1353: Maximum number of events that can be attended**

### Difficulty

`Medium`

### Description

You are given an array of events where `events[i] = [startDay i, endDay i]`. Every event `i` starts at `startDay i` and ends at `endDay i`.
You can attend an event `i` at any day d where `startTime i <= d <= endTime i`. You can only attend one event at any time d.
Return the maximum number of events you can attend.

<img src="https://imgur.com/Rqzd86a.png" style="max-width:500px">

```
Input: events = [[1,2],[2,3],[3,4]]
Output: 3
Explanation: You can attend all the three events.
One way to attend them all is as shown.
Attend the first event on day 1.
Attend the second event on day 2.
Attend the third event on day 3.

Input: events = [[1,2],[2,3],[3,4],[1,2]]
Output: 4
Explanation: You can attend all the four events.
One way to attend them all is as shown.
Attend the first event on day 1.
Attend the second event on day 2.
Attend the second event on day 3.
Attend the third event on day 4.
```

**NOTE**: Total # of days = `len(events)`

### Solution Approach

Sort the events by their ending time. Intuitively we want to go to as many events as possible
so that means we should prioritize (think greedily) the shortest events instead of the longest events.
We can prioritize by sorting on the end times.

Then we perform a union of all the days that map to events we're selecting to attend.
We smartly point the last day of the selected event to the "next day" so as to answer the question:
"What is my next avilable day if i attend event i?" (parents[start] = next available day). If the
answer to this question is less than or equal to the scheduled end day `p <= end`, then we know it's in our interest
to go to the event. Once we take the event, we move the pointer for next available day 1 value to the right.

The result will be the maximum number of events we can attend given the scheduled events.


In [7]:
def max_events(events):
    total = 0
    events.sort(key=lambda n: n[1])  # sort by ending day of event
    parents = list(
        range(0, events[-1][1] + 2)
    )  # Add an extra index, since we always point to the "next" day

    def find(v):
        if v != parents[v]:
            parents[v] = find(parents[v])
        return parents[v]

    for start, end in events:
        _end = find(start)
        if (
            _end <= end
        ):  # If the next available day is lte to the scheduled end time, take it!
            total += 1
            parents[_end] += 1  # Our next available day is the "next day"
    return total


max_events([[1, 2], [2, 3], [3, 4], [1, 2]])

4

Below would be the solution using the common Data Structure template for union find. Without the explicit data structure we can eliminate about 10 lines of code.


In [12]:
class DSU:
    def __init__(self, events):
        events.sort(key=lambda n: n[1])
        self.parents = list(range(events[-1][1] + 2))

    def find(self, v):
        if v == self.parents[v]:
            return self.parents[v]
        self.parents[v] = self.find(self.parents[v])
        return self.parents[v]

    def union(self, v1, v2):
        p1, p2 = map(self.find, (v1, v2))
        self.parents[p1] = p2


def maxEvents(events):
    dsu = DSU(events)
    total_events = 0
    for start, end in events:
        _end = dsu.find(start)
        if _end <= end:
            total_events += 1
            dsu.union(_end, _end + 1)
    return total_events


maxEvents([[1, 3], [1, 4], [3, 4], [1, 2]])

4

NOTE: We should especially point out the comparison of this approach to a typical Union-Find algorithm

A typical Union-Find will abstract the union of sets to it's own function. In this problem however, we simplified the union on **line 14**. By changing the parents array for the current index to the next index, we're effectively making a union of index `p` and index `p + 1` sets. It's simplified because we're only focused on creating a chain of events. We don't need to 1) Worry about how many events (how many sets), and 2) We don't need to worry about how deep a set can become (how many nodes in a set). Ideas 1 & 2 are typically concerns that are addressed and accounted for in a typical _DSU.union_ method. In this problem we only care about the items NOT in a set, which is the same as saying: _The next available day we can schedule an event._

The _find()_ function in the above solution however does map exactly to a typical _DSU.find_ method. It also leverages the _compression_ step: `p[v] = find(p, p[v])`. This step is what flattens out the set's child nodes from being unecessarily ancestralized. Instead, with compression the relationship from parent to child is as shallow as possible. This makes the lookup time extremely fast ~O(1) lookup time, rather than ~O(log(n)).


---

### Detect A Cycle

1. We track all set sizes:
   1. Negative number indicates a parent set.
   2. Positive number indicates the parent set index for the given i'th vertex.
      i.e. that i'th vertex is NOT a parent.
2. As we discover more and more edges, we ensure the edge vertices belong to
   two different sets. If they belong to the same set, then it means there is a cycle.
3. The algo tracks the "weight" of the set: Frequency of vertices belonging to that set.
   Do not confuse it with edge-weights.


In [32]:
# Solution without an explicit DSU data structure.
def detect_cycle(n, edges):
    parents = [-1] * n
    set_count = n
    for u, v in edges:
        u_parent, u_weight = find(parents, u)
        v_parent, v_weight = find(parents, v)
        if any([u_parent == v, v_parent == u, u_parent == v_parent]):
            return False  # cycle detected
        else:
            union(parents, u_weight, u_parent, v_weight, v_parent)
            set_count -= 1
    return set_count == 1


def find(parents, n):
    weight = parents[n]
    while weight >= 0:
        n = parents[n]
        weight = parents[n]
    return n, weight


def union(parents, u_weight, u_parent, v_weight, v_parent):
    if u_weight <= v_weight:
        parents[u_parent] += v_weight
        parents[v_parent] = u_parent
    else:
        parents[v_parent] += u_weight
        parents[u_parent] = v_parent


args = {
    "n": 8,
    "edges": [
        [0, 1],
        [0, 2],
        [0, 5],
        [1, 4],
        [2, 3],
        [3, 0],
        [3, 7],
        [4, 5],
        [4, 6],
        [4, 7],
        [5, 1],
        [5, 6],
        [7, 6],
        # ['', ''],
    ],
}
print("Result: ", detect_cycle(*args.values()))

Result:  False


A cleaner and more deterministic implementation is below. This is the code we can learn to commit to memory and try to adapt and modify to specific problem use cases.


In [14]:
class DSU:
    def __init__(self, n):
        self.parents = [i for i in range(n)]
        self.ranks = [1] * n

    def find(self, v):
        if v == self.parents[v]:
            return self.parents[v]
        self.parents[v] = self.find(self.parents[v])
        return self.parents[v]

    def union(self, v1, v2):
        p1, p2 = map(self.find, (v1, v2))
        if p1 == p2:
            return True  # cycle detected
        else:
            if self.ranks[p2] > self.ranks[p1]:
                p1, p2 = p2, p1
            self.parents[p2] = p1
            self.ranks[p1] += self.ranks[p2]
        return False


def detect_cycle(n, edges):
    dsu = DSU(n)
    remove_count = 0
    for u, v in edges:
        has_cycle = dsu.union(u - 1, v - 1)
        if has_cycle:
            remove_count += 1
    return remove_count


detect_cycle(4, [[1, 2], [2, 3], [3, 4], [4, 1], [2, 3]])

2