In [None]:
import networkx as nx
import math
import random


class KnowledgeBase:
    def __init__(self, G: nx.Graph, current_node: tuple[int, int],
                 target_node: tuple[int, int]):
        self.G = G
        self.current_node = current_node
        self.target_node = target_node
        self.visited = set()
        self.deadlock = set()
        self.visited.add(current_node)
        self.stack = [current_node]

    def get_next_node(self):
        neighbors = list(self.G.neighbors(self.current_node))

        if self.target_node in neighbors:
            return self.target_node

        unvisited_neighbors = [n for n in neighbors if n not in self.visited]
        unvisited_neighbors = [n for n in unvisited_neighbors
                               if n not in self.deadlock]

        if len(unvisited_neighbors) == 0:
            self.stack.pop()
            return None

        min_distance = 100

        for n in unvisited_neighbors:
            n_neighbors = list(self.G.neighbors(n))
            n_neighbors.remove(self.current_node)

            if len(n_neighbors) == 0:
                self.deadlock.add(n)
                continue

            distance = min(self.calculate_distance(
                node, self.target_node) for node in n_neighbors)
            if distance < min_distance:
                min_distance = distance

        valid_nodes = set()
        unvisited_neighbors = [n for n in unvisited_neighbors
                               if n not in self.deadlock]

        for n in unvisited_neighbors:
            n_neighbors = list(self.G.neighbors(n))
            distance = min(self.calculate_distance(
                node, self.target_node) for node in n_neighbors)
            if distance <= min_distance:
                valid_nodes.add(n)

        valid_nodes = list(valid_nodes)
        if len(valid_nodes) <= 0:
            self.stack.pop()
            return None

        return random.choice(list(valid_nodes))

    def calculate_distance(self, src: tuple[int, int], dest: tuple[int, int]):
        x1, y1 = src
        x2, y2 = dest
        return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)

    def move(self, node: tuple[int, int]):
        if node is None and len(self.stack) == 0:
            return None
        elif node is None:
            self.current_node = self.stack[-1]
            print("Going back")
            return self.current_node

        self.current_node = node
        self.visited.add(node)
        self.stack.append(node)
        return self.current_node

    def shouldContinue(self):
        if (self.current_node == self.target_node):
            return False
        else:
            return True
