In [None]:
import math

## Problem

Edna wants to find the optimum locations to setup medical emergency centers in flooded region

![image info](mst/mst.001.png)

### Solution

#### Step 1: Get the lay of the land

![image info](mst/mst.002.png)

In [None]:
connections = [[1,1,0,0,0,0,1,1],
               [1,1,0,0,0,0,0,0],
               [0,0,1,1,0,0,0,0],
               [0,0,1,1,0,1,0,0],
               [0,0,0,0,1,1,0,0],
               [0,0,0,1,1,1,0,0],
               [1,0,0,0,0,0,1,0],
               [1,0,0,0,0,0,0,1]]

#### Disjoint Set

In [None]:
class DisjointSet:

    def __init__(self, values):
        self.parents = {}
        self.root = {}
        for value in range(len(values)):
            # parent node set to be themselves
            self.parents[value] = value

            # root is initially one because all nodes are individual 
            # and not yet connected
            self.root[value] = 1

    # function to find the root node of a value
    def find(self, value):

        # base condition 
        if self.parents[value] == value:
            return value
        
        # recursive call
        self.parents[value] = self.find(self.parents[value])
        return self.parents[value]

    # make a union of two nodes and bring them together
    def union(self, node1, node2):
        # find parent of node1
        root1 = self.find(node1)
        # find parent of node2
        root2 = self.find(node2)

        # if both nodes have the same root node then exit function
        if root1 == root2:
            return
        
        # if root of node1 is greater than equal to root of node2 
        if self.root[root1] >= self.root[root2]:
            # update the root of parent 
            self.parents[root2] = root1
            # adjust the root
            self.root[root1] += self.root[root2]
        else:
            # update the root of parent
            self.parents[root1] = root2
            # adjust the root
            self.root[root2] += self.root[root1]


In [None]:
def find_clusters(connected_list):
    # create a disjoint set with connected_list
    number_of_nodes= len(connected_list)
    ds = DisjointSet(connected_list)
    for row in range(number_of_nodes):
        for col in range(row+1,number_of_nodes):
            if connected_list[row][col]==1:
                ds.union(row,col)
                a=[(ds.parents[i]) for i in range(number_of_nodes)]
                b=[(ds.root[i]) for i in range(number_of_nodes)]
                print("parent",a)
                print("root  ",b)

find_clusters(connections)


#### Maximum Independent Set

Maximum number of nodes in a set that are not connected by two vertices.

Formally, for a graph $G = (V,E)$ a S is an independent set where no two adjecent nodes are present. So essentially, there is no edge that connects any two vertices in that set.

![image info](mst/mst.007.png)

Step 2: Find the minimum set of edges that connected ALL nodes. We want to minimize the cost.

#### Spanning Tree

Spanning Tree is described as a set of edges that cover all vertices

![image info](mst/mst.006.png)

Is $ST = ((A,H),(B,C),(F,E))$ a Spanning Tree ?

Is $ST = ((A,H),(B,C),(F,E),(B,E))$ a Spanning Tree ?

Is $ST = ((A,H),(A,G),(B,C),(C,D),(D,G),(B,E),(D,F))$ a Spanning Tree ?

#### Graphs aren't really uniform. Navigation has a cost

![image info](mst/mst.008.png)

![image info](mst/mst.005.png)

Minimum Spanning Tree is a tree that has the minimum number of edges that cover all indices with minimum weight.

In [None]:
class WeightedGraph:

    def __init__(self, nodes):
        self.graph = {}
        self.weight = {}
        for i in range(nodes):
            self.graph[i] = []

    def are_connected(self, node1, node2):
        for node in self.adj[node1]:
            if node == node2:
                return True
        return False

    def connected_nodes(self, node):
        return self.graph[node]

    def add_node(self,):
        #add a new node number = length of existing node
        self.graph[len(self.graph)] = []

    def add_edge(self, node1, node2, weight):
        if node1 not in self.graph[node2]:
            self.graph[node1].append(node2)
            self.weight[(node1, node2)] = weight

            #since it is undirected
            self.graph[node2].append(node1)
            self.weight[(node2, node1)] = weight

    def number_of_nodes(self,):
        return len(self.graph)

    def has_edge(self, src, dst):
        return dst in self.graph[src]

    def get_weight(self,):
        total = 0
        for node1 in self.graph:
            for node2 in self.graph[node1]:
                total += self.weight[(node1, node2)]
                
        # because it is undirected
        return total/2

![image info](mst/mst.008.png)

In [None]:
g = WeightedGraph(8)
edges = [(0,0),(0,1),(0,6),(0,7),
         (1,0),(1,1),(1,2),(1,4),
         (2,1),(2,2),(2,3),
         (3,2),(3,3),(3,5),(3,6),
         (4,)]

g.add_edge(0,0)
g.add_edge(0,0)

#### Constructing minimum spanning tree: Prim's Algorithm

