#Libraries

In [1]:
import random
from collections import deque
import heapq
from graphviz import Digraph
from IPython.display import display, Image
import matplotlib.pyplot as plt


#Puzzle State Class

In [2]:
#Puzzle State Class
GOAL_STATES = [
    [1,2,3,4,5,6,7,8,0],
    [8,7,6,5,4,3,2,1,0],
    [0,1,2,3,4,5,6,7,8],
    [0,8,7,6,5,4,3,2,1],
]
class PuzzleState:
  """
  Initialize the pizzle state with a list of 9 integers
  0 represents the blank tile
  For example:
  [1,2,3,4,5,6,7,8,0]
  """
  def __init__(self, initial: list[int]):
    self.state = initial
    self.zero_index = self.state.index(0) #Store the index of the blank (0)

  def is_goal(self, goal_state: list[int]) -> bool:
    """
    Check if the current state is the goal state
    """
    return self.state == goal_state

  def get_neighbors(self) -> list[tuple['PuzzleState', str]]:
    """
    Generate and return a list of neighbor puzzle states by moving the blank tile
    Each neighbor is returned as a tuple containing:
    - The new PuzzleState after the move
    - A string representing the move action ("up", "down", "left", "right")
    """
    neighbors = []
    # Define possible moves as (row_change, col_change, move_name)
    moves = [
        (-1,0,'U'),
        (1,0,'D'),
        (0,-1,'L'),
        (0,1,'R')
    ]
    row = self.zero_index // 3 # Determine the row by dividing the index by 3
    col = self.zero_index % 3  # Determine the column by taking the remainder when the index is divided by 3

    # Compute new postion of the blank when move
    for dr, dc, move in moves:
      new_row = row + dr
      new_col = col + dc
      # Check if the move is within the bounds of the puzzle
      if 0 <= new_row < 3 and 0 <= new_col < 3:
        new_index = new_row * 3 + new_col
        # Create a new state by swapping the blank
        new_state = self.state[:]
        # Swap the blank tile with the tile in the new position
        new_state[self.zero_index], new_state[new_index] = new_state[new_index],new_state[self.zero_index]
        neighbor_state = PuzzleState(new_state)
        # Check for auto swap between 1 <-> 3 or 2 <-> 4
        s = neighbor_state.state
        def swap_if_adjacent(a, b):
            idx_a, idx_b = s.index(a), s.index(b)
            if abs(idx_a - idx_b) == 1 or abs(idx_a - idx_b) == 3:
                s[idx_a], s[idx_b] = s[idx_b], s[idx_a]
        swap_if_adjacent(1, 3)
        swap_if_adjacent(2, 4)
        neighbors.append((neighbor_state, move))
    return neighbors

  def print_state(self):
    for i in range(0,9,3):
      print(self.state[i:i+3])
    print()

  def __eq__(self, other):
    return self.state == other.state

  def __hash__(self):
    return hash(tuple(self.state))

  @staticmethod
  def is_solvable(state: list[int], goal_states: list[list[int]]) -> tuple[bool, list[int]]:
    """
    Check if the puzzle state is solvable with respect to at least one goal state.
    Returns a tuple (is_solvable, selected_goal_state).
    """
    for goal_state in goal_states:
      state_without_zero = [x for x in state if x != 0]
      inversions = 0
      for i in range(len(state_without_zero)):
        for j in range(i + 1, len(state_without_zero)):
          if state_without_zero[i] > state_without_zero[j]:
            inversions += 1
      # Find the blank position (row, start from 1)
      blank_row_state = (state.index(0) // 3) + 1
      blank_row_goal = (goal_state.index(0) // 3) + 1
      row_diff = abs(blank_row_state - blank_row_goal)
      # Check if the number of inversions is even
      if (inversions % 2) == (row_diff % 2):
        return True, goal_state
    return False, None

#Node Class

In [3]:
#Node Class
class Node:
  def __init__(self, puzzle, g, h, parent=None, move=None):
    """
    initialize a Node with:
    - puzzle: the current puzzle state
    - g: cost from the start node to this node (number of moves)
    - h: heuristic cost from this node to the goal
    - parent: the parent node (for backtracking the solution path)
    - move: the move taken to reach this node from the parent
    """
    self.puzzle = puzzle
    self.g = g
    self.h = h
    self.f = g + h
    self.parent = parent
    self.move = move

  def get_f(self) -> int:
        """
        Calculate and return the f value for the node: f = g + h.
        """
        return self.g + self.h

  def __lt__(self, other):
    """
    Less-than comparision for priority queue, based on f value
    """
    return self.get_f() < other.get_f()



#A* Search Algorithm

In [4]:
#A* Solver
class AStarSolver:
  def __init__(self, heuristic):
    """
    Initialize the A* solver with a given heuristic function
    """
    self.heuristic = heuristic


  def solve(self, start_puzzle, current_goal):
    print(f"Selected goal state for this puzzle:")
    for i in range(0, 9, 3):
      print(current_goal[i:i+3])
    print()

    start_node = Node(start_puzzle, 0, self.heuristic(start_puzzle, current_goal))
    open_list = []
    heapq.heappush(open_list, (start_node.get_f(), id(start_node), start_node))
    closed_set = set()

    while open_list:
      _, _, current_node = heapq.heappop(open_list)
      current_state = tuple(current_node.puzzle.state)

      if current_node.puzzle.is_goal(current_goal):
        return current_node

      closed_set.add(current_state)

      for neighbor, move in current_node.puzzle.get_neighbors():
        neighbor_state = tuple(neighbor.state)
        if neighbor_state in closed_set:
            continue
        new_g = current_node.g + 1
        new_h = self.heuristic(neighbor, current_goal)
        neighbor_node = Node(neighbor, new_g, new_h, current_node, move)
        heapq.heappush(open_list, (neighbor_node.get_f(), id(neighbor_node), neighbor_node))

    return None



#Heuristic Functions

In [5]:
#Heuristic Functions
def misplaced_tiles_heuristic(puzzle: 'PuzzleState', goal_state: list[int]) -> int:
  """
  Count the number of tiles that are out of place from the goal state

  """
  return sum(1 for i in range(9) if puzzle.state[i] != 0 and puzzle.state[i] != goal_state[i])

def manhattan_distance_heuristic(puzzle: 'PuzzleState', goal_state: list[int]) -> int:
  """
  Calculate the Manhattan distance heuristic for the puzzle
  Sum the distances of each tile from its current position to its position in the first goal state
  """
  distance = 0
  for i, tile in enumerate(puzzle.state):
    if tile != 0:
      goal_index = goal_state.index(tile)
      current_row, current_col = i//3, i%3
      goal_row, goal_col = goal_index // 3, goal_index % 3
      distance += abs(current_row - goal_row) + abs(current_col - goal_col)
  return distance


#Puzzle Experiment


In [6]:
#Puzzle Experiment
class PuzzleExperiment:
  def __init__(self, numtrials: int = 5):
    self.numtrials = numtrials
    self.goal_states = GOAL_STATES

  def random_puzzle(self) -> PuzzleState:
    """
    Generate a random 8-puzzle state
    """
    tiles = list(range(9))
    random.shuffle(tiles)
    return PuzzleState(tiles)

  def run_experiment(self):
    """
    Run the experiment using different heuristic functions
    For each heuristic, solve a number of random puzzles and print the results
    """
    results = {}
    heuristics = {
        'Misplaced': misplaced_tiles_heuristic,
        'Manhattan': manhattan_distance_heuristic
    }
    for hname, heuristic in heuristics.items():
      print(f"\n--- Using heuristic: {hname} ---")
      solver = AStarSolver(heuristic)
      total_moves = 0
      successful_trials = 0
      for trial in range(self.numtrials):
        print(f"Trial #{trial+1}, puzzle state:")

        # Generate a solvable puzzle
        puzzle = None
        selected_goal = None
        while True:
          tiles = list(range(9))
          random.shuffle(tiles)
          solvable, selected_goal = PuzzleState.is_solvable(tiles, self.goal_states)
          if solvable:
            puzzle = PuzzleState(tiles)
            break

        puzzle.print_state()
        goal_node = solver.solve(puzzle, selected_goal)
        if goal_node:

          print(f"Solved in {goal_node.g} moves")
          total_moves += goal_node.g
          successful_trials += 1


      if successful_trials > 0:
        avg_moves = total_moves / successful_trials
        print(f"Average moves to solve: {avg_moves}")
        results[hname] = avg_moves
      else:
        print("No successful trials")
        results[hname] = None
    return results


#Search Tree

In [7]:
#Search Tree
class SearchTreeVisualizer:
  def __init__(self):
    pass

  def puzzle_to_multiline(self, state: PuzzleState) -> str:
    lines = []
    for i in range(0, 9, 3):
      row = state.state[i:i+3]
      row_str = ''.join(str(x) if x != 0 else '_' for x in row)
      lines.append(row_str)
    return "\n".join(lines)

  def illustrate_search_tree(self, start_puzzle: PuzzleState, n: int):
    root = Node(start_puzzle, g=0, h=0)
    queue = deque([root])
    adjacency = {}
    visited = set([root.puzzle])
    expansions = 0

    while queue and expansions < n:
      current_node = queue.popleft()
      adjacency[current_node] = []
      for neighbor_state, move in current_node.puzzle.get_neighbors():
        if neighbor_state not in visited:
          visited.add(neighbor_state)
          child_node = Node(neighbor_state, current_node.g + 1, 0,
                          parent=current_node, move = move)
          adjacency[current_node].append(child_node)
          queue.append(child_node)
      expansions += 1
    print("\n=== Text-based Search Tree Illustration (up to", n, "expansions) ===")
    queue = deque([(root, 0)])
    visited_nodes = set()
    while queue:
      node, level = queue.popleft()
      if node in visited_nodes:
        continue
      visited_nodes.add(node)

      indent = ' ' * level * 2
      print(f"{indent}Node: {self.puzzle_to_str(node.puzzle)} (g={node.g})")
      children = adjacency.get(node, [])
      for child in children:
        print(f"{indent}  [{child.move}] -> {self.puzzle_to_str(child.puzzle)} (g={child.g})")
        queue.append((child, level + 1))

  def puzzle_to_str(self, state: PuzzleState) -> str:
    return ''.join(str(x) if x != 0 else '_' for x in state.state)

  def illustrate_tree_graph(self, start_puzzle: PuzzleState, n: int) -> 'Digraph':
    dot = Digraph()
    root = Node(start_puzzle, g=0, h=0)
    node_ids = {}
    id_counter = 0
    node_ids[root] = str(id_counter)
    dot.node(str(id_counter), self.puzzle_to_multiline(root.puzzle))

    queue = deque([root])
    expansions = 0
    while queue and expansions < n:
      current_node = queue.popleft()
      for neighbor, move in current_node.puzzle.get_neighbors():
        child_node = Node(neighbor, current_node.g + 1, 0,
                        current_node, move)
        if child_node not in node_ids:
          id_counter += 1
          node_ids[child_node] = str(id_counter)
          dot.node(str(id_counter), self.puzzle_to_multiline(child_node.puzzle))
        dot.edge(node_ids[current_node], node_ids[child_node], label=move)
        queue.append(child_node)
      expansions += 1
    return dot


#Main

In [8]:
#Main
def plot_comparison_chart(results: dict):
  heuristics = list(results.keys())
  avg_costs = list(results.values())
  plt.bar(heuristics, avg_costs)
  plt.ylabel("Average Path Cost (moves)")
  plt.title("A* Heuristic Comparison on 8-Puzzle")
  plt.grid(True, axis='y')
  plt.tight_layout()
  plt.show()
def main():
  print("Welcome to the 8-Puzzle solver with A* search!")
  print("1) Provide your own puzzle tiles")
  print("2) Generate a random puzzle")
  choice = input("Enter your choice (1 or 2): ")

  if choice == '1':
    # User provides 9 integers
    print("Please enter 9 integers (0 represents the blank). Example: 1 2 3 4 5 6 7 8 0")
    user_tiles = list(map(int, input("Tiles: ").split()))
    puzzle = PuzzleState(user_tiles)
    solvable, selected_goal = PuzzleState.is_solvable(user_tiles, GOAL_STATES)
    if not solvable:
      print("This puzzle is not solvable with any of the given goal states.")
      return

  else:
    # Generate random puzzle
    while True:
      tiles = list(range(9))
      random.shuffle(tiles)
      solvable, selected_goal = PuzzleState.is_solvable(tiles, GOAL_STATES)
      if solvable:
        break
    puzzle = PuzzleState(tiles)
    print("Random puzzle generated:")
    puzzle.print_state()
    print(f"Selected goal state for this puzzle:")
    for i in range(0, 9, 3):
        print(selected_goal[i:i+3])
    print()

  n = int(input("Enter the number of nodes to illustrate in the search tree: "))

  visualizer = SearchTreeVisualizer()
  visualizer.illustrate_search_tree(puzzle, n)
  dot = visualizer.illustrate_tree_graph(puzzle, n)
  print("\nGraphviz source for the search tree")
  print(dot.source)

  dot.format = 'png'
  dot.render('search_tree', view=False)
  display(Image(filename='search_tree.png'))

  # Run A* search using both heuristics
  heuristics = {
      "Misplaced Tiles": misplaced_tiles_heuristic,
      "Manhattan Distance": manhattan_distance_heuristic
  }
  for name, heuristic in heuristics.items():
    print(f"\n=== Running A* with {name} heuristic ===")
    solver = AStarSolver(heuristic)
    goal_node = solver.solve(puzzle, selected_goal)
    if goal_node:
      print("Puzzle solved!")
      print("Number of moves (g) = ", goal_node.g)
      # Reconstruct the solution path
      moves = []
      current = goal_node
      while current.parent is not None:
        moves.append(current.move)
        current = current.parent
      moves.reverse()
      print("Solution path of moves:", moves)
    else:
      print("No solution found with this heuristic")

  # Experiment for evaluating heuristics
  print("\n=== Running Experiment to Evaluate Heuristics ===")
  experiment = PuzzleExperiment(numtrials=5)
  results = experiment.run_experiment()
  plot_comparison_chart(results)

if __name__ == "__main__":
  main()

Welcome to the 8-Puzzle solver with A* search!
1) Provide your own puzzle tiles
2) Generate a random puzzle
Random puzzle generated:
[1, 6, 5]
[0, 2, 4]
[7, 3, 8]

