#### Check if the model is stable and reliable

In [None]:
import os
from collections import OrderedDict

import chess
import chess.svg
import hydra
import rootutils
import torch
from omegaconf import OmegaConf

rootutils.setup_root(os.path.abspath("."), indicator=".project-root", pythonpath=True)

from src.utils.chess_utils import ChessBoard, ChessData, ChessGame, ChessMove


def get_correct_state_dict(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace("net.", "")  # remove "net." from the keys
        new_state_dict[name] = v
    return new_state_dict


def get_tempfix_for_torch(ckpt):
    """TODO(mai0313): remove _orig_mod. from the state_dict due to pytorch issue #101107.

    Ref: https://discuss.pytorch.org/t/how-to-save-load-a-model-with-torch-compile/179739/2
         https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089
    In short, when you train a model with torch.compile, it will add _orig_mod. to the state_dict, which is not what we need;
    So we just simply remove it.
    """
    new_dict = OrderedDict()
    for k, v in ckpt["state_dict"].items():
        name = k.replace("_orig_mod.", "")
        new_dict[name] = v
    return new_dict


def load_model_from_path(log_directory):
    ckpt_path = f"{log_directory}/checkpoints/last.ckpt"
    model_config = OmegaConf.load(f"{log_directory}/.hydra/config.yaml")
    compile_option = model_config.model.compile
    if compile_option:
        model_instance = hydra.utils.instantiate(model_config.model)
        checkpoint = torch.load(ckpt_path)
        fixed_state_dict = get_tempfix_for_torch(checkpoint)
        model_instance.load_state_dict(fixed_state_dict)
    else:
        model_instance = hydra.utils.instantiate(model_config.model)
        model_instance.load_from_checkpoint(ckpt_path)
    return model_instance

#### Enter the path of your logs

- `logs_path = '../logs/2020-05-20_15-00-00'`

In [None]:
log_directory_v1 = "../logs/chess_md2/runs/2023-10-01_22-54-38"
log_directory_v2 = "../logs/chess_md2/runs/2023-10-01_22-54-38"

#### Self-Play

- AI vs AI

In [None]:
model_instance = load_model_from_path(log_directory_v1)
ChessGame(white_model = model_instance, gpu = True).self_play(gui = True)

#### Model vs Stockfish

- This function allows you to see if your model is better than Stockfish

In [None]:
model_instance = load_model_from_path(log_directory_v1)
ChessGame(model_instance, True).model_vs_stockfish(gui = True, stockfish_path = "../stockfish_linux/stockfish-ubuntu-x86-64-avx2", cpu_nums = 8)

#### Self-Play Version 2

- This is for model 1 vs model 2

In [None]:
model_instance = load_model_from_path(log_directory_v1)
model_instance_2 = load_model_from_path(log_directory_v2)

ChessGame(model_instance, False).model_vs_model(model_instance_2, True)

#### Play with AI

- You v.s. AI

In [None]:
model_instance = load_model_from_path(log_directory_v1)
ChessGame(model_instance, True).play_against_ai(gui = True)

#### Solve a chess puzzle

- Given a chessboard, find the best move for white.

In [None]:
model_instance = load_model_from_path(log_directory_v1)

board = chess.Board("8/1K6/8/1P6/2rpP3/2P5/8/8 b - - 0 1")
best_move = ChessGame(model_instance, True).solve_puzzle(board = board, gui = True)
print(f"模型推薦的移動是：{best_move}")