In [1]:
from __future__ import annotations
import copy
from abc import ABC, abstractmethod
from itertools import chain
from typing import Callable, Literal, List, TypedDict

# Define custom types here
State = List[List[str]]
class OperationDict(TypedDict):
    operator: Callable
    state: State


class Node:
    """ Implements a node. """
    
    def __init__(
        self, 
        state: State, 
        parent_node: Node = None,
        operator: Callable = None, 
        depth: int = 0,
        path_cost: int = 0,
        f_value: int = 0,
        h_value: int = 0,
        p_value: int = 0,
        s_value: int = 0
    ) -> None:
        """ 
        Since we're not using matrix packages like numpy, we'll just implement
        the matrix with list. So the format will be something like
        [
            [1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]
        ]

        As for the operators, it's best to just define them outside the class
        to be more flexible as we might want a different way to play the game.
        """
        self._state = state
        self._parent_node = parent_node
        self._operator = operator
        self._depth = depth
        self._path_cost = path_cost
        self._f_value = f_value
        self._p_value = p_value
        self._h_value = h_value
        self._s_value = s_value

    @property
    def state(self) -> State:
        return self._state

    @property
    def parent_node(self) -> Node:
        return self._parent_node
    
    @parent_node.setter
    def parent_node(self, node: Node) -> None:
        self._parent_node = node

    @property
    def operator(self) -> Callable:
        return self._operator
    
    @operator.setter
    def operator(self, operator: Operator) -> None:
        self._operator = operator

    @property
    def depth(self) -> int:
        return self._depth
    
    @depth.setter
    def depth(self, depth: int) -> None:
        self._depth = depth

    @property
    def path_cost(self) -> int:
        return self._path_cost
    
    @path_cost.setter
    def path_cost(self, pc: int) -> None:
        self._path_cost = pc

    @property
    def h_value(self) -> int:
        return self._h_value

    @h_value.setter
    def h_value(self, h: int) -> None:
        self._h_value = h
    
    @property
    def f_value(self) -> int:
        return self._f_value
    
    @f_value.setter
    def f_value(self, amt: int) -> None:
        self._f_value = amt

    @property
    def s_value(self) -> int:
        return self._s_value
    
    @s_value.setter
    def s_value(self, amt: int) -> None:
        self._s_value = amt

    @property
    def p_value(self) -> int:
        return self._p_value
    
    @p_value.setter
    def p_value(self, amt: int) -> None:
        self._p_value = amt

    def __eq__(self, other: Node) -> bool:
        return self.state == other.state
    
    def __lt__(self, other: Node) -> bool:
        """ 
        If the f-values are equal, try to see which node
        has the lesser heuristic point
        """
        if self.f_value == other.f_value:
            return self.h_value < other.h_value
        return self.f_value < other.f_value
    
    def __ge__(self, other: Node) -> bool:
        return self.f_value >= other.f_value
    
    def __repr__(self) -> str:
        return f'{self.state}'

    def __str__(self) -> str:
        return f'{self.state}'


class Problem:
    """
    This is the class that I envision will take in the initial and the goal states,
    as well as handling the state transitions. Meaning I will put the successor fn
    as a method in this class as generate_successors.
    """
    def __init__(
        self, 
        initial_state: State = [],
        goal_state: State = [],
        operators: List[Operator] = [],
        heuristic: HeuristicsABC = None
    ) -> None:
        self._initial_state = initial_state
        self._goal_state = goal_state
        self._operators = operators
        self._heuristic = heuristic

    @classmethod
    def from_file(cls, filepath: str, operators: List[Operator]):
        """ Class method that inititalizes the search problem """
        goal_state = []
        initial_state = [] 

        # Determine where to append. Points to the correct list.
        current_list = initial_state

        with open(filepath, 'r') as file:
            for line in file:
                row = line.strip().split(' ')
                if row[0] == 'start':
                    current_list = initial_state
                elif row[0] == 'goal':
                    current_list = goal_state
                else:
                    current_list.append(row)

        return cls(initial_state, goal_state, operators)
    
    @property
    def initial_state(self) -> State:
        return self._initial_state

    @property
    def goal_state(self) -> State:
        return self._goal_state

    @property
    def operators(self):
        return self._operators

    @property
    def heuristic(self):
        return self._heuristic
    
    @heuristic.setter
    def heuristic(self, heuristic: HeuristicsABC):
        self._heuristic = heuristic
    
    def is_goal(self, curr_state: State) -> bool:
        """ Returns True if we finally reached goal """
        return curr_state == self._goal_state


