#### Check if the model is stable and reliable

In [None]:
import os
from collections import OrderedDict

import chess
import chess.svg
import hydra
import IPython.display as display
import numpy as np
import rootutils
import torch
from omegaconf import OmegaConf

# set root
rootutils.setup_root(os.path.abspath("."), indicator=".project-root", pythonpath=True)


def board_to_array(board):
    # Initialize an empty array to represent the board
    board_array = np.zeros((8, 8, 12), dtype=np.float32)

    # Mapping of piece types to index in the third dimension of the board_array
    piece_idx = {'p': 0, 'P': 6, 'n': 1, 'N': 7, 'b': 2, 'B': 8, 'r': 3, 'R': 9, 'q': 4, 'Q': 10, 'k': 5, 'K': 11}

    for i in range(8):
        for j in range(8):
            square = 8 * (7 - j) + i  # Calculate square index
            piece = board.piece_at(square)
            if piece:
                board_array[j, i, piece_idx[piece.symbol()]] = 1.0
    return board_array

def model_recommends_move(model, board):
    best_move = choose_move(model, board)
    return best_move

def show_and_move(board, move):
    board.push(move)
    show_board(board)

def choose_move(model, board):
    best_move = None
    best_value = -1.0  # Initialize with a low value

    for move in board.legal_moves:
        board.push(move)
        board_array = board_to_array(board)
        board_array = np.transpose(board_array, (2, 0, 1))
        board_array = torch.tensor(board_array).float().unsqueeze(0)
        value = model(board_array).item()
        board.pop()
        if value > best_value:
            best_value = value
            best_move = move
    return best_move


def show_board(board):
    board_svg = chess.svg.board(board=board, size=300)
    display.clear_output(wait=True)  # 清除舊的棋盤
    display.display(display.HTML(board_svg))

def self_play(model, show_board_option=False):
    board = chess.Board()
    while not board.is_game_over():
        move = choose_move(model, board)
        board.push(move)
        if show_board_option:
            show_board(board)


def play_against_ai(model):
    board = chess.Board()
    while not board.is_game_over():
        show_board(board)

        valid_move = False
        while not valid_move:
            try:
                human_move = input("Enter your move: ")
                if human_move == "q" or human_move == "quit" or human_move == "exit":
                    return
                board.push_san(human_move)
                valid_move = True
                show_board(board)
            except ValueError:
                print("Invalid move. Please enter a valid move.")

        if not board.is_game_over():
            model_move = choose_move(model, board)
            board.push(model_move)
            print(f"Model's move: {model_move}")
            show_board(board)


def get_correct_state_dict(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace("net.", "")  # remove "net." from the keys
        new_state_dict[name] = v
    return new_state_dict


def get_tempfix_for_torch(ckpt):
    """TODO(mai0313): remove _orig_mod. from the state_dict due to pytorch issue #101107.

    Ref: https://discuss.pytorch.org/t/how-to-save-load-a-model-with-torch-compile/179739/2
         https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089
    In short, when you train a model with torch.compile, it will add _orig_mod. to the state_dict, which is not what we need;
    So we just simply remove it.
    """
    new_dict = OrderedDict()
    for k, v in ckpt["state_dict"].items():
        name = k.replace("_orig_mod.", "")
        new_dict[name] = v
    return new_dict

#### Load The Model You Just Trained

- Given a path to your `Log` like `../logs/chess_md1/runs/2023-10-01_04-22-25`

In [None]:
log_directory = "../logs/chess_md1/runs/2023-10-01_18-33-08"

# no need to change
ckpt_path = f"{log_directory}/checkpoints/last.ckpt"
model_config = OmegaConf.load(f"{log_directory}/.hydra/config.yaml")
compile_option = model_config.model.compile

if compile_option:
    model_instance = hydra.utils.instantiate(model_config.model)
    checkpoint = torch.load(ckpt_path)
    fixed_state_dict = get_tempfix_for_torch(checkpoint)
    model_instance.load_state_dict(fixed_state_dict)
else:
    model_instance = hydra.utils.instantiate(model_config.model)
    model_instance.load_from_checkpoint(ckpt_path)

#### Self-Play

- AI vs AI

In [None]:
self_play(model_instance, True)

#### Play with AI

- You v.s. AI

In [None]:
play_against_ai(model_instance)

#### Solve a chess puzzle

- Given a chessboard, find the best move for white.

In [None]:
board = chess.Board("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1")
best_move = model_recommends_move(model_instance, board)
print(f"模型推薦的移動是：{best_move}")

show_and_move(board, best_move)