In [1]:
import torch
import torch.nn as nn
from generate_training_data import ChessDataset
from torch.utils.data import DataLoader
import numpy as np

In [2]:
class ChessNet(nn.Module):
    def __init__(self):
        super(ChessNet, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(14 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 256)
        self.out1 = nn.Linear(256, 64)  # source square of the piece it wishes to move
        self.out2 = nn.Linear(256, 64)  # target square where it wants to piece to land

    def forward(self, x):
        x = self.relu(self.fc2(self.relu(self.fc1(x))))
        from_square = self.out1(x)
        to_square = self.out2(x)
        return from_square, to_square


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # Learning rate scheduler
loss_fn = nn.CrossEntropyLoss()

In [3]:
class EarlyStopping:
    """
    Will stop the training if loss doesn't improve for a given number of epochs'
    """
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, prev_loss, curr_loss):
        if abs(curr_loss - prev_loss) < self.min_delta:
            self.counter += 1
            if self.counter >= self.tolerance:  
                self.early_stop = True
                
early_stopping = EarlyStopping(tolerance=2, min_delta=0.0001)

In [5]:
batch_size = 128
train_data = ChessDataset(num_examples=32768)
train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

In [6]:
num_epochs = 100
prev_loss = float("inf")
for epoch in range(num_epochs):
    
    for batch, (x_train, y_train) in enumerate(train_data_loader):
        batch_x = x_train.to(device)
        sources = y_train[:, 0].to(device)
        destinations = y_train[:, 1].to(device)
        
        batch_x = batch_x.reshape(batch_size, 14 * 8 * 8)

        optimizer.zero_grad()
        predicted_source, predicted_destination = model(batch_x)

        loss_from = loss_fn(predicted_source, sources)
        loss_to = loss_fn(predicted_destination, destinations)

        loss = loss_from + loss_to
        loss.backward()
        optimizer.step()

    scheduler.step()  # Update the learning rate
    
    early_stopping(curr_loss=loss.item(), prev_loss=prev_loss)
    if early_stopping.early_stop:
        print("Early stopping")
        break
    prev_loss = loss.item()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch + 1}/{num_epochs} Loss: {loss.item():.4f}')
        
print("Final loss: ", loss.item())


Epoch 1/100 Loss: 6.9434
Epoch 11/100 Loss: 2.8964
Epoch 21/100 Loss: 2.7492
Epoch 31/100 Loss: 2.4871
Epoch 41/100 Loss: 2.5214
Epoch 51/100 Loss: 2.4575
Epoch 61/100 Loss: 2.5155
Epoch 71/100 Loss: 2.6629
Epoch 81/100 Loss: 2.6233
Epoch 91/100 Loss: 2.4574
Final loss:  2.733704090118408


In [9]:
PATH = 'chess_net.pth'
torch.save(model.state_dict(), PATH)

In [6]:
import chess.engine

engine = chess.engine.SimpleEngine.popen_uci(r"C:\Users\jaint\stockfish\stockfish-windows-x86-64-avx2")  
# stockfish's evaluation for a position will be the reward for the RL algorithm

def evaluate_board(board):
    result = engine.analyse(board, chess.engine.Limit(time=0.1))  # gives stockfish score of the current position (scaled up by 100)
    evaluation = result["score"]
    if evaluation.is_mate():  # score() returns None if the position has forced mate - so it is handled separately
        plies = evaluation.pov(chess.WHITE).mate()
        if plies > 0:  # White is the one checkmating
            return (21 - plies) * 100  # return a large positive score that decays with the number of moves till mate
        return (-21 - plies) * 100  # Black is the one checkmating
    return result["score"].relative.score()


In [3]:
model = ChessNet().to(device)
model.load_state_dict(torch.load('chess_net.pth'))

<All keys matched successfully>

In [10]:
import random
from chess_board import get_chess_board, square_to_uci
memory = []
gamma = 0.99
batch_size = 64

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


def choose_action(board):
    if random.random() < 0.2:  # 20% chance to make a random legal move
        return random.choice(list(board.legal_moves))
    else:
        tensor = torch.from_numpy(get_chess_board(board).reshape(-1, 14 * 8 * 8).astype(np.float32)).to(device)
        q_values = model(tensor)
        move_source, move_destination = q_values
        move_source = square_to_uci(torch.argmax(move_source, 1)[0].data.item())
        move_destination = square_to_uci(torch.argmax(move_destination, 1)[0].data.item())
        if move_source == move_destination:  # NULL move
            UCI = '0000'
        else:
            UCI = move_source + move_destination
        
        return chess.Move.from_uci(UCI)  # convert it to chess' Move class
    
    