Selected goal state for this puzzle:
[1, 2, 3]
[4, 5, 6]
[7, 8, 0]


=== Text-based Search Tree Illustration (up to 30 expansions) ===
Node: 165_24738 (g=0)
  [U] -> _65142738 (g=1)
  [D] -> 165742_38 (g=1)
  [R] -> 1652_4738 (g=1)
  Node: _65142738 (g=1)
    [R] -> 6_5124738 (g=2)
  Node: 165742_38 (g=1)
    [R] -> 1657243_8 (g=2)
  Node: 1652_4738 (g=1)
    [U] -> 1_5264738 (g=2)
    [D] -> 1652347_8 (g=2)
    [L] -> 165_42738 (g=2)
    [R] -> 16542_738 (g=2)
    Node: 6_5124738 (g=2)
      [D] -> 6251_4738 (g=3)
      [R] -> 65_142738 (g=3)
    Node: 1657243_8 (g=2)
      [U] -> 1657_4328 (g=3)
      [R] -> 16574238_ (g=3)
    Node: 1_5264738 (g=2)
      [L] -> _15264738 (g=3)
      [R] -> 15_264738 (g=3)
    Node: 1652347_8 (g=2)
      [L] -> 165234_78 (g=3)
      [R] -> 16523478_ (g=3)
    Node: 165_42738 (g=2)
      [U

ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH