# Q-Network implementation

In [9]:
import chess
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
from collections import deque
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Encode the board for Q-Net input

In [None]:
def encode_board(board):
    # Encode pieces position
    piece_map = np.zeros(64, dtype=np.float32)
    for i in range(64):
        piece = board.piece_at(i)
        if piece:
            piece_map[i] = piece.piece_type + (6 if piece.color else 0)
    
    # Current game state information
    turn = np.array([board.turn], dtype=np.float32)
    castling = np.array([
        board.has_kingside_castling_rights(chess.WHITE),
        board.has_queenside_castling_rights(chess.WHITE),
        board.has_kingside_castling_rights(chess.BLACK),
        board.has_queenside_castling_rights(chess.BLACK),
    ], dtype=np.float32)
    en_passant = np.array([board.ep_square or 64], dtype=np.float32)
    halfmove = np.array([board.halfmove_clock/50], dtype=np.float32)
    
    # standardize
    encoded = np.concatenate([
        piece_map/12,
        turn,
        castling,
        en_passant/64,
        halfmove
    ])
    return encoded

## Q-Network implementation

In [11]:
class QNetwork(nn.Module):
    def __init__(self, input_dim=73, hidden_dim=256):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        return self.fc(x)

## Initialize the model

In [None]:
q_net = QNetwork().to(device)
target_net = QNetwork().to(device)
target_net.load_state_dict(q_net.state_dict())
optimizer = optim.Adam(q_net.parameters(), lr=0.001) 

In [13]:
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, experience):
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

buffer = ReplayBuffer()

In [None]:
def get_best_move(board, network, time_limit=10):
    start_time = time.time()
    legal_moves = list(board.legal_moves)
    
    if not legal_moves:
        return None
    
    # Encode game state
    state = encode_board(board)
    state_tensor = torch.FloatTensor(state).to(device)
    
    # Calculate Q-value for legal move
    q_values = []
    for move in legal_moves:
        action = np.array([move.from_square/63, move.to_square/63], dtype=np.float32)
        network_input = torch.FloatTensor(np.concatenate([state, action])).unsqueeze(0).to(device)
        q_values.append(network(network_input).item())
    
    # Best move
    best_move = legal_moves[np.argmax(q_values)]
    
    # Ensure time limit
    elapsed = time.time() - start_time
    if elapsed < time_limit:
        time.sleep(time_limit - elapsed)
    
    return best_move

In [None]:
gamma = 0.99
epsilon = 1.0
batch_size = 128
sync_interval = 100

for episode in range(1000):
    board = chess.Board()
    while not board.is_game_over():
        # Current board fen
        current_fen = board.fen()
        
        # Randomzing moves
        if random.random() < epsilon:
            move = random.choice(list(board.legal_moves))
        else:
            move = get_best_move(board, q_net)
        
        # Move
        board.push(move)
        next_fen = board.fen()
        done = board.is_game_over()
        
        # reward
        if done:
            result = board.result()
            reward = 1 if result == "1-0" else -1 if result == "0-1" else 0
        else:
            reward = 0
        
        # Save exp to buffer
        buffer.push((current_fen, move, reward, next_fen, done))
    
    # Update model
    if len(buffer) >= batch_size:
        # Batch execution
        batch = buffer.sample(batch_size)
        
        states, actions, rewards, next_states, dones = [], [], [], [], []
        for exp in batch:
            states.append(encode_board(chess.Board(exp[0])))
            actions.append([exp[1].from_square/63, exp[1].to_square/63])
            rewards.append(exp[2])
            next_states.append(encode_board(chess.Board(exp[3])))
            dones.append(exp[4])
        

        # Convert to numpy arrays before switch to tensor (warned when not so do this)
        states = np.array(states)  
        actions = np.array(actions)
        rewards = np.array(rewards)
        next_states = np.array(next_states)
        dones = np.array(dones)


        # Convert to tensor
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        inputs = torch.cat([states, actions], dim=1)
        
        # Q-values
        current_q = q_net(inputs).squeeze()
        
        # target Q-values
        next_q = []
        with torch.no_grad():
            next_states = torch.FloatTensor(next_states).to(device)
            for i in range(batch_size):
                if dones[i]:
                    next_q.append(rewards[i])
                else:
                    temp_board = chess.Board()
                    temp_board.set_fen(batch[i][3])  # get FEN from buffer
                    legal_moves = list(temp_board.legal_moves)
                    
                    if not legal_moves:
                        next_q.append(0)
                        continue
                        
                    # Q-values for next state
                    q_values = []
                    for move in legal_moves:
                        action = np.array([move.from_square/63, move.to_square/63])
                        network_input = torch.FloatTensor(
                            np.concatenate([next_states[i].cpu().numpy(), action])
                        ).unsqueeze(0).to(device)
                        q_values.append(target_net(network_input).item())
                    
                    next_q.append(rewards[i] + gamma * max(q_values))
        
        # update model
        loss = nn.MSELoss()(current_q, torch.FloatTensor(next_q).to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Update target network
    if episode % sync_interval == 0:
        target_net.load_state_dict(q_net.state_dict())
    
    # Decrease epsilon gradually
    epsilon = max(0.1, epsilon * 0.995)

In [None]:
def play_game():
    board = chess.Board()
    while not board.is_game_over():
        move = get_best_move(board, q_net)
        board.push(move)
        print(board)
        print("---")
    print("Kết quả:", board.result())

play_game()