In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import random
from chess_board import get_chess_board
import chess
from generate_training_data import ChessDataset

In [2]:
class ChessNetCNN(nn.Module):
    def __init__(self):
        super(ChessNetCNN, self).__init__()

        # Convolutional layers
        self.input = nn.Conv2d(in_channels=12, out_channels=32, kernel_size=5, padding=2)  # size after = (batch, 32, 8, 8)
        self.conv1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)  # size after = (batch, 64, 8, 8)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)  # size after = (batch, 128, 8, 8)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.activation = nn.ReLU()

        # Fully connected layers
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.dropout = nn.Dropout(0.5)
        self.out = nn.Linear(1024, 4096)

    def forward(self, x):
        # Convolutional layers
        x = self.activation(self.input(x))
        x = self.activation(self.bn1(self.conv1(x)))
        x = self.activation(self.bn2(self.conv2(x)))
        x = x.view(-1, 128 * 8 * 8)

        # Fully connected layers
        x = self.dropout(F.relu(self.fc1(x)))
        return self.out(x)
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = ChessNetCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)  # decay by half every 20 epochs

In [None]:
train_data = ChessDataset(num_examples=8192)
train_data_loader = DataLoader(train_data, batch_size=64, shuffle=True)

In [None]:
num_epochs = 200
for epoch in range(num_epochs):

    for batch, (board, move) in enumerate(train_data_loader):
        board = board.to(device)
        sources = move[:, 0].to(device)  # take the source square of each move
        destinations = move[:, 1].to(device)  # take the destination square of each move

        target_array = torch.zeros((sources.shape[0], 4096)).to(device)
        indices = (64 * sources) + destinations
        target_array[torch.arange(sources.shape[0]), indices] = 1

        predicted_array = model(board)

        loss = criterion(predicted_array, target_array)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    scheduler.step()

    if epoch % 25 == 0:
        print(f'Epoch {epoch + 1}/{num_epochs} Loss: {loss.item():.4f}')
print(f"Final loss: {loss.item():.4f}")

In [None]:
# PATH = 'new_net.pth'
# torch.save(model.state_dict(), PATH)

In [3]:
model = ChessNetCNN().to(device)
model.load_state_dict(torch.load('new_net.pth', weights_only=True))

<All keys matched successfully>

In [4]:
# Checking how well the model can play after initial training - goal is for it to play legal moves, regardless of how good or bad they are