class HeuristicsABC(ABC):

    def manhattan_distance(self, curr_state: State, goal_state: State) -> int:
        """ 
        Really bad implementation but it's just so that we could print P(s).
        Just return 0 for all except Nilsson Sequence Score
        """
        return 0
    
    def sequence_score(self, curr_state: State, goal_state: State) -> int:
        """ Same case as manhattan distance """
        return 0
    
    @property
    @abstractmethod
    def identifier(self) -> str:
        return 'general heuristic'

    @abstractmethod
    def compute_heuristic(self, curr_state: State, goal_state: State) -> int:
        pass


class ManhattanDistance(HeuristicsABC):
    """
    For two points (a_1, a_2) and (g_1, g_2), the manhattan distance is just
    md = abs(b_1 - g_1) + abs(g_2 - a_2) for every element not in their
    proper place.

    This is easier if we just concatenate all the rows and operate on modulo 3.
    We'll use itertools.chain to simplify flattening our State type
    """

    @property
    def identifier(self) -> str:
        return 'ManhattanDistance'

    def compute_heuristic(self, curr_state: State, goal_state: State) -> int:
        flattened_curr_state = list(chain.from_iterable(curr_state))
        flattened_goal_state = list(chain.from_iterable(goal_state))
        total = 0
        for index, tile in enumerate(flattened_curr_state):
            if tile == '*':
                continue
            curr_state_q, curr_state_r = divmod(index, 3)
            goal_state_q, goal_state_r = divmod(flattened_goal_state.index(tile), 3)
            total += abs(curr_state_q - goal_state_q) + abs(curr_state_r - goal_state_r)
        
        return total
    

class NumOfTilesInWrongPosition(HeuristicsABC):
    """
    Just check if every corresponding index has the same element.
    We again neglect the '*' element.
    We'll use the flattened version of the matrix again as the code
    has less loops.
    """

    @property
    def identifier(self) -> str:
        return 'NumOfTilesWrong'

    def compute_heuristic(self, curr_state: State, goal_state: State) -> int:
        flattened_curr_state = list(chain.from_iterable(curr_state))
        flattened_goal_state = list(chain.from_iterable(goal_state))
        total = 0
        for i in range(len(flattened_curr_state)):
            if flattened_curr_state[i] == '*':
                continue
            
            if flattened_curr_state[i] != flattened_goal_state[i]:
                total += 1

        return total
    

class NilssonSequenceScore(ManhattanDistance):
    """
    h(n) = P(n) + 3 S(n) where P(n) is the Manhattan distance of each tile from its 
    goal position and S(n) is a sequence score
    """

    @property
    def identifier(self) -> str:
        return 'NilssonSequenceScore'

    def manhattan_distance(self, curr_state: List[Node], goal_state: List[Node]) -> int:
        return super().compute_heuristic(curr_state, goal_state)
    
    def N(self, tile: str, curr_state: State) -> int:
        """ 
        N(tile) is the current square where the tile currently lies.

        We'll try to map in the format given in stackoverflow
        +---+---+---+
        | 0 | 1 | 2 |
        +---+---+---+
        | 7 | 8 | 3 |
        +---+---+---+
        | 6 | 5 | 4 |
        +---+---+---+
        """
        ordering = {
            (0, 0): 0, (0, 1): 1, (0, 2): 2,
            (1, 2): 3, (2, 2): 4, (2, 1): 5,
            (2, 0): 6, (1, 0): 7, (1, 1): 8
        }
        for index, row in enumerate(curr_state):
            if tile in row:
                return ordering[(index, row.index(tile))]
            else:
                continue
        
    def next_tile(self, tile: str) -> str:
        """ Returns the next value of the tile """
        q, r = divmod(int(tile) + 1, 9)
        return str(q + r)

    def sequence_score(self, curr_state: State, goal_state: State) -> int:
        score = 0
        for tile in list(chain.from_iterable(curr_state)):
            if tile == '*':
                continue

            # Not going to be multiplying by 3 just yet. Will do it in the final NSS
            if self.N(tile, curr_state) == 8:
                score += 1
            elif self.N(self.next_tile(tile), curr_state) != (self.N(tile, curr_state) + 1) % 8:
                score += 2

        return score

    def compute_heuristic(self, curr_state: State, goal_state: State) -> int:
        manhattan_distance = super().compute_heuristic(curr_state, goal_state)
        return manhattan_distance +  3 * self.sequence_score(curr_state, goal_state) 


class Operator:
    def __init__(self, move: Literal['up', 'down', 'left', 'right']) -> None:
        self._move = move

    def do_move(self, state: State, star_row: int, star_col: int) -> State:
        """ Detect where the '*' is and exchange position """
        move_map = {
            'up': (star_row - 1, star_col),
            'down': (star_row + 1, star_col),
            'left': (star_row, star_col - 1),
            'right': (star_row, star_col + 1)
        }
        
        moved_row, moved_col = move_map[self._move]
        # Check if anything is out of bound
        is_legal = (
            0 <= moved_row < len(state) 
            and 0 <= moved_col < len(state[0])
        )
        if not is_legal:
            raise IndexError(f'Move {self._move} is illegal')

        state[star_row][star_col], state[moved_row][moved_col] = (
            state[moved_row][moved_col],
            state[star_row][star_col]
        )
        return state
    
    def __repr__(self) -> str:
        return self._move

    def __str__(self) -> str:
        return f'Move {self._move.capitalize()}'


class AstarAlgorithm:
    """ 
    Implements the algorithm based on the pseudocode in the slids. 
    Steps 1 - 8 should be numbered accordingly and are in the run_algo method
    """

    def __init__(self, problem: Problem):
        self._problem = problem

    @property
    def problem(self) -> Problem:
        return self._problem
    
    def compute_fs(self, node: Node) -> int:
        """ Just h + g """
        h = self.problem.heuristic.compute_heuristic(node.state, problem.goal_state)
        g = node.path_cost
        return h + g

    def generate_successors(self, n: Node) -> List[Node]:
        """
        This should be handling #5 of the A* pseudo code by generating all the
        successors and computing each node's f(s)
        """
        successor_nodes: List[Node] = []
        for operator in self.problem.operators:
            for index, items in enumerate(n.state):
                if '*' not in items:
                    continue
                row, col = index, items.index('*')
                try:
                    new_state = operator.do_move(copy.deepcopy(n.state), row, col)
                    if n.parent_node and new_state == n.parent_node.state:
                        continue
                    successor_node = Node(new_state, n, operator, n.depth + 1, n.path_cost + 1)
                    successor_node.f_value = self.compute_fs(successor_node)
                    successor_node.h_value = self.problem.heuristic.compute_heuristic(
                        new_state, self.problem.goal_state)
                    successor_node.s_value = self.problem.heuristic.sequence_score(
                        new_state, self.problem.goal_state)
                    successor_node.p_value = self.problem.heuristic.manhattan_distance(
                        new_state, self.problem.goal_state)
                    successor_nodes.append(successor_node)
                except IndexError as e:
                    continue

        return successor_nodes

    def run_algo(self) -> List[Node]:
        """ Method that runs the pseudocode. """

        # Step 1
        s = Node(
            self.problem.initial_state, 
            None, 
            None, 
            0, 
            0, 
        )
        s.f_value = self.compute_fs(s)
        s.h_value = self.problem.heuristic.compute_heuristic(
            s.state, self.problem.goal_state)
        s.s_value = self.problem.heuristic.sequence_score(
            s.state, self.problem.goal_state)
        s.p_value = self.problem.heuristic.manhattan_distance(
            s.state, self.problem.goal_state)
        OPEN = [s]
        CLOSED = []
        nodes_generated = 0
        unique_nodes_generated = 0
        nodes_expanded = 0

        counter = 0
        print('\nDisplay relevant costs...')
        print('+----------------------------------------------------------+')
        while counter < 1000000: # Just setting a hard limit of 1M iterations before stopping
            # Step 2
            if not len(OPEN):
                raise ValueError('OPEN MUST NOT BE EMPTY!!!')

            # Step 3
            n = min(OPEN)
            print_string = f'state = {n}; f(n) = {n.f_value}; g(n) = {n.path_cost}; h(n) = {n.h_value}'
            if self.problem.heuristic.identifier == 'NilssonSequenceScore':
                print_string += f'; p(n) = {n.p_value}; s(n) = {n.s_value}'
            print_string += f'; Number of nodes generated: {nodes_generated}'
            print(print_string)

            OPEN.remove(n)
            CLOSED.append(n)

            # Step 4
            if self.problem.is_goal(n.state):
                solution_path = []
                while n:
                    solution_path.append(n)
                    n = n.parent_node

                print(f'\nSummary for A* using {self.problem.heuristic.identifier}')
                print('+----------------------------------------------------------+')
                print(f'Nodes generated: {nodes_generated}; Unique nodes generated: {unique_nodes_generated}; Nodes expanded: {nodes_expanded}')
                print(f'Open List: {len(OPEN)}; Closed List: {len(CLOSED)}')
                return list(reversed(solution_path))
            
            # Step 5
            n_successors = self.generate_successors(n)
            nodes_expanded += 1
            if not len(n_successors):
                continue

            for successor in n_successors:
                nodes_generated += 1
                # Step 6
                if successor not in OPEN and successor not in CLOSED:
                    unique_nodes_generated += 1
                    OPEN.append(successor)
                    continue

                list_to_use = OPEN if successor in OPEN else CLOSED
                prev_successor = list_to_use[list_to_use.index(successor)]
                
                if successor >= prev_successor:
                    continue
                
                # Step 7
                prev_successor.depth = successor.depth
                prev_successor.path_cost = successor.path_cost
                prev_successor.operator = successor.operator
                prev_successor.parent_node = successor.parent_node
                prev_successor.f_value = successor.f_value
                prev_successor.s_value = successor.s_value
                prev_successor.p_value = successor.p_value
                prev_successor.h_value = successor.h_value

                if successor in CLOSED:
                    CLOSED.remove(prev_successor)
                    OPEN.append(successor)

                # Step 8 - since we're just continuing the loop
                counter += 1


def print_matrices(*matrices, chunk_size=10):
    # Format each matrix into a string with rows
    formatted_matrices = [
        [" ".join(row) for row in matrix] for matrix in matrices
    ]
    
    # Iterate over the list in chunks
    for start_index in range(0, len(formatted_matrices), chunk_size):
        end_index = min(start_index + chunk_size, len(formatted_matrices))  # Ensure we don't go out of bounds
        chunk = formatted_matrices[start_index:end_index]
        
        print(f'Moves {start_index} to {end_index - 1}')
        
        # Combine rows with arrows for the middle row
        combined_rows = []
        for i, row_group in enumerate(zip(*chunk)):
            if i == 1:  # Middle row (index 1 for 3x3 matrices)
                combined_rows.append(" -> ".join(row_group))
            else:
                combined_rows.append("    ".join(row_group))
        
        print("\n".join(combined_rows))
        # print("\n")


if __name__ == "__main__":
    # Modify input file below
    directory = 'astar_in.txt'
    problem = Problem.from_file(
        directory,
        [
            Operator('up'),
            Operator('down'),
            Operator('left'),
            Operator('right')
        ], 
    )

    # Initialize the heuristics
    nss = NilssonSequenceScore()
    md = ManhattanDistance()
    nw = NumOfTilesInWrongPosition()
    heuristics = [nw, md, nss]
    algo = AstarAlgorithm(problem)
    for heur in heuristics:
        problem.heuristic = heur
        print('\n\n====================================================================================')
        print(f'Starting A* for {heur.identifier}')
        print('====================================================================================')
        solution = algo.run_algo()
        matrices = [matrix.state for matrix in solution]
        print(f'\nSequence of states from start to goal')
        print('+----------------------------------------------------------+')
        print_matrices(*matrices)
   



Starting A* for NumOfTilesWrong

Display relevant costs...
+----------------------------------------------------------+
state = [['2', '1', '6'], ['4', '*', '8'], ['7', '5', '3']]; f(n) = 7; g(n) = 0; h(n) = 7; Number of nodes generated: 0
state = [['2', '*', '6'], ['4', '1', '8'], ['7', '5', '3']]; f(n) = 8; g(n) = 1; h(n) = 7; Number of nodes generated: 4
state = [['*', '2', '6'], ['4', '1', '8'], ['7', '5', '3']]; f(n) = 8; g(n) = 2; h(n) = 6; Number of nodes generated: 6
state = [['2', '1', '6'], ['4', '5', '8'], ['7', '*', '3']]; f(n) = 8; g(n) = 1; h(n) = 7; Number of nodes generated: 7
state = [['2', '1', '6'], ['*', '4', '8'], ['7', '5', '3']]; f(n) = 8; g(n) = 1; h(n) = 7; Number of nodes generated: 9
state = [['2', '1', '6'], ['4', '8', '*'], ['7', '5', '3']]; f(n) = 8; g(n) = 1; h(n) = 7; Number of nodes generated: 11
state = [['4', '2', '6'], ['*', '1', '8'], ['7', '5', '3']]; f(n) = 9; g(n) = 3; h(n) = 6; Number of nodes generated: 13
state = [['2', '6', '*'], ['4', '1',