In [None]:
from pocket_cube.cube import Cube
from pocket_cube.cube import Move
from tests import test_list, test, is_solved, TestCase, draw_graph, test_mcts
from heuristics import hamming, blocked_hamming, manhattan, build_database, database_heuristic, is_admissible
from utils import get_neighbors, get_path, met_in_the_middle, FrontierItem, DiscoveredDict

from heapq import heappush, heappop
from typing import Callable
import time


In [None]:
# A*
def a_star(cube: Cube, heuristic: Callable[[Cube], int]) -> (list[Move], int):
    # initialize with cube
    frontier: list[FrontierItem] = []
    heappush(frontier, FrontierItem(heuristic(cube), cube))
    discovered: DiscoveredDict = {cube.hash(): (None, None, 0)}
    # search
    while frontier:
        currentCube: Cube = heappop(frontier).cube
        if is_solved(currentCube):
            break
        for (neighbor, move) in get_neighbors(currentCube):
            score: int = discovered[currentCube.hash()][2] + 1
            if neighbor.hash() not in discovered or score < discovered[neighbor.hash()][2]:
                discovered[neighbor.hash()] = (currentCube.hash(), move, score)
                node: FrontierItem = FrontierItem(score + heuristic(neighbor), neighbor)
                heappush(frontier, node)
    # get path
    return (get_path(currentCube.hash(), discovered), len(discovered))

In [None]:
# Test A*
test_res_astar: list[TestCase] = test(lambda cube: a_star(cube, manhattan), test_list)
draw_graph(test_res_astar)

In [None]:
# Bidirectional BFS
from collections import deque

def bidirectional_bfs(cube: Cube) -> (list[Move], int):
    frontiers: list[deque[Cube]] = [deque(), deque()]
    frontiers[0].append(cube)

    solved_cube = cube.clone()
    solved_cube.state = solved_cube.goal_state
    frontiers[1].append(solved_cube)
    discovereds: list[DiscoveredDict] = [{cube.hash(): (None, None, 0)}, {solved_cube.hash(): (None, None, 0)}]

    while frontiers[0] and frontiers[1]:
        met_cube_key: str = met_in_the_middle(discovereds[0], discovereds[1])
        if met_cube_key is not None:
            break
        currentCubes: tuple[Cube] = (frontiers[0].popleft(), frontiers[1].popleft())
        for i in range(2):
            for (neighbor, move) in get_neighbors(currentCubes[i]):
                score: int = discovereds[i][currentCubes[i].hash()][2] + 1
                if neighbor.hash() not in discovereds[i] or score < discovereds[i][neighbor.hash()][2]:
                    discovereds[i][neighbor.hash()] = (currentCubes[i].hash(), move, score)
                    frontiers[i].append(neighbor)
    path1: list[Move] = get_path(met_cube_key, discovereds[0])
    path2: list[Move] = get_path(met_cube_key, discovereds[1])
    path2.reverse()
    path2 = list(map(Move.opposite, path2))
    return (path1 + path2, len(discovereds[0]) + len(discovereds[1]))

In [None]:
# Test Bidirectional BFS
test_res_bfs: list[TestCase] = test(bidirectional_bfs, test_list)
draw_graph(test_res_bfs)

In [None]:
# # MTCS with UCB
from math import sqrt, log

N = "N"
Q = "Q"
PARENT = "PARENT"
MOVE = "MOVE"
CHILDREN = "CHILDREN"
Node = dict[int, int, Cube, dict[Move, Cube]]

def init_node(parent = None) -> Node:
    return {N: 0, Q: 0, PARENT: parent, CHILDREN: {}}

def select_move(node: Node, c) -> Move:
    N_node = node[N]
    max_move: Move = None
    max_expr: float = float('-inf')
    for move in node[CHILDREN]:
        child = node[CHILDREN][move]
        expr = child[Q] / child[N] + c * sqrt(log(N_node) / child[N])
        if expr > max_expr:
            max_expr = expr
            max_move = move
    return max_move

In [None]:
from random import choice

def mcts(cube0: Cube, budget: int, tree: Node, cp: float, heuristic: Callable[[Cube], int]) -> Node:
    states_visited: int = 0
    if not tree:
        tree = init_node()
    for _ in range(budget):
        cube = cube0
        node = tree
        # go down the tree until a final state or an unexplored move is found
        while not is_solved(cube) and not any([move not in node[CHILDREN] for move in Move]):
            move: Move = select_move(node, cp)
            cube = cube.move(move)
            node = node[CHILDREN][move]
        # if node is not final and not every move has been explored, create a new node
        if not is_solved(cube):
            new_node: Node = init_node(node)
            move: Move = choice([move for move in Move if move not in node[CHILDREN]])
            node[CHILDREN][move] = new_node
            cube = cube.move(move)
            node = new_node
            states_visited += 1
        # simulate a random game
        max_moves: int = 14
        max_h: int = 0
        while not is_solved(cube) and max_moves > 0:
            new_node: Node = init_node(node)
            move: Move = choice([move for move in Move])
            node[CHILDREN][move] = new_node
            cube = cube.move(move)
            node = new_node
            max_h = max(max_h, 1 / max(heuristic(cube), 0.1))
            max_moves -= 1
            states_visited += 1
        while node:
            node[N] += 1
            node[Q] += max_h
            node = node[PARENT]
    return (tree, states_visited)

def play_mcts(cube: Cube, budget: int, cp: float, heuristic: Callable[[Cube], int]) -> (list[Move], int):
    (tree, states) = mcts(cube, budget, None, cp, heuristic)
    node: Node = tree
    path: list[Move] = []
    while node and node[CHILDREN]:
        move: Move = select_move(node, 0)
        node = node[CHILDREN][move]
        path.append(move)
    return (path, states)

In [None]:
# Test MTCS
test_mcts(lambda cube, budget, c, heuristic: play_mcts(cube, budget, c, heuristic), [manhattan, blocked_hamming])

In [None]:
# Build database
start = time.time()
database = build_database()
end = time.time()
print(f"Database built in {end - start} seconds.")

In [None]:
# Test A* with database
test_result_astar_database: list[TestCase] = test(lambda cube: a_star(cube, lambda cube: database_heuristic(cube, database, manhattan)), test_list)
draw_graph(test_result_astar_database)

In [None]:
# Test MTCS with database
test_mcts(lambda cube, budget, c, heuristic: play_mcts(cube, budget, c, heuristic), lambda cube: database_heuristic(cube, database, manhattan))