In [230]:
class WeightedGraph:

    weightSet = set()

    def __init__(self, n):
        self.weightSet.clear() #Set tracking addition
        self.adj = {}
        for i in range(n):
            self.adj[i] = []

    def are_connected(self, node1, node2):
        for edge in self.adj[node1]:
            if edge[0] == node2:
                return True
        return False

    def adjacent_nodes(self, node):
        return self.adj[node]

    def add_node(self):
        self.adj[len(self.adj)] = []

    def add_edge(self, node1, node2, weight):
        if node1 not in self.adj[node2]:
            self.adj[node1].append((node2, weight))
            self.adj[node2].append((node1, weight))

    def w(self, node1, node2):
        for edge_info in self.adj[node1]:
            if node2 == edge_info[0]:
                return edge_info[1]

    def number_of_nodes(self):
        return len(self.adj)

    def showWeights(self):
        for x in self.weightSet:
            print(x)

    def show(self):
        for i in range(len(self.adj)):
            print("node:", i)
            print(self.adj[i])

In [231]:
import math

class MinHeap:
    length = 0
    data = []

    def __init__(self, L):
        self.data = L
        self.length = len(L)
        self.map = {}
        for i in range(len(L)):
            self.map[L[i].value] = i
        self.build_heap()

    def build_heap(self):
        for i in range(self.length // 2 - 1, -1, -1):
            self.sink(i)

    def sink(self, i):
        smallest_known = i
        if self.left(i) < self.length and self.data[self.left(i)].key < self.data[i].key:
            smallest_known = self.left(i)
        if self.right(i) < self.length and self.data[self.right(i)].key < self.data[smallest_known].key:
            smallest_known = self.right(i)
        if smallest_known != i:
            self.data[i], self.data[smallest_known] = self.data[smallest_known], self.data[i]
            self.map[self.data[i].value] = i
            self.map[self.data[smallest_known].value] = smallest_known
            self.sink(smallest_known)

    def insert(self, element):
        if len(self.data) == self.length:
            self.data.append(element)
        else:
            self.data[self.length] = element
        self.map[element.value] = self.length
        self.length += 1
        self.swim(self.length - 1)

    def insert_elements(self, L):
        for element in L:
            self.insert(element)

    def swim(self, i):
        while i > 0 and self.data[i].key < self.data[self.parent(i)].key:
            self.data[i], self.data[self.parent(i)] = self.data[self.parent(i)], self.data[i]
            self.map[self.data[i].value] = i
            self.map[self.data[self.parent(i)].value] = self.parent(i)
            i = self.parent(i)

    def get_min(self):
        if len(self.data) > 0:
            return self.data[0]
  
    def extract_min(self):
        self.data[0], self.data[self.length - 1] = self.data[self.length - 1], self.data[0]
        self.map[self.data[self.length - 1].value] = self.length - 1
        self.map[self.data[0].value] = 0
        min_element = self.data[self.length - 1]
        self.length -= 1
        self.map.pop(min_element.value)
        self.sink(0)
        return min_element

    def decrease_key(self, value, new_key):
        if new_key >= self.data[self.map[value]].key:
            return
        index = self.map[value]
        self.data[index].key = new_key
        self.swim(index)

    def get_element_from_value(self, value):
        return self.data[self.map[value]]

    def get_key_from_value(self, value):
        return self.data[self.map[value]].key

    def is_empty(self):
        return self.length == 0
    
    def left(self, i):
        return 2 * (i + 1) - 1

    def right(self, i):
        return 2 * (i + 1)

    def parent(self, i):
        return (i + 1) // 2 - 1

    def __str__(self):
        height = math.ceil(math.log(self.length + 1, 2))
        whitespace = 2 ** height
        s = ""
        for i in range(height):
            for j in range(2 ** i - 1, min(2 ** (i + 1) - 1, self.length)):
                s += " " * whitespace
                s += str(self.data[j]) + " "
            s += "\n"
            whitespace = whitespace // 2
        return s

class Element:

    def __init__(self, value, key):
        self.value = value
        self.key = key

    def __str__(self):
        return "(" + str(self.value) + "," + str(self.key) + ")"

In [232]:
gr1 = WeightedGraph(8)
gr1.add_edge(4,5,35)
gr1.add_edge(4,7,37)
gr1.add_edge(5,7,28)
gr1.add_edge(0,7,16)
gr1.add_edge(1,5,32)
gr1.add_edge(0,4,38)
gr1.add_edge(2,3,17)
gr1.add_edge(1,7,19)
gr1.add_edge(0,2,26)
gr1.add_edge(1,2,36)
gr1.add_edge(1,3,29)
gr1.add_edge(2,7,34)
gr1.add_edge(6,2,40)
gr1.add_edge(3,6,52)
gr1.add_edge(6,0,58)
gr1.add_edge(6,4,93)

gr1.show()

node: 0
[(7, 16), (4, 38), (2, 26), (6, 58)]
node: 1
[(5, 32), (7, 19), (2, 36), (3, 29)]
node: 2
[(3, 17), (0, 26), (1, 36), (7, 34), (6, 40)]
node: 3
[(2, 17), (1, 29), (6, 52)]
node: 4
[(5, 35), (7, 37), (0, 38), (6, 93)]
node: 5
[(4, 35), (7, 28), (1, 32)]
node: 6
[(2, 40), (3, 52), (0, 58), (4, 93)]
node: 7
[(4, 37), (5, 28), (0, 16), (1, 19), (2, 34)]


In [233]:
def prim1(graph):

    mst = WeightedGraph(graph.number_of_nodes()) #create new graph
    pqueueEdge = list()
    pqueueNode = list()
    pqueue = list()
    marked = set()
    node = 0

    visit(graph, pqueueEdge, pqueueNode, pqueue, marked, mst, node)

    #add current node parameter
    while(len(pqueue)!= 0):
        pqueue.sort(key=lambda tup: tup[2]) #sorts based on weights
        minEdge = pqueue.pop(0)

        if (minEdge[0] in marked and minEdge[1] in marked):
            continue #this vertix is marked or the other one is
        
        mst.add_edge(minEdge[0],minEdge[1],minEdge[2]) #add connection to completed mst
        if (minEdge[0] not in marked):
            visit(graph, pqueueEdge, pqueueNode, pqueue, marked, mst, minEdge[0])
        if (minEdge[1] not in marked):
            visit(graph, pqueueEdge, pqueueNode, pqueue, marked, mst, minEdge[1])

    return mst


def visit(graph, pqueueEdge, pqueueNode, pqueue, marked, mst, node):
    marked.add(node)
    for x in graph.adjacent_nodes(node):
        if x[0] not in marked:
            pqueue.append((node,x[0],x[1]))
            # pqueueEdge.append(x)
            # pqueueNode.append(node)
            # print(x)


In [234]:
import sys

def primsAdv(graph):

    mst = WeightedGraph(graph.number_of_nodes())

    L = []
    for v in range(graph.number_of_nodes()):
        L.append(Element(v,1e7))

    heap = MinHeap(L)
    heap.decrease_key(0,0)
    # print(heap.get_element_from_value(88))
    # print(heap.get_key_from_value(1))

    while heap.is_empty() == False:
 
        # Extract the vertex with minimum distance value
        newNode = heap.extract_min()
        
        currentVert = newNode.value

        # print("next")

        for nodes in graph.adjacent_nodes(currentVert):
            #update the distances for the neighbor vertices
            vert = nodes[0]

            print(graph.w(currentVert,vert))

            # print(v)
            # print("done")

            # heap.decrease_key(vert, graph.w(currentVert,vert))

            print(heap)

            # if v in heap.data and pCrawl[1] < key[v]:
                    
                # heap.decrease_key(v, key[v])
            #         #For every adjacent vertex v of u, check if v is in Min Heap (not yet included in MST). If v is in Min Heap and its key value is more than weight of u-v, then update the key value of v as weight of u-v.
                
        # mst.add_edge()

In [235]:
def prim2(graph):
        nodes = graph.number_of_nodes()
        mst = WeightedGraph(nodes)
        inf = 1e7
         
        key = [0]
        L = [Element(0,0)] 
        parent = [-1]

        for n in range(1,nodes):
            parent.append(-1)
            key.append(inf)
            L.append(Element(n,inf))
        
        minHeap = MinHeap(L)
        minHeap.length = nodes
 
        while not minHeap.is_empty():
            current = minHeap.extract_min()
            
            for adj in graph.adjacent_nodes(current.value):
                v = adj[0]
                
                try:
                    minHeap.get_element_from_value(v)
                except KeyError:
                    continue
                if adj[1] < key[v]:
                    key[v] = adj[1]
                    parent[v] = current.value
                    minHeap.decrease_key(v, key[v])
        
        for n in range(1,nodes):
            mst.add_edge(parent[n],n,key[n])
        
        return mst

In [236]:
mst = prim2(gr1)
# mst.show()

In [237]:
mst = prim1(gr1)
# mst.show()

In [238]:
import random

def create_random_graph(n, c): #nodes, edges
    g = WeightedGraph(n)
    edges = []
    edge = (0,0)

    weights = []
    for x in range(c+1):
        weights.append(x)
    random.shuffle(weights)

    # print(weights)

    if c > n*(n-1)/2: #cap max verticies
        c = n*(n-1)/2
    while c > 0:
        while edge[0] == edge[1] or edge in edges or edge[::-1] in edges:
            edge = (random.randint(0,n-1), random.randint(0,n-1))
            # print(edge)
            # print(edge[::-1])
        edges.append(edge)
        # print(edge[0],edge[1])
        g.add_edge(edge[0],edge[1],weights[c])
        c -= 1
    return g

def logtocsv(num,results,filename):
    with open(filename, 'a') as f: #'w' for write, 'a' for append
        f.write(str(num))
        f.write(",")
        f.write(str(results))
        f.write('\n')

def timer(index):
    if __name__ == '__main__':
        import timeit
        print("timing for size of {}".format(index))
        return timeit.repeat("prim2(graph)", setup="from __main__ import prim2, graph", repeat=1, number=1) 

In [271]:
# import copy 

# i = 10
# while (i < 1001):
#     for x in range(1):
#         testGraph = create_random_graph(i,i*2)

#         graph = copy.copy(testGraph)
#         logtocsv(i, timer(i)[0], "prim2.csv") #To switch function to test, changer here and in the method.

#     i *= 10

# mst = prim1(gr2)
# mst.show()
# gr2 = create_random_graph(100,200)
# gr2.show()
gr3 = create_random_graph(100,150)
gr3.show()
# mst2 = prim2(gr2)
# mst2.show()
mst3 = prim2(gr3)
mst3.show()

node: 0
[(40, 88), (15, 64), (1, 94)]
node: 1
[(64, 120), (63, 149), (7, 69), (36, 10), (0, 94)]
node: 2
[(3, 6), (97, 106)]
node: 3
[(53, 2), (27, 45), (33, 131), (2, 6), (17, 119), (82, 82)]
node: 4
[(33, 62), (15, 21), (37, 87), (47, 95), (71, 47), (81, 56)]
node: 5
[(90, 128)]
node: 6
[(97, 111), (18, 0)]
node: 7
[(1, 69)]
node: 8
[(81, 79)]
node: 9
[(78, 99), (69, 28), (50, 66), (45, 132), (20, 43)]
node: 10
[(91, 84), (11, 19), (54, 31)]
node: 11
[(10, 19), (32, 146)]
node: 12
[(53, 107), (92, 139), (80, 32), (75, 90)]
node: 13
[(47, 55), (55, 16)]
node: 14
[(64, 141), (43, 118), (78, 112)]
node: 15
[(0, 64), (4, 21), (59, 33), (20, 85)]
node: 16
[(95, 18), (61, 23), (45, 3), (47, 122)]
node: 17
[(23, 30), (3, 119), (71, 115), (98, 5)]
node: 18
[(26, 125), (79, 14), (75, 75), (6, 0)]
node: 19
[(38, 9)]
node: 20
[(39, 136), (40, 15), (93, 116), (98, 133), (65, 74), (15, 85), (9, 43), (99, 126)]
node: 21
[(89, 140), (70, 134)]
node: 22
[(89, 130), (77, 48)]
node: 23
[(81, 78), (99,

KeyError: -1