In [None]:
import pandas
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from queue import PriorityQueue

class GraphPlotter:
  def __init__(self, filename='graph_print.pdf'):
    self.HIGHLIGHT_COLOR = '#e31a1c'
    self.VISITED_COLOR = '#5fe382'
    self.NEUTRAL_COLOR = '#1f78b4'
    self.output = PdfPages(filename)
    self.graph = nx.Graph()
    self.node_color = []

  def color_visited(self, node):
    self.node_color[node] = self.VISITED_COLOR

  def node_highlight(self, node):
    self.node_color[node] = self.HIGHLIGHT_COLOR

  def draw(self, current_node=None):
    fig = plt.figure(figsize=(12, 8))
    fig = nx.draw(self.graph, node_color=self.node_color, with_labels=True, font_weight='bold')
    return fig

def print_graph_pdf(graph_plotter, algorithm, params=[], filename='dijkstra.pdf'):
  graph_plotter.setup(params)

  algorithm(graph_plotter, params)

  graph_plotter.output.close()


In [None]:
class DijkstraPlotter(GraphPlotter):
  def setup(self, params):
    df = pandas.read_csv('dataset.csv', sep=',', decimal='.')
    input_graph = nx.from_pandas_edgelist(df, source='source', target='target', edge_attr='weight', create_using=nx.DiGraph())
    self.graph = nx.Graph()
    self.graph.add_nodes_from(sorted(input_graph.nodes(data=True)))
    self.graph.add_edges_from(input_graph.edges(data=True))
    self.node_color = ['#1f78b4'] * len(self.graph.nodes)
  
  def draw(self, current_node):
    fig = plt.figure(figsize=(12, 8))
    
    plt.text(0.05, 0.95, f'Current node: {current_node}\nNeighbors: {[i for i in self.graph.neighbors(current_node)]}', transform = fig.transFigure, size = len(self.graph.nodes))

    fig = nx.draw(self.graph, node_color=self.node_color, with_labels=True, font_weight='bold')

    graph_data = self.graph.nodes(data=True)
    data_text = ''
    for node in graph_data:
      data_text += f'{node[0]}: {node[1]["distance"]} Km {node[1]["path"]} \n'
    plt.figtext(0.01, 0.01, data_text, fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})
    
    return fig

  def callback(self, current_node):
    self.node_highlight(current_node-1)
    self.output.savefig(self.draw(current_node), bbox_inches='tight')
    self.color_visited(current_node-1)

def dijkstra(graph_plotter, params):
  graph = graph_plotter.graph
  origin = params[0]
  for i in graph.nodes():
    graph.nodes[i]['visited'] = False
    graph.nodes[i]['distance'] = int(9999999)
    graph.nodes[i]['path'] = []

  node_queue = PriorityQueue()

  graph.nodes[origin]['distance'] = 0
  graph.nodes[origin]['path'] = [origin]
  node_queue.put((graph.nodes[origin]['distance'], origin))

  while(not node_queue.empty()):
    current_node = node_queue.get()[1]
    if (graph.nodes[current_node]['visited'] == True):
      continue

    graph.nodes[current_node]['visited'] = True

    if (graph_plotter.callback):
      graph_plotter.callback(current_node=current_node)

    for i in graph.neighbors(current_node):
      if (graph.nodes[i]['distance'] > graph.nodes[current_node]['distance'] + graph.edges[current_node, i]['weight']):
        graph.nodes[i]['distance'] = graph.nodes[current_node]['distance'] + graph.edges[current_node, i]['weight']
        graph.nodes[i]['path'] = graph.nodes[current_node]['path'] + [i]
      if (not graph.nodes[i]['visited']):
        node_queue.put((graph.nodes[i]['distance'], i))

print_graph_pdf(DijkstraPlotter('output/dijkstra.pdf'), dijkstra, [4])

In [None]:
class RandomPlotter(GraphPlotter):
  def setup(self):
    gen = nx.complete_graph(20)
    self.graph.add_nodes_from(sorted(gen.nodes(data=True)))
    self.graph.add_edges_from(gen.edges(data=True))
    self.node_color = ['#1f78b4'] * len(self.graph.nodes)

  def draw(self):
    fig = plt.figure(figsize=(12, 8))

    fig = nx.draw(self.graph, node_color=self.node_color, with_labels=True, font_weight='bold')
    
    return fig

  def callback(self):
    self.output.savefig(self.draw(), bbox_inches='tight')

def random_graph(graph_plotter, params):
    if (graph_plotter.callback):
      graph_plotter.callback()

print_graph_pdf(RandomPlotter('output/random_graph.pdf'), random_graph, [])