In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import random
from chess_board import get_chess_board, square_to_uci, get_board_from_featurization
import chess
from generate_training_data import ChessDataset

In [None]:
train_data = ChessDataset(num_examples=2048)
train_data_loader = DataLoader(train_data, batch_size=64, shuffle=True)

In [None]:
torch.cuda.is_available()

In [None]:
class ChessNetCNN(nn.Module):
    def __init__(self, hidden_size):
        super(ChessNetCNN, self).__init__()
        self.hidden_size = hidden_size
        self.input = nn.Conv2d(in_channels=12, out_channels=hidden_size, kernel_size=3, padding=1)
        self.conv1 = nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(hidden_size)
        self.activation = nn.SELU()
        self.activation_fc = nn.ReLU()

        self.fc1 = nn.Linear(hidden_size * 8 * 8, 4096)
        self.out1 = nn.Linear(4096, 64)
        self.out2 = nn.Linear(4096, 64)

    def forward(self, x):
        x = self.activation(self.bn(self.input(x)))
        x = self.activation(self.bn(self.conv1(x)))
        x = self.activation(self.bn(self.conv2(x)))
        x = self.activation(self.bn(self.conv3(x)))
        x = self.activation(self.bn(self.conv4(x)))
        x = x.view(-1, self.hidden_size * 8 * 8)
        x = self.activation_fc(self.fc1(x))
        return self.out1(x), self.out2(x)
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
model = ChessNetCNN(128).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)  # decay by half every 10 epochs

In [None]:
num_epochs = 1000
for epoch in range(num_epochs):

    for batch, (board, move) in enumerate(train_data_loader):
        board = board.to(device)
        sources = move[:, 0].to(device)  # take the source square of each move
        destinations = move[:, 1].to(device)  # take the destination square of each move

        pred_sources, pred_destinations = model(board)
        loss_from = criterion(pred_sources, sources)
        loss_to = criterion(pred_destinations, destinations)

        optimizer.zero_grad()
        loss = loss_from + loss_to
        loss.backward()
        optimizer.step()

    scheduler.step()

    if epoch % 50 == 0:
        print(f'Epoch {epoch + 1}/{num_epochs} Loss: {loss.item():.4f}')
print(f"Final loss: {loss.item():.4f}")

In [None]:
# PATH = 'chess_net_CNN2.pth'
# torch.save(model.state_dict(), PATH)

In [None]:
import chess
from chess_board import get_chess_board, square_to_uci

new_board = chess.Board()

with torch.no_grad():
    count = 0
    try:
        while not new_board.is_game_over():
            featurized = torch.from_numpy(get_chess_board(new_board).reshape(1, 12, 8, 8).astype(np.float32)).to(device)
    
            predicted_source, predicted_destination = model(featurized)
            source = square_to_uci(torch.argmax(predicted_source, 1)[0].data.item())
            destination = square_to_uci(torch.argmax(predicted_destination, 1)[0].data.item())
    
            uci = source + destination
            print(uci)
            new_board.push_uci(uci)
            
            count += 1
    except chess.IllegalMoveError:
        print(f"Count={count}")

In [None]:
model = ChessNetCNN(128).to(device)
model.load_state_dict(torch.load('chess_net_CNN2.pth'))

In [None]:
import chess.engine

engine = chess.engine.SimpleEngine.popen_uci(r"C:\Users\jaint\stockfish\stockfish-windows-x86-64-avx2")
# stockfish's evaluation for a position will be the reward for the RL algorithm

def evaluate_board(board):
    result = engine.analyse(board, chess.engine.Limit(time=0.1))  # gives stockfish score of the current position (scaled up by 100)
    evaluation = result["score"]
    if evaluation.is_mate():  # score() returns None if the position has forced mate - so it is handled separately
        plies = evaluation.pov(chess.WHITE).mate()
        if plies > 0:  # White is the one checkmating
            return 21 - plies  # return a large positive score that decays with the number of moves till mate
        return -21 - plies  # Black is the one checkmating
    return result["score"].relative.score() / 100

In [None]:
from copy import deepcopy
target_network = deepcopy(model)

memory = []
max_memory = 10_000
epsilon = 0.2  # exploration chance
batch_size = 4
gamma = 0.99  # bellman equation constant

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def random_board(max_depth=30):
    depth = random.randint(0, max_depth)
    board = chess.Board()
    try:
        for _ in range(depth):
            board.push(random.choice(list(board.legal_moves)))
        return board
    except IndexError:
        return board


