In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import chess
import chess.pgn

from torch.utils.data import Dataset, DataLoader, IterableDataset

from time import time
import tqdm

HEADERS = ("bitmaps", "movePlayed")
BATCH_SIZE = 64

In [2]:
def extract_fens_from_pgn(pgn_file, label_from="result"):
    positions = []
    with open(pgn_file, "r", encoding="utf-8") as f:
        while True:
            game = chess.pgn.read_game(f)
            if game is None:
                break

            board = game.board()
            result = game.headers.get("Result")

            # Label based on game outcome
            if result == "1-0":
                label = 1.0
            elif result == "0-1":
                label = -1.0
            else:
                label = 0.0

            # Step through moves
            for move in game.mainline_moves():
                board.push(move)
                fen = board.fen()
                positions.append((fen, label))  # ← You can also call Stockfish here if you want

    return positions

In [3]:
import torch.nn as nn
import torch.nn.functional as F
class PositionEvaluatorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        return torch.tanh(self.fc2(x))  # Output between -1 and 1
    

class ChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])
        self.conv1 = nn.Conv2d(13,64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8 * 8 * 128, 256)
        self.fc2 = nn.Linear(256, 64)
        self.relu = nn.ReLU()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x


## OUTPUT COMPLETE MOVES
class CompleteChessBotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 13 channels: 12 pieces + side to play (tensor[true's|false's])
        self.conv1 = nn.Conv2d(13,64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8 * 8 * 128, 256)
        self.fc2 = nn.Linear(256, 64 * 63) # (Choose 2 squares from the board where the order matters) 
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, x):
        x = self.sigmoid(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x
## Idea:
##   One model that answers "best piece to move in this position"
##   Then another model that answers "best square to move piece X to"

# Add normalization after conv
# Switch from relu to sigmoid or smth
# Add more preprocessing by making the tensors there and avoid pre processing before training

In [4]:
from itertools import product


piece_to_idx = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def board_to_tensor(board):
    tensor = np.zeros((12, 8, 8), dtype=np.uint8)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            idx = piece_to_idx[piece.symbol()]
            row = 7 - square // 8
            col = square % 8
            tensor[idx, row, col] = 1
    return tensor

def idxToUci(idx):
    return chr(ord('a') + idx % 8) + str(idx // 8 + 1)


letters = ["a", "b", "c", "d", "e", "f", "g", "h"]
numbers = list(range(1, 10)) # [1..9]
MOVE_DICTIONARY = {}
cumulative = 0
for i in range(8):
    for j in range(8):
        for k in range(8):
            for w in range (8):
                if (i == k and j == w):
                    cumulative += 1
                    continue
                from_square = f"{letters[i]}{numbers[j]}"
                to_square = f"{letters[k]}{numbers[w]}"
                MOVE_DICTIONARY[f"{from_square}{to_square}"] = (i * 8**3) + (j * 8**2) + (k * 8) + w - cumulative
REVERSE_MOVE_DICTIONARY = {
    value: key for key,value in MOVE_DICTIONARY.items()
}

In [None]:
from typing import Literal
import polars as pl

class ChessEvalDataset(Dataset):
# class ChessEvalDataset(IterableDataset):
    def __init__(self, file: str, model: Literal["pieces", "moves"] = "pieces", load_batch_size = 6_400):
        self.model = model
        self.lazy_dataset = pl.scan_csv(file, has_header=False, new_columns=HEADERS)
        self.batch_size = load_batch_size
        self.feature_col = "bitmaps"
        self.target_col = "movePlayed"
        self.total_rows = self.lazy_dataset.select(pl.len()).collect().item()

        self.cached_batches: dict[int, tuple] = {}
        self.cached_batch_id: int | None = None

    def __len__(self):
        return self.total_rows
    
    def _get_batch(self, batch_id):
        """Load a specific batch of data, wrapped with lru_cache for memory management"""
        # Calculate batch range
        if batch_id == self.cached_batch_id:
            return self.cached_batches[batch_id]
        self.cached_batches.pop(self.cached_batch_id, None) # Delete old batch
        
        start_idx = batch_id * self.batch_size
        end_idx = min((batch_id + 1) * self.batch_size, self.total_rows)
        
        # Fetch only this batch of data using offset and limit
        batch_dataset = (self.lazy_dataset
                    .slice(start_idx, end_idx - start_idx)
                    .collect())
        
        # Process features and target
        features = batch_dataset.select(self.feature_col).map_rows(
            lambda row: (np.frombuffer(
                row[0].encode(),
                dtype=np.int64
            ).reshape(13, 8, 8), )
        ).to_numpy()

        features = np.stack(features[:, 0])
        features = torch.tensor(features, dtype=torch.float)

        targets = batch_dataset.select(s                            sideToPlay,
elf.target_col).to_numpy()
        targets = torch.tensor(targets, dtype=torch.long)

        self.cached_batch_id = batch_id
        self.cached_batches[self.cached_batch_id] = (features, targets)
        
        return self.cached_batches[batch_id]

    def __getitem__(self, idx):
        # Calculate which batch this index belongs to
        batch_id = idx // self.batch_size
        # Get the batch
        features, targets = self._get_batch(batch_id)
        # Get the item from the batch
        idx_in_batch = idx % self.batch_size
        
        return features[idx_in_batch], targets[idx_in_batch]

In [None]:
DATASET_PATH = '../dataset/processed/npbytes_results_serialized.csv'
NUM_EPOCHS = 100

TRAINING_MODE = "pieces" # "pieces" or "moves"
MODEL_WEIGHTS_OUTPUT_PATH = "CompleteModel_SKIP_INITIAL.pth"
test = ChessEvalDataset(file = DATASET_PATH, model=TRAINING_MODE, load_batch_size = 192_000)
loader = DataLoader(test, batch_size=64, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CompleteChessBotNetwork().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
print(torch.cuda.is_available())
print("Using device: ", device)
for epoch in range(NUM_EPOCHS):
    model.train()
    t0 = time()
    running_loss = 0.0
    for board_tensor_batch, target_eval_batch in tqdm.tqdm(loader):
        board_tensor_batch_gpu, target_eval_batch_gpu = board_tensor_batch.to(device), target_eval_batch.to(device)  # Move data to GPU
        optimizer.zero_grad()
        pred = model(board_tensor_batch_gpu)

        # Compute loss
        loss = loss_fn(pred, target_eval_batch_gpu.squeeze(1))
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        running_loss += loss.item()

    tf = time()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f} - {running_loss / len(loader):.4f} | Time: {tf-t0}")

    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"test_{epoch+1}.pth")       
# Save the trained model
torch.save(model.state_dict(), MODEL_WEIGHTS_OUTPUT_PATH)

True
Using device:  cuda


 54%|█████▎    | 23991/44798 [02:20<00:54, 380.32it/s] 

In [None]:
import polars as pl
df = pl.read_csv('../dataset/processed/npbytes_results_serialized.csv', n_rows=2, has_header=False)
print(df)
features = df.select("column_1").map_rows(
    lambda row: (np.frombuffer(
        row[0].encode(),  # DO NOT ASK!
        dtype=np.int64
    ).reshape(13, 8, 8), )
).to_numpy()
features = np.stack(features[:, 0])
print(features)
# print(np.shape(features.reshape((13,8,8))))
# print(features.reshape((13,8,8)))

shape: (2, 2)
┌─────────────────────────────────┬──────────┐
│ column_1                        ┆ column_2 │
│ ---                             ┆ ---      │
│ str                             ┆ i64      │
╞═════════════════════════════════╪══════════╡
│                               … ┆ 2303     │
│                               … ┆ 2682     │
└─────────────────────────────────┴──────────┘
[[[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [1 1 0 ... 1 1 1]
   [0 0 0 ... 0 0 0]]

  [[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 1 0 0]
   [0 0 0 ... 0 0 0]
   [0 1 0 ... 0 0 0]]

  [[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 1 ... 1 0 0]]

  ...

  [[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]

  [[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
