In [None]:
import networkx as nx
import matplotlib.pyplot as plt
# from typing import Dict, List, Optional, TypeVar, Self, Tuple, Any, Tuple
import typing
from dataclasses import dataclass
from matplotlib import gridspec

class HeapObject():
  # The priority of this heap object
  priority: int
  # The heap object itself.
  object: Any

  def __init__(self, priority: int, obj: Any):
    self.priority = priority
    self.object = obj

  def new_priority(self, new_priority: int):
    """ Reset the priority of the heap object. """
    self.priority = new_priority

  def __repr__(self) -> str:
    return f"priority: {self.priority}; object: {self.object}"

class Heap():
  heap: List[HeapObject]

  def __init__(self):
    self.heap = []
    pass

  def add(self, priority: int, obj: Any):
    """ Add a new object to the heap. """
    pass

  def decrease_object_priority(self, obj: Any, new_priority: int):
    """ Decrease the priority (to _new_priority_) of the HeapObject holding _obj_. """
    pass

  def _extract_index(self, index) -> HeapObject:
    """ Extract the heap object at the given index. """
    pass

  def extract(self) -> HeapObject:
    """ Extract the HeapObject at the top of the heap. """
    pass

  def children_indexes(self, index) -> Tuple[int, int]:
    """ Calculate the indexes of the children of a node at the given index. """
    pass

  def parent_index(self, index) -> int:
    """ Calculate the index of the parent of a node at the given index. """
    pass

  def fixup_down(self, index):
    """ Fixup a potential break of the invariant (toward the leafs) starting at the given index. """
    pass

  def fixup_up(self, index):
    """ Fixup a potential break of the invariant (toward the root) starting at the given index. """
    pass

  def contains(self, obj) -> bool:
    """ Determine whether the heap contains a HeapObject that refers to the given object. """
    pass

  def __repr__(self):
    return "\n".join(map(str, self.heap))


@dataclass
class Edge():
  """ An edge is between a source node and a destination node, as id'd by name.

      Each edge also tracks the weight of the path between them which represents
      the cost to traverse from the source to the destination.
  """
  destination: str
  weight: int

@dataclass
class Node:
  """ A Node represents a node in a graph, as id'd by a name.

      A node also has a color (for visualization), a set of outgoing
      edges, a reference to its previous node (for use during shortest
      path calculation), and a distance from some other node (also for
      use during shortest path calculation).
  """
  name: str
  edges: List[Edge]
  previous: Optional[Self]
  distance: Optional[int]

  def __eq__(self, other) -> bool:
    return self.name == other.name
  def __lt__(self, other) -> bool:
    return self.distance < other.distance

class PathableGraph():
  # A dictionary that maps node names to nodes.
  nodes: Dict[str, Node]
  # A reference to the start node of the path being traced.
  start_node: Node
  # A reference to the end node of the path being traced.
  end_node: Node
  # A list of the names of the nodes that have been completely explored.
  black: List[str]
  # A min priority queue that holds the nodes on the frontier.
  gray: Heap
  # A count of the number of iterations that have been performed.
  iteration_count: int

  def __init__(self):
    self.nodes = {}
    self.black = []
    self.gray = Heap()
    self.draw_count = 0

  def add_node(self, node_name: str):
    """ Add a node (with the given name) to the graph. """
    self.nodes[node_name] = Node(node_name, [], None, None)

  def add_edge(self, start: str, stop: str, weight: int):
    """ Add an edge (between the nodes with the given names) to the graph. """
    self.nodes[start].edges.append(Edge(stop, weight))

  def get_iteration_count(self) -> int:
    """ Get the number of iterations that have been performed. """
    pass

  def reset_sp(self):
    """ Reset the state of the shortest path calculation -- to start again. """
    def clear(x: Node):
      x.distance = None
      x.previous = None
    map(clear, self.nodes.values())
    self.gray = Heap()
    self.black = []
    self.start_node.distance = 0

  def start_processing(self, start: str, stop: str):
    self.start_node = self.nodes[start]
    self.end_node = self.nodes[stop]
    self.reset_sp()

    # Make the start node gray to kick things off!
    self.gray.add(self.start_node.distance, self.start_node)

  def iterate(self):
    """ Do a single _iteration_ of shortest path calculation.

        Doing one iteration at a time makes it possible to visualize
        the algorithm's progress.
    """
    pass

  def path_found(self):
    """ Calculate whether a path was found between the start and end nodes. """

    ###### NOTE #######
    # The return True here is _not_ what you will ultimately want. It is
    # to prevent the skeleton code (as distributed) from running forever.
    ###################
    return True

