In [41]:

class UnionFind:
    def __init__(self):
        '''\
Create an empty union find data structure.'''
        self.num_weights = {}
        self.parent_pointers = {}
        self.num_to_objects = {}
        self.objects_to_num = {}
        self.__repr__ = self.__str__
    def insert_objects(self, objects):
        '''\
Insert a sequence of objects into the structure.  All must be Python hashable.'''
        for object in objects:
            self.find(object);
    def find(self, object):
        '''\
Find the root of the set that an object is in.
If the object was not known, will make it known, and it becomes its own set.
Object must be Python hashable.'''
        if not object in self.objects_to_num:
            obj_num = len(self.objects_to_num)
            self.num_weights[obj_num] = 1
            self.objects_to_num[object] = obj_num
            self.num_to_objects[obj_num] = object
            self.parent_pointers[obj_num] = obj_num
            return object
        stk = [self.objects_to_num[object]]
        par = self.parent_pointers[stk[-1]]
        while par != stk[-1]:
            stk.append(par)
            par = self.parent_pointers[par]
        for i in stk:
            self.parent_pointers[i] = par
        return self.num_to_objects[par]
    def union(self, object1, object2):
        '''\
Combine the sets that contain the two objects given.
Both objects must be Python hashable.
If either or both objects are unknown, will make them known, and combine them.'''
        o1p = self.find(object1)
        o2p = self.find(object2)
        if o1p != o2p:
            on1 = self.objects_to_num[o1p]
            on2 = self.objects_to_num[o2p]
            w1 = self.num_weights[on1]
            w2 = self.num_weights[on2]
            if w1 < w2:
                o1p, o2p, on1, on2, w1, w2 = o2p, o1p, on2, on1, w2, w1
            self.num_weights[on1] = w1+w2
            del self.num_weights[on2]
            self.parent_pointers[on2] = on1

In [42]:
uf = UnionFind()
uf.insert_objects(['A', 'B', 'C', 'D'])
print("Before union:")
print(uf.parent_pointers)  # Initially, each element is its own parent
uf.union('A', 'B')  # Merge sets containing 'A' and 'B'
print("After union:")
print(uf.parent_pointers)  # 'A' and 'B' should now have the same parent


Before union:
{0: 0, 1: 1, 2: 2, 3: 3}
After union:
{0: 0, 1: 0, 2: 2, 3: 3}


In [43]:
uf = UnionFind()
uf.insert_objects(['X', 'Y', 'Z'])

print("number of connnected components in the structure after BEFORE a union operation", len(set(uf.parent_pointers.values())))

uf.union('X', 'Y')

print("number of connnected components in the structure AFTER doing a union operation", len(set(uf.parent_pointers.values())))

print("the representive of Y is", uf.find('Y'))  # Should print the root of the set containing 'Y', which is 'X'

print("the representive of X is", uf.find('X'))  # Should print the root of the set containing 'Y', which is 'X'



# note that since X and Y are in the same set, the function find returns the same representive element.

number of connnected components in the structure after BEFORE a union operation 3
number of connnected components in the structure AFTER doing a union operation 2
the representive of Y is X
the representive of X is X


In [44]:
def rankedges(G):
    '''\
    Given a graph G, return a list of edges sorted by their weights.'''
    lst=[]
    for e in G.edges():
        w = G[e[0]][e[1]]['weight']
        lst.append((e,w))
            
    lst.sort(key=lambda x: x[1])    
    return lst 

def spanningtree(G):
    '''\
    Given a graph G, return a minimum spanning tree (as a graph) and a list of ranked edges.'''
    
    tree_ranked_edges=[]
    lst=rankedges(G)

    uf = UnionFind()

    for node in G.nodes():
        uf.find(node)
        tree=nx.Graph()

    for edge in lst:

        if uf.find(edge[0][0])==uf.find(edge[0][1]):
            continue
        else:
            tree_ranked_edges.append(edge)
            uf.union(edge[0][0],edge[0][1])
            tree.add_edge(edge[0][0],edge[0][1])        
    return tree,tree_ranked_edges

In [45]:
G = nx.Graph()

G.add_edge(0,1, weight=1.2)
G.add_edge(1,2, weight=0.65)
G.add_edge(2,3, weight=0.23)
G.add_edge(3,0, weight=0.15)
G.add_edge(0,5, weight=0.12)
G.add_edge(5,0, weight=0.2)
G.add_edge(0,10, weight=2.3)



In [51]:
tree,  = spanningtree(G)

In [60]:
print("the edges of the tree are ",tree.edges)


#print("number of edges in the tree",len(original_graph_edges))


the edges of the tree are  [(0, 3), (0, 5), (0, 10), (3, 2), (2, 1)]
