# Imports

In [15]:
import chess
import chess.svg
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import os
from IPython.display import display, SVG, clear_output
import gymnasium as gym
from gymnasium import spaces
import copy
import time

In [38]:
import warnings
warnings.filterwarnings('ignore')

## Setting up our device

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Board encoding

In [40]:
# Convert board to 896D tensor (8x8x14: 6 piece types x 2 colors + 2 castling rights).
def board_to_tensor(board):
    planes = np.zeros((8, 8, 14), dtype=np.float32)
    piece_map = {chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
                 chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5}
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            color_offset = 0 if piece.color == chess.WHITE else 6
            piece_idx = piece_map[piece.piece_type] + color_offset
            row, col = divmod(square, 8)
            planes[row, col, piece_idx] = 1
    planes[:, :, 12] = 1 if board.has_kingside_castling_rights(chess.WHITE) else 0
    planes[:, :, 13] = 1 if board.has_queenside_castling_rights(chess.WHITE) else 0
    return planes.flatten()

# Setting our environment

In [41]:
class ChessEnv(gym.Env):
    # Initialize chess board and Gym spaces.
    def __init__(self):
        super(ChessEnv, self).__init__()
        self.board = chess.Board()
        self.action_space = spaces.Discrete(4096)
        self.observation_space = spaces.Box(low=0, high=1, shape=(896,), dtype=np.float32)
        self.max_steps = 1000
        self.step_count = 0
        self._legal_moves_cache = None

    # Reset board to initial position.
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.board.reset()
        self.step_count = 0
        self._legal_moves_cache = None
        return self._get_obs(), {}

    # Execute a move, return observation, reward, done, truncated, info.
    def step(self, action):
        if self._legal_moves_cache is None:
            self._legal_moves_cache = list(self.board.legal_moves)
        legal_moves = self._legal_moves_cache

        if not legal_moves:
            return self._get_obs(), -1, True, False, {}

        move = self._index_to_move(action, legal_moves)
        reward = 0
        done = False
        truncated = False

        if move in legal_moves:
            self.board.push(move)
            self.step_count += 1
            if self.board.is_checkmate():
                reward = 1
                done = True
            elif self.board.is_capture(self.board.peek()):
                reward = 0.1
            done = done or self.board.is_game_over()  # Stalemate, draw, etc.
            truncated = self.step_count >= self.max_steps
        else:
            reward = -0.1  # Penalize invalid moves

        self._legal_moves_cache = None  # Clear cache for next step
        return self._get_obs(), reward, done, truncated, {}

    def render(self, mode="human"):
        display(SVG(chess.svg.board(self.board, size=400)))

    # Get current board state as tensor.
    def _get_obs(self):
        return board_to_tensor(self.board)

    # Map action index to chess move, handle invalid inputs.
    def _index_to_move(self, idx, legal_moves):
        try:
            idx = int(idx)
            return legal_moves[idx % len(legal_moves)]
        except (TypeError, ValueError) as e:
            print(f"Error in _index_to_move: idx={idx}, type={type(idx)}, error={e}")
            return legal_moves[0]

In [5]:
env = ChessEnv()
print("Chess environment initialized")

Chess environment initialized


# Neural Network Architecture

In [6]:
# Multi-head network for Q-values, policy logits, and state values.
class ChessNet(nn.Module):
    def __init__(self, action_size=4096):
        super(ChessNet, self).__init__()
        # Shared feature extractor: 896D input -> 256D features.
        self.feature_extractor = nn.Sequential(
            nn.Linear(896, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        # For DQN: Q-values
        self.q_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )
        # For PPO/A3C: Policy and Value
        self.policy_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )
        self.value_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x, return_value=False):
        features = self.feature_extractor(x)
        q_values = self.q_head(features)
        if return_value:  # PPO/A3C
            policy_logits = self.policy_head(features)
            value = self.value_head(features)
            return policy_logits, value
        return q_values  # DQN

    # AlphaZero: Policy + Value
    def forward_alpha_zero(self, x):  
        features = self.feature_extractor(x)
        policy_logits = self.policy_head(features)
        value = self.value_head(features)
        return policy_logits, value


In [7]:
net = ChessNet().to(device)
print(f"Network initialized with ~{sum(p.numel() for p in net.parameters())} parameters")

Network initialized with ~1451009 parameters


# DQN Algorithm

In [16]:
class DQN:
    def __init__(self, env, lr=0.001, gamma=0.99, buffer_size=10000):
        self.env = env
        self.policy_net = ChessNet().to(device) # Main Q-network.
        self.target_net = ChessNet().to(device) # Frozen target for stability.
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.gamma = gamma
        self.replay_buffer = deque(maxlen=buffer_size)
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.batch_size = 128  # GPU-friendly batch size.

    # Choose action: Epsilon-greedy policy.
    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.env.action_space.n)
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = self.policy_net(state_tensor)
        return q_values.argmax().item()

    def train(self, num_timesteps=20000):
        t = 0
        episode = 0
        start_time = time.time()
        while t < num_timesteps:
            state, _ = self.env.reset()
            done = truncated = False
            episode_steps = 0
            episode_start = time.time()

            # Play one episode
            while not (done or truncated) and t < num_timesteps:
                action = self.select_action(state)
                next_state, reward, done, truncated, _ = self.env.step(action)
                self.replay_buffer.append((state, action, reward, next_state, done))
                state = next_state
                episode_steps += 1
                t += 1

                # Update network if enough experiences.
                if len(self.replay_buffer) >= self.batch_size:
                    batch = random.sample(self.replay_buffer, self.batch_size)
                    states = torch.tensor([s for s, _, _, _, _ in batch], dtype=torch.float32).to(device)
                    actions = torch.tensor([a for _, a, _, _, _ in batch], dtype=torch.long).to(device)
                    rewards = torch.tensor([r for _, _, r, _, _ in batch], dtype=torch.float32).to(device)
                    next_states = torch.tensor([ns for _, _, _, ns, _ in batch], dtype=torch.float32).to(device)
                    dones = torch.tensor([d for _, _, _, _, d in batch], dtype=torch.float32).to(device)

                    q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
                    with torch.no_grad():
                        next_q_values = self.target_net(next_states).max(1)[0]
                    targets = rewards + self.gamma * next_q_values * (1 - dones)
                    loss = nn.MSELoss()(q_values, targets)

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                if t % 1000 == 0:
                    self.target_net.load_state_dict(self.policy_net.state_dict())
                    elapsed = time.time() - start_time
                    print(f"DQN timestep {t}/{num_timesteps}, epsilon {self.epsilon:.3f}, time {elapsed:.1f}s")

            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
            episode += 1
            episode_time = time.time() - episode_start
            print(f"DQN episode {episode}, steps {episode_steps}, time {episode_time:.1f}s")

# PPO Algorithm

In [18]:
class PPO:
    def __init__(self, env, lr=0.0003, gamma=0.99, clip=0.2):
        self.env = env
        self.net = ChessNet().to(device) # Policy and value network.
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
        self.gamma = gamma
        self.clip = clip
        self.batch_size = 128

    # Sample action from policy, return probability.
    def select_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            logits, _ = self.net(state_tensor, return_value=True)
            probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(probs, 1).item()
        return action, probs[0, action].item()

    def train(self, num_timesteps=20000, rollout_size=1024, epochs=5):
        t = 0
        episode = 0
        start_time = time.time()
        while t < num_timesteps:
            states, actions, rewards, log_probs, values = [], [], [], [], []
            state, _ = self.env.reset()
            episode_steps = 0
            episode_start = time.time()

            # Collect rollout
            for _ in range(rollout_size):
                if t >= num_timesteps:
                    break
                action, prob = self.select_action(state)
                next_state, reward, done, truncated, _ = self.env.step(action)
                state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
                with torch.no_grad():
                    _, value = self.net(state_tensor, return_value=True)
                
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                log_probs.append(np.log(prob + 1e-10))
                values.append(value.item())
                state = next_state
                episode_steps += 1
                t += 1

                if done or truncated:
                    state, _ = self.env.reset()
                    episode += 1
                    episode_time = time.time() - episode_start
                    print(f"PPO episode {episode}, steps {episode_steps}, time {episode_time:.1f}s")
                    episode_steps = 0
                    episode_start = time.time()

            # Process rollout
            states_tensor = torch.tensor(states, dtype=torch.float32).to(device)
            actions_tensor = torch.tensor(actions, dtype=torch.long).to(device)
            rewards_tensor = torch.tensor(rewards, dtype=torch.float32).to(device)
            old_log_probs = torch.tensor(log_probs, dtype=torch.float32).to(device)
            old_values = torch.tensor(values, dtype=torch.float32).to(device)

            # Compute returns
            returns = torch.zeros_like(rewards_tensor)
            R = 0
            for i in reversed(range(len(rewards))):
                R = rewards[i] + self.gamma * R
                returns[i] = R

            # Update policy
            for _ in range(epochs):
                logits, values = self.net(states_tensor, return_value=True)
                probs = torch.softmax(logits, dim=-1)
                new_log_probs = torch.log(probs.gather(1, actions_tensor.unsqueeze(1)).squeeze() + 1e-10)
                advantages = returns - old_values

                ratio = torch.exp(new_log_probs - old_log_probs)
                surr1 = ratio * advantages
                surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = nn.MSELoss()(values.squeeze(), returns)
                loss = policy_loss + 0.5 * value_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            if t % 1000 == 0:
                elapsed = time.time() - start_time
                print(f"PPO timestep {t}/{num_timesteps}, time {elapsed:.1f}s")


# A3C Algorithm (Single-Threaded)

In [19]:
class A3C:
    def __init__(self, env, lr=0.0001, gamma=0.99):
        self.env = env
        self.net = ChessNet().to(device) # Shared actor-critic network.
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
        self.gamma = gamma
        self.batch_size = 128

    # Sample action, return probability.
    def select_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            logits, _ = self.net(state_tensor, return_value=True)
            probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(probs, 1).item()
        return action, probs[0, action].item()

    def train(self, num_timesteps=20000):
        t = 0
        episode = 0
        start_time = time.time()
        while t < num_timesteps:
            states, actions, rewards, log_probs = [], [], [], []
            state, _ = self.env.reset()
            done = truncated = False
            episode_steps = 0
            episode_start = time.time()

            # Collect experience batch.
            while not (done or truncated) and t < num_timesteps:
                action, prob = self.select_action(state)
                next_state, reward, done, truncated, _ = self.env.step(action)
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                log_probs.append(np.log(prob + 1e-10))
                state = next_state
                episode_steps += 1
                t += 1

                # Update network per batch.
                if len(states) >= self.batch_size or done or truncated:
                    states_tensor = torch.tensor(states, dtype=torch.float32).to(device)
                    actions_tensor = torch.tensor(actions, dtype=torch.long).to(device)
                    rewards_tensor = torch.tensor(rewards, dtype=torch.float32).to(device)
                    log_probs_tensor = torch.tensor(log_probs, dtype=torch.float32).to(device)

                    logits, values = self.net(states_tensor, return_value=True)
                    returns = torch.zeros_like(rewards_tensor)
                    R = 0
                    for i in reversed(range(len(rewards))):
                        R = rewards[i] + self.gamma * R
                        returns[i] = R

                    advantages = returns - values.squeeze()
                    policy_loss = -(log_probs_tensor * advantages.detach()).mean()
                    value_loss = nn.MSELoss()(values.squeeze(), returns)
                    loss = policy_loss + 0.5 * value_loss

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    states, actions, rewards, log_probs = [], [], [], []

                if done or truncated:
                    state, _ = self.env.reset()
                    episode += 1
                    episode_time = time.time() - episode_start
                    print(f"A3C episode {episode}, steps {episode_steps}, time {episode_time:.1f}s")
                    episode_steps = 0
                    episode_start = time.time()

            if t % 1000 == 0:
                elapsed = time.time() - start_time
                print(f"A3C timestep {t}/{num_timesteps}, time {elapsed:.1f}s")


# AlphaZero (MCTS-guided policy/value learning) -- Simplified

In [30]:
class MCTSNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state # Chess board state.
        self.parent = parent
        self.action = action # Move leading to this node.
        self.children = {} # Child nodes (moves).
        self.visits = 0
        self.value = 0

class AlphaZero:
    def __init__(self, env, c_puct=1.0, num_simulations=10):
        self.env = env
        self.net = ChessNet().to(device) # Policy/value network.
        self.optimizer = optim.Adam(self.net.parameters(), lr=0.001)
        self.c_puct = c_puct # Exploration constant.
        self.num_simulations = num_simulations # MCTS iterations.

    # Run MCTS to select best move.
    def mcts(self, root_state):
        root = MCTSNode(root_state)
        for _ in range(self.num_simulations):
            node = root
            state = copy.deepcopy(root_state)
            env_copy = ChessEnv()
            env_copy.board = state

            # Selection: Choose best child by UCB.
            legal_moves = list(state.legal_moves)
            while node.children and all(a in node.children for a in legal_moves):
                node = max(node.children.values(), key=lambda n: n.value / (n.visits + 1) + self.c_puct * np.sqrt(node.visits) / (n.visits + 1))
                action_idx = list(state.legal_moves).index(node.action)
                env_copy.step(action_idx)  # Pass integer index
                state = env_copy.board

            # Expansion: Add new child node.
            legal_moves = list(env_copy.board.legal_moves)
            if legal_moves and not env_copy.board.is_game_over():
                action = random.choice(legal_moves)
                action_idx = list(env_copy.board.legal_moves).index(action)
                next_state, reward, done, _, _ = env_copy.step(action_idx)
                if action not in node.children:
                    node.children[action] = MCTSNode(env_copy.board, node, action)
                node = node.children[action]

            # Simulation: Estimate value with network.
            state_tensor = torch.tensor(board_to_tensor(state), dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                logits, value = self.net.forward_alpha_zero(state_tensor)
            value = value.item()

            # Backpropagation: Update node stats.
            while node:
                node.visits += 1
                node.value += value if node.parent else -value
                node = node.parent

        if not root.children:
            print("Warning: No MCTS children generated")
            return random.choice(legal_moves) if legal_moves else None
        return max(root.children.items(), key=lambda x: x[1].visits)[0]

    def train(self, num_episodes=150):
        episode = 0
        start_time = time.time()
        while episode < num_episodes:
            states, actions, rewards = [], [], []
            state, _ = self.env.reset()
            done = truncated = False
            episode_steps = 0
            episode_start = time.time()

            # Play one episode with MCTS.
            while not (done or truncated):
                move = self.mcts(self.env.board)
                if move is None:
                    print("MCTS returned None; skipping episode")
                    break
                try:
                    action_idx = list(self.env.board.legal_moves).index(move)
                except ValueError:
                    print(f"Invalid move {move}; skipping")
                    break

                next_state, reward, done, truncated, _ = self.env.step(action_idx)
                states.append(board_to_tensor(self.env.board))
                actions.append(action_idx)
                rewards.append(reward)
                state = next_state
                episode_steps += 1

                # Update network with episode data.
                if done or truncated:
                    states_tensor = torch.tensor(states, dtype=torch.float32).to(device)
                    actions_tensor = torch.tensor(actions, dtype=torch.long).to(device)
                    rewards_tensor = torch.tensor(rewards, dtype=torch.float32).to(device)

                    logits, values = self.net.forward_alpha_zero(states_tensor)
                    probs = torch.softmax(logits, dim=-1)
                    log_probs = torch.log(probs.gather(1, actions_tensor.unsqueeze(1)).squeeze() + 1e-10)
                    returns = torch.zeros_like(rewards_tensor)
                    R = 0
                    for i in reversed(range(len(rewards))):
                        R = rewards[i] + 0.99 * R
                        returns[i] = R

                    policy_loss = -(log_probs * (returns - values.squeeze()).detach()).mean()
                    value_loss = nn.MSELoss()(values.squeeze(), returns)
                    loss = policy_loss + value_loss

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    episode += 1
                    episode_time = time.time() - episode_start
                    print(f"AlphaZero episode {episode}/{num_episodes}, steps {episode_steps}, time {episode_time:.1f}s")
                    episode_steps = 0
                    episode_start = time.time()

                    if episode >= num_episodes:
                        break

        elapsed = time.time() - start_time
        print(f"AlphaZero total time {elapsed:.1f}s")


# Training and Saving agents

In [17]:
dqn_agent = DQN(env)

print("Starting DQN training")
dqn_agent.train(num_timesteps=20000)
torch.save(dqn_agent.policy_net.state_dict(), "dqn_model.pth")

Starting DQN training


  states = torch.tensor([s for s, _, _, _, _ in batch], dtype=torch.float32).to(device)


DQN episode 1, steps 308, time 4.5s
DQN episode 2, steps 333, time 5.9s
DQN timestep 1000/20000, epsilon 0.990, time 16.6s
DQN episode 3, steps 643, time 11.1s
DQN episode 4, steps 396, time 6.9s
DQN timestep 2000/20000, epsilon 0.980, time 33.8s
DQN episode 5, steps 339, time 5.8s
DQN episode 6, steps 317, time 5.2s
DQN episode 7, steps 245, time 4.2s
DQN timestep 3000/20000, epsilon 0.966, time 50.9s
DQN episode 8, steps 426, time 7.5s
DQN episode 9, steps 151, time 2.6s
DQN episode 10, steps 359, time 6.3s
DQN episode 11, steps 234, time 4.1s
DQN timestep 4000/20000, epsilon 0.946, time 68.3s
DQN episode 12, steps 537, time 9.5s
DQN episode 13, steps 472, time 8.2s
DQN episode 14, steps 89, time 1.5s
DQN timestep 5000/20000, epsilon 0.932, time 85.8s
DQN episode 15, steps 358, time 6.3s
DQN episode 16, steps 370, time 6.5s
DQN timestep 6000/20000, epsilon 0.923, time 103.5s
DQN episode 17, steps 555, time 9.8s
DQN episode 18, steps 377, time 6.7s
DQN episode 19, steps 263, time 4.6s

In [22]:
ppo_agent = PPO(env)

print("Starting PPO training")
ppo_agent.train(num_timesteps=20000)
torch.save(ppo_agent.net.state_dict(), "ppo_model.pth")

Starting PPO training
PPO episode 1, steps 485, time 0.5s
PPO episode 2, steps 518, time 0.4s
PPO episode 3, steps 107, time 0.1s
PPO episode 4, steps 460, time 0.4s
PPO episode 5, steps 343, time 0.3s
PPO episode 6, steps 292, time 0.3s
PPO episode 7, steps 483, time 0.4s
PPO episode 8, steps 586, time 0.5s
PPO episode 9, steps 428, time 0.4s
PPO episode 10, steps 264, time 0.2s
PPO episode 11, steps 357, time 0.3s
PPO episode 12, steps 253, time 0.2s
PPO episode 13, steps 536, time 0.5s
PPO episode 14, steps 502, time 0.5s
PPO episode 15, steps 179, time 0.2s
PPO episode 16, steps 227, time 0.2s
PPO episode 17, steps 59, time 0.1s
PPO episode 18, steps 149, time 0.1s
PPO episode 19, steps 259, time 0.2s
PPO episode 20, steps 377, time 0.3s
PPO episode 21, steps 461, time 0.4s
PPO episode 22, steps 121, time 0.1s
PPO episode 23, steps 403, time 0.3s
PPO episode 24, steps 440, time 0.4s
PPO episode 25, steps 245, time 0.2s
PPO episode 26, steps 306, time 0.3s
PPO episode 27, steps 332,

In [23]:
a3c_agent = A3C(env)

print("Starting A3C training")
a3c_agent.train(num_timesteps=20000)
torch.save(a3c_agent.net.state_dict(), "a3c_model.pth")

Starting A3C training
A3C episode 1, steps 533, time 0.7s
A3C episode 2, steps 237, time 0.2s
A3C episode 3, steps 291, time 0.3s
A3C episode 4, steps 370, time 0.3s
A3C episode 5, steps 273, time 0.2s
A3C episode 6, steps 457, time 0.4s
A3C episode 7, steps 456, time 0.4s
A3C episode 8, steps 266, time 0.3s
A3C episode 9, steps 359, time 0.3s
A3C episode 10, steps 498, time 0.4s
A3C episode 11, steps 575, time 0.5s
A3C episode 12, steps 496, time 0.4s
A3C episode 13, steps 341, time 0.3s
A3C episode 14, steps 348, time 0.3s
A3C episode 15, steps 239, time 0.2s
A3C episode 16, steps 362, time 0.3s
A3C episode 17, steps 411, time 0.3s
A3C episode 18, steps 389, time 0.3s
A3C episode 19, steps 459, time 0.4s
A3C episode 20, steps 396, time 0.3s
A3C episode 21, steps 357, time 0.3s
A3C episode 22, steps 507, time 0.4s
A3C episode 23, steps 493, time 0.4s
A3C episode 24, steps 466, time 0.3s
A3C episode 25, steps 286, time 0.2s
A3C episode 26, steps 85, time 0.1s
A3C episode 27, steps 349,

In [31]:
alpha_zero_agent = AlphaZero(env)

print("Starting AlphaZero training")
alpha_zero_agent.train(num_episodes=150)
torch.save(alpha_zero_agent.net.state_dict(), "alpha_zero_model.pth")

Starting AlphaZero training
AlphaZero episode 1/150, steps 199, time 1.6s
AlphaZero episode 2/150, steps 485, time 5.2s
AlphaZero episode 3/150, steps 49, time 0.4s
AlphaZero episode 4/150, steps 479, time 5.2s
AlphaZero episode 5/150, steps 415, time 4.3s
AlphaZero episode 6/150, steps 479, time 5.1s
AlphaZero episode 7/150, steps 463, time 4.8s
AlphaZero episode 8/150, steps 495, time 5.5s
AlphaZero episode 9/150, steps 441, time 4.8s
AlphaZero episode 10/150, steps 276, time 2.7s
AlphaZero episode 11/150, steps 598, time 7.5s
AlphaZero episode 12/150, steps 376, time 4.0s
AlphaZero episode 13/150, steps 362, time 3.8s
AlphaZero episode 14/150, steps 381, time 3.9s
AlphaZero episode 15/150, steps 281, time 2.6s
AlphaZero episode 16/150, steps 266, time 2.6s
AlphaZero episode 17/150, steps 110, time 0.9s
AlphaZero episode 18/150, steps 538, time 6.2s
AlphaZero episode 19/150, steps 324, time 3.5s
AlphaZero episode 20/150, steps 590, time 7.1s
AlphaZero episode 21/150, steps 501, time 

# Gameplay loop

In [39]:
# Loading trained models
dqn_net = ChessNet().to(device)
dqn_net.load_state_dict(torch.load("dqn_model.pth"))
ppo_net = ChessNet().to(device)
ppo_net.load_state_dict(torch.load("ppo_model.pth"))
a3c_net = ChessNet().to(device)
a3c_net.load_state_dict(torch.load("a3c_model.pth"))
alpha_zero_net = ChessNet().to(device)
alpha_zero_net.load_state_dict(torch.load("alpha_zero_model.pth"))

# Mapping agents
agents = {
    "DQN": dqn_agent,
    "PPO": ppo_agent,
    "A3C": a3c_agent,
    "AlphaZero": alpha_zero_agent
}

# Widget setup
agent_dropdown = widgets.Dropdown(
    options=["DQN", "PPO", "A3C", "AlphaZero"],
    value="PPO",
    description="Agent:"
)

color_dropdown = widgets.Dropdown(
    options=[("White", chess.WHITE), ("Black", chess.BLACK)],
    value=chess.WHITE,
    description="Your Color:"
)

start_button = widgets.Button(
    description="Start Game",
    button_style="success",
    tooltip="Click to start the game"
)

move_dropdown = widgets.Dropdown(
    options=[],
    description="Your Move:",
    disabled=True
)

submit_button = widgets.Button(
    description="Submit Move",
    button_style="info",
    tooltip="Submit your move",
    disabled=True
)

abort_button = widgets.Button(
    description="Abort Game",
    button_style="danger",
    tooltip="Stop the current game",
    disabled=True
)

last_move_label = widgets.Label(
    value="Last move: None",
    layout={'width': '200px'}
)

output = widgets.Output()

# Game state
game_state = {
    "board": None,
    "env": None,
    "agent": None,
    "human_color": None,
    "obs": None,
    "move_count": 0,
    "active": False,
    "last_move": None
}

# Update board display
def update_board():
    with output:
        clear_output()
        if game_state["board"]:
            display(SVG(chess.svg.board(game_state["board"], size=400)))
            print(f"Move {game_state['move_count']}: Turn = {'White' if game_state['board'].turn == chess.WHITE else 'Black'}")
            print(f"Last move: {game_state['last_move'] if game_state['last_move'] else 'None'}")
            last_move_label.value = f"Last move: {game_state['last_move'] if game_state['last_move'] else 'None'}"
            if game_state["board"].is_game_over():
                print(f"Game over. Result: {game_state['board'].result()}")

def update_move_dropdown():
    if game_state["board"] and game_state["board"].turn == game_state["human_color"] and game_state["active"]:
        legal_moves = list(game_state["board"].legal_moves)
        move_dropdown.options = [move.uci() for move in legal_moves]
        move_dropdown.disabled = False
        submit_button.disabled = False
        abort_button.disabled = False
    else:
        move_dropdown.options = []
        move_dropdown.disabled = True
        submit_button.disabled = True
        abort_button.disabled = not game_state["active"]

def start_game(b):
    with output:
        clear_output()
        agent_name = agent_dropdown.value
        if agent_name not in agents:
            print(f"No agent named {agent_name}")
            return
        
        game_state["agent"] = agents[agent_name]
        game_state["human_color"] = color_dropdown.value
        game_state["board"] = chess.Board()
        game_state["env"] = ChessEnv()
        game_state["obs"], _ = game_state["env"].reset()
        game_state["board"] = game_state["env"].board.copy()
        game_state["move_count"] = 0
        game_state["active"] = True
        game_state["last_move"] = None

        print(f"Game started against {agent_name}. You are {'White' if game_state['human_color'] == chess.WHITE else 'Black'}.")
        update_board()
        update_move_dropdown()
        if game_state["board"].turn != game_state["human_color"]:
            ai_move()

def ai_move():
    if not game_state["active"] or game_state["board"].is_game_over():
        return
    
    agent_name = agent_dropdown.value
    if agent_name == "AlphaZero":
        move = game_state["agent"].mcts(game_state["board"])
        if move is None:
            with output:
                print("AI failed to select a move")
            game_state["active"] = False
            update_move_dropdown()
            return
        action_idx = list(game_state["env"].board.legal_moves).index(move)
    else:
        state_tensor = torch.tensor(game_state["obs"], dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            if agent_name == "DQN":
                q_values = dqn_net(state_tensor)
                action_idx = q_values.argmax().item()
            else:  # PPO, A3C
                logits, _ = (ppo_net if agent_name == "PPO" else a3c_net)(state_tensor, return_value=True)
                probs = torch.softmax(logits, dim=-1)
                action_idx = torch.multinomial(probs, 1).item()
        move = game_state["env"]._index_to_move(action_idx, list(game_state["board"].legal_moves))

    game_state["obs"], reward, terminated, truncated, _ = game_state["env"].step(action_idx)
    game_state["board"] = game_state["env"].board.copy()
    game_state["move_count"] += 1
    game_state["last_move"] = move.uci()

    with output:
        print(f"{agent_name}'s move: {move.uci()}, Reward={reward}, Terminated={terminated}, Truncated={truncated}")
        if terminated or truncated:
            print("Game terminated after AI move")
            game_state["active"] = False
    
    update_board()
    update_move_dropdown()

def submit_move(b):
    if not game_state["active"] or game_state["board"].is_game_over():
        return
    
    move_uci = move_dropdown.value
    try:
        move = chess.Move.from_uci(move_uci)
        legal_moves = list(game_state["board"].legal_moves)
        if move in legal_moves:
            game_state["board"].push(move)
            env_legal_moves = list(game_state["env"].board.legal_moves)
            action_idx = env_legal_moves.index(move)
            game_state["obs"], reward, terminated, truncated, _ = game_state["env"].step(action_idx)
            game_state["board"] = game_state["env"].board.copy()
            game_state["move_count"] += 1
            game_state["last_move"] = move.uci()

            with output:
                print(f"Your move: {move.uci()}, Reward={reward}, Terminated={terminated}, Truncated={truncated}")
                if terminated or truncated:
                    print("Game terminated after your move")
                    game_state["active"] = False
            
            update_board()
            update_move_dropdown()
            if not (terminated or truncated) and game_state["board"].turn != game_state["human_color"]:
                ai_move()
        else:
            with output:
                print("Illegal move. Try again.")
    except ValueError as e:
        with output:
            print(f"Invalid UCI format: {e}. Try again.")

def abort_game(b):
    if game_state["active"]:
        game_state["active"] = False
        game_state["board"] = None
        game_state["env"] = None
        game_state["agent"] = None
        game_state["human_color"] = None
        game_state["obs"] = None
        game_state["move_count"] = 0
        game_state["last_move"] = None
        with output:
            clear_output()
            print("Game aborted.")
        update_move_dropdown()
        last_move_label.value = "Last move: None"

# Bind buttons
start_button.on_click(start_game)
submit_button.on_click(submit_move)
abort_button.on_click(abort_game)

# Display UI
display(widgets.VBox([agent_dropdown, color_dropdown, start_button, move_dropdown, submit_button, abort_button, last_move_label, output]))

VBox(children=(Dropdown(description='Agent:', index=1, options=('DQN', 'PPO', 'A3C', 'AlphaZero'), value='PPO'…