In [None]:
import torch
import chess
import chess.svg
from IPython.display import SVG, display, clear_output
from environment import BulletChessEnv
from agent_dqn import BulletChessDQNAgent

def display_board_and_timers(env):
    clear_output(wait=True)
    board_svg = chess.svg.board(env.game_state.board, size=400)
    display(SVG(board_svg))
    print(f"White time: {int(env.game_state.white_time)}s | Black time: {int(env.game_state.black_time)}s")

env = BulletChessEnv()
agent = BulletChessDQNAgent()
agent.load("models/checkpoint_100.pth")

obs = env.reset()
done = False

print("Game start! Enter moves in UCI format (e.g., e2e4).")

while not done:
    display_board_and_timers(env)

    if env.game_state.board.turn == chess.WHITE:
        while True:
            move_uci = input("Your move (UCI): ")
            try:
                obs, reward, done, info = env.step(move_uci)
                if reward == -1:  # illegal move signal
                    print("Illegal move, try again.")
                    continue
                break
            except Exception as e:
                print(f"Error: {e}. Try again.")
    else:
        obs = env.get_observation()
        legal_actions = env.get_legal_actions()
        action = agent.select_action(obs, legal_actions)
        obs, reward, done, info = env.step(action)
        print(f"Agent played: {env._action_to_move(action).uci()}")

    if done:
        display_board_and_timers(env)
        if info.get("reason") == "time":
            print("Game over: Time out!")
        elif info.get("reason") == "checkmate":
            print("Game over: Checkmate!")
        else:
            print("Game over.")
        break


KeyError: 'q_network_state_dict'