In [4]:
!pip install chess

Collecting chess
  Using cached chess-1.10.0-py3-none-any.whl.metadata (19 kB)
Using cached chess-1.10.0-py3-none-any.whl (154 kB)
Installing collected packages: chess
Successfully installed chess-1.10.0


In [1]:
import torch
import torch.nn as nn
from generate_training_data import load_dataset
import numpy as np

In [2]:
class ChessNet(nn.Module):
    def __init__(self):
        super(ChessNet, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(14 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 256)
        self.out1 = nn.Linear(256, 64)
        self.out2 = nn.Linear(256, 64)

    def forward(self, x):
        x = self.relu(self.fc2(self.relu(self.fc1(x))))
        from_square = self.out1(x)
        to_square = self.out2(x)
        return from_square, to_square


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # Learning rate scheduler
loss_fn = nn.CrossEntropyLoss()

In [53]:
class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, prev_loss, curr_loss):
        if abs(curr_loss - prev_loss) < self.min_delta:
            self.counter += 1
            if self.counter >= self.tolerance:  
                self.early_stop = True
                
early_stopping = EarlyStopping(tolerance=2, min_delta=0.0001)

In [54]:
x_train, y_train = load_dataset(2048)
x_train = torch.from_numpy(x_train.reshape(-1, 14 * 8 * 8).astype(np.float32)).to(device)
y_train_from = torch.tensor([move[0] for move in y_train], dtype=torch.long).to(device)
y_train_to = torch.tensor([move[1] for move in y_train], dtype=torch.long).to(device)

In [55]:
num_epochs = 100
num_examples = len(x_train)
batch_size = 16
prev_loss = float("inf")
for epoch in range(num_epochs):
    
    for i in range(0, len(x_train), batch_size):
        batch_x = x_train[i:i + batch_size]
        batch_y_from = y_train_from[i:i + batch_size]
        batch_y_to = y_train_to[i:i + batch_size]

        optimizer.zero_grad()
        predicted_source, predicted_destination = model(batch_x)

        loss_from = loss_fn(predicted_source, batch_y_from)
        loss_to = loss_fn(predicted_destination, batch_y_to)

        loss = loss_from + loss_to
        loss.backward()
        optimizer.step()

    scheduler.step()  # Update the learning rate
    
    early_stopping(curr_loss=loss.item(), prev_loss=prev_loss)
    if early_stopping.early_stop:
        print("Early stopping")
        break
    prev_loss = loss.item()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch + 1}/{num_epochs} Loss: {loss.item():.4f}')
        
print("Final loss: ", loss.item())


Epoch 1/100 Loss: 4.8713
Epoch 11/100 Loss: 0.5512
Epoch 21/100 Loss: 0.1975
Epoch 31/100 Loss: 0.1901
Early stopping
Final loss:  0.1882476508617401


In [56]:
PATH = 'chess_net.pth'
torch.save(model.state_dict(), PATH)

In [57]:
import chess
from chess_board import get_chess_board, square_to_coord
new_board = chess.Board()

model = ChessNet().to(device)
model.load_state_dict(torch.load('chess_net.pth'))

with torch.no_grad():
    
    num_moves = 16
    while not new_board.is_game_over() and num_moves:
        featurized = torch.from_numpy(get_chess_board(new_board).reshape(-1, 14 * 8 * 8).astype(np.float32))
        
        predicted_source, predicted_destination = model(featurized)
        source = square_to_coord(torch.argmax(predicted_source, 1)[0].data.item())
        destination = square_to_coord(torch.argmax(predicted_destination, 1)[0].data.item())
        
        uci = source + destination
        print(uci)
        new_board.push_uci(uci)
        
        # print(new_board)
        # print("\n")
        
        num_moves -= 1
    
    # predicted_source, predicted_destination = model(featurized)
    # source = square_to_coord(torch.argmax(predicted_source, 1)[0].data.item())
    # destination = square_to_coord(torch.argmax(predicted_destination, 1)[0].data.item())
    # 
    # uci = source + destination
    # new_board.push_uci(uci)
    

e2e4
e7e5
g1f3
b8c6
f1b5
a7a6
b5a4
g8f6
e1g1
f6e4
f1e1
e4c5
a4c6
d7c6
f3e5
d8e6


IllegalMoveError: illegal uci: 'd8e6' in r1bqkb1r/1pp2ppp/p1p5/2n1N3/8/8/PPPP1PPP/RNBQR1K1 b kq - 0 8

In [None]:
class ChessNetCNN(nn.Module):
    def __init__(self):
        super(ChessNetCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=14, out_channels=48, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=48, out_channels=96, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(p=0.2)
        
        self.fc1 = nn.Linear(0, 256)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))

In [6]:
import chess.engine

engine = chess.engine.SimpleEngine.popen_uci(r"C:\Users\jaint\stockfish\stockfish-windows-x86-64-avx2")