def train_model():

    if len(memory) < batch_size:
        return
    batch = random.sample(memory, batch_size)
    states, rewards, next_states, dones = zip(*batch)

    states = torch.stack(states).to(device)  # stack all the states together to pass into model
    rewards = torch.tensor(rewards, dtype=torch.float32).reshape(batch_size, -1).to(device)  # Convert rewards to a tensor
    dones = torch.tensor(dones, dtype=torch.float32).reshape(batch_size, -1).to(device)  # Convert dones to a tensor

    q_values = model(states)  # Predict Q-values for current states
    from_squares = q_values[0]
    to_squares = q_values[1]

    from_squares = torch.argmax(from_squares, dim=2)  # get the actual square because MSE loss is used
    to_squares = torch.argmax(to_squares, dim=2)

    next_from_squares = torch.zeros(size=(batch_size, 1, 64)) 
    next_to_squares = torch.zeros(size=(batch_size, 1, 64))
    non_terminal_mask = torch.tensor([s is not None for s in next_states], dtype=torch.bool)  
    # in the case for terminal states (i.e, no next state) the next state is None so it needs to be ignored

    if non_terminal_mask.any():
        non_terminal_next_states = torch.stack([s for s in next_states if s is not None])
        next_from_squares[non_terminal_mask], next_to_squares[non_terminal_mask] = model(non_terminal_next_states)


    next_from_squares = torch.argmax(next_from_squares, dim=2)
    next_to_squares = torch.argmax(next_to_squares, dim=2)

    target_from_values = rewards + gamma * next_from_squares * (1 - dones)  # bellman equation
    target_to_values = rewards + gamma * next_to_squares * (1 - dones)

    loss_f = loss_fn(from_squares, target_from_values)
    loss_t = loss_fn(to_squares, target_to_values)

    loss = loss_f + loss_t
    loss.requires_grad = True  # torch.argmax has requires_grad as False, so I explicitly set it to True
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


for episode in range(10000):
    board = chess.Board()
    # Make two random moves so that model doesn't just play the same game everytime
    board.push(random.choice(list(board.legal_moves)))
    board.push(random.choice(list(board.legal_moves)))
    while not board.is_game_over():
        state = torch.from_numpy(get_chess_board(board).reshape(-1, 14 * 8 * 8).astype(np.float32))
        action = choose_action(board)

        if action in board.legal_moves:
            board.push(action)
            next_state = torch.from_numpy(get_chess_board(board).reshape(-1, 14 * 8 * 8).astype(np.float32))
            reward = evaluate_board(board)
            done = board.is_game_over()
        else:
            reward = -15000
            next_state = None
            done = 1
        memory.append((state, reward, next_state, done))
        if done:
            break
    train_model()
    if episode % 500 == 0:
        print(f"Episode: {episode + 1}")


Episode: 1
Episode: 501
Episode: 1001
Episode: 1501
Episode: 2001
Episode: 2501
Episode: 3001
Episode: 3501
Episode: 4001
Episode: 4501
Episode: 5001
Episode: 5501
Episode: 6001
Episode: 6501
Episode: 7001
Episode: 7501
Episode: 8001
Episode: 8501
Episode: 9001
Episode: 9501


In [4]:
PATH = 'chess_net_linear_RL.pth'
torch.save(model.state_dict(), PATH)

In [4]:
# Testing

import chess
from chess_board import get_chess_board, square_to_uci
new_board = chess.Board()

with torch.no_grad():
    
    while not new_board.is_game_over():
        featurized = torch.from_numpy(get_chess_board(new_board).reshape(-1, 14 * 8 * 8).astype(np.float32))
        
        predicted_source, predicted_destination = model(featurized)
        source = square_to_uci(torch.argmax(predicted_source, 1)[0].data.item())
        destination = square_to_uci(torch.argmax(predicted_destination, 1)[0].data.item())
        
        uci = source + destination
        print(uci)
        new_board.push_uci(uci)
        

e2e4
e7e5
g1f3
b8c6
f1c4
g8f6
d2d3
f8d6
e1g1
e8g8
b1c3
a7c5


IllegalMoveError: illegal uci: 'a7c5' in r1bq1rk1/pppp1ppp/2nb1n2/4p3/2B1P3/2NP1N2/PPP2PPP/R1BQ1RK1 b - - 4 6