In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
import chess
from time import time
import tqdm

HEADERS = ("bitmaps", "movePlayed", "validMoves")
BATCH_SIZE = 64

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class PositionEvaluatorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        return torch.tanh(self.fc2(x))  # Output between -1 and 1
    

class ChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])
        self.conv1 = nn.Conv2d(13,64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8 * 8 * 128, 256)
        self.fc2 = nn.Linear(256, 64)
        self.relu = nn.ReLU()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x

class CompleteChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])
        self.conv1 = nn.Conv2d(13,64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(8 * 8 * 128, 256)
        self.fc2 = nn.Linear(256, 64 * 63) # (Choose 2 squares from the board where the order matters) 
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.flatten = nn.Flatten()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.sigmoid(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x
## Idea:
##   One model that answers "best piece to move in this position"
##   Then another model that answers "best square to move piece X to"

# Add normalization after conv
# Switch from relu to sigmoid or smth
# Add more preprocessing by making the tensors there and avoid pre processing before training

In [None]:
import json

piece_to_idx = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def board_to_tensor(board):
    tensor = np.zeros((12, 8, 8), dtype=np.uint8)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            idx = piece_to_idx[piece.symbol()]
            row = 7 - square // 8
            col = square % 8
            tensor[idx, row, col] = 1
    return tensor

def convert_to_array(row: str):
    """"
    Converts board into array

    :param str row: Board with side to play;
                        This must represent a np.ndarray[np.int64, shape=(13)] 
    :return: np.ndarray[shape(13, 8, 8), dtype=np.float32]]:
    """
    boards = np.array(json.loads(row), dtype=np.uint64) # shape = (13,)
    array = np.empty((13, 8, 8), dtype=np.uint8)

    # Set pieces boards
    board_int8_view = boards.view(dtype=np.uint8).reshape((13, 8)) # shape = (13, 8)
    board_as_int = np.unpackbits(board_int8_view, axis=1).reshape((13, 8, 8))
    array[:] = board_as_int

    # Set side to play board
    array[12] = np.ones(shape = (1, 8, 8)) if boards[12] == 1 else np.zeros(shape=(1,8,8))
    return array


letters = ["a", "b", "c", "d", "e", "f", "g", "h"]
numbers = list(range(1, 10)) # [1..9]
MOVE_DICTIONARY = {}
cumulative = 0
for i in range(8):
    for j in range(8):
        for k in range(8):
            for w in range (8):
                if (i == k and j == w):
                    cumulative += 1
                    continue
                from_square = f"{letters[i]}{numbers[j]}"
                to_square = f"{letters[k]}{numbers[w]}"
                MOVE_DICTIONARY[f"{from_square}{to_square}"] = (i * 8**3) + (j * 8**2) + (k * 8) + w - cumulative
REVERSE_MOVE_DICTIONARY = {
    value: key for key,value in MOVE_DICTIONARY.items()
}

In [4]:
from typing import Literal
import polars as pl

class ChessEvalDataset(Dataset):
# class ChessEvalDataset(IterableDataset):
    def __init__(self, file: str, model: Literal["pieces", "moves"] = "pieces", load_batch_size = 6_400):
        self.model = model
        self.lazy_dataset = pl.scan_csv(file, has_header=False, new_columns=HEADERS)
        self.batch_size = load_batch_size
        self.feature_col = "bitmaps"
        self.target_col = "movePlayed"
        self.total_rows = self.lazy_dataset.select(pl.len()).collect().item()

        self.cached_batches: dict[int, tuple] = {}
        self.cached_batch_id: int | None = None

    def __len__(self):
        return self.total_rows
    
    def _get_batch(self, batch_id):
        """Load a specific batch of data, wrapped with lru_cache for memory management"""
        # Calculate batch range
        if batch_id == self.cached_batch_id:
            return self.cached_batches[batch_id]
        self.cached_batches.pop(self.cached_batch_id, None) # Delete old batch
        
        # Calculate batch range
        start_idx = batch_id * self.batch_size
        end_idx = min((batch_id + 1) * self.batch_size, self.total_rows)
        
        # Fetch only this batch of data using offset and limit
        batch_dataset = (self.lazy_dataset
                    .slice(start_idx, end_idx - start_idx)
                    .collect())
        
        # Process features and target
        features = batch_dataset.select(self.feature_col)
        features = torch.tensor(np.array([convert_to_array(bitmaps) for bitmaps in features["bitmaps"]]), dtype=torch.float32)
        
        played_moves = batch_dataset.select(self.target_col).to_numpy()
        # valid_moves = batch_dataset.select("validMoves").to_numpy()
        valid_moves = np.zeros(shape=np.shape(played_moves))

        targets = np.array([played_moves, valid_moves], dtype=np.float32)
        targets = torch.tensor(targets, dtype=torch.long)

        self.cached_batch_id = batch_id
        self.cached_batches[self.cached_batch_id] = (features, targets)
        
        return self.cached_batches[batch_id]

    def __getitem__(self, idx):
        # Calculate which batch this index belongs to
        batch_id = idx // self.batch_size
        # Get the batch
        features, targets = self._get_batch(batch_id)
        # Get the item from the batch
        idx_in_batch = idx % self.batch_size

        return features[idx_in_batch], targets[:, idx_in_batch]

In [None]:
VALID_MOVE_LOSS = -0.5
INVALID_MOVE_LOSS = +10

piece_to_idx = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def bitmaps_to_board(bitmaps):
    board = chess.Board(fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1")
    for piece_name, piece_idx in piece_to_idx.items():
        for row in range(8):
            for col in range(8):
                if bitmaps[piece_idx][row][col] == 1:
                    piece = chess.Piece(piece_idx, chess.WHITE) if piece_idx < 6 else chess.Piece(piece_idx - 6, chess.BLACK)
                    board.set_piece_at(row * 8 + col, piece)
    return board

# LOSS FUNCTION
cross_entropy_loss = torch.nn.CrossEntropyLoss()
def loss_fn(outputs: torch.Tensor, targets, valid_moves): # targets tensor with twp elements: [targets, possible_moves]
    loss = cross_entropy_loss(outputs, targets)

    played_move = [output.argmax().item() for output in outputs]
    penalization = VALID_MOVE_LOSS
    for move, valid_moves in zip(played_move, valid_moves):
        num_idx = move // 64
        bit_idx = move % 64
        num = valid_moves[num_idx]
        if (num >> bit_idx) & 1 == 0:
            penalization += INVALID_MOVE_LOSS
        else:
            penalization += VALID_MOVE_LOSS
    return torch.add(loss, penalization)

In [None]:
DATASET_PATH = '../dataset/processed/results_with_valid_moves_no_skip.csv'
NUM_EPOCHS = 100

TRAINING_MODE = "pieces" # "pieces" or "moves"
MODEL_WEIGHTS_OUTPUT_PATH = "./models/CompleteModel_noskip_FINISHED.pth"

## !IMPORTANT: This dictates how much ram will be used, and how much data will be loaded
# 1_280_000 loads around 5gb, dont push this too high as it will crash if ram deplects
NUM_EXAMPLES_TO_LOAD_PER_FETCH = 1_280_000 

test = ChessEvalDataset(file = DATASET_PATH, model=TRAINING_MODE, load_batch_size = NUM_EXAMPLES_TO_LOAD_PER_FETCH)
loader = DataLoader(test, batch_size=64, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CompleteChessBotNetwork().to(device)

# Continue with pretrained weights
# model.load_state_dict(torch.load('./models/CompleteModel_Epoch-80.pth'))

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print(torch.cuda.is_available())
print("Using device: ", device)
for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    t0 = time()
    avg_loss = 0.0
    i = 1
    for board_tensor, target_eval in tqdm.tqdm(loader):
        valid_moves = target_eval[:, 1, :]
        target_eval = target_eval[:, 0, :] 
        
        board_tensor_gpu, target_eval_gpu = board_tensor.to(device), target_eval.to(device)  # Move data to GPU
        optimizer.zero_grad()
        pred = model(board_tensor_gpu)

        # Compute loss with valid move vlaidaiton
        # loss = loss_fn(board_tensor, pred, target_eval_gpu.squeeze(1))
        loss = loss_fn(pred, target_eval_gpu.squeeze(1), valid_moves)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        avg_loss += loss.item()
        i+=1

    tf = time()
    print(f"Epoch {epoch} - {avg_loss / len(loader):.4f} | Time: {tf-t0}")

    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"./models/CompleteModel_noskip_epoch-{epoch}.pth")
# Save the trained model
torch.save(model.state_dict(), MODEL_WEIGHTS_OUTPUT_PATH)

True
Using device:  cuda


100%|██████████| 51773/51773 [02:51<00:00, 302.72it/s]


Epoch 1 - 3.5338 | Time: 171.02689218521118


100%|██████████| 51773/51773 [02:50<00:00, 302.84it/s]


Epoch 2 - 2.5336 | Time: 170.96192598342896


100%|██████████| 51773/51773 [02:49<00:00, 305.70it/s]


Epoch 3 - 2.2681 | Time: 169.36138129234314


100%|██████████| 51773/51773 [02:49<00:00, 305.55it/s]


Epoch 4 - 2.1291 | Time: 169.44144797325134


100%|██████████| 51773/51773 [02:49<00:00, 305.53it/s]


Epoch 5 - 2.0450 | Time: 169.45404505729675


100%|██████████| 51773/51773 [02:48<00:00, 306.97it/s]


Epoch 6 - 1.9900 | Time: 168.66061568260193


100%|██████████| 51773/51773 [02:48<00:00, 306.48it/s]


Epoch 7 - 1.9508 | Time: 168.92704105377197


100%|██████████| 51773/51773 [02:49<00:00, 305.15it/s]


Epoch 8 - 1.9209 | Time: 169.66611075401306


100%|██████████| 51773/51773 [02:48<00:00, 306.87it/s]


Epoch 9 - 1.8971 | Time: 168.71227622032166


100%|██████████| 51773/51773 [02:49<00:00, 304.84it/s]


Epoch 10 - 1.8778 | Time: 169.83908343315125


100%|██████████| 51773/51773 [02:52<00:00, 300.21it/s]


Epoch 11 - 1.8618 | Time: 172.45528602600098


100%|██████████| 51773/51773 [02:50<00:00, 304.23it/s]


Epoch 12 - 1.8483 | Time: 170.18055033683777


100%|██████████| 51773/51773 [02:51<00:00, 302.49it/s]


Epoch 13 - 1.8364 | Time: 171.15522742271423


100%|██████████| 51773/51773 [02:49<00:00, 305.03it/s]


Epoch 14 - 1.8263 | Time: 169.73297119140625


100%|██████████| 51773/51773 [02:49<00:00, 306.31it/s]


Epoch 15 - 1.8172 | Time: 169.0251874923706


100%|██████████| 51773/51773 [02:53<00:00, 298.29it/s]


Epoch 16 - 1.8092 | Time: 173.5646812915802


100%|██████████| 51773/51773 [02:45<00:00, 312.00it/s]


Epoch 17 - 1.8018 | Time: 165.93786597251892


100%|██████████| 51773/51773 [02:49<00:00, 304.64it/s]


Epoch 18 - 1.7951 | Time: 169.9509949684143


100%|██████████| 51773/51773 [02:50<00:00, 304.21it/s]


Epoch 19 - 1.7891 | Time: 170.1909146308899


100%|██████████| 51773/51773 [02:49<00:00, 304.95it/s]


Epoch 20 - 1.7836 | Time: 169.7792990207672


100%|██████████| 51773/51773 [02:54<00:00, 297.36it/s]


Epoch 21 - 1.7786 | Time: 174.10879158973694


100%|██████████| 51773/51773 [02:48<00:00, 306.99it/s]


Epoch 22 - 1.7739 | Time: 168.64929223060608


100%|██████████| 51773/51773 [04:43<00:00, 182.43it/s] 


Epoch 23 - 1.7698 | Time: 283.80107975006104


100%|██████████| 51773/51773 [04:56<00:00, 174.41it/s] 


Epoch 24 - 1.7657 | Time: 296.8620948791504


100%|██████████| 51773/51773 [04:51<00:00, 177.54it/s] 


Epoch 25 - 1.7620 | Time: 291.61624121665955


100%|██████████| 51773/51773 [03:21<00:00, 256.58it/s]


Epoch 26 - 1.7585 | Time: 201.78068947792053


 90%|█████████ | 46618/51773 [02:31<00:16, 307.36it/s]


KeyboardInterrupt: 