def evaluate_board(board):
    result = engine.analyse(board, chess.engine.Limit(time=0.1))  # gives stockfish score of the current position (scaled up by 100)
    evaluation = result["score"]
    if evaluation.is_mate():
        plies = evaluation.pov(chess.WHITE).mate()
        if plies > 0:
            return (31 - plies) * 100  # return a large positive score that decays with the number of moves till mate
        return (-31 - plies) * 100
    return result["score"].relative.score()


In [59]:
model = ChessNet().to(device)
model.load_state_dict(torch.load('chess_net.pth'))

<All keys matched successfully>

In [14]:
import random
from chess_board import get_chess_board
memory = []
gamma = 0.99
batch_size = 256

random.seed(42)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


def choose_action(board):
    if random.random() < 0.2:
        return random.choice(list(board.legal_moves))
    else:
        tensor = torch.from_numpy(get_chess_board(board).reshape(-1, 14 * 8 * 8).astype(np.float32)).to(device)
        q_values = model(tensor)
        move_source, move_destination = q_values
        move_source = square_to_coord(torch.argmax(move_source, 1)[0].data.item())
        move_destination = square_to_coord(torch.argmax(move_destination, 1)[0].data.item())
        
        UCI = move_source + move_destination
        
        return chess.Move.from_uci(UCI)
    
    
def train_model():

    if len(memory) < batch_size:
        return
    batch = random.sample(memory, batch_size)
    states, actions, rewards, next_states, dones = zip(*batch)

    states = torch.stack(states).to(device)  # Convert states to a batch tensor
    rewards = torch.tensor(rewards, dtype=torch.float32).reshape(batch_size, -1).to(device)  # Convert rewards to a tensor
    dones = torch.tensor(dones, dtype=torch.float32).reshape(batch_size, -1).to(device)  # Convert dones to a tensor
        
    q_values = model(states)  # Predict Q-values for current states
    from_squares = q_values[0]
    to_squares = q_values[1]
    
    from_squares = torch.argmax(from_squares, dim=2)
    to_squares = torch.argmax(to_squares, dim=2)
    
    next_from_squares = torch.zeros(size=(batch_size, 1, 64)) 
    next_to_squares = torch.zeros(size=(batch_size, 1, 64))
    non_terminal_mask = torch.tensor([s is not None for s in next_states], dtype=torch.bool)

    if non_terminal_mask.any():
        non_terminal_next_states = torch.stack([s for s in next_states if s is not None])
        next_from_squares[non_terminal_mask], next_to_squares[non_terminal_mask] = model(non_terminal_next_states)


    next_from_squares = torch.argmax(next_from_squares, dim=2)
    next_to_squares = torch.argmax(next_to_squares, dim=2)
        
    target_from_values = rewards + gamma * next_from_squares * (1 - dones)
    target_to_values = rewards + gamma * next_to_squares * (1 - dones)

    loss_f = loss_fn(from_squares, target_from_values)
    loss_t = loss_fn(to_squares, target_to_values)

    loss = loss_f + loss_t
    loss.requires_grad = True
    optimizer.zero_grad()
    loss.backward()  # Backpropagate the loss
    optimizer.step()  # Update the network weights


for episode in range(10000):
    board = chess.Board()
    while not board.is_game_over():
        state = torch.from_numpy(get_chess_board(board).reshape(-1, 14 * 8 * 8).astype(np.float32))
        action = choose_action(board)

        if action in board.legal_moves:
            board.push(action)
            next_state = torch.from_numpy(get_chess_board(board).reshape(-1, 14 * 8 * 8).astype(np.float32))
            reward = evaluate_board(board)
            done = board.is_game_over()
        else:
            reward = -10000
            next_state = None
            done = 1
        memory.append((state, action, reward, next_state, done))
        if done:
            break
    train_model()
    if episode % 200 == 0:
        print(f"Episode: {episode + 1}")


Episode: 1


KeyboardInterrupt: 

In [9]:
PATH = 'chess_net_linear_RL.pth'
torch.save(model.state_dict(), PATH)

In [12]:
model = ChessNet().to(device)
model.load_state_dict(torch.load('chess_net_linear_RL.pth'))

<All keys matched successfully>

In [13]:
import chess
from chess_board import get_chess_board, square_to_coord
new_board = chess.Board()

with torch.no_grad():
    
    while not new_board.is_game_over():
        featurized = torch.from_numpy(get_chess_board(new_board).reshape(-1, 14 * 8 * 8).astype(np.float32))
        
        predicted_source, predicted_destination = model(featurized)
        source = square_to_coord(torch.argmax(predicted_source, 1)[0].data.item())
        destination = square_to_coord(torch.argmax(predicted_destination, 1)[0].data.item())
        
        uci = source + destination
        print(uci)
        new_board.push_uci(uci)
        

e2e4
e7e5
g1f3
b8c6
f1b5
a7a6
b5a4
g8f6
e1g1
f6e4
f1e1
e4c5
a4c6
d7c6
f3e5
d8e6


IllegalMoveError: illegal uci: 'd8e6' in r1bqkb1r/1pp2ppp/p1p5/2n1N3/8/8/PPPP1PPP/RNBQR1K1 b kq - 0 8