A minimum spanning tree (MST) is a subset of the edges of connected, weighted and undirected graph which:

- Connects all vertices together
- No cycle
- Minimum total edge


Disjoint set:
Data structure that keeps track of elements which are partitioned into a number of disjoint and non-overlapping sets and each sets have representative which helps in identifying that sets

Disjoing operation:

- Make set
- Union
- Find set


makeSet(N): used to create initial set

union(x,y): marge two given sets

findSet(x): return the set name in which this element is there


In [3]:
class DisJointSet:
    def __init__(self, vertices):
        self.vertices = vertices
        self.parent = {}
        for i in self.vertices:
            self.parent[i] = i
        self.rank = dict.fromkeys(vertices, 0)

    def find(self, item):
        if self.parent[item] == item:
            return item
        else:
            return self.find(self.parent[item])

    def union(self, x, y):
        xroot = self.find(x)
        yroot = self.find(y)

        if self.rank[xroot] > self.rank[yroot]:
            self.rank[yroot] = xroot
        elif self.rank[xroot] < self.rank[yroot]:
            self.rank[xroot] = yroot
        else:
            self.parent[yroot] = xroot
            self.rank[xroot] += 1


In [6]:
vertices = ["A", "B", "C", "D", "E"]

ds = DisJointSet(vertices)
ds.union("A", "B")
ds.union("A", "C")
print(ds.find("A"))
print(ds.find("B"))


A
A


<h1>Kriskal Algorithm:</h1>

<h5>psudo- code:</h5>

```
for each vertex:
    makeset(v)

sort each edge in non-decreasing order by weight

for each edge (u,v):
    if findset(u) != findset(v):
        union(u,v)
        cost = cost + edge(u,v)

```


In [20]:
class Kruskal:
    def __init__(self, vertices):
        self.vertices = vertices
        self.graph = []
        self.nodes = []
        self.MST = []

    def addEdge(self, s, d, w):
        self.graph.append([s, d, w])

    def addNode(self, value):
        self.nodes.append(value)

    def printSolution(self, s, d, w):
        for s, d, w in self.MST:
            print("%s - %s : %s" % (s, d, w))

    def kruskalMST(self):
        i, e = 0, 0

        dst = DisJointSet(self.nodes)
        self.graph = sorted(self.graph, key=lambda item: item[2])
        while e < self.vertices - 1:
            s, d, w = self.graph[i]
            i += 1
            x = dst.find(s)
            y = dst.find(d)
            if x != y:
                self.MST.append([s, d, w])
                dst.union(x, y)
                e += 1
        self.printSolution(s, d, w)


![MST](MST.png)


In [21]:
kruskalGraph = Kruskal(5)


In [22]:
kruskalGraph.addNode("A")
kruskalGraph.addNode("B")
kruskalGraph.addNode("C")
kruskalGraph.addNode("D")
kruskalGraph.addNode("E")


In [23]:
kruskalGraph.addEdge("A", "C", 13)
kruskalGraph.addEdge("A", "E", 15)
kruskalGraph.addEdge("A", "B", 5)
kruskalGraph.addEdge("E", "A", 15)
kruskalGraph.addEdge("E", "C", 20)
kruskalGraph.addEdge("B", "A", 5)
kruskalGraph.addEdge("B", "D", 8)
kruskalGraph.addEdge("B", "C", 10)
kruskalGraph.addEdge("D", "B", 8)
kruskalGraph.addEdge("D", "C", 6)
kruskalGraph.addEdge("C", "E", 20)
kruskalGraph.addEdge("C", "A", 13)
kruskalGraph.addEdge("C", "B", 10)
kruskalGraph.addEdge("C", "D", 6)


In [24]:
kruskalGraph.kruskalMST()


A - B : 5
D - C : 6
B - D : 8
A - E : 15


In [48]:
import sys


class PrimsMST:
    def __init__(self, vertNum, edges, nodes):
        self.vertNum = vertNum
        self.edges = edges
        self.nodes = nodes
        self.MST = []

    def printSolution(self):
        print("Edge : Weight")
        for s, d, w in self.MST:
            print("%s -> %s : %s" % (s, d, w))

    def primsAlgo(self):
        visited = [0] * self.vertNum
        edgeNum = 0
        visited[0] = True

        while edgeNum < self.vertNum - 1:
            min = sys.maxsize

            for i in range(self.vertNum):
                if visited[i]:
                    for j in range(self.vertNum):
                        if ((not visited[j]) and self.edges[i][j]):
                            if min > self.edges[i][j]:
                                min = self.edges[i][j]
                                s = i
                                d = j
            self.MST.append([self.nodes[s], self.nodes[d], self.edges[s][d]])
            visited[d] = True
            edgeNum += 1
        self.printSolution()


In [49]:
edges = [[0, 10, 20, 0, 0],
         [10, 0, 30, 5, 0],
         [20, 30, 0, 15, 6],
         [0, 5, 15, 0, 8],
         [0, 0, 6, 8, 0]]


In [50]:
nodes = ["A", "B", "C", "D", "E"]


In [51]:
g = PrimsMST(5, edges, nodes)


In [52]:
g.primsAlgo()


Edge : Weight
A -> B : 10
B -> D : 5
D -> E : 8
E -> C : 6
