In [1]:
import numpy as np
import math
from operator import itemgetter

def collect(l, index):
    return list(map(itemgetter(index), l))

class Node: # This class could potentially form the basis for GNNs
    # Might also be useful to have this structure for future graph related problems you encounter
    def __init__(self,parents,children,children_distance,parent_distance,value,tag,visited):
        self.parents = parents  # An array storing all parent node objects
        self.children = children # An array storing all child node objects
        self.children_distance = children_distance # An array storing distances to all children nodes
        self.parent_distance = parent_distance # An array storing distances to all parent nodes
        self.tag = tag # Name assigned to node
        self.value = value # Value assigned to node
        self.visited = visited # Visited status -> Used for Dijkstra's algorithm
        
    def add_child(self,node):
        if self.children is None:
            self.children = []
            self.children.append(node)
            if node.parents is None:
                node.parents = []
                node.parents.append(self)
            else:
                node.parents.append(self)
        else:
            if node.parents is None:
                node.parents = []
                node.parents.append(self)
            else:
                node.parents.append(self)
            self.children.append(node)
    
    def add_parent(self,node):
        if self.parents is None:
            self.parents = []
            self.parents.append(node)
            if node.children is None:
                node.children = []
                node.children.append(self)
            else:
                node.children.append(self)
        else:
            self.parents.append(node)
            if node.children is None:
                node.children = []
                node.children.append(self)
            else:
                node.children.append(self)
                
    def add_distance(self,node,value):
        if self.parents is not None:
            if node in self.parents:
                if self.parent_distance is None:
                    self.parent_distance = []
                    self.parent_distance.append((node,value))
                    if node.children_distance is None:
                        node.children_distance = []
                        node.children_distance.append((self,value))
                    else:
                        node.children_distance.append((self,value))
                else:
                    if node.children_distance is None:
                        node.children_distance = []
                        node.children_distance.append((self,value))
                    else:
                        node.children_distance.append((self,value))
                    self.parent_distance.append((node,value))
        if self.children is not None:
            if node in self.children:
                if self.children_distance is None:
                    self.children_distance = []
                    self.children_distance.append((node,value))
                    if node.parent_distance is None:
                        node.parent_distance = []
                        node.parent_distance.append((self,value))
                    else:
                        node.parent_distance.append((self,value))
                else:
                    if node.parent_distance is None:
                        node.parent_distance = []
                        node.parent_distance.append((self,value))
                    else:
                        node.parent_distance.append((self,value))
                    self.children_distance.append((node,value))
    def get_distance(self,node):
        if self.parents is not None:
            if node in self.parents:
                index = collect(self.parent_distance,0).index(node)
                return self.parent_distance[index][1]
        if self.children is not None:
            if node in self.children:
                index = collect(self.children_distance,0).index(node)
                return self.children_distance[index][1]
# Making a graph
graph = []           
node1 = Node(None,None,None,None,10,'node1',False)
node2 = Node(None,None,None,None,2,'node2',False)
node3 = Node(None,None,None,None,5,'node3',False)
node4 = Node(None,None,None,None,4,'node4',False)
node5 = Node(None,None,None,None,7,'node5',False)
node6 = Node(None,None,None,None,14,'node6',False)
node1.add_child(node2)
node1.add_distance(node2,5)
node2.add_child(node3)
node2.add_distance(node3,7)
node1.add_child(node4)
node1.add_distance(node4,4)
node3.add_child(node4)
node3.add_distance(node4,3)
node3.add_child(node6)
node3.add_distance(node6,1)
node3.add_child(node1)
node3.add_distance(node1,8)
node3.add_child(node5)
node3.add_distance(node5,10)
node5.add_child(node6)
node5.add_distance(node6,2)
node5.add_child(node2)
node5.add_distance(node2,6)
graph.append(node1)
graph.append(node2)
graph.append(node3)
graph.append(node4)
graph.append(node5)
graph.append(node6)


def shortest_distance_dijkstra(node1,node2,graph):
    shortest_path = []
    removed_nodes = []
    shortest_distance = 0
    shortest_path.append(node1.tag)
    for node in graph:
        node.value = float('Inf')
        node.visited = False
    current_node = node1
    current_node.value = 0
    while not node2.visited:
        print(current_node.tag)
        children = current_node.children
        parents = current_node.parents
        if children is not None:
            children.extend(parents)
            neighbours = children
        else:
            neighbours = parents
        print(current_node.tag,[neighbour.tag for neighbour in neighbours])
        for neighbour in neighbours:
#             print(current_node.tag,neighbour.tag,shortest_path)
            if not neighbour.visited:
                distancex = neighbour.get_distance(current_node)
                if current_node.value + distancex < neighbour.value: # Analogous to value function iteration in RL
                    neighbour.value = current_node.value + distancex
                    if current_node.tag not in shortest_path and current_node not in removed_nodes:
                        shortest_path.append(current_node.tag)
                else:
                    if current_node.tag in shortest_path and current_node.tag not in removed_nodes:
                        shortest_path.remove(current_node.tag)
                        removed_nodes.append(current_node.tag)
        current_node.visited = True
        previous_node = current_node
        min_node_value = []
        possible_nodes = []
        for node in graph:
            if not node.visited:
                min_node_value.append(node.value)
                possible_nodes.append(node)
        minval_index = np.argmin(min_node_value)
#         for node in graph:
#             if node.visited:
#                 if node.value != np.argmin(min_node_value) and node.tag in shortest_path:
#                     shortest_path.remove(node.tag)
        current_node = possible_nodes[minval_index]
        distancex = previous_node.get_distance(current_node)
    return node2.value,shortest_path

print(shortest_distance_dijkstra(node1,node6,graph))

node1
node1 ['node2', 'node4', 'node3']
node4
node4 ['node1', 'node3']
node2
node2 ['node3', 'node1', 'node5']
node3
node3 ['node4', 'node6', 'node1', 'node5', 'node2']
node6
node6 ['node3', 'node5']
(8, ['node1', 'node4', 'node2', 'node6'])
