# Minimum Spanning Tree

In this assignment, we will implement the two minimum spanning tree algorithms: Prim's and Kruskal's.

## The Graph Class

The following class is used for inputting the graphs.

In [1]:
class Graph:
    def __init__(self, edges, directed=False):
        self.adj_list = {}
        self.directed = directed
        print("length:",len(edges))
        for u, v, k in edges:
            self.add_edge(u, v, k)
            
    def _add_edge_single(self, u, v, k):
        """Internal function. Do not use directly.
        Add a single edge to the graph.
        """
        if u not in self.adj_list:
            self.adj_list[u] = []
        self.adj_list[u].append((v, k))        
                
    def add_edge(self, u, v, k):
        """Add an edge to the graph. Add the reverse edge 
        when the graph is undirected."""
        self._add_edge_single(u, v, k)
        if not self.directed:
            self._add_edge_single(v, u, k)
    
    def neighbors(self, u):
        """Return the list of neighbors and the 
        corresponding weights of u"""
        return self.adj_list[u]
    
    
    def vertices(self):
        """Return the set of vertices of the graph"""
        return self.adj_list.keys()
    

## Prim's Algorithm

Implement the Prim's algorithm in the following cell. Return a hash with nodes as keys and the corresponding parents as values.

In [2]:
# Add any functions/classes as necessary.

# q : (parent , weight)
q={}

def extract_min():
    
    global q
    l=[]
#     l(weight,vertex_name)
    for i in q:
        l.append((q[i][1],i))
        

    mn = min(l)
    print("min",mn[1])

    return(mn[1])

def prim(g, s):
    # Your code here
    global q
    p = {}
    
    q[s] = (None,0)
    
    for i in g.vertices()-[s]:
        q[i]= (None,float('inf'))

       
    while q:
        print("q" ,q)
        u = extract_min()
        print("u " ,u)
        parent = q[u][0]
        q.pop(u)
        print("======================")
        print("new q ",q)
        p[u] = parent
        
        for v in g.neighbors(u):

            if v[0] in q:
                node = q[v[0]]
                if(node[1] > v[1]):
                    q[v[0]] = (u , v[1])

                
                
    for i in p:
        print(i, ":" , p[i])
    return p

In [3]:
# Test cell (from CLRS)

graph = Graph([(0, 1, 4), (0, 7, 8), (1, 2, 8), (1, 7, 11), (2, 3, 7), (2, 8, 2), (2, 5, 4), (3, 4, 9), (3, 5, 14), (4, 5, 10), (5, 6, 2), (6, 7, 1), (6, 8, 6), (7, 8, 7)])

def check_parents(mst, parents, n):
    for i in range(n):
        if mst[i] != parents[i]:
            return False
    return True

mst = prim( graph , 0 )
sol1 = [None, 0, 1, 2, 3, 2, 5, 6, 2]
sol2 = [None, 0, 5, 2, 3, 6, 7, 0, 2]

assert check_parents(mst, sol1, 9) or check_parents(mst, sol2, 9)

length: 14
q {0: (None, 0), 1: (None, inf), 2: (None, inf), 3: (None, inf), 4: (None, inf), 5: (None, inf), 6: (None, inf), 7: (None, inf), 8: (None, inf)}
min 0
u  0
new q  {1: (None, inf), 2: (None, inf), 3: (None, inf), 4: (None, inf), 5: (None, inf), 6: (None, inf), 7: (None, inf), 8: (None, inf)}
q {1: (0, 4), 2: (None, inf), 3: (None, inf), 4: (None, inf), 5: (None, inf), 6: (None, inf), 7: (0, 8), 8: (None, inf)}
min 1
u  1
new q  {2: (None, inf), 3: (None, inf), 4: (None, inf), 5: (None, inf), 6: (None, inf), 7: (0, 8), 8: (None, inf)}
q {2: (1, 8), 3: (None, inf), 4: (None, inf), 5: (None, inf), 6: (None, inf), 7: (0, 8), 8: (None, inf)}
min 2
u  2
new q  {3: (None, inf), 4: (None, inf), 5: (None, inf), 6: (None, inf), 7: (0, 8), 8: (None, inf)}
q {3: (2, 7), 4: (None, inf), 5: (2, 4), 6: (None, inf), 7: (0, 8), 8: (2, 2)}
min 8
u  8
new q  {3: (2, 7), 4: (None, inf), 5: (2, 4), 6: (None, inf), 7: (0, 8)}
q {3: (2, 7), 4: (None, inf), 5: (2, 4), 6: (8, 6), 7: (8, 7)}
min 5
u  