class PathableGrapherVisualizer():
  graph: PathableGraph
  draw_count: int

  def __init__(self, graph: PathableGraph):
    self.graph = graph
    self.reset_viz()

  def reset_viz(self):
    self.draw_count = 0
    self.figure = plt.figure(layout="tight")
    # Start with 0 subplots
    self.row = 0
    self.draw_count = 0
    self.draw_state = None

  def add_viz(self: Self):
    """ Visualize the current state of the algorithm's progress in calculating the shortest path. """
    #iteration_plot = self.figure.add_subplot(self.draw_count + 1, 1, self.draw_count + 1, title=f"Iteration {self.draw_count + 1}")

    self.row += 1
    gs = gridspec.GridSpec(self.row, 1)
    self.figure.set_figwidth(10)
    self.figure.set_figheight(10 *self.row)

    # Reposition existing subplots
    for i, ax in enumerate(self.figure.axes):
        ax.set_position(gs[i].get_position(self.figure))
        ax.set_subplotspec(gs[i])

    # Add new subplot
    iteration_plot = self.figure.add_subplot(gs[self.row-1],title=f"Iteration {self.draw_count + 1}")

    if self.draw_state is None:
      #visualization = nx.DiGraph()
      visualization = nx.Graph()
      for node_name in self.graph.nodes.keys():
        for outgoing in self.graph.nodes[node_name].edges:
          (destination, weight) = (outgoing.destination, outgoing.weight)
          visualization.add_edge(node_name, destination, weight=10)
      layout = nx.spring_layout(visualization)
      self.draw_state = (visualization, layout)

    node_colors: List[str] = []
    node_font_colors: Dict[Node, str] = {}
    edge_labels: Dict[Tuple[str, str], str] = {}
    node_names: Dict[str, str] = {}

    for node in self.graph.nodes.values():
      node_font_colors[node.name] = "black"
      if node.name in self.graph.black:
        node_colors.append("black")
        node_font_colors[node.name] = "white"
      elif self.graph.gray.contains(node):
        node_colors.append("gray")
      else:
        node_colors.append("#8fd3fe")
      for edge in node.edges:
        edge_labels[(node.name, edge.destination)] = str(edge.weight)
      node_names[node.name] = f"{node.name}: {node.distance}\n(from {node.previous.name if not node.previous is None else '?'})"

    draw_params = {}
    draw_params['pos'] = self.draw_state[1]
    draw_params['nodelist'] = self.graph.nodes.keys()
    draw_params['font_weight'] = "bold"
    draw_params['node_color'] = node_colors
    draw_params['node_size']=[len(node_names[v]) * 800 for v in self.graph.nodes.keys()]
    nx.draw(self.draw_state[0], **draw_params)

    label_params = {}
    label_params["labels"] = node_names
    label_params["font_color"] = node_font_colors
    nx.draw_networkx_labels(self.draw_state[0], self.draw_state[1], **label_params)

    nx.draw_networkx_edge_labels(self.draw_state[0], self.draw_state[1], edge_labels=edge_labels)

    self.draw_count = self.draw_count + 1

  def finalize_viz(self, to_file:str):
    plt.savefig(to_file)

def main():

  # Don't look at interactive plots.
  plt.ioff()

  sg = PathableGraph()

  sg.add_node("A")
  sg.add_node("B")
  sg.add_node("C")
  sg.add_node("D")
  sg.add_node("E")

  sg.add_edge("A", "B", 2)
  sg.add_edge("A", "C", 6)
  sg.add_edge("B", "D", 5)

  sg.add_edge("C", "D", 3)
  sg.add_edge("D", "E", 1)
  sg.add_edge("C", "E", 6)

  sg.start_processing("A", "E")
  sg_viz = PathableGrapherVisualizer(sg)
  sg_viz.add_viz()
  while not sg.path_found():
    sg.iterate()
    sg_viz.add_viz()
  sg_viz.finalize_viz('example1.png')

  # The example from Page 621.
  sg2 = PathableGraph()

  sg2.add_node('s')
  sg2.add_node('t')
  sg2.add_node('x')
  sg2.add_node('z')
  sg2.add_node('y')

  sg2.add_edge("s", "t", 10)
  sg2.add_edge("s", "y", 5)
  sg2.add_edge("t", "x", 1)
  sg2.add_edge("x", "z", 4)
  sg2.add_edge("z", "x", 6)
  sg2.add_edge("z", "s", 7)
  sg2.add_edge("y", "z", 2)
  sg2.add_edge("y", "t", 3)
  sg2.add_edge("t", "y", 2)
  sg2.add_edge("y", "x", 9)

  sg2.start_processing("s", "x")
  sg2_viz = PathableGrapherVisualizer(sg2)
  sg2_viz.add_viz()
  while not sg2.path_found():
    sg2.iterate()
    sg2_viz.add_viz()
  sg2_viz.finalize_viz('example2.png')

  sg3 = PathableGraph()
  sg3.add_node('A')
  sg3.add_node('a')
  sg3.add_node('b')
  sg3.add_node('c')
  sg3.add_node('d')
  sg3.add_node('e')
  sg3.add_node('f')

  sg3.add_edge("A", "a", 2)
  sg3.add_edge("A", "b", 5)
  sg3.add_edge("a", "c", 7)
  sg3.add_edge("a", "d", 8)
  sg3.add_edge("c", "e", 12)
  sg3.add_edge("a", "f", 1)
  sg3.add_edge("f", "c", 1)

  sg3_viz = PathableGrapherVisualizer(sg3)

  sg3.start_processing("A", "c")
  sg3_viz = PathableGrapherVisualizer(sg3)
  sg3_viz.add_viz()
  while not sg3.path_found():
    sg3.iterate()
    sg3_viz.add_viz()
  sg3_viz.finalize_viz('example3.png')
main()

NameError: name 'Any' is not defined