<a href="https://colab.research.google.com/github/alexandercorey2013/rl-chess-agent/blob/main/train_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install python-chess torch numpy

import chess
import random
import numpy as np
from collections import deque

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


# ----------------------------
# Device (CPU/GPU)
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# ----------------------------
# Chess Environment
# ----------------------------
class ChessEnv:
    def __init__(self):
        self.board = chess.Board()

    def reset(self):
        self.board.reset()
        return self.get_state()

    def step(self, action):
        if action not in self.board.legal_moves:
            return self.get_state(), -1, True

        self.board.push(action)

        if self.board.is_game_over():
            return self.get_state(), self.get_reward(), True

        opponent_action = random.choice(list(self.board.legal_moves))
        self.board.push(opponent_action)

        if self.board.is_game_over():
            return self.get_state(), self.get_reward(), True

        return self.get_state(), 0, False

    def get_state(self):
        return self.board.fen()

    def get_reward(self):
        outcome = self.board.outcome()
        if outcome.winner is None:
            return 0.5
        return 1 if outcome.winner == chess.WHITE else -1

    def index_to_move(self, index):
        return chess.Move(index // 64, index % 64)

    def move_to_index(self, move):
        return move.from_square * 64 + move.to_square

    def legal_action_indices(self):
        return [self.move_to_index(m) for m in self.board.legal_moves]

    def display_board(self):
        grid = [["." for _ in range(8)] for _ in range(8)]

        for square, piece in self.board.piece_map().items():
            row = 7 - chess.square_rank(square)
            col = chess.square_file(square)
            grid[row][col] = piece.symbol()

        print("\n  a b c d e f g h")
        print("  ----------------")

        for i, row in enumerate(grid):
            print(f"{8 - i} | " + " ".join(row))

        print()


# ----------------------------
# Replay Buffer
# ----------------------------
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def add(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def size(self):
        return len(self.buffer)


# ----------------------------
# CNN Q-Network + Encoder
# ----------------------------
class QNetwork(nn.Module):
    def __init__(self, output_dim=4096):
        super().__init__()

        self.conv1 = nn.Conv2d(12, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, output_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        q_values = self.fc2(x)
        return q_values

    def encode(self, fen_list):
        piece_to_channel = {
            'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
            'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }

        batch = []

        for fen in fen_list:
            board = chess.Board(fen)
            tensor = np.zeros((12, 8, 8), dtype=np.float32)

            for square, piece in board.piece_map().items():
                channel = piece_to_channel[piece.symbol()]
                row = 7 - chess.square_rank(square)
                col = chess.square_file(square)
                tensor[channel, row, col] = 1.0

            batch.append(tensor)

        return torch.tensor(batch, dtype=torch.float32, device=device)


# ----------------------------
# DQN Agent
# ----------------------------
class DQNAgent:
    def __init__(
        self, env, q_net, target_net, optimizer,
        batch_size=32, gamma=0.99,
        epsilon=1.0, epsilon_min=0.05, epsilon_decay=0.995,
        replay_capacity=50000,
        warmup_steps=500
    ):
        self.env = env
        self.q_net = q_net
        self.target_net = target_net
        self.optimizer = optimizer

        self.batch_size = batch_size
        self.gamma = gamma

        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay

        self.memory = ReplayBuffer(replay_capacity)
        self.action_size = 4096
        self.warmup_steps = warmup_steps

    def choose_action(self, state):
        legal = self.env.legal_action_indices()

        if random.random() < self.epsilon:
            return random.choice(legal)

        self.q_net.eval()
        with torch.no_grad():
            state_tensor = self.q_net.encode([state])
            q_values = self.q_net(state_tensor)[0].detach().cpu().numpy()
        self.q_net.train()

        masked = np.full(self.action_size, -np.inf)
        for idx in legal:
            masked[idx] = q_values[idx]

        return int(np.argmax(masked))

    def store_transition(self, state, action, reward, next_state, done):
        self.memory.add((state, action, reward, next_state, done))

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

    def update_target_network(self):
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

    def update(self):
        if self.memory.size() < max(self.warmup_steps, self.batch_size):
            return

        batch = self.memory.sample(self.batch_size)

        states      = [t[0] for t in batch]
        actions     = torch.tensor([t[1] for t in batch], dtype=torch.long, device=device)
        rewards     = torch.tensor([t[2] for t in batch], dtype=torch.float32, device=device)
        next_states = [t[3] for t in batch]
        dones       = torch.tensor([t[4] for t in batch], dtype=torch.float32, device=device)

        states_t      = self.q_net.encode(states)
        next_states_t = self.q_net.encode(next_states)

        q_values = self.q_net(states_t)
        Q_current = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            next_q = self.target_net(next_states_t)
            max_next_q = next_q.max(dim=1)[0]
            Q_target = rewards + self.gamma * max_next_q * (1 - dones)

        loss = F.mse_loss(Q_current, Q_target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


# ----------------------------
# Training Loop
# ----------------------------
def train(agent, env, num_episodes=1500, target_update_frequency=10):

    reward_window = deque(maxlen=50)

    for episode in range(num_episodes):

        state = env.reset()
        done = False
        total_reward = 0

        while not done:
            action_idx = agent.choose_action(state)
            move = env.index_to_move(action_idx)

            next_state, reward, done = env.step(move)
            agent.store_transition(state, action_idx, reward, next_state, done)

            agent.update()

            state = next_state
            total_reward += reward

        reward_window.append(total_reward)
        avg_reward = sum(reward_window) / len(reward_window)

        agent.decay_epsilon()

        if episode % target_update_frequency == 0:
            agent.update_target_network()

        if episode % 50 == 0:
            env.display_board()

        print(
            f"Episode {episode} | "
            f"Reward {total_reward:.2f} | "
            f"Avg50 {avg_reward:.3f} | "
            f"Îµ {agent.epsilon:.3f}"
        )


# ----------------------------
# Initialize Training
# ----------------------------
env = ChessEnv()

q_net = QNetwork().to(device)
target_net = QNetwork().to(device)
target_net.load_state_dict(q_net.state_dict())
target_net.eval()
q_net.train()

optimizer = optim.Adam(q_net.parameters(), lr=1e-3)
agent = DQNAgent(env, q_net, target_net, optimizer)

train(agent, env, num_episodes=1500)

