## Kruskal's Algorithm

Next we implement the Kruskal's Minimum Spanning Tree algorithm. For doing that, we first need to implment the Union and Find algorithms.

In [4]:
# the Union Find algorithms

class UnionFind:
    def __init__(self, nodes):
        # nodes is a list of nodes. Associate an identifier for each node.
        self.set = {}
        for i in nodes:
            self.set[i] = i
            
        
    def find(self, u):
        # Return the representative element (identifier/label) for u
        return self.set[u]
        
    def union(self, u, v):
        # Make the representative elements correponding to u and v the same.
        x = self.find(u)
        y = self.find(v)
        self.set[v] = self.set[u]
        for i in self.set:
            if(self.set[i] == x):
                self.set[i] =y
        
        

In [5]:
# Test

nodes = [2, 3, 4, 5, 6, 7]
uf = UnionFind(nodes)

for i in range(len(nodes)):
    for j in range(i+1, len(nodes)):
        assert uf.find(nodes[i]) != uf.find(nodes[j])
        

uf.union(4, 7)
print(uf.find(4), uf.find(7))
assert uf.find(4) == uf.find(7)

uf.union(4, 2)

assert uf.find(4) == uf.find(2) and uf.find(4) == uf.find(7)

assert uf.find(4) != uf.find(5)
assert uf.find(4) != uf.find(6)
assert uf.find(4) != uf.find(3)

uf.union(3, 5)

assert uf.find(3) == uf.find(5)
assert uf.find(4) != uf.find(3)
assert uf.find(6) != uf.find(3)

uf.union(5, 2)

for k in [3, 4, 5, 7]:
    assert uf.find(k) == uf.find(2)
    
assert uf.find(2) != uf.find(6)



7 7


Using the UnionFind class, implement the Kruskal's algorithm in the following cell. Output the __set__ of edges included in the MST. Return the edges as pairs $(u, v)$.

In [6]:
# Add any functions/classes as necessary.

def insertion_sort(alist):
    for i, kv in enumerate(alist):
        for j in range(i-1, -1, -1):
            if(alist[j][0] > kv[0]):
                alist[j+1]= alist[j]
            else:
                break

            alist[j]=kv    
    return alist 

def kruskal(g):
    mst_edges = set()
    # Your code here
    edges=[]
    
    
    make_set = UnionFind([u for u in g.adj_list])
    
    print(g.adj_list)
    for u in g.adj_list:
        for v in g.neighbors(u):
            edges.append((v[1],u,v[0]))
    edges = insertion_sort(edges)

    
    for i in edges:
        if make_set.find(i[1]) != make_set.find(i[2]):

            mst_edges.add((i[1], i[2]))
            make_set.union(i[1], i[2])
        

    # Return the MST.
    print("edges : ",mst_edges)
    return mst_edges
    

In [7]:
# Test for Kruskal's 

mst = kruskal( graph )
sol1 = set([(0, 1), (0, 7), (2, 8), (2, 5), (2, 3), (3, 4), (6, 7), (6, 5)])
sol2 = set([(0, 1), (1, 2), (2, 8), (2, 5), (2, 3), (3, 4), (6, 7), (6, 5)])

assert mst == sol1 or mst == sol2



{0: [(1, 4), (7, 8)], 1: [(0, 4), (2, 8), (7, 11)], 7: [(0, 8), (1, 11), (6, 1), (8, 7)], 2: [(1, 8), (3, 7), (8, 2), (5, 4)], 3: [(2, 7), (4, 9), (5, 14)], 8: [(2, 2), (6, 6), (7, 7)], 5: [(2, 4), (3, 14), (4, 10), (6, 2)], 4: [(3, 9), (5, 10)], 6: [(5, 2), (7, 1), (8, 6)]}
edges :  {(0, 1), (5, 6), (2, 8), (7, 6), (0, 7), (2, 3), (2, 5), (3, 4)}


AssertionError: 