In [32]:
import chess
import chess.svg
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import os
from IPython.display import display, SVG, clear_output
import gymnasium as gym
from gymnasium import spaces
import copy

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [34]:
def board_to_tensor(board):
    planes = np.zeros((8, 8, 14), dtype=np.float32)
    piece_map = {chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
                 chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5}
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            color_offset = 0 if piece.color == chess.WHITE else 6
            piece_idx = piece_map[piece.piece_type] + color_offset
            row, col = divmod(square, 8)
            planes[row, col, piece_idx] = 1
    planes[:, :, 12] = 1 if board.has_kingside_castling_rights(chess.WHITE) else 0
    planes[:, :, 13] = 1 if board.has_queenside_castling_rights(chess.WHITE) else 0
    return planes.flatten()

In [35]:
class ChessEnv(gym.Env):
    def __init__(self):
        super(ChessEnv, self).__init__()
        self.board = chess.Board()
        self.action_space = spaces.Discrete(4096)
        self.observation_space = spaces.Box(low=0, high=1, shape=(896,), dtype=np.float32)
        self.max_steps = 1000
        self.step_count = 0

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.board.reset()
        self.step_count = 0
        return self._get_obs(), {}

    def step(self, action):
        legal_moves = list(self.board.legal_moves)
        if not legal_moves:
            return self._get_obs(), -1, True, False, {}
        move = self._index_to_move(action, legal_moves)
        if move in legal_moves:
            self.board.push(move)
        self.step_count += 1
        reward = 1 if self.board.is_checkmate() else 0.1 if self.board.is_capture(self.board.peek()) else 0
        done = self.board.is_game_over()
        truncated = self.step_count >= self.max_steps
        return self._get_obs(), reward, done, truncated, {}

    def render(self, mode="human"):
        display(SVG(chess.svg.board(self.board, size=400)))

    def _get_obs(self):
        return board_to_tensor(self.board)

    def _index_to_move(self, idx, legal_moves):
        return legal_moves[idx % len(legal_moves)]

In [36]:
env = ChessEnv()
print("Environment initialized")

Environment initialized


In [37]:
class ChessNet(nn.Module):
    def __init__(self, action_size=4096):
        super(ChessNet, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(896, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        # For DQN: Q-values
        self.q_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )
        # For PPO/A3C: Policy and Value
        self.policy_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )
        self.value_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x, return_value=False):
        features = self.feature_extractor(x)
        q_values = self.q_head(features)
        if return_value:  # PPO/A3C
            policy_logits = self.policy_head(features)
            value = self.value_head(features)
            return policy_logits, value
        return q_values  # DQN

    def forward_alpha_zero(self, x):  # AlphaZero: Policy + Value
        features = self.feature_extractor(x)
        policy_logits = self.policy_head(features)
        value = self.value_head(features)
        return policy_logits, value

In [38]:
net = ChessNet().to(device)
print(f"Network initialized with ~{sum(p.numel() for p in net.parameters())} parameters")

Network initialized with ~1451009 parameters


In [39]:
class DQN:
    def __init__(self, env, lr=0.001, gamma=0.99, buffer_size=10000):
        self.env = env
        self.policy_net = ChessNet().to(device)
        self.target_net = ChessNet().to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.gamma = gamma
        self.replay_buffer = deque(maxlen=buffer_size)
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.env.action_space.n)
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = self.policy_net(state_tensor)
        return q_values.argmax().item()

    def train(self, batch_size=64, num_timesteps=20000):
        for t in range(num_timesteps):
            state, _ = self.env.reset()
            done = truncated = False
            while not (done or truncated):
                action = self.select_action(state)
                next_state, reward, done, truncated, _ = self.env.step(action)
                self.replay_buffer.append((state, action, reward, next_state, done))
                state = next_state

                if len(self.replay_buffer) >= batch_size:
                    batch = random.sample(self.replay_buffer, batch_size)
                    states, actions, rewards, next_states, dones = zip(*batch)
                    states = torch.tensor(states, dtype=torch.float32).to(device)
                    actions = torch.tensor(actions, dtype=torch.long).to(device)
                    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
                    next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
                    dones = torch.tensor(dones, dtype=torch.float32).to(device)

                    q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
                    next_q_values = self.target_net(next_states).max(1)[0].detach()
                    targets = rewards + self.gamma * next_q_values * (1 - dones)
                    loss = nn.MSELoss()(q_values.squeeze(), targets)

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
            if t % 1000 == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict())
                print(f"DQN timestep {t}/{num_timesteps}, epsilon {self.epsilon:.3f}")


