In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
import chess
from time import time
import tqdm

HEADERS = ("bitmaps", "movePlayed")
BATCH_SIZE = 64

In [2]:
import torch.nn as nn
import torch.nn.functional as F
class PositionEvaluatorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        return torch.tanh(self.fc2(x))  # Output between -1 and 1
    

class ChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])
        self.conv1 = nn.Conv2d(13,64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8 * 8 * 128, 256)
        self.fc2 = nn.Linear(256, 64)
        self.relu = nn.ReLU()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x

class CompleteChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])
        self.conv1 = nn.Conv2d(13,64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8 * 8 * 128, 256)
        self.fc2 = nn.Linear(256, 64 * 63) # (Choose 2 squares from the board where the order matters) 
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.sigmoid(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x
## Idea:
##   One model that answers "best piece to move in this position"
##   Then another model that answers "best square to move piece X to"

# Add normalization after conv
# Switch from relu to sigmoid or smth
# Add more preprocessing by making the tensors there and avoid pre processing before training

In [3]:
import json

piece_to_idx = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def board_to_tensor(board):
    tensor = np.zeros((12, 8, 8), dtype=np.uint8)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            idx = piece_to_idx[piece.symbol()]
            row = 7 - square // 8
            col = square % 8
            tensor[idx, row, col] = 1
    return tensor

def idxToUci(idx):
    return chr(ord('a') + idx % 8) + str(idx // 8 + 1)

def convert_to_array(row: str):
    """"
    Converts board into array

    :param str row: Board with side to play;
                        This must represent a np.ndarray[np.int64, shape=(13)] 
    :return: np.ndarray[shape(13, 8, 8), dtype=np.float32]]:
    """
    boards = np.array(json.loads(row), dtype=np.uint64) # shape = (13,)
    array = np.empty((13, 8, 8), dtype=np.uint8)

    # Set pieces boards
    board_int8_view = boards.view(dtype=np.uint8).reshape((13, 8)) # shape = (13, 8)
    board_as_int = np.unpackbits(board_int8_view, axis=1).reshape((13, 8, 8))
    array[:] = board_as_int

    # Set side to play board
    array[12] = np.ones(shape = (1, 8, 8)) if boards[12] == 1 else np.zeros(shape=(1,8,8))
    return array


letters = ["a", "b", "c", "d", "e", "f", "g", "h"]
numbers = list(range(1, 10)) # [1..9]
MOVE_DICTIONARY = {}
cumulative = 0
for i in range(8):
    for j in range(8):
        for k in range(8):
            for w in range (8):
                if (i == k and j == w):
                    cumulative += 1
                    continue
                from_square = f"{letters[i]}{numbers[j]}"
                to_square = f"{letters[k]}{numbers[w]}"
                MOVE_DICTIONARY[f"{from_square}{to_square}"] = (i * 8**3) + (j * 8**2) + (k * 8) + w - cumulative
REVERSE_MOVE_DICTIONARY = {
    value: key for key,value in MOVE_DICTIONARY.items()
}

In [4]:
from typing import Literal
import polars as pl

class ChessEvalDataset(Dataset):
# class ChessEvalDataset(IterableDataset):
    def __init__(self, file: str, model: Literal["pieces", "moves"] = "pieces", load_batch_size = 6_400):
        self.model = model
        self.lazy_dataset = pl.scan_csv(file, has_header=False, new_columns=HEADERS)
        self.batch_size = load_batch_size
        self.feature_col = "bitmaps"
        self.target_col = "movePlayed"
        self.total_rows = self.lazy_dataset.select(pl.len()).collect().item()

        self.cached_batches: dict[int, tuple] = {}
        self.cached_batch_id: int | None = None

    def __len__(self):
        return self.total_rows
    
    def _get_batch(self, batch_id):
        """Load a specific batch of data, wrapped with lru_cache for memory management"""
        # Calculate batch range
        if batch_id == self.cached_batch_id:
            return self.cached_batches[batch_id]
        self.cached_batches.pop(self.cached_batch_id, None) # Delete old batch
        
        # Calculate batch range
        start_idx = batch_id * self.batch_size
        end_idx = min((batch_id + 1) * self.batch_size, self.total_rows)
        
        # Fetch only this batch of data using offset and limit
        batch_dataset = (self.lazy_dataset
                    .slice(start_idx, end_idx - start_idx)
                    .collect())
        
        # Process features and target
        features = batch_dataset.select(self.feature_col)
        features = torch.tensor(np.array([convert_to_array(bitmaps) for bitmaps in features["bitmaps"]]), dtype=torch.float32)
        
        # targets = batch_dataset.select(self.target_col).map_rows(lambda move_string: MOVE_DICTIONARY[move_string[0][:4]]).to_numpy()
        targets = batch_dataset.select(self.target_col).to_numpy()
        targets = torch.tensor(targets, dtype=torch.long)

        self.cached_batch_id = batch_id
        self.cached_batches[self.cached_batch_id] = (features, targets)
        
        return self.cached_batches[batch_id]

    def __getitem__(self, idx):
        # Calculate which batch this index belongs to
        batch_id = idx // self.batch_size
        # Get the batch
        features, targets = self._get_batch(batch_id)
        # Get the item from the batch
        idx_in_batch = idx % self.batch_size
        
        return features[idx_in_batch], targets[idx_in_batch]

In [5]:
VALID_MOVE_LOSS = -0.5
INVALID_MOVE_LOSS = +10

piece_to_idx = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def bitmaps_to_board(bitmaps):
    board = chess.Board(fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1")
    for piece_name, piece_idx in piece_to_idx.items():
        for row in range(8):
            for col in range(8):
                if bitmaps[piece_idx][row][col] == 1:
                    piece = chess.Piece(piece_idx, chess.WHITE) if piece_idx < 6 else chess.Piece(piece_idx - 6, chess.BLACK)
                    board.set_piece_at(row * 8 + col, piece)
    return board

# LOSS FUNCTION
cross_entropy_loss = torch.nn.CrossEntropyLoss()
def loss_fn(batch_bitmaps, outputs: torch.Tensor, targets):
    loss = cross_entropy_loss(outputs, targets)

    played_move = [REVERSE_MOVE_DICTIONARY[output.argmax().item()] for output in outputs]
    game_board = [bitmaps_to_board(bitmaps) for bitmaps in batch_bitmaps]
    penalization = VALID_MOVE_LOSS
    for board, move in zip(game_board, played_move):
        if chess.Move.from_uci(move) not in board.legal_moves:
            penalization = INVALID_MOVE_LOSS
            break
    return torch.add(loss, penalization)

shape: (2, 3)
┌─────────────────────────────────┬────────────┬─────────────────────────────────┐
│ bitmaps                         ┆ movePlayed ┆ validMoves                      │
│ ---                             ┆ ---        ┆ ---                             │
│ str                             ┆ i64        ┆ str                             │
╞═════════════════════════════════╪════════════╪═════════════════════════════════╡
│ [277497088, 0, 0, 48, 17179869… ┆ 1326       ┆ [0, 1, 0, 0, 0, 0, 0, 0, 0, 14… │
│ [277497088, 0, 0, 48, 13421772… ┆ 1286       ┆ [0, 0, 0, 34426978624, 9223372… │
└─────────────────────────────────┴────────────┴─────────────────────────────────┘


In [None]:
DATASET_PATH = '../dataset/processed/results.csv'
NUM_EPOCHS = 100

TRAINING_MODE = "pieces" # "pieces" or "moves"
MODEL_WEIGHTS_OUTPUT_PATH = "./models/CompleteModel_FINISHED.pth"

## !IMPORTANT: This dictates how much ram will be used, and how much data will be loaded
# 1_280_000 loads around 5gb, dont push this too high as it will crash if ram deplects
NUM_EXAMPLES_TO_LOAD_PER_FETCH = 1_280_000 

test = ChessEvalDataset(file = DATASET_PATH, model=TRAINING_MODE, load_batch_size = NUM_EXAMPLES_TO_LOAD_PER_FETCH)
loader = DataLoader(test, batch_size=64, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CompleteChessBotNetwork().to(device)

# Continue with pretrained weights
# model.load_state_dict(torch.load('./models/CompleteModel_Epoch-80.pth'))

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss = torch.nn.CrossEntropyLoss()

print(torch.cuda.is_available())
print("Using device: ", device)
for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    t0 = time()
    avg_loss = 0.0
    i = 1
    for board_tensor, target_eval in tqdm.tqdm(loader):
        board_tensor_gpu, target_eval_gpu = board_tensor.to(device), target_eval.to(device)  # Move data to GPU
        optimizer.zero_grad()
        pred = model(board_tensor_gpu)

        # Compute loss with valid move vlaidaiton
        # loss = loss_fn(board_tensor, pred, target_eval_gpu.squeeze(1))
        loss = loss_fn(pred, target_eval_gpu.squeeze(1))
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        avg_loss += loss.item()
        i+=1

    tf = time()
    print(f"Epoch {epoch} - {avg_loss / len(loader):.4f} | Time: {tf-t0}")

    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"./models/CompleteModel_Epoch-80_retrain-{epoch}.pth")
# Save the trained model
torch.save(model.state_dict(), MODEL_WEIGHTS_OUTPUT_PATH)

True
Using device:  cuda


  0%|          | 100/44798 [00:54<5:26:09,  2.28it/s]

avg loss in current batch till now:  11.857575197219848


  0%|          | 200/44798 [01:37<5:24:41,  2.29it/s]

avg loss in current batch till now:  11.933535203933715


  1%|          | 300/44798 [02:21<5:26:23,  2.27it/s]

avg loss in current batch till now:  11.911775217056274


  1%|          | 400/44798 [03:05<5:22:47,  2.29it/s]

avg loss in current batch till now:  11.918792834281922


  1%|          | 500/44798 [03:49<5:23:42,  2.28it/s]

avg loss in current batch till now:  11.915034656524659


  1%|▏         | 600/44798 [04:33<5:22:26,  2.28it/s]

avg loss in current batch till now:  11.909203254381815


  2%|▏         | 700/44798 [05:17<5:23:47,  2.27it/s]

avg loss in current batch till now:  11.9184036391122


  2%|▏         | 800/44798 [06:01<5:23:37,  2.27it/s]

avg loss in current batch till now:  11.927353695631027


  2%|▏         | 900/44798 [06:45<5:19:09,  2.29it/s]

avg loss in current batch till now:  11.933392205768161


  2%|▏         | 1000/44798 [07:29<5:23:05,  2.26it/s]

avg loss in current batch till now:  11.932725671768189


  2%|▏         | 1100/44798 [08:13<5:20:40,  2.27it/s]

avg loss in current batch till now:  11.935056311867454


  3%|▎         | 1200/44798 [08:57<5:13:50,  2.32it/s]

avg loss in current batch till now:  11.927253228823345


  3%|▎         | 1300/44798 [09:43<5:39:28,  2.14it/s]

avg loss in current batch till now:  11.926826942884006


  3%|▎         | 1400/44798 [10:29<5:40:32,  2.12it/s]

avg loss in current batch till now:  11.925605938775199


  3%|▎         | 1500/44798 [11:16<5:33:07,  2.17it/s]

avg loss in current batch till now:  11.927548115412394


  4%|▎         | 1600/44798 [12:02<5:34:19,  2.15it/s]

avg loss in current batch till now:  11.926126132011413


  4%|▍         | 1700/44798 [12:49<5:39:50,  2.11it/s]

avg loss in current batch till now:  11.92449364606072


  4%|▍         | 1800/44798 [13:36<5:42:15,  2.09it/s]

avg loss in current batch till now:  11.92283832444085


  4%|▍         | 1900/44798 [14:23<5:31:48,  2.15it/s]

avg loss in current batch till now:  11.921677813279002


  4%|▍         | 2000/44798 [15:09<5:26:54,  2.18it/s]

avg loss in current batch till now:  11.921292686939239


  5%|▍         | 2100/44798 [15:55<5:24:22,  2.19it/s]

avg loss in current batch till now:  11.922308484940302


  5%|▍         | 2200/44798 [16:42<5:28:29,  2.16it/s]

avg loss in current batch till now:  11.922467660470442


  5%|▌         | 2300/44798 [17:28<5:28:50,  2.15it/s]

avg loss in current batch till now:  11.922866395452749


  5%|▌         | 2400/44798 [18:14<5:23:52,  2.18it/s]

avg loss in current batch till now:  11.923578549226125


  6%|▌         | 2500/44798 [19:01<5:27:01,  2.16it/s]

avg loss in current batch till now:  11.92233555908203


  6%|▌         | 2600/44798 [19:47<5:23:58,  2.17it/s]

avg loss in current batch till now:  11.922095777805035


  6%|▌         | 2700/44798 [20:33<5:20:40,  2.19it/s]

avg loss in current batch till now:  11.920521691287005


  6%|▋         | 2800/44798 [21:19<5:21:05,  2.18it/s]

avg loss in current batch till now:  11.922736343656267


  6%|▋         | 2900/44798 [22:05<5:21:30,  2.17it/s]

avg loss in current batch till now:  11.922876532653282


  7%|▋         | 3000/44798 [22:51<5:19:24,  2.18it/s]

avg loss in current batch till now:  11.923817964871725


  7%|▋         | 3100/44798 [23:37<5:17:43,  2.19it/s]

avg loss in current batch till now:  11.926611727437665


  7%|▋         | 3200/44798 [24:24<5:14:56,  2.20it/s]

avg loss in current batch till now:  11.927543960809707


  7%|▋         | 3300/44798 [25:09<5:17:16,  2.18it/s]

avg loss in current batch till now:  11.927446868491895


  8%|▊         | 3400/44798 [25:55<5:10:20,  2.22it/s]

avg loss in current batch till now:  11.92548957600313


  8%|▊         | 3500/44798 [26:40<5:10:03,  2.22it/s]

avg loss in current batch till now:  11.926983944484165


  8%|▊         | 3600/44798 [27:25<5:11:55,  2.20it/s]

avg loss in current batch till now:  11.928095752398173


  8%|▊         | 3700/44798 [28:10<5:11:24,  2.20it/s]

avg loss in current batch till now:  11.926794871510687


  8%|▊         | 3800/44798 [28:55<5:12:47,  2.18it/s]

avg loss in current batch till now:  11.927140939612137


  9%|▊         | 3900/44798 [29:41<5:05:04,  2.23it/s]

avg loss in current batch till now:  11.92708479025425


  9%|▉         | 4000/44798 [30:26<5:08:55,  2.20it/s]

avg loss in current batch till now:  11.927191143274307


  9%|▉         | 4100/44798 [31:12<5:02:56,  2.24it/s]

avg loss in current batch till now:  11.92818832537023


  9%|▉         | 4200/44798 [31:57<5:04:53,  2.22it/s]

avg loss in current batch till now:  11.927765271323068


 10%|▉         | 4300/44798 [32:42<5:05:27,  2.21it/s]

avg loss in current batch till now:  11.928268807876941


 10%|▉         | 4400/44798 [33:27<5:06:43,  2.20it/s]

avg loss in current batch till now:  11.92766689040444


 10%|█         | 4500/44798 [34:12<5:05:58,  2.20it/s]

avg loss in current batch till now:  11.927702038235134


 10%|█         | 4600/44798 [34:58<5:06:11,  2.19it/s]

avg loss in current batch till now:  11.92729791102202


 10%|█         | 4700/44798 [35:43<5:07:47,  2.17it/s]

avg loss in current batch till now:  11.927640388772843


 11%|█         | 4800/44798 [36:29<5:02:06,  2.21it/s]

avg loss in current batch till now:  11.925971078276634


 11%|█         | 4900/44798 [37:14<4:58:44,  2.23it/s]

avg loss in current batch till now:  11.925992020782159


 11%|█         | 5000/44798 [37:59<5:00:48,  2.21it/s]

avg loss in current batch till now:  11.926816752433776


 11%|█▏        | 5100/44798 [38:43<4:48:55,  2.29it/s]

avg loss in current batch till now:  11.92848099633759


 12%|█▏        | 5200/44798 [39:27<4:48:46,  2.29it/s]

avg loss in current batch till now:  11.928871269409473


 12%|█▏        | 5300/44798 [40:10<4:44:49,  2.31it/s]

avg loss in current batch till now:  11.929693371934711


 12%|█▏        | 5400/44798 [40:54<4:47:18,  2.29it/s]

avg loss in current batch till now:  11.929205136828953


 12%|█▏        | 5500/44798 [41:38<4:49:47,  2.26it/s]

avg loss in current batch till now:  11.928653882980347


 13%|█▎        | 5600/44798 [42:21<4:46:09,  2.28it/s]

avg loss in current batch till now:  11.928991406134196


 13%|█▎        | 5700/44798 [43:05<4:43:13,  2.30it/s]

avg loss in current batch till now:  11.928967689213


 13%|█▎        | 5800/44798 [43:49<4:45:27,  2.28it/s]

avg loss in current batch till now:  11.929436610649372


 13%|█▎        | 5900/44798 [44:32<4:42:07,  2.30it/s]

avg loss in current batch till now:  11.929326706579175


 13%|█▎        | 6000/44798 [45:16<4:41:57,  2.29it/s]

avg loss in current batch till now:  11.928561247030894


 14%|█▎        | 6100/44798 [45:59<4:40:48,  2.30it/s]

avg loss in current batch till now:  11.928878697879979


 14%|█▍        | 6200/44798 [46:43<4:40:38,  2.29it/s]

avg loss in current batch till now:  11.929123383645088


 14%|█▍        | 6300/44798 [47:28<4:42:56,  2.27it/s]

avg loss in current batch till now:  11.92848021083408


 14%|█▍        | 6400/44798 [48:13<5:02:58,  2.11it/s]

avg loss in current batch till now:  11.928368185907603


 15%|█▍        | 6500/44798 [48:57<4:34:35,  2.32it/s]

avg loss in current batch till now:  11.928082758683425


 15%|█▍        | 6600/44798 [49:43<4:56:55,  2.14it/s]

avg loss in current batch till now:  11.928074796561038


 15%|█▍        | 6700/44798 [50:29<4:49:44,  2.19it/s]

avg loss in current batch till now:  11.92771754649148


 15%|█▌        | 6800/44798 [51:15<4:42:15,  2.24it/s]

avg loss in current batch till now:  11.92842193477294


 15%|█▌        | 6900/44798 [52:01<4:50:41,  2.17it/s]

avg loss in current batch till now:  11.928609739662944


 16%|█▌        | 7000/44798 [52:48<4:51:56,  2.16it/s]

avg loss in current batch till now:  11.92843684387207


 16%|█▌        | 7100/44798 [53:35<4:29:42,  2.33it/s]

avg loss in current batch till now:  11.92891793425654


 16%|█▌        | 7200/44798 [54:20<4:41:15,  2.23it/s]

avg loss in current batch till now:  11.929201448493533


 16%|█▋        | 7300/44798 [55:05<4:30:44,  2.31it/s]

avg loss in current batch till now:  11.929456341233973


 17%|█▋        | 7400/44798 [55:50<4:45:04,  2.19it/s]

avg loss in current batch till now:  11.92934327512174


 17%|█▋        | 7500/44798 [56:37<4:40:19,  2.22it/s]

avg loss in current batch till now:  11.929926036198934


 17%|█▋        | 7585/44798 [57:15<4:47:22,  2.16it/s]