Before you turn this problem in, make sure everything runs as expected. First, **restart the kernel** (in the menubar, select Kernel$\rightarrow$Restart) and then **run all cells** (in the menubar, select Cell$\rightarrow$Run All).

Make sure you fill in any place that says `YOUR CODE HERE` or "YOUR ANSWER HERE", as well as your name and collaborators below:

In [9]:
NAME = "Lars Janssen"

---

For those not familiar with Python, a quick overview is given [here](https://github.com/palcu/python-for-competitive-programming/blob/master/python-for-competitive-programming.ipynb).

# Notebook BAPC week 10: Datastructures

## A general remark on reading input
When reading *large* amounts of data, `input()` is sometimes too slow.
A much faster alternative is to `import sys` and later replace all `input()` by `sys.stdin.readline()`. Note that using `readline`, the line will end in a newline character `\n`; this doesn't matter when using it to read integers, but watch out if you're using it to read strings. See [this link](https://stackoverflow.com/a/58537094/12354474) if you are interested.

## Union-Find
A Union-Find data structure is a data structure that tracks a set of elements partitioned into a number of disjoint subsets. Implemented correctly, it provides near-constant-time operations to merge existing sets (`union`) and to determine whether elements are in the same set (`same_set`). This data structure plays a key role in *Kruskal's algorithm* for finding the minimum spanning tree of a graph.

In this notebook, we will incrementally build an optimal-complexity implementation of Union-Find.

The datastructure consists of a list of `parent`s. The `parent` list encodes the "disjoint set forest", where each tree represents one disjoint set. Suppose we have 5 disjoint sets
![](https://i.imgur.com/2M6MWLX.png)
and let `parent[0] = 1` and `parent[2] = 1`, then the disjoint set forest becomes
![](https://i.imgur.com/pz7w7Fl.png)
If we then set `parent[3] = 2`, we get
![](https://i.imgur.com/BA7BgNS.png)

As `parent[4] = 4` and `parent[1] = 1`, they are both the roots of their respective disjoint set trees.

### Exercise 1: a naive implementation


We have two abstract methods:
* `find(i)` follows the chain of parent pointers from `i` up the tree until it reaches a root element, whose parent is itself. This root element is the representative member of the set to which `i` belongs, and may be `i` itself (as was the case with `i=1` and `i=4` in the previous picture).
* `union(i, j)` unites the `i`- and `j`-trees by hanging the `i`-tree as a subtree of the `j`-tree.

Consider the following naive implementation of the Union-Find datastructure, and assure yourself that it is correct.

In [10]:
class NaiveUnionFind:
    def __init__(self, N):
        """ A naive Union-Find data structure. """
        self.parent = list(range(N))

    def find(self, i):
        """ Finds the root of the disjoint set that `i` is in. """
        if self.parent[i] == i:
            return i
        return self.find(self.parent[i])

    def union(self, i, j):
        """ Unites the disjoint sets that `i` and `j` are in. """
        i_root, j_root = self.find(i), self.find(j)
        self.parent[i_root] = j_root
        
    def same_set(self, i, j):
        """ Returns whether `i` and `j` are in the same disjoint set. """
        return self.find(i) == self.find(j)

In [11]:
# The test case from the lecture.
UF = NaiveUnionFind(5)
assert not UF.same_set(3, 4)
UF.union(3, 4)
assert UF.same_set(3, 4)
UF.union(1, 4)
UF.union(0, 2)

This naive implementation has multiple defects, but perhaps the most glaring one is that in the worst case, the disjoint-set trees can be highly skewed and become almost like a linked list:
![](https://i.imgur.com/ErECFMg.png)

Use the cell below to generate this worst-case scenario for a `NaiveUnionFind` given `N` nodes.

In [12]:
def worst_case(N, UnionFind=NaiveUnionFind):
    """ Unite the elements 0..N-1 in the `worst-case` way.
    
    Our resulting disjoint set tree will have length N-1. """
    UF = UnionFind(N)
    for i in range(N-1):
        UF.union(i, i+1)
    return UF

In [13]:
import sys
sys.setrecursionlimit(30000)

UF_10 = worst_case(10)
# Assert that all elements are in the same set.
assert all(UF_10.same_set(0, i) for i in range(10))

print("You will see that the following call is linear in N:")
for N in [100, 1000, 10000]:
    UF = worst_case(N)
    timing = %timeit -o -r 1 -q UF.same_set(0, N-1)
    print("for N=%d, `same_set(0, N-1)` took %f ms" % (N, 1000 * timing.best))

You will see that the following call is linear in N:
for N=100, `same_set(0, N-1)` took 0.014523 ms
for N=1000, `same_set(0, N-1)` took 0.172048 ms
for N=10000, `same_set(0, N-1)` took 4.136246 ms


### Exercise 2: Union by size
The obvious problem is that the disjoint-set trees can have enormous height, which makes the call to `find` an $\mathcal O(N)$ operation. Instead of always hanging the `i`-tree under `j`, we can hang the *shorter* tree under the *longer* one, resulting in a tree of maximum height $\mathcal O(\log N)$. To achieve this, we will keep a `sizes` array which stores information about the size of each disjoint set. Finish the code below.

In [22]:
class SmarterUnionFind:
    def __init__(self, N):
        """ A faster Union-Find data structure. """
        self.parent = list(range(N))
        self.sizes = [1] * N

    def find(self, i):
        """ Finds the root of the disjoint set that `i` is in. """
        if self.parent[i] == i:
            return i
        return self.find(self.parent[i])

    def union(self, i, j):
        """ Unites the disjoint sets that `i` and `j` are in. """
        i_root, j_root = self.find(i), self.find(j)
        if i_root == j_root: return
        if self.sizes[i_root] <= self.sizes[j_root]:
            i_root, j_root = j_root, i_root
        # j_root is now the smaller tree, so hang it under i_root.
        self.parent[j_root] = i_root
        
        # TODO: update the sizes array accordingly.
        self.sizes[i_root] += self.sizes[j_root]
        
    def same_set(self, i, j):
        """ Returns whether `i` and `j` are in the same disjoint set. """
        return self.find(i) == self.find(j)
        
    def size(self, i):
        """ Returns the size of the tree that `i` is in. """
        root = self.find(i)
        return self.sizes[root]

In [23]:
# The test case from the lecture.
UF = SmarterUnionFind(5)
assert not UF.same_set(3, 4)
UF.union(3, 4)
assert UF.same_set(3, 4)
assert UF.size(3) == UF.size(4) == 2
UF.union(1, 4)
assert UF.size(1) == UF.size(3) == 3
UF.union(0, 2)

# The first test case from Tildes.
UF = SmarterUnionFind(10)
UF.union(0, 9)
UF.union(0, 1)
UF.union(0, 2)
assert UF.size(0) == UF.size(1) == UF.size(2) == 4
assert UF.size(3) == 1
UF.union(4, 5)
assert UF.size(4) == UF.size(5) == 2
assert UF.size(9) == 4

# The second test case from Tildes.
UF = SmarterUnionFind(5)
assert UF.size(0) == 1
UF.union(0, 1)
assert UF.size(0) == 2
UF.union(0, 1)
assert UF.size(0) == 2
UF.union(3, 4)
assert UF.size(4) == 2
UF.union(0, 4)
assert UF.size(1) == 4
UF.union(2, 1)
assert UF.size(2) == 5

The worst-case scenario is somewhat more difficult to build in this case, but rest assured that the previous call is no longer linear time:

In [24]:
print("You will see that the following call is no longer linear in N:")
for N in [100, 1000, 10000]:
    UF = worst_case(N, UnionFind=SmarterUnionFind)
    timing = %timeit -o -r 1 -q UF.same_set(0, N-1)
    print("for N=%d, `same_set(0, N-1)` took %f ms" % (N, 1000 * timing.best))

You will see that the following call is no longer linear in N:
for N=100, `same_set(0, N-1)` took 0.000553 ms
for N=1000, `same_set(0, N-1)` took 0.000556 ms
for N=10000, `same_set(0, N-1)` took 0.000562 ms


### Exercise 3: Path compression
There is one more low-hanging fruit to speed up `find`, which we call "path compression". Path compression flattens the structure of each tree by making every node  `i` point directly to its root whenever `find(i)` is called. In the figure below, we call `find(7)` which flattens the tree for nodes 7, 5, 3, and 2.
![](https://raw.githubusercontent.com/e-maxx-eng/e-maxx-eng/master/img/DSU_path_compression.png)

This is valid, since each element visited on the way from `i` to its root is part of the same disjoint set. The resulting flatter tree speeds up future `find`-operations. Copy your code from the `union` and `size` functions from Exercise 2 into the code below, and finish the `find` function.

In [26]:
class SmartestUnionFind:
    def __init__(self, N):
        """ A good and fast Union-Find data structure. """
        self.parent = list(range(N))
        self.sizes = [1] * N

    def find(self, i):
        """ Finds the root of the disjoint set that `i` is in. """
        if self.parent[i] == i:
            return i
        parent = self.find(self.parent[i])
        self.parent[i] = parent
        return parent
        
    def union(self, i, j):
        """ Unites the disjoint sets that `i` and `j` are in. """
        i_root, j_root = self.find(i), self.find(j)
        if i_root == j_root: return
        if self.sizes[i_root] <= self.sizes[j_root]:
            i_root, j_root = j_root, i_root
        # j_root is now the smaller tree, so hang it under i_root.
        self.parent[j_root] = i_root
        
        # TODO: update the sizes array accordingly.
        self.sizes[i_root] += self.sizes[j_root]
        
    def same_set(self, i, j):
        """ Returns whether `i` and `j` are in the same disjoint set. """
        return self.find(i) == self.find(j)
        
    def size(self, i):
        """ Returns the size of the tree that `i` is in. """
        root = self.find(i)
        return self.sizes[root]

In [27]:
# The test case from the lecture.
UF = SmartestUnionFind(5)
assert not UF.same_set(3, 4)
UF.union(3, 4)
assert UF.same_set(3, 4)
assert UF.size(3) == UF.size(4) == 2
UF.union(1, 4)
assert UF.size(1) == UF.size(3) == 3
UF.union(0, 2)

# The first test case from Tildes.
UF = SmartestUnionFind(10)
UF.union(0, 9)
UF.union(0, 1)
UF.union(0, 2)
assert UF.size(0) == UF.size(1) == UF.size(2) == 4
assert UF.size(3) == 1
UF.union(4, 5)
assert UF.size(4) == UF.size(5) == 2
assert UF.size(9) == 4

# The second test case from Tildes.
UF = SmartestUnionFind(5)
assert UF.size(0) == 1
UF.union(0, 1)
assert UF.size(0) == 2
UF.union(0, 1)
assert UF.size(0) == 2
UF.union(3, 4)
assert UF.size(4) == 2
UF.union(0, 4)
assert UF.size(1) == 4
UF.union(2, 1)
assert UF.size(2) == 5

### Exercise 4: Tildes
Finish the Kattis problem [Tildes](https://open.kattis.com/problems/tildes)! Be sure to read [this](#A-general-remark-on-reading-input) too.