In [1]:
from collections import deque
from enum import Enum
from itertools import batched
from random import shuffle
from typing import Any, Sequence

N = 3


class Action(Enum):
    LEFT = 1
    RIGHT = 2
    UP = 3
    DOWN = 4

    def __repr__(self) -> str:
        return f'<{self.name}>'

In [2]:
def is_valid_list(list_values: Sequence[int]) -> bool:
    """Check that the list of values passed as argument is a valid grid,
    i.e contains all the values from 0 to (N * N) - 1"""
    encountered_value = [False] * (N * N)

    for val in list_values:
        if 0 <= val <= (N * N - 1):
            encountered_value[val] = True

    return all(encountered_value)


def is_solvable(list_values: Sequence[int]) -> bool:
    """Make sure the grid is solvable by counting the number of inversions which need to be even"""
    num_inversion = 0

    for i, val in enumerate(list_values):
        if val != 0:
            for val_2 in list_values[i + 1 :]:
                if (val_2 != 0) and val_2 < val:
                    num_inversion += 1

    return (num_inversion % 2) == 0


def from_list(list_values: Sequence[int]) -> tuple[int, ...]:
    """Return a grid from a list"""
    if not is_valid_list(list_values):
        raise ValueError(
            f'{list_values=} is not a valid list for a grid of size {N} * {N}'
        )

    if not is_solvable(list_values):
        raise ValueError(f'{list_values=} is not solvable')

    return tuple(list_values)


def random_state() -> tuple[int, ...]:
    """Generate a random grid"""
    list_values = [i for i in range(N * N)]
    shuffle(list_values)

    # Pseudo do-while until we find a solvable grid
    while not is_solvable(list_values):
        list_values = [i for i in range(N * N)]
        shuffle(list_values)

    return tuple(list_values)


def is_solution(grid: Sequence[int]) -> bool:
    """Check if a grid is a solution"""
    for i, val in enumerate(grid):
        if i != val:
            return False

    return True


