In [1]:
import torch
import tkinter as tk
import chess

from src.agent.model import ChessNet, AlphaGoNet
from src.agent.mcts_agent import MCTSAgent
from src.env.chess import ChessEnv

In [2]:
input_shape = (12, 8, 8)
output_shape = (4096,1)
n_res_blocks = 19

model = AlphaGoNet(input_shape, output_shape, n_res_blocks)
model.load_state_dict(torch.load("best_model.pth"))

  model.load_state_dict(torch.load("best_model.pth"))


<All keys matched successfully>

In [3]:
# Constants for board drawing.
SQUARE_SIZE = 60
BOARD_SIZE = SQUARE_SIZE * 8

# Unicode mapping for chess pieces.
piece_unicode = {
    chess.PAWN:   {"white": "\u2659", "black": "\u265F"},
    chess.KNIGHT: {"white": "\u2658", "black": "\u265E"},
    chess.BISHOP: {"white": "\u2657", "black": "\u265D"},
    chess.ROOK:   {"white": "\u2656", "black": "\u265C"},
    chess.QUEEN:  {"white": "\u2655", "black": "\u265B"},
    chess.KING:   {"white": "\u2654", "black": "\u265A"},
}

class ChessGUI:
    def __init__(self, master, model, human_color="white"):
        self.master = master
        master.title("Chess: Human (Black) vs MCTS Agent (White)")
        self.canvas = tk.Canvas(master, width=BOARD_SIZE, height=BOARD_SIZE)
        self.canvas.pack()

        # Human's color: now set to "black".
        self.human_color = human_color.lower()

        # Initialize the chess environment.
        self.env = ChessEnv()
        # Initialize the MCTS agent with a chosen number of iterations.
        self.agent = MCTSAgent(model , iterations=100)

        # Reset environment and get the initial state.
        self.state = self.env.reset()

        # Variable to keep track of the human-selected square.
        self.selected_square = None

        # Bind mouse click events on the canvas.
        self.canvas.bind("<Button-1>", self.on_canvas_click)

        # Draw the initial board.
        self.draw_board()
        
        # if self.env.board.turn == chess.WHITE and self.human_color != "white":
        #     self.master.after(500, self.agent_move)

    def draw_board(self):
        """Draws board squares and pieces onto the canvas."""
        self.canvas.delete("all")
        colors = ["#F0D9B5", "#B58863"]  # Light and dark square colors.
        for row in range(8):
            for col in range(8):
                color = colors[(row + col) % 2]
                x1 = col * SQUARE_SIZE
                y1 = row * SQUARE_SIZE
                x2 = x1 + SQUARE_SIZE
                y2 = y1 + SQUARE_SIZE
                self.canvas.create_rectangle(x1, y1, x2, y2, fill=color, tags="square")
        # Highlight a selected square if one is chosen.
        if self.selected_square is not None:
            row, col = self.selected_square
            x1 = col * SQUARE_SIZE
            y1 = row * SQUARE_SIZE
            x2 = x1 + SQUARE_SIZE
            y2 = y1 + SQUARE_SIZE
            self.canvas.create_rectangle(x1, y1, x2, y2, outline="red", width=3, tags="highlight")
        # Draw pieces using Unicode.
        board = self.env.board
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece is not None:
                row = 7 - chess.square_rank(square)  # Row 0 corresponds to rank 8.
                col = chess.square_file(square)
                x = col * SQUARE_SIZE + SQUARE_SIZE // 2
                y = row * SQUARE_SIZE + SQUARE_SIZE // 2
                color_str = "white" if piece.color == chess.WHITE else "black"
                symbol = piece_unicode[piece.piece_type][color_str]
                self.canvas.create_text(x, y, text=symbol, font=("Arial", 32), tags="piece")

    def on_canvas_click(self, event):
        """Handles clicks on the board. First click selects a piece; second click selects destination."""
        col = event.x // SQUARE_SIZE
        row = event.y // SQUARE_SIZE
        # Convert canvas coordinates to a chess square.
        square = chess.square(col, 7 - row)
        board = self.env.board

        # Check if it's the human's turn.
        if (board.turn == chess.WHITE and self.human_color == "white") or \
           (board.turn == chess.BLACK and self.human_color == "black"):
            if self.selected_square is None:
                # No piece selected yet: ensure the clicked square contains one of your pieces.
                piece = board.piece_at(square)
                if piece is not None and ((piece.color == chess.WHITE and self.human_color == "white") or 
                                          (piece.color == chess.BLACK and self.human_color == "black")):
                    self.selected_square = (row, col)
                    self.draw_board()
            else:
                # A square is already selected; try to make a move.
                start_row, start_col = self.selected_square
                start_square = chess.square(start_col, 7 - start_row)
                move = chess.Move(start_square, square)
                if move in board.legal_moves:
                    # Make the move.
                    self.state, reward, done, _ = self.env.step(move.uci())
                    self.selected_square = None
                    self.draw_board()
                    self.master.update()
                    if done:
                        self.show_result()
                        return
                    # After your move, schedule the MCTS agent's move.
                    self.master.after(500, self.agent_move)
                else:
                    # Illegal move: reset selection.
                    self.selected_square = None
                    self.draw_board()

    def agent_move(self):
        """Gets the move from the MCTS agent and applies it."""
        move = self.agent.select_move(self.state)
        if move is not None:
            self.state, reward, done, _ = self.env.step(move)
            self.draw_board()
            if done:
                self.show_result()
        else:
            self.show_result()

    def show_result(self):
        """Displays the game result on the board."""
        result = self.env.board.result()  # "1-0", "0-1", or "1/2-1/2"
        if result == "1-0":
            winner = "White"
        elif result == "0-1":
            winner = "Black"
        else:
            winner = "Draw"
        result_text = f"Game Over: {result} - {winner} wins!" if winner != "Draw" else "Game Over: Draw!"
        self.canvas.create_text(BOARD_SIZE // 2, BOARD_SIZE // 2, text=result_text,
                                font=("Arial", 24), fill="red", tags="result")

if __name__ == "__main__":
    root = tk.Tk()
    # Set the human to play as black.
    app = ChessGUI(root, model , human_color="white")
    root.mainloop()

0.16172102093696594
-0.16172102093696594
0.16172102093696594
0.22991129755973816
-0.22991129755973816
0.22991129755973816
0.11536160856485367
-0.11536160856485367
0.11536160856485367
0.2453034669160843
-0.2453034669160843
0.2453034669160843
0.03549225255846977
-0.03549225255846977
0.03549225255846977
0.12011145055294037
-0.12011145055294037
0.12011145055294037
0.11444325745105743
-0.11444325745105743
0.11444325745105743
0.10025770217180252
-0.10025770217180252
0.10025770217180252
0.32510215044021606
-0.32510215044021606
0.32510215044021606
0.03975614160299301
-0.03975614160299301
0.03975614160299301
0.401065468788147
-0.401065468788147
0.401065468788147
0.25474417209625244
-0.25474417209625244
0.25474417209625244
0.29223519563674927
-0.29223519563674927
0.29223519563674927
-0.1269197016954422
0.1269197016954422
-0.1269197016954422
0.1568818986415863
-0.1568818986415863
0.1568818986415863
-0.03701610118150711
0.03701610118150711
-0.03701610118150711
-0.2506036162376404
0.250603616237640