In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import chess
import chess.svg
from IPython.display import display, SVG, clear_output
import ipywidgets as widgets
import torch

from src.engine import Engine
from src.transformer import TransformerConfig, PositionalEncodings, TransformerDecoder, Predictor
from src.utils import MOVE_TO_ACTION
from src.tokenizer import SEQUENCE_LENGTH

In [None]:
# Load model and predictor
model_path = "checkpoint_epoch1_20250504_092552.pt" # "Checkpoint_Epoch_33770.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"

transformer_config = TransformerConfig(
    vocab_size=len(MOVE_TO_ACTION),
    output_size=128,
    pos_encodings=PositionalEncodings.SINUSOID,
    max_sequence_length=SEQUENCE_LENGTH + 2,
    num_heads=4,
    num_layers=4,
    embedding_dim=64,
    apply_post_ln=True,
    apply_qk_layernorm=False,
    use_causal_mask=False,
)

model = TransformerDecoder(transformer_config)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state'])

predictor = Predictor(model)
engine = Engine(predictor)

# Ask user for side
while True:
    side = input("Play as White or Black? ").strip().lower()
    if side in ["white", "black"]:
        break
    print("Please type 'White' or 'Black'.")

player_is_white = (side == "white")
board = engine.board
game_ended = False

# Output widget to capture move display
output = widgets.Output()
display(output)

def display_board(message=None):
    clear_output(wait=True)
    svg_board = chess.svg.board(board=board, size=350, flipped=not player_is_white)
    display(SVG(svg_board))
    display(move_input)
    if message:
        with output:
            output.clear_output()
            print(message)
        display(output)

def make_computer_move():
    move = engine.get_best_move()
    san_move = board.san(move)
    board.push(move)
    display_board(f"Computer played: {san_move}")

if not player_is_white:
    make_computer_move()

# Text input for player move
move_input = widgets.Text(
    description="Your move:",
    placeholder='e.g., e4, Nf3, exd5',
)
display_board()

def on_enter(change):
    global game_ended
    if game_ended:
        return

    move_text = change.value.strip()
    move_input.value = ""

    try:
        move = board.parse_san(move_text)
        if move in board.legal_moves:
            engine.human_play(move_text)
        else:
            display_board("Illegal move. Try again.")
            return
    except Exception:
        display_board("Invalid move format. Try again.")
        return

    if board.is_game_over():
        display_board("Game Over: " + board.result())
        game_ended = True
        return

    make_computer_move()

    if board.is_game_over():
        display_board("Game Over: " + board.result())
        game_ended = True
        return

# Trigger move on Enter key
move_input.on_submit(on_enter)