In [None]:
import random

import autorootcwd
import chess
import chess.svg
import torch

from src.data.components.convert import ChessConverter
from src.models.components.mcts import MCTSNode, monte_carlo_tree_search
from src.utils.chess_utils import ChessBoard, ChessGame
from src.utils.model_loader import load_model_from_path


def mcts_play(show_gui: bool, model_instance: torch.nn.Module):
    board = chess.Board()
    data = []
    while not board.is_game_over():
        if board.turn == chess.WHITE:
            node = MCTSNode(board, None, None, model_instance)
            move = monte_carlo_tree_search(node, 1000)
        else:
            node = MCTSNode(board, None, None, model_instance)
            move = monte_carlo_tree_search(node, 1000)
        board.push(move)
        ChessBoard().show(board, show_gui)


if __name__ == "__main__":
    log_directory_v1 = "./logs/md3_alpha_zero/runs/2023-10-07_18-34-35"
    model_instance = load_model_from_path(log_directory_v1)
    lc0_path = "engine/lc0/lc0"
    lc0_model_path = "engine/lc0_model/t2-768x15x24h-swa-5230000.pb"
    ChessGame(white_model = model_instance, gpu = True, mcts = True).model_vs_lc0(gui = True, lc0_path = lc0_path, lc0_model_path = lc0_model_path)
    mcts_play(show_gui=True, model_instance=model_instance)