def choose_action(curr_board):
    legal_moves = list(curr_board.legal_moves)
    if random.random() < epsilon:
        return random.choice(legal_moves)

    tensor = torch.from_numpy(get_chess_board(board).reshape(1, 12, 8, 8).astype(np.float32)).to(device)
    move_source, move_destination = model(tensor)
    
    move_sources_sorted = torch.argsort(move_source, dim=1, descending=True)
    move_destinations_sorted = torch.argsort(move_destination, dim=1, descending=True)

    for source_idx in move_sources_sorted[0]:
        move_source = square_to_uci(source_idx.data.item())
        for dest_idx in move_destinations_sorted[0]:
            move_destination = square_to_uci(dest_idx.data.item())

            if move_source == move_destination:
                continue

            uci_move = move_source + move_destination
            chess_move = chess.Move.from_uci(uci_move)

            if chess_move in legal_moves:
                return chess_move 
            
    return random.choice(legal_moves)


def train():
    if len(memory) < batch_size:
        return

    batch = random.sample(memory, batch_size)
    next_states, states, actions, rewards, dones = zip(*batch)

    states = torch.stack(states)
    rewards = torch.tensor(rewards, dtype=torch.float32).reshape(batch_size).to(device)  # Convert rewards to a tensor
    dones = torch.tensor(dones, dtype=torch.float32).reshape(batch_size).to(device)  # Convert dones to a tensor
    
    
    target_from_values = torch.zeros(size=(batch_size, 64)).to(device)
    target_to_values = torch.zeros(size=(batch_size, 64)).to(device)
    non_terminal_mask = torch.tensor([s is not None for s in next_states], dtype=torch.bool)  

    if non_terminal_mask.any():
        non_terminal_next_states = torch.stack([torch.tensor(s, dtype=torch.float32) for s in next_states if s is not None])
        non_terminal_next_states = non_terminal_next_states.to(device)
        target_from_values[non_terminal_mask], target_to_values[non_terminal_mask] = target_network(non_terminal_next_states)
    
    target_from_values = torch.argmax(target_from_values, dim=1)
    target_to_values = torch.argmax(target_to_values, dim=1)
    
    target_from = rewards + (gamma * target_from_values * (1 - dones))
    target_to = rewards + (gamma * target_to_values * (1 - dones))
    
    current_q_from, current_q_to = model(states)
    current_q_from = torch.argmax(current_q_from, dim=1)
    current_q_to = torch.argmax(current_q_to, dim=1)

    loss_f = loss_fn(target_from, current_q_from)
    loss_t = loss_fn(target_to, current_q_to)

    loss = loss_f + loss_t
    loss.requires_grad = True
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


for episode in range(100):

    board = random_board()

    while not board.is_game_over():
        state = torch.from_numpy(get_chess_board(board).astype(np.float32)).reshape(12, 8, 8).to(device)
        action = choose_action(board)

        if action in list(board.legal_moves):
            reward = evaluate_board(board)
            done = board.is_game_over()
            board.push(action)
            memory.append((get_chess_board(board), state, action, reward, done))
        else:
            reward = -100
            done = 1
            memory.append((None, state, action, reward, done))

        if len(memory) > max_memory:
            memory.pop(0)

        if done:
            target_network.load_state_dict(model.state_dict())
            break
    train()
    print(f"Episode: {episode}")

In [None]:
PATH = 'chess_net_CNN_RL2.pth'
torch.save(model.state_dict(), PATH)

In [None]:
model = ChessNetCNN(64).to(device)
model.load_state_dict(torch.load('chess_net_CNN_RL2.pth'))

In [None]:
import chess
from chess_board import get_chess_board, square_to_uci

new_board = chess.Board()

with torch.no_grad():
    count = 0
    try:
        while not new_board.is_game_over():
            featurized = torch.from_numpy(get_chess_board(new_board).reshape(1, 12, 8, 8).astype(np.float32)).to(device)
    
            predicted_source, predicted_destination = model(featurized)
            source = square_to_uci(torch.argmax(predicted_source, 1)[0].data.item())
            destination = square_to_uci(torch.argmax(predicted_destination, 1)[0].data.item())
    
            uci = source + destination
            print(uci)
            new_board.push_uci(uci)
            
            count += 1
    except chess.IllegalMoveError:
        print(f"Count={count}")