board = chess.Board()
uci_moves = []
with torch.no_grad():
    model.eval()
    count = 0
    try:
        while not board.is_game_over():
            featurized = torch.from_numpy(get_chess_board(board).reshape(1, 12, 8, 8).astype(np.float32)).to(device)
    
            predicted_array = model(featurized)
            predicted_move = torch.argmax(predicted_array, dim=1).item()
            source = chess.square_name(predicted_move // 64)
            destination = chess.square_name(predicted_move % 64)
    
            uci = source + destination
            uci_moves.append(uci)
            board.push_uci(uci)
            
            count += 1
            
    except chess.IllegalMoveError:
        print(f"Count={count}")


Count=41


In [None]:
import chess.pgn

game = chess.pgn.Game()

# metadata
game.headers["Event"] = "Example Game"
game.headers["Site"] = "None"
game.headers["Date"] = "2024.12.21"
game.headers["Round"] = "1"
game.headers["White"] = "Chess Net"
game.headers["Black"] = "Also Chess Net"
game.headers["Result"] = "*"

# Add moves to the PGN
node = game
for move in board.move_stack:
    node = node.add_variation(move)

with open("game.pgn", "w") as f:
    print(game, file=f)

In [4]:
import chess.engine

engine = chess.engine.SimpleEngine.popen_uci(r"C:\Users\jaint\stockfish\stockfish-windows-x86-64-avx2")

def evaluate_board(board, color=1):
    """
    :param board: chess.Board, position to be evaluated
    :param color: int, perspective of the bot, 0 = black and 1 = white
    :return: score: int, stockfish evaluation of the position
    """
    try:
        result = engine.analyse(board, chess.engine.Limit(depth=15, time=0.2))  # gives stockfish score (scaled up by 100)
        evaluation = result["score"]

        if evaluation.is_mate():  # score returns None if the position has forced mate
            plies = evaluation.pov(chess.WHITE).mate()
            if plies > 0:  # White is the one checkmating
                score =  100 - plies  # return a large positive score that decays with the number of moves till mate
            else:
                score =  (-100 - plies)  # Black is the one checkmating
        else:
            score = evaluation.relative.score() / 100

        return score if color else -score
    except:
        return 0

In [5]:
def random_board(max_depth=20):

    depth = random.randint(0, max_depth)
    board = chess.Board()
    try:
        for _ in range(depth):
            board.push(random.choice(list(board.legal_moves)))
        return board
    except IndexError:
        return board

In [None]:
from copy import deepcopy
from collections import deque

target_network = deepcopy(model)
checkpoint_path = 'RL_checkpoint.pth'
max_memory = 10_000
memory = deque(maxlen=max_memory)  # automatically discards earlier entries when max memory is hit
epsilon = 0.2  # exploration chance
batch_size = 16
gamma = 0.995  # bellman equation constant

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


def choose_action(curr_board):
    legal_moves = list(curr_board.legal_moves)
    if random.random() < epsilon:  # exploration
        random_move = random.choice(legal_moves)
        src = random_move.from_square
        dest = random_move.to_square
        index = src * 64 + dest
        return random_move, index

    tensor = torch.from_numpy(get_chess_board(curr_board).reshape(1, 12, 8, 8).astype(np.float32)).to(device)
    move_distribution = model(tensor)
    move_distribution = F.softmax(move_distribution, dim=1)

    legal_move_mask = torch.zeros(4096, device=device)
    for move in legal_moves:
        src = move.from_square
        dst = move.to_square
        idx = src * 64 + dst
        legal_move_mask[idx] = 1

    masked_distribution = move_distribution * legal_move_mask  # set probability of illegal moves to 0
    best_move_index = torch.argmax(masked_distribution, dim=1).item()
    source = chess.square_name(best_move_index // 64)
    destination = chess.square_name(best_move_index % 64)
    return chess.Move.from_uci(source + destination), best_move_index


def train():
    if len(memory) < batch_size:
        return

    batch = random.sample(memory, batch_size)
    next_states, states, actions, rewards, dones = zip(*batch)

    states = torch.stack(states)
    actions = torch.tensor(actions, dtype=torch.int64, device=device).reshape(batch_size)
    rewards = torch.tensor(rewards, dtype=torch.float32).reshape(batch_size).to(device)  # Convert rewards to a tensor
    dones = torch.tensor(dones, dtype=torch.float32).reshape(batch_size).to(device)  # Convert dones to a tensor

    q_values = model(states)

    non_terminal_mask = torch.tensor([s is not None for s in next_states], dtype=torch.bool)  # a mask to filter out null next_states
    next_q_values = torch.zeros(batch_size, device=device)

    if non_terminal_mask.any():
        non_terminal_next_states = torch.stack([torch.tensor(s, dtype=torch.float32) for s in next_states if s is not None]).to(device)
        next_q_values[non_terminal_mask] = torch.max(target_network(non_terminal_next_states), dim=1).values

    target_q_values = rewards + gamma * next_q_values * (1 - dones)
    predicted_q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

    loss = loss_fn(target_q_values, predicted_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


for episode in range(10):

    board = random_board()

    while not board.is_game_over():
        state = torch.from_numpy(get_chess_board(board).astype(np.float32)).reshape(12, 8, 8).to(device)
        action, i = choose_action(board)
        reward = evaluate_board(board)
        done = board.is_game_over()

        if not done:
            board.push(action)
            memory.append((get_chess_board(board), state, i, reward, done))
        else:
            memory.append((None, state, i, reward, done))

    if episode % 5 == 0:
        train()
        epsilon = max(0.05, epsilon * 0.995)

    if episode % 200 == 0:
        torch.save({'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'target_network_state_dict': target_network.state_dict(),
                    'epsilon': epsilon,
                    'memory': memory}, checkpoint_path)
    if episode % 1 == 0:
        target_network.load_state_dict(model.state_dict())
        print(f"Episode: {episode}")

In [None]:
PATH = 'RL_savepoint.pth'
torch.save(model.state_dict(), PATH)

In [None]:
model = ChessNetCNN().to(device)
model.load_state_dict(torch.load('chess_net_CNN_RL2.pth'))

In [None]:
import chess.pgn

board = chess.Board()
uci_moves = []
with torch.no_grad():
    model.eval()
    count = 0
    try:
        while not board.is_game_over():
            featurized = torch.from_numpy(get_chess_board(board).reshape(1, 12, 8, 8).astype(np.float32)).to(device)

            predicted_array = model(featurized)
            predicted_move = torch.argmax(predicted_array, dim=1).item()
            source = chess.square_name(predicted_move // 64)
            destination = chess.square_name(predicted_move % 64)

            uci = source + destination
            uci_moves.append(uci)
            board.push_uci(uci)

            count += 1

    except chess.IllegalMoveError:
        print(f"Count={count}")


game = chess.pgn.Game()

# metadata
game.headers["Event"] = "Example Game"
game.headers["Site"] = "None"
game.headers["Date"] = "2024.12.21"
game.headers["Round"] = "1"
game.headers["White"] = "Chess Net"
game.headers["Black"] = "Also Chess Net"
game.headers["Result"] = "*"

# Add moves to the PGN
node = game
for move in board.move_stack:
    node = node.add_variation(move)

with open("game_one.pgn", "w") as f:
    print(game, file=f)