In [None]:
def prims(G):
    mst = WeightedGraph(G.number_of_nodes())
    visited_nodes = {}
    for i in G.keys():
        visited_nodes[i]=False

    # initialize the tree with a single node, chosen arbitarily    
    visited_nodes[0]=True

    # find all the edges that connect the tree with the remaining vertices
    for i in range(G.number_of_nodes()-1):
        current_edge = (0,0,99999)
        for start_node in G.graph:
            for end_node in G.graph[start_node]:
                # if the start node is visited but end node is not
                if visited_nodes[start_node] and not visited_nodes[end_node]:
                    # find the minimum weigthed edge - if the weight is less than current
                    if G.weight[(start_node,end_node)] < current_edge[2]:
                        current_edge = (start_node,end_node, G.weight[(start_node,end_node)])
        
        # mark the current node as visited
        visited_nodes[current_edge[1]] = True

        #add the node
        mst.add_edge(current_edge[0],current_edge[1],current_edge[2])

    return mst
   

### Prim's using min heap

In [None]:
class Item:
    def __init__(self, key, value):
        self.key = key
        self.value = value
    
    def __str__(self):
        return "(key:" + str(self.key) + ",value:" + str(self.value) + ")"

In [None]:
nodes = [Item(1,'A'),Item(2,'B'),Item(3,'C'),Item(4,'D')]
for n in nodes:
    print(n)

In [None]:
class MinHeap:
    def __init__(self, data):
        self.items = data
        self.length = len(data)
        self.build_heap()

        # add a map based on input node
        self.map = {}
        for i in range(self.length):
            self.map[self.items[i].value] = i

    def find_left_index(self,index):
        return 2 * (index + 1) - 1

    def find_right_index(self,index):
        return 2 * (index + 1)

    def find_parent_index(self,index):
        return (index + 1) // 2 - 1  
    
    def sink_down(self, index):
        smallest_known_index = index
        left_index = self.find_left_index(index)
        right_index = self.find_right_index(index)

        if left_index < self.length and self.items[left_index].key < self.items[index].key:
            smallest_known_index = left_index

        if right_index < self.length and self.items[right_index].key < self.items[smallest_known_index].key:
            smallest_known = right_index

        if smallest_known != index:
            self.items[index], self.items[smallest_known_index] = self.items[smallest_known_index], self.items[index]
            
            # update map
            self.map[self.items[index].value] = index
            self.map[self.items[smallest_known_index].value] = smallest_known_index

            # recursive call
            self.sink(smallest_known_index)

    def build_heap(self,):
        for i in range(self.length // 2 - 1, -1, -1):
            self.sink_down(i) 

    def insert(self, node):
        if len(self.items) == self.length:
            self.items.append(node)
        else:
            self.items[self.length] = node
        self.map[node.value] = self.length
        self.length += 1
        self.swim_up(self.length - 1)

    def insert_nodes(self, node_list):
        for node in node_list:
            self.insert(node)

    def swim_up(self, index):
        while index > 0 and self.items[index].key < self.items[self.parent(index)].key:
            #swap values
            self.items[index], self.items[self.parent(index)] = self.data[self.parent(index)], self.items[index]
            #update map
            self.map[self.items[index].value] = index
            self.map[self.items[self.parent(index)].value] = self.parent(index)
            index = self.parent(index)

    def get_min(self):
        if len(self.items) > 0:
            return self.items[0]

    def extract_min(self,):
        #xchange
        self.items[0], self.items[self.length - 1] = self.items[self.length - 1], self.items[0]
        #update map
        self.map[self.data[self.length - 1].value] = self.length - 1
        self.map[self.data[0].value] = 0

        min_node = self.items[self.length - 1]
        self.length -= 1
        self.map.pop(min_node.value)
        self.sink_down(0)
        return min_node

    def decrease_key(self, value, new_key):
        if new_key >= self.items[self.map[value]].key:
            return
        index = self.map[value]
        self.items[index].key = new_key
        self.swim_up(index)

    def get_element_from_value(self, value):
        return self.items[self.map[value]]

    def is_empty(self):
        return self.length == 0


In [None]:
def prim_with_heap_1(G):
    mst = WeightedGraph(G.number_of_nodes())
    visited_nodes = {}
    for i in G.keys():
        visited_nodes[i]=False

    # initialize the tree with a single node, chosen arbitarily    
    visited_nodes[0]=True

    # create an empty heap
    Q = MinHeap([])

    # add the first set of edges into the Q
    for end_node in G.graph[0]:
        Q.insert(Item((0,end_node),G.weight[(0,end_node)]))

    # find all the edges that connect the tree with the remaining vertices
    while not Q.is_empty():

        # find the minimum weigthed edge - if the weight is less than current
        min_edge = Q.extract_min().value
        curr_edge = min_edge[1]
        # if the start node is visited but end node is not
        if not visited_nodes[curr_edge]:
            # add the node
            mst.add_edge(min_edge[0], curr_edge, G.weight[min_edge])
            # adjust the heap
            for end_node in G.graph[curr_edge]:
                Q.insert(Item((curr_edge, end_node), G.weight[(curr_edge, end_node)]))

            # mark the current node as visited
            visited_nodes[curr_edge] = True
    return mst