In [1]:
import numpy as np
from sortedcontainers import SortedDict

In [56]:
def board_hash(board):
    return ''.join(map(str, board.ravel()))

def get_simple_symmetry(board, rot):
    if rot == 1:
        return np.rot90(board)
    elif rot == 2:
        return np.rot90(board, k=2)
    elif rot == 3:
        return np.rot90(board, k=-1)
    elif rot == 4:
        return np.transpose(board)
    elif rot == 5:
        return np.rot90(np.transpose(board))
    elif rot == 6:
        return np.rot90(np.transpose(board), k=2)
    elif rot == 7:
        return np.rot90(np.transpose(board), k=-1)
    else:
        return board

def get_symmetry_index(index, rot):
    board = np.zeros((3, 3), dtype=np.int8)
    board[index] = 1
    return next(zip(*np.where(get_simple_symmetry(board, rot) == 1)))

def encode_index(index):
    return index[0] * 3 + index[1]

def get_canonical_hash(board, subgame):
    rot1 = np.rot90(np.rot90(board, axes=(2,3)))
    rot2 = np.rot90(np.rot90(rot1, axes=(2,3)))
    rot3 = np.rot90(np.rot90(rot2, axes=(2,3)))
    flip = np.transpose(board, axes=(1,0,3,2))
    flip_rot1 = np.rot90(np.rot90(flip, axes=(2,3)))
    flip_rot2 = np.rot90(np.rot90(flip_rot1, axes=(2,3)))
    flip_rot3 = np.rot90(np.rot90(flip_rot2, axes=(2,3)))

    board_hashes = map(board_hash, [board, rot1, rot2, rot3, flip, flip_rot1, flip_rot2, flip_rot3])
    #print(list(board_hashes))

    s_board = np.zeros((3, 3), dtype=np.int8)
    if subgame is None:
        s_board_hashes = [board_hash(s_board)] * 8
    else:
        s_board[subgame] = 1
        s_rot1 = np.rot90(s_board)
        s_rot2 = np.rot90(s_rot1)
        s_rot3 = np.rot90(s_rot2)
        s_flip = np.transpose(s_board)
        s_flip_rot1 = np.rot90(s_flip)
        s_flip_rot2 = np.rot90(s_flip_rot1)
        s_flip_rot3 = np.rot90(s_flip_rot2)

        s_board_hashes = map(board_hash, [s_board, s_rot1, s_rot2, s_rot3, s_flip, s_flip_rot1, s_flip_rot2, s_flip_rot3])

    combined_hashes = map(lambda x: x[0] + x[1], zip(board_hashes, s_board_hashes))
    
    return max(combined_hashes)

In [32]:
def is_game_win(board, play):
    y, x = play
    if board[y, 0] == board[y, 1] == board[y, 2] != 0:
        return True
    elif board[0, x] == board[1, x] == board[2, x] != 0:
        return True
    elif y == x and board[0, 0] == board[1, 1] == board[2, 2] != 0:
        return True
    elif y + x == 2 and board[0, 2] == board[1, 1] == board[2, 0] != 0:
        return True
    else:
        return False

In [63]:
def gen(a):
    while True:
        yield a

def get_valid_moves(board, subgame):
    if subgame is None:
        valid = np.where(board == 0)
        return zip(*valid)
    else:
        valid = np.where(board[subgame] == 0)
        return zip(gen(subgame[0]), gen(subgame[1]), *valid)

In [61]:
def explore(by_depth, player):
    states = by_depth[-1].values()

    by_depth.append(SortedDict())
    new_states = by_depth[-1]

    for board, macro, subgame, winning in states:
        if winning:
            continue

        for play in get_valid_moves(board, subgame):
            new_board = board.copy()
            new_board[play] = player
            if len(play) < 4:
                print(play)
            macro_play = (play[0], play[1])
            micro_play = (play[2], play[3])
            new_macro = macro
            if is_game_win(new_board[macro_play], micro_play):
                new_board[macro_play] = player
                new_macro = macro.copy()
                new_macro[macro_play] = player
                new_subgame = None
            else:
                new_subgame = micro_play

            canonical_hash = get_canonical_hash(new_board, new_subgame)
            if canonical_hash not in new_states:
                new_winning = is_game_win(new_macro, macro_play)
                new_states.setdefault(canonical_hash, default=(new_board, new_macro, new_subgame, new_winning))

In [64]:
board = np.zeros((3, 3, 3, 3), dtype=np.int8)
macro = np.zeros((3, 3), dtype=np.int8)
subgame = None
winning = False
initial_state = (board, macro, subgame, winning)
by_depth = [SortedDict.fromkeys([get_canonical_hash(board, subgame)], initial_state)]

player = 2
for depth in range(board.size):
#for depth in range(1):
    explore(by_depth, player)
    
    print(len(by_depth[-1]))
    player = 2 if player == 1 else 1

15
102
822
6920
58282
482954


KeyboardInterrupt: 

In [70]:
print(list(by_depth[1].keys()))

['000000000000000000000000000000000000000020000000000000000000000000000000000000000000010000', '000000000000000000000000000000000000020000000000000000000000000000000000000000000010000000', '000000000000000000000000000000000000200000000000000000000000000000000000000000000100000000', '000000000000000020000000000000000000000000000000000000000000000000000000000000000000000010', '000000000000000200000000000000000000000000000000000000000000000000000000000000000000000100', '000000000000020000000000000000000000000000000000000000000000000000000000000000000000010000', '000000000000200000000000000000000000000000000000000000000000000000000000000000000000100000', '000000000020000000000000000000000000000000000000000000000000000000000000000000000010000000', '000000000200000000000000000000000000000000000000000000000000000000000000000000000100000000', '000000002000000000000000000000000000000000000000000000000000000000000000000000000000000001', '0000020000000000000000000000000000000000000000000000000000