## **Code playground for SDA sem 14**

# Minimum spanning tree

## Prim's algorithm

Lets consider the following weighted undirected graph.

In [1]:
weighted_graph = {
    1: [(2, 4), (4, 1)],
    2: [(1, 4), (3, 3)],
    3: [(2, 3), (4, 4)],
    4: [(1, 1), (3, 4), (5, 5)],
    5: [(4, 5)]   
}

![Weighted graph example](media/weighted_graph.png)

Prim's algorithm finds a Minimum spanning tree in a weighted graph:

In [2]:
from heapq import heappush, heappop

def prim(start, V, graph):
    visited = set()
    pq = [(0, start)]
    mst_weight = 0
    
    while len(visited) != V:
        current_weight, current_vertex = heappop(pq)
        
        if current_vertex in visited:
            continue
        
        visited.add(current_vertex)
        mst_weight += current_weight
        
        for neighb, weight in graph[current_vertex]:
            if neighb in visited:
                continue
                            
            heappush(pq, (weight, neighb))
    
    return mst_weight

start = 5
V = 5
mst_weight = prim(start, V, weighted_graph)
print(mst_weight)

13


![Prim's algorithm creating a MST of a graph, step by step example.](media/prims_algorithm_example.png)

Verbose version:

In [3]:
from heapq import heappush, heappop

def prim_verbose(start, V, graph):
    visited = set()
    pq = [(0, start)]
    mst_weight = 0
    
    while len(visited) != V:
        current_weight, current_vertex = heappop(pq)
        
        if current_vertex in visited:
            print(f"Skipping edge with weight {current_weight} to visited vertex {current_vertex}.", visited)
            continue
        
        print(f"Edge with weight {current_weight} to vertex {current_vertex} added to MST.")
        visited.add(current_vertex)
        mst_weight += current_weight
        
        for neighb, weight in graph[current_vertex]:
            if neighb in visited:
                continue
                            
            heappush(pq, (weight, neighb))
    
    print("Edges that are not used: ", pq)
    return mst_weight

start = 5
V = 5
mst_weight = prim_verbose(start, V, weighted_graph)
print("Total weight of MST =", mst_weight)

Edge with weight 0 to vertex 5 added to MST.
Edge with weight 5 to vertex 4 added to MST.
Edge with weight 1 to vertex 1 added to MST.
Edge with weight 4 to vertex 2 added to MST.
Edge with weight 3 to vertex 3 added to MST.
Edges that are not used:  [(4, 3)]
Total weight of MST = 13


Note that there might be edges that are still left in the priority queue. The algorithm can stop without iterating over them because the minimum spanning tree has been found when the tree has *V - 1* edges. The algorithm skips some edges that is why the while loop does not have a fixed amount of iterations (*V - 1*).

Moreover, note that the edge with minimum weight on some iterations may not be used, because it will create a cycle.

In [4]:
start = 1
V = 5
mst_weight = prim_verbose(start, V, weighted_graph)
print("Total weight of MST =", mst_weight)

Edge with weight 0 to vertex 1 added to MST.
Edge with weight 1 to vertex 4 added to MST.
Edge with weight 4 to vertex 2 added to MST.
Edge with weight 3 to vertex 3 added to MST.
Skipping edge with weight 4 to visited vertex 3. {1, 2, 3, 4}
Edge with weight 5 to vertex 5 added to MST.
Edges that are not used:  []
Total weight of MST = 13


## Disjoint-set data structure (Union-Find)

The structure provides:
- efficient searching - which set an element belongs to.
- efficient union - uniting two sets of elements into one.


### Simplest version
The simplest implementation of Disjoint set looks like this:

In [5]:
def find(x, parents):
    if parents[x] == x:
        return x
    
    return find(parents[x], parents)

def union(x, y, parents):
    x_root = find(x, parents)
    y_root = find(y, parents)

    parents[x_root] = y_root

particles = [i for i in range(10)]
parents = [i for i in particles]

print(particles)
print(parents)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


Now lets join the first three particles - 0, 1 and 2 together:

In [6]:
union(2, 0, parents)
print(parents)

union(1, 0, parents)
print(parents)

print(find(1, parents))
print(find(2, parents))

[0, 1, 0, 3, 4, 5, 6, 7, 8, 9]
[0, 0, 0, 3, 4, 5, 6, 7, 8, 9]
0
0


Lets connect 3 and 4, and then add them to the first set as well:

In [7]:
union(3, 4, parents)
union(4, 1, parents)

print(find(3, parents))
print(find(4, parents))
print(parents)

0
0
[0, 0, 0, 4, 0, 5, 6, 7, 8, 9]


Note that the immediate father of 3 is 4, but because the function is recursive, when called to find the father of 3 it returns 0. This gives us the required information that the elements 3 and 1 for example are in the same set because `find(1, parents) == find(3, parents)`

In [8]:
find(3, parents) == find(4, parents)
print(parents)

[0, 0, 0, 4, 0, 5, 6, 7, 8, 9]


Now consider the tree of parents having a linear depth (similar to the worst case for a binary search tree). All the nodes are connected (in one set) but finding this out takes a out of steps due to the recursion:

In [9]:
def find_verbose(x, parents):
    """Adds print to the find() function"""
    if parents[x] == x:
        return x
    
    print("+ 1 step to find the parents of", x)
    return find_verbose(parents[x], parents)

union(5, 6, parents)
union(6, 7, parents)
union(7, 8, parents)
union(8, 9, parents)
union(9, 0, parents)

print(parents)
print(find_verbose(3, parents) == find_verbose(5, parents))
print(find_verbose(3, parents) == find_verbose(5, parents))

[0, 0, 0, 4, 0, 6, 7, 8, 9, 0]
+ 1 step to find the parents of 3
+ 1 step to find the parents of 4
+ 1 step to find the parents of 5
+ 1 step to find the parents of 6
+ 1 step to find the parents of 7
+ 1 step to find the parents of 8
+ 1 step to find the parents of 9
True
+ 1 step to find the parents of 3
+ 1 step to find the parents of 4
+ 1 step to find the parents of 5
+ 1 step to find the parents of 6
+ 1 step to find the parents of 7
+ 1 step to find the parents of 8
+ 1 step to find the parents of 9
True


### *find()* optimization
Now consider the following improvement in the *find()* function:

In [10]:
def find(x, parents):
    if parents[x] == x:
        return x
    
    furthest_parent = find(parents[x], parents)
    parents[x] = furthest_parent

    return furthest_parent

def union(x, y, parents):
    x_root = find(x, parents)
    y_root = find(y, parents)

    parents[x_root] = y_root

parents = [i for i in particles]
print(parents)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


Let's repeat the same steps as before:

In [11]:
union(2, 0, parents)
print(parents)

union(1, 0, parents)
print(parents)

union(3, 4, parents)
union(4, 1, parents)
print(parents)

[0, 1, 0, 3, 4, 5, 6, 7, 8, 9]
[0, 0, 0, 3, 4, 5, 6, 7, 8, 9]
[0, 0, 0, 4, 0, 5, 6, 7, 8, 9]


And create the long tree:

In [12]:
union(5, 6, parents)
union(6, 7, parents)
union(7, 8, parents)
union(8, 9, parents)
union(9, 0, parents)

def find_verbose(x, parents):
    """Adds print to the find() function"""
    if parents[x] == x:
        return x
    
    print("+ 1 step to find the parent of", x)
    furthest_parent = find_verbose(parents[x], parents)
    parents[x] = furthest_parent

    return furthest_parent

print(parents)
print(find_verbose(3, parents) == find_verbose(5, parents))
print(parents)
print(find_verbose(3, parents) == find_verbose(5, parents))

[0, 0, 0, 4, 0, 6, 7, 8, 9, 0]
+ 1 step to find the parent of 3
+ 1 step to find the parent of 4
+ 1 step to find the parent of 5
+ 1 step to find the parent of 6
+ 1 step to find the parent of 7
+ 1 step to find the parent of 8
+ 1 step to find the parent of 9
True
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ 1 step to find the parent of 3
+ 1 step to find the parent of 5
True


Notice how after the first search all parent values are updated. The next *find()* takes only 1 step to finish. This is because the parent for each node got updated with the starting node of the set. The long almost linear search happened once, but if called a second time the time complexity will be reduced.

### *union()* optimization
There is way to avoid the linear tree stacking using the following optimization:

In [13]:
def find(x, parents):
    if parents[x] == x:
        return x
    
    furthest_parent = find(parents[x], parents)
    parents[x] = furthest_parent

    return furthest_parent

def union(x, y, parents, rank):
    x_root = find(x, parents)
    y_root = find(y, parents)

    if rank[x_root] < rank[y_root]:
        parents[x_root] = y_root
    elif rank[x_root] > rank[y_root]:
        parents[y_root] = x_root
    else:
        parents[x_root] = y_root
        rank[y_root] += 1 # Only in this case the depth of the tree increases

parents = [i for i in range(10)]
rank = [0 for _ in range(10)]


Let's repeat the same experiment a third time:

In [14]:
union(2, 0, parents, rank)
print(parents)

union(1, 0, parents, rank)
print(parents)

union(3, 4, parents, rank)
union(4, 1, parents, rank)
print(parents)
print("Ranks:", rank)

[0, 1, 0, 3, 4, 5, 6, 7, 8, 9]
[0, 0, 0, 3, 4, 5, 6, 7, 8, 9]
[0, 0, 0, 4, 0, 5, 6, 7, 8, 9]
Ranks: [2, 0, 0, 0, 1, 0, 0, 0, 0, 0]


And create the "linear" tree:

In [15]:
union(5, 6, parents, rank)
union(6, 7, parents, rank)
union(7, 8, parents, rank)
union(8, 9, parents, rank)
union(9, 0, parents, rank)

print(parents)
print("Ranks:", rank)

[0, 0, 0, 4, 0, 6, 0, 6, 6, 6]
Ranks: [2, 0, 0, 0, 1, 0, 1, 0, 0, 0]


 Notice that the maximum depth (rank) is 2:

In [16]:
def find_verbose(x, parents):
    """Adds print to the find() function"""
    if parents[x] == x:
        return x
    
    print("+ 1 step to find the parent of", x)
    furthest_parent = find_verbose(parents[x], parents)
    parents[x] = furthest_parent

    return furthest_parent

print(parents)
print(find_verbose(3, parents) == find_verbose(5, parents))
print(parents)
print(find_verbose(3, parents) == find_verbose(5, parents))
print(parents)

[0, 0, 0, 4, 0, 6, 0, 6, 6, 6]
+ 1 step to find the parent of 3
+ 1 step to find the parent of 4
+ 1 step to find the parent of 5
+ 1 step to find the parent of 6
True
[0, 0, 0, 0, 0, 0, 0, 6, 6, 6]
+ 1 step to find the parent of 3
+ 1 step to find the parent of 5
True
[0, 0, 0, 0, 0, 0, 0, 6, 6, 6]


Disjoint set can be used to solve multiple tasks like:
- Finding a cycle in a graph.
- Finding the connected components in a graph.
- Finding the Minimum spanning tree of a graph.

Why Dijkstra algorithm does not necessary find the Minimum spanning tree consider the simple example starting from *vertex 1*:

![Minimum spanning tree vs Dijkstra simple graph example](media/mst_vs_dijkstra.png)

## Kruskal's algorithm

Let's consider the same graph from Prim's algorithm but represented as list of edges:

In [17]:
graph_list_of_edges = [
    (1, 2, 4), (1, 4, 1),
    (2, 1, 4), (2, 3, 3),
    (3, 2, 3), (3, 4, 4),
    (4, 1, 1), (4, 3, 4), (4, 5, 5),
    (5, 4, 5)
]

![Weighted graph example](media/weighted_graph.png)

Kruskal's algorithm utilizes the Disjoint set data structure to find a Minimum spanning tree in a graph:

In [18]:
def find(x, parents):
    if parents[x] == x:
        return x
    
    furthest_parent = find(parents[x], parents)
    parents[x] = furthest_parent

    return furthest_parent

def union(x, y, parents, rank):
    x_root = find(x, parents)
    y_root = find(y, parents)

    if rank[x_root] < rank[y_root]:
        parents[x_root] = y_root
    elif rank[x_root] > rank[y_root]:
        parents[y_root] = x_root
    else:
        parents[x_root] = y_root
        rank[y_root] += 1

def kruskal(V, edges):
    edges.sort(key=lambda x: x[2])  # Sorts edges by weight
    parents = [i for i in range(V + 1)]
    rank = [0] * (V + 1)
    mst_weight = 0

    for x, y, w in edges:
        if find(x, parents) != find(y, parents):
            mst_weight += w
            union(x, y, parents, rank)

    return mst_weight

V = 5
kruskal(V, graph_list_of_edges)

13

![Kruskal's algorithm creating a MST of a graph, step by step example.](media/kruskals_algorithm_example.png)

Verbose version:

In [19]:
def kruskal_verbose(V, edges):
    edges.sort(key=lambda x: x[2])
    print(edges)

    parents = [i for i in range(V + 1)]
    rank = [0] * (V + 1)
    mst_weight = 0

    print(parents)

    for x, y, w in edges:
        print(f"Consider {x} to {y} for {w}")
        if find(x, parents) != find(y, parents):
            mst_weight += w
            union(x, y, parents, rank)
            print(f"Joining {x} and {y}...", parents)
            print()

    return mst_weight

V = 5
kruskal_verbose(V, graph_list_of_edges)

[(1, 4, 1), (4, 1, 1), (2, 3, 3), (3, 2, 3), (1, 2, 4), (2, 1, 4), (3, 4, 4), (4, 3, 4), (4, 5, 5), (5, 4, 5)]
[0, 1, 2, 3, 4, 5]
Consider 1 to 4 for 1
Joining 1 and 4... [0, 4, 2, 3, 4, 5]

Consider 4 to 1 for 1
Consider 2 to 3 for 3
Joining 2 and 3... [0, 4, 3, 3, 4, 5]

Consider 3 to 2 for 3
Consider 1 to 2 for 4
Joining 1 and 2... [0, 4, 3, 3, 3, 5]

Consider 2 to 1 for 4
Consider 3 to 4 for 4
Consider 4 to 3 for 4
Consider 4 to 5 for 5
Joining 4 and 5... [0, 3, 3, 3, 3, 3]

Consider 5 to 4 for 5


13