In [None]:
from pocket_cube.cube import Cube
from pocket_cube.cube import Move
import tests
import numpy as np
from heapq import heappush, heappop

%matplotlib notebook


 # A*

In [None]:
def is_solved(cube: Cube) -> bool:
    for i in range(len(cube.state)):
        if cube.state[i] != cube.goal_state[i]:
            return False
    return True

In [None]:
def get_neighbors(cube: Cube) -> list[tuple[Cube, Move]]:
    return [(cube.move(move), move) for move in Move]

In [None]:
def heuristic(cube: Cube) -> int:
    return np.sum(cube.state != cube.goal_state)

In [None]:
def a_star(cube: Cube) -> list[Move]:
    # initialize with cube
    frontier: list[tuple(int, str, Cube)] = []
    heappush(frontier, (0 + heuristic(cube), cube.hash(), cube.clone()))
    discovered: dict[str, tuple[Cube, Move, int]] = {cube.hash(): (None, None, 0)}
    # search
    while frontier:
        currentCube: Cube = heappop(frontier)[2]
        if is_solved(currentCube):
            break
        for (neighbor, move) in get_neighbors(currentCube):
            score: int = discovered[currentCube.hash()][2] + 1
            if neighbor.hash() not in discovered or score < discovered[neighbor.hash()][2]:
                discovered[neighbor.hash()] = (currentCube, move, score)
                node: tuple[int, str, Cube] = (score + heuristic(neighbor), neighbor.hash(), neighbor.clone())
                heappush(frontier, node)
    # get path
    path: list[Move] = []
    currentNode = discovered[currentCube.hash()]
    while currentNode[0] is not None:
        path.append(currentNode[1])
        currentNode = discovered[currentNode[0].hash()]
    path.reverse()
    return path

In [None]:
test_list = [tests.case1, tests.case2, tests.case3, tests.case4]
test_list = list(map(lambda t: list(map(Move.from_str, t.split(" "))), test_list))
for idx, test in enumerate(test_list):
    cube: Cube = Cube(test)
    path: list[Move] = a_star(cube)
    for move in path:
        cube = cube.move(move)
    if not is_solved(cube):
        print(f"Test {idx} failed")
        break
    else:
        print(f"Test {idx} passed")