In [40]:
class PPO:
    def __init__(self, env, lr=0.0003, gamma=0.99, clip=0.2):
        self.env = env
        self.net = ChessNet().to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
        self.gamma = gamma
        self.clip = clip

    def select_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            logits, _ = self.net(state_tensor, return_value=True)
            probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(probs, 1).item()
        return action

    def train(self, num_timesteps=20000, rollout_size=2048, epochs=10):
        for t in range(0, num_timesteps, rollout_size):
            states, actions, rewards, next_states, dones = [], [], [], [], []
            state, _ = self.env.reset()
            for _ in range(rollout_size):
                action = self.select_action(state)
                next_state, reward, done, truncated, _ = self.env.step(action)
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                next_states.append(next_state)
                dones.append(done or truncated)
                state = next_state
                if done or truncated:
                    state, _ = self.env.reset()

            states = torch.tensor(states, dtype=torch.float32).to(device)
            actions = torch.tensor(actions, dtype=torch.long).to(device)
            rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
            next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
            dones = torch.tensor(dones, dtype=torch.float32).to(device)

            with torch.no_grad():
                _, next_values = self.net(next_states, return_value=True)
                returns = rewards + self.gamma * next_values.squeeze() * (1 - dones)

            for _ in range(epochs):
                logits, values = self.net(states, return_value=True)
                probs = torch.softmax(logits, dim=-1)
                old_probs = probs.gather(1, actions.unsqueeze(1)).squeeze()
                advantages = returns - values.squeeze()

                ratio = torch.exp(torch.log(old_probs + 1e-10) - torch.log(old_probs + 1e-10))  
                surr1 = ratio * advantages
                surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = nn.MSELoss()(values.squeeze(), returns)
                loss = policy_loss + 0.5 * value_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            print(f"PPO timestep {t + rollout_size}/{num_timesteps}")


In [41]:
class A3C:
    def __init__(self, env, lr=0.0001, gamma=0.99):
        self.env = env
        self.net = ChessNet().to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
        self.gamma = gamma

    def select_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
            logits, _ = self.net(state_tensor, return_value=True)
            probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(probs, 1).item()
        return action

    def train(self, num_timesteps=20000):
        state, _ = self.env.reset()
        for t in range(num_timesteps):
            states, actions, rewards = [], [], []
            done = truncated = False
            while not (done or truncated):
                action = self.select_action(state)
                next_state, reward, done, truncated, _ = self.env.step(action)
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                state = next_state

            states = torch.tensor(states, dtype=torch.float32).to(device)
            actions = torch.tensor(actions, dtype=torch.long).to(device)
            rewards = torch.tensor(rewards, dtype=torch.float32).to(device)

            logits, values = self.net(states, return_value=True)
            probs = torch.softmax(logits, dim=-1)
            log_probs = torch.log(probs.gather(1, actions.unsqueeze(1)).squeeze() + 1e-10)
            returns = torch.zeros_like(rewards)
            R = 0
            for i in reversed(range(len(rewards))):
                R = rewards[i] + self.gamma * R
                returns[i] = R

            advantages = returns - values.squeeze()
            policy_loss = -(log_probs * advantages.detach()).mean()
            value_loss = nn.MSELoss()(values.squeeze(), returns)
            loss = policy_loss + 0.5 * value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if done or truncated:
                state, _ = self.env.reset()

            if t % 1000 == 0:
                print(f"A3C timestep {t}/{num_timesteps}")

In [42]:
class MCTSNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = {}
        self.visits = 0
        self.value = 0

