# Minimum Spanning Tree

In this assignment, we will implement the two minimum spanning tree algorithms: Prim's and Kruskal's.

## The Graph Class

The following class is used for inputting the graphs.

In [18]:
class Graph:
    def __init__(self, edges, directed=False):
        self.adj_list = {}
        self.directed = directed
        
        for u, v, k in edges:
            self.add_edge(u, v, k)
            
    def _add_edge_single(self, u, v, k):
        """Internal function. Do not use directly.
        Add a single edge to the graph.
        """
        if u not in self.adj_list:
            self.adj_list[u] = []
        self.adj_list[u].append((v, k))        
                
    def add_edge(self, u, v, k):
        """Add an edge to the graph. Add the reverse edge 
        when the graph is undirected."""
        self._add_edge_single(u, v, k)
        if not self.directed:
            self._add_edge_single(v, u, k)
    
    def neighbors(self, u):
        """Return the list of neighbors and the 
        corresponding weights of u"""
        return self.adj_list[u]
    
     
    def vertices(self):
        """Return the set of vertices of the graph"""
        return self.adj_list.keys()
    

## Prim's Algorithm

Implement the Prim's algorithm in the following cell. Return a hash with nodes as keys and the corresponding parents as values.

In [19]:
# Add any functions/classes as necessary.



def prim(graph, s):
    mst={}
    All_v=[]
    final={}
    for i in graph.vertices():
        All_v.append((i,float('inf')))
        final[i]=float('inf')
    
    All_v[0]=(s,0)
    final[0]=None
#     l=l2.pop(0)
#     l1.append(l)
#     #print(l2)
#     #print(l1)
    le2=len(All_v)
    All_v=dict(All_v)
    while len(mst)<le2:
 
        
        #print(All_v)
        m = min(All_v.keys(), key=(lambda k: All_v[k]))
        #print(m)
        #l2.sort(key=lambda x: x[1]) 
        val=All_v.pop(m)
        #print(node)
        mst[m]=val#Adding explored node and its pathweight object to list
        #print(mst)
        #print("neighbors",graph.neighbors(m))
        for k,v in graph.neighbors(m):
            #print(mst[m])
            if k not in mst.keys():
                if(All_v[k]>=v):
                    All_v[k]=v
                    final[k]=m
            #print(All_v)
            #print(final)
    return final  

In [20]:
graph = Graph([(0, 1, 4), (0, 7, 8), (1, 2, 8), (1, 7, 11), (2, 3, 7), (2, 8, 2), (2, 5, 4), (3, 4, 9), (3, 5, 14), (4, 5, 10), (5, 6, 2), (6, 7, 1), (6, 8, 6), (7, 8, 7)])

prim(graph,0)

{0: None, 1: 0, 7: 0, 2: 5, 3: 2, 8: 2, 5: 6, 4: 3, 6: 7}

In [21]:
# Test cell (from CLRS)

graph = Graph([(0, 1, 4), (0, 7, 8), (1, 2, 8), (1, 7, 11), (2, 3, 7), (2, 8, 2), (2, 5, 4), (3, 4, 9), (3, 5, 14), (4, 5, 10), (5, 6, 2), (6, 7, 1), (6, 8, 6), (7, 8, 7)])

def check_parents(mst, parents, n):
    for i in range(n):
        if mst[i] != parents[i]:
            return False
    return True

mst = prim( graph,0 )
sol1 = [None, 0, 1, 2, 3, 2, 5, 6, 2]
sol2 = [None, 0, 5, 2, 3, 6, 7, 0, 2]

assert check_parents(mst, sol1, 9) or check_parents(mst, sol2, 9)

## Kruskal's Algorithm

Next we implement the Kruskal's Minimum Spanning Tree algorithm. For doing that, we first need to implment the Union and Find algorithms.

In [24]:
# the Union Find algorithms

class UnionFind:
    def __init__(self, nodes):
        # nodes is a list of nodes. Associate an identifier for each node.
        self.id={}
        for i in nodes:
            self.id[i]=i#Setting id as same as the element
        print(id)
    def find(self, u):
        # Return the representative element (identifier/label) for u
        return self.id[u]
    def union(self, u, v):
        # Make the representative elements correponding to u and v the same
        x=self.find(u)
        y=self.find(v)
        self.id[v]=self.id[u]
        for i in self.id:
            if self.id[i]==x:
                self.id[i]=y
        

In [25]:
# Test

nodes = [2, 3, 4, 5, 6, 7]
uf = UnionFind(nodes)

for i in range(len(nodes)):
    for j in range(i+1, len(nodes)):
        assert uf.find(nodes[i]) != uf.find(nodes[j])
        

uf.union(4, 7)

assert uf.find(4) == uf.find(7)

uf.union(4, 2)

assert uf.find(4) == uf.find(2) and uf.find(4) == uf.find(7)

assert uf.find(4) != uf.find(5)
assert uf.find(4) != uf.find(6)
assert uf.find(4) != uf.find(3)

uf.union(3, 5)

assert uf.find(3) == uf.find(5)
assert uf.find(4) != uf.find(3)
assert uf.find(6) != uf.find(3)

uf.union(5, 2)

for k in [3, 4, 5, 7]:
    assert uf.find(k) == uf.find(2)
    
assert uf.find(2) != uf.find(6)



<built-in function id>


Using the UnionFind class, implement the Kruskal's algorithm in the following cell. Output the __set__ of edges included in the MST. Return the edges as pairs $(u, v)$.

In [26]:
# Add any functions/classes as necessary.
def kruskal(graph):
    mst_edges = set()
    edges=[]
    sort_e=[]
    # Your code here
    nodes=[]
    for j in graph.vertices():
        nodes.append(j)
    #print(nodes)
    f=UnionFind(nodes)
    v_len=len(nodes)
    #print(v_len)
    for i in graph.vertices():
        for k,v in graph.neighbors(i):
            if (i,k,v) not in edges and(k,i,v) not in edges:
                
                edges.append((i,k,v)) 
    #print(edges)
    print(edges.sort(key=lambda x:x[2]))#Sort inplace no return
    #print(edges)
    for u,v,w in edges:
        if f.find(u)!=f.find(v):#Checking if (u,v) forms a cycle
            mst_edges.add((u,v))# add (u,v) to mst_edges if it doesn't form a cycle
            f.union(u,v)   
    # Return the MST.
    print(mst_edges)
    return mst_edges
    

In [27]:
# Test for Kruskal's 

mst = kruskal( graph )
sol1 = set([(0, 1), (0, 7), (2, 8), (2, 5), (2, 3), (3, 4), (7,6), (5,6)])
sol2 = set([(0, 1), (1, 2), (2, 8), (2, 5), (2, 3), (3, 4), (6,7), (6,5)])

assert mst == sol1 or mst == sol2



<built-in function id>
None
{(0, 1), (5, 6), (2, 8), (7, 6), (0, 7), (2, 3), (2, 5), (3, 4)}
