# Minimum Spanning Tree

We will be trying to find the minimum spanning tree of all nodes over Piedmont in California 

We will be implementing kruskal's algorithm, which find the minimum spanning tree over each connected component in out graph

Here is somethings you need to think about when you are doing this algorithm over a graph contructed from a map

- The graph may seem like a one big connected component but it actually isn't. Because there are one-way streets which seem to connect the adjacent nodes on the below graph, but they don't because for example you can get from A to B but you can't get from B to A.

- That is why you will find multiple spanning tree at the end. Over each *connected* streets/neighbourhood of nodes.

- You will find loops in the spanning tree which is the biggest violation of tree-like structure, but you need to think about one-way and two-way street as I mentioned above. It may seem like a loop to you but it is just a one-way street

In [None]:
##############################################################################
# IGNORE if you are running on a local machine and have all the dependencies #
##############################################################################

# installing dependencies --- to be able to run on google colab
# it would take 2-3 minutes
!apt-get install libspatialindex-c4v5;
!pip3 install Rtree;
!pip3 install osmnx;

# you HAVE to upload problem.py file 
# so the directory should be 
#|- sample_data/ 
#|- problem.py

In [None]:
import networkx as nx
import osmnx as ox
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
from problem import *

In [None]:
place = 'Piedmont, California, USA'
G = ox.graph_from_place(place, network_type='drive_service')
G = ox.project_graph(G)

In [None]:
fig, ax = ox.plot_graph(G)

We will be getting x,y coordinates of each nodes for visulaization at the end

In [None]:
node_Xs = [float(x) for _, x in G.nodes(data='x')]
node_Ys = [float(y) for _, y in G.nodes(data='y')]

In [None]:
# tuple (source, destination, distance)
Edges = PriorityQueue('min', lambda edge: edge[2])   

In [None]:
for source, destination, data in G.edges(keys=False, data=True):
    Edges.append((source, destination, data['length']))

We need to maintan disjoint set for the edges and nodes to keep track of what node is connected to what node and avoid loops

In [None]:
# disjoint set for the edges
nodes_set = {}
for i in Edges.heap:
    nodes_set[i[1][0]] = -1
    nodes_set[i[1][1]] = -1

In [None]:
# returns -1 if the node has not been attached to other node on the run time
def find_parent(node):
    r = node
    while nodes_set[r] >= 0:
        r = nodes_set[r]
    return r

# The Algorithm

In [None]:
# super easy if you spend 5 mins thinking about it
# the premise of the algorithm is just to avoid loop
# and add edges greedily
# The algorithm build forestS over each connected component in the graph
# which will diverge to a single tree over each component at the end of the algorithm
Forest = set()
Size = 0
j = 0
m = len(Edges)
n = len(G.nodes)
while Size < n and j < m:
    j += 1
    # the shortest edge till this point
    edge = Edges.pop()
    node1 = edge[0]
    node2 = edge[1]
    parent1 = find_parent(node1)
    parent2 = find_parent(node2)
    if parent1 != parent2:
        Forest.add(edge)
        Size += 1
        nodes_set[node1] = node2

# Visulaization

In [None]:
sources = [edge[0] for edge in Forest]
destinations = [edge[1] for edge in Forest]

In [None]:
fig, ax =  plt.subplots(figsize=(15, 11))
ax.set_facecolor('w')
lines = []
colorS = []
widthS = []
for u, v, data in G.edges(keys=False, data=True):
        if 'geometry' in data:
            xs, ys = data['geometry'].xy
            lines.append(list(zip(xs, ys)))
        else:
            x1 = G.nodes[u]['x']
            y1 = G.nodes[u]['y']
            x2 = G.nodes[v]['x']
            y2 = G.nodes[v]['y']
            line = [(x1, y1), (x2, y2)]
            lines.append(line)
        if u in sources and v in destinations:
            colorS.append('b')
            widthS.append(2.5)
            continue
        if v in sources and u in destinations:
            colorS.append('b')
            widthS.append(2.5)
            continue
        colorS.append('r')
        widthS.append(0.4)    
lc = LineCollection(lines, colors=colorS, linewidths=widthS, alpha=0.3)
ax.add_collection(lc)
scat = ax.scatter(node_Xs, node_Ys,c='k', s=5)

These loops which you are seeing and these many components over the graph are caused by the presence of one-way streets, don't let it fool you.