class AlphaZero:
    def __init__(self, env, c_puct=1.0, num_simulations=20):
        self.env = env
        self.net = ChessNet().to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=0.001)
        self.c_puct = c_puct
        self.num_simulations = num_simulations

    def mcts(self, root_state):
        root = MCTSNode(root_state)
        for _ in range(self.num_simulations):
            node = root
            state = copy.deepcopy(root_state)
            env_copy = ChessEnv()
            env_copy.board = state

            # Selection
            while node.children and all(a in node.children for a in list(state.legal_moves)):
                node = max(node.children.values(), key=lambda n: n.value / (n.visits + 1) + self.c_puct * np.sqrt(node.visits) / (n.visits + 1))
                env_copy.step(node.action)

            # Expansion
            legal_moves = list(env_copy.board.legal_moves)
            if legal_moves and not env_copy.board.is_game_over():
                action = random.choice(legal_moves)
                next_state, reward, done, _, _ = env_copy.step(env_copy._index_to_move(action.uci(), legal_moves))
                if action not in node.children:
                    node.children[action] = MCTSNode(env_copy.board, node, action)
                node = node.children[action]

            # Simulation
            state_tensor = torch.tensor(board_to_tensor(state), dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                logits, value = self.net.forward_alpha_zero(state_tensor)
            policy = torch.softmax(logits, dim=-1).cpu().numpy().flatten()
            value = value.item()

            # Backpropagation
            while node:
                node.visits += 1
                node.value += value if node.parent else -value
                node = node.parent

        return max(root.children.items(), key=lambda x: x[1].visits)[0]

    def train(self, num_episodes=150):
        for episode in range(num_episodes):
            state, _ = self.env.reset()
            states, actions, rewards = [], [], []
            done = truncated = False
            while not (done or truncated):
                move = self.mcts(self.env.board)
                action_idx = list(self.env.board.legal_moves).index(move)
                next_state, reward, done, truncated, _ = self.env.step(action_idx)
                states.append(board_to_tensor(self.env.board))
                actions.append(action_idx)
                rewards.append(reward)
                state = next_state

            states = torch.tensor(states, dtype=torch.float32).to(device)
            actions = torch.tensor(actions, dtype=torch.long).to(device)
            rewards = torch.tensor(rewards, dtype=torch.float32).to(device)

            logits, values = self.net.forward_alpha_zero(states)
            probs = torch.softmax(logits, dim=-1)
            log_probs = torch.log(probs.gather(1, actions.unsqueeze(1)).squeeze() + 1e-10)
            returns = torch.zeros_like(rewards)
            R = 0
            for i in reversed(range(len(rewards))):
                R = rewards[i] + 0.99 * R
                returns[i] = R

            policy_loss = -(log_probs * (returns - values.squeeze()).detach()).mean()
            value_loss = nn.MSELoss()(values.squeeze(), returns)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if episode % 10 == 0:
                print(f"AlphaZero episode {episode}/{num_episodes}")

In [None]:
dqn_agent = DQN(env)
print("Starting DQN training")
dqn_agent.train(num_timesteps=20000)

torch.save(dqn_agent.policy_net.state_dict(), "dqn_model.pth")

Starting DQN training
DQN timestep 0/20000, epsilon 0.995


In [None]:
ppo_agent = PPO(env)
print("Starting PPO training")
ppo_agent.train(num_timesteps=20000)

torch.save(ppo_agent.net.state_dict(), "ppo_model.pth")

In [None]:
a3c_agent = A3C(env)
print("Starting A3C training")
a3c_agent.train(num_timesteps=20000)

torch.save(a3c_agent.net.state_dict(), "a3c_model.pth")

In [None]:
alpha_zero_agent = AlphaZero(env)
print("Starting AlphaZero training")
alpha_zero_agent.train(num_episodes=150)

torch.save(alpha_zero_agent.net.state_dict(), "alpha_zero_model.pth")

In [None]:
dqn_net = ChessNet().to(device)
dqn_net.load_state_dict(torch.load("dqn_model.pth"))
ppo_net = ChessNet().to(device)
ppo_net.load_state_dict(torch.load("ppo_model.pth"))
a3c_net = ChessNet().to(device)
a3c_net.load_state_dict(torch.load("a3c_model.pth"))
alpha_zero_net = ChessNet().to(device)
alpha_zero_net.load_state_dict(torch.load("alpha_zero_model.pth"))

agents = {
    "DQN": dqn_agent,
    "PPO": ppo_agent,
    "A3C": a3c_agent,
    "AlphaZero": alpha_zero_agent
}


In [None]:
def play_against_agent(agent_name, human_color=chess.WHITE):
    if agent_name not in agents:
        print(f"No agent named {agent_name}")
        return
    
    agent = agents[agent_name]
    board = chess.Board()
    env = ChessEnv()
    obs, _ = env.reset()
    board = env.board.copy()
    
    print("Starting game. Enter moves in UCI format (e.g., 'e2e4'). Type 'quit' to exit.")
    display(SVG(chess.svg.board(board, size=400)))
    
    move_count = 0
    while not board.is_game_over():
        move_count += 1
        print(f"\nMove {move_count}: Turn = {'White' if board.turn == chess.WHITE else 'Black'}")
        
        if board.turn == human_color:
            move_input = input("Your move: ").strip()
            print(f"Received input: {move_input}")
            if move_input.lower() == 'quit':
                print("Game ended by user.")
                return
            try:
                move = chess.Move.from_uci(move_input)
                legal_moves = list(board.legal_moves)
                if move in legal_moves:
                    board.push(move)
                    env_legal_moves = list(env.board.legal_moves)
                    action_idx = env_legal_moves.index(move)
                    obs, reward, terminated, truncated, _ = env.step(action_idx)
                    board = env.board.copy()
                    print(f"After human move: Reward={reward}, Terminated={terminated}, Truncated={truncated}")
                    if terminated or truncated:
                        print("Game terminated after human move")
                        break
                    clear_output(wait=True)
                    display(SVG(chess.svg.board(board, size=400)))
                else:
                    print("Illegal move. Try again.")
            except ValueError as e:
                print(f"Invalid UCI format or error: {e}. Try again.")
        else:
            if agent_name == "AlphaZero":
                move = agent.mcts(board)
                action_idx = list(env.board.legal_moves).index(move)
            else:
                state_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
                with torch.no_grad():
                    if agent_name == "DQN":
                        q_values = dqn_net(state_tensor)
                        action_idx = q_values.argmax().item()
                    else:  # PPO, A3C
                        logits, _ = (ppo_net if agent_name == "PPO" else a3c_net)(state_tensor, return_value=True)
                        probs = torch.softmax(logits, dim=-1)
                        action_idx = torch.multinomial(probs, 1).item()
                move = env._index_to_move(action_idx, list(board.legal_moves))
            obs, reward, terminated, truncated, _ = env.step(action_idx)
            board = env.board.copy()
            print(f"{agent_name}'s move: {move.uci()}, Reward={reward}, Terminated={terminated}, Truncated={truncated}")
            if terminated or truncated:
                print("Game terminated after agent move")
                break
            clear_output(wait=True)
            display(SVG(chess.svg.board(board, size=400)))
    
    print(f"Game over. Result: {board.result()}")

play_against_agent("PPO", human_color=chess.WHITE)