def coord_blank_cell(grid: Sequence[int]) -> tuple[int, int] | None:
    """Return the coordinates of the blank tile as a tuple(x, y)"""
    for i, val in enumerate(grid):
        if val == 0:
            return (i // N, i % N)
    return None


def actions(grid) -> list[Action]:
    """Return all possible actions from the current grid."""
    blank_cell = coord_blank_cell(grid)
    assert blank_cell is not None
    row, col = blank_cell
    actions = []

    if 0 < col:
        # Blank cell not in the first column
        actions.append(Action.LEFT)
    if col < N - 1:
        # Blank cell not in the last column
        actions.append(Action.RIGHT)
    if 0 < row:
        # Blank cell not in the first row
        actions.append(Action.UP)
    if row < N - 1:
        # Blank cell not in the last row
        actions.append(Action.DOWN)

    return actions


def apply_action(grid: Sequence[int], action: Action) -> tuple[int, ...]:
    """Return the grid resulting from applying an action to a grid"""
    blank_cell = coord_blank_cell(grid)
    assert blank_cell is not None
    row, col = blank_cell
    res = list(grid)

    match action:
        case Action.LEFT:
            res[row * N + col], res[row * N + col - 1] = res[row * N + col - 1], res[row * N + col]  # fmt: skip
            return tuple(res)
        case Action.RIGHT:
            res[row * N + col], res[row * N + col + 1] = res[row * N + col + 1], res[row * N + col]  # fmt: skip
            return tuple(res)
        case Action.UP:
            res[row * N + col], res[(row - 1) * N + col] = res[(row - 1) * N + col], res[row * N + col]  # fmt: skip
            return tuple(res)
        case Action.DOWN:
            res[row * N + col], res[(row + 1) * N + col] = res[(row + 1) * N + col], res[row * N + col]  # fmt: skip
            return tuple(res)


def display(grid: Sequence[int]) -> None:
    """Display the grid in a visual format."""
    for row in batched(grid, N):
        print(list(row))


def num_misplaced_tiles(grid: Sequence[int]) -> int:
    """Heuristic that counts the number of misplaced tiles"""
    counter = 0

    for i, val in enumerate(grid):
        if i != val:
            counter += 1

    return counter


def manhattan_heuristic(grid: Sequence[int]) -> int:
    """Manhattan heuristic"""
    counter = 0

    for i, val in enumerate(grid):
        if val != 0:
            counter += abs((val // N) - (i // N)) + abs((val % N) - (i % N))

    return counter

In [3]:
class Node:
    id: int = 0

    def __init__(
        self,
        state: tuple[int, ...],
        parent: 'Node | None' = None,
        action: Action | None = None,
        path_cost: int = 0,
        id: int | None = None,
    ) -> None:
        self.state = state
        self.parent = parent
        self.action = action
        self.path_cost = path_cost

        # Allow to reset the count to a specific value by specifying a value,
        # otherwise auto-increment
        if id is not None:
            Node.id = id
        self.id: int = Node.id
        Node.id += 1

    def expand(self) -> list['Node']:
        """Generate all valid child nodes from this node."""
        children = []
        action_parent = self.action

        for action in actions(self.state):
            if action_parent is not None:
                if (
                    (action == Action.LEFT and action_parent == Action.RIGHT)
                    or (action == Action.RIGHT and action_parent == Action.LEFT)
                    or (action == Action.UP and action_parent == Action.DOWN)
                    or (action == Action.DOWN and action_parent == Action.UP)
                ):
                    # Skip useless series of actions, like doing <LEFT> then <RIGHT>
                    continue

            new_state = apply_action(self.state, action)
            children.append(
                Node(
                    state=new_state,
                    parent=self,
                    action=action,
                    path_cost=self.path_cost + 1,
                )
            )
        return children

    def get_path(self) -> list[tuple[str | None, str]]:
        """Reconstruct the path from root to this node."""
        path = []
        node = self
        while node is not None:
            if node.action is not None:
                action = node.action
            else:
                action = 'None'
            path.append((f'{node.state}', action))
            node = node.parent
        return list(reversed(path))

    def __repr__(self) -> str:
        parent_id = -1 if self.parent is None else str(self.parent.id)
        return f'{self.id=},\n{self.action=},\n{parent_id=},\n{self.path_cost=},\nself.state={list(self.state)}'

    def __eq__(self, other: Any) -> bool:
        """Overrides the default implementation"""
        if isinstance(other, Node):
            return self.state == other.state
        return False

    def __hash__(self) -> int:
        """Overrides the default implementation"""
        return hash(self.state)

    def __lt__(self, other):
        """Overrides the default implementation so that in the PriorityQueue, in case of
        equality for the value of h, node with an smaller id are picked first"""
        if not isinstance(other, type(self)):
            raise TypeError(
                f'unsupported operand for <: {type(self).__name__} and {type(other).__name__}'
            )
        return self.id < other.id

In [4]:
def bfs_set(initial_state: tuple[int, ...]) -> Node | None:
    root_node = Node(state=initial_state, id=0)
    frontier = deque([root_node])
    reached = set([initial_state])

    if is_solution(root_node.state):
        return root_node

    while len(frontier) > 0:
        node = frontier.popleft()

        for child in node.expand():
            if is_solution(child.state):
                return child
            elif child.state not in reached:
                frontier.append(child)
                reached.add(child.state)

    return None

In [5]:
r = random_state()
display(r)

[1, 6, 5]
[8, 4, 3]
[2, 0, 7]


In [6]:
res = bfs_set(r)
res

self.id=237769,
self.action=<LEFT>,
parent_id='200851',
self.path_cost=25,
self.state=[0, 1, 2, 3, 4, 5, 6, 7, 8]

In [7]:
res.get_path()

[('(1, 6, 5, 8, 4, 3, 2, 0, 7)', 'None'),
 ('(1, 6, 5, 8, 0, 3, 2, 4, 7)', <UP>),
 ('(1, 6, 5, 8, 3, 0, 2, 4, 7)', <RIGHT>),
 ('(1, 6, 5, 8, 3, 7, 2, 4, 0)', <DOWN>),
 ('(1, 6, 5, 8, 3, 7, 2, 0, 4)', <LEFT>),
 ('(1, 6, 5, 8, 3, 7, 0, 2, 4)', <LEFT>),
 ('(1, 6, 5, 0, 3, 7, 8, 2, 4)', <UP>),
 ('(1, 6, 5, 3, 0, 7, 8, 2, 4)', <RIGHT>),
 ('(1, 0, 5, 3, 6, 7, 8, 2, 4)', <UP>),
 ('(0, 1, 5, 3, 6, 7, 8, 2, 4)', <LEFT>),
 ('(3, 1, 5, 0, 6, 7, 8, 2, 4)', <DOWN>),
 ('(3, 1, 5, 6, 0, 7, 8, 2, 4)', <RIGHT>),
 ('(3, 1, 5, 6, 2, 7, 8, 0, 4)', <DOWN>),
 ('(3, 1, 5, 6, 2, 7, 0, 8, 4)', <LEFT>),
 ('(3, 1, 5, 0, 2, 7, 6, 8, 4)', <UP>),
 ('(0, 1, 5, 3, 2, 7, 6, 8, 4)', <UP>),
 ('(1, 0, 5, 3, 2, 7, 6, 8, 4)', <RIGHT>),
 ('(1, 2, 5, 3, 0, 7, 6, 8, 4)', <DOWN>),
 ('(1, 2, 5, 3, 7, 0, 6, 8, 4)', <RIGHT>),
 ('(1, 2, 5, 3, 7, 4, 6, 8, 0)', <DOWN>),
 ('(1, 2, 5, 3, 7, 4, 6, 0, 8)', <LEFT>),
 ('(1, 2, 5, 3, 0, 4, 6, 7, 8)', <UP>),
 ('(1, 2, 5, 3, 4, 0, 6, 7, 8)', <RIGHT>),
 ('(1, 2, 0, 3, 4, 5, 6, 7, 8)', <UP>),
