### Dependencies

In [2]:
import os

import chess
import chess.pgn

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

### File Paths

In [13]:
training_dataset = "lichess_2017-10_dataset.pth"
validation_dataset = "lichess_2017-08_dataset.pth"
model_path = "chess_model_fast.pth"

### Conversions

In [4]:
def board_to_tensor(board, color):
    """
    Convert a chess board to a tensor
    """
    # Initialize tensor
    tensor = torch.zeros(15, 8, 8)
    
    # Fill tensor with pieces
    for i in range(8):
        for j in range(8):
            piece = board.piece_at(chess.square(i, j))
            if piece is not None:
                piece_type = piece.piece_type
                piece_color = piece.color
                tensor[piece_type + 6 * (piece_color == color), i, j] = 1
    
    # Fill tensor with en passant squares
    if board.ep_square is not None:
        tensor[12, board.ep_square // 8, board.ep_square % 8] = 1
    
    # Fill tensor with castling rights
    white_castle_channel = 13 if color is chess.WHITE else 14
    if board.has_kingside_castling_rights(chess.WHITE):
        tensor[white_castle_channel, 6, 0] = 1
    if board.has_queenside_castling_rights(chess.WHITE):
        tensor[white_castle_channel, 2, 0] = 1
    if board.has_kingside_castling_rights(chess.BLACK):
        tensor[27 - white_castle_channel, 6, 7] = 1
    if board.has_queenside_castling_rights(chess.BLACK):
        tensor[27 - white_castle_channel, 2, 7] = 1
    
    return tensor

def move_to_index(move) -> int:
    """
    Convert a move to an index
    """
    # Get the coordinates of the move
    from_square = move.from_square
    to_square = move.to_square
    
    # Convert the coordinates to the output tensor
    from_x, from_y = from_square // 8, from_square % 8
    to_x, to_y = to_square // 8, to_square % 8

    assert from_x >= 0 and from_x < 8
    assert from_y >= 0 and from_y < 8
    assert to_x >= 0 and to_x < 8
    assert to_y >= 0 and to_y < 8

    # 0bjjjkkklllmmm : j=from_x, k=from_y, l=to_x, m=to_y
    return (from_x << 9) | (from_y << 6) | (to_x << 3) | (to_y)
    # print("from_x", from_x, "from_y", from_y, "to_x", to_x, "to_y", to_y, "index", index)
    

def move_to_output_tensor(move) -> torch.Tensor:
    """
    Convert a move to the coordinate in the output tensor (8x8x8x8)
    """

    # output_tensor = torch.zeros(8*8*8*8)
    # output_tensor[move_to_index(move)] = 1
    return torch.tensor(move_to_index(move))

### Dataset Model

In [5]:
class ChessDataset(Dataset):
    def __init__(self, data):
        self.fens = []
        self.next_moves = []
        for sample in data:
            self.fens.append(sample['fen'])
            self.next_moves.append(sample['next_move'])

    def __len__(self):
        return len(self.fens)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        board = chess.Board(self.fens[idx])
        color = board.turn
        board_tensor = board_to_tensor(board, color)
        next_move = move_to_output_tensor(self.next_moves[idx])

        sample = {'board': board_tensor, 'next_move': next_move}
        return sample

### Dataset Generation

In [8]:

import zstandard as zstd

def decompress_zstd(in_file, out_file):
    with open(in_file, 'rb') as compressed_file:
        decomp = zstd.ZstdDecompressor()
        with open(out_file, 'wb') as decompressed_file:
            decomp.copy_stream(compressed_file, decompressed_file)

def extract_data(game) -> list:
    data = []
    board = game.board()
    game_result = game.headers['Result'][0]
    if game_result == '1':
        winning_color = chess.WHITE
    else:
        winning_color = chess.BLACK
    
    for move in game.mainline_moves():
        if board.turn == winning_color:
            data.append({'fen': board.fen(), 'next_move': move})
        board.push(move)
    return data

def parse_pgn_file(file, max_games = 0) -> list:
    data = []
    tenths_done = -1
    game_count = 0
    extracted_game_count = 0
    with open(file) as pgn_file:
        while True:
            game = chess.pgn.read_game(pgn_file)
            if game is None: # No more games in file
                break 
            game_count += 1
            
            white_elo = int(game.headers['WhiteElo'])
            black_elo = int(game.headers['BlackElo'])
            if white_elo > 2000 or black_elo > 2000:
                data += extract_data(game)
                extracted_game_count += 1
            
            if max_games > 0 and tenths_done < (game_count / max_games) * 10:
                tenths_done += 1
                print("[", end="")
                print("=" * tenths_done, end="")
                print("." * (10 - tenths_done), end="")
                print("]")
                
            if game_count == max_games:
                break

    print(f"Extracted {len(data)} moves from {extracted_game_count} games")
    return data

def generate_dataset(pgn_file: str, dst_file: str):
    print("Parsing pgn: ", os.path.abspath(pgn_file))
    data = parse_pgn_file(pgn_file, 100000)

    dataset = ChessDataset(data)
    print("Generated a dataset with", len(dataset), "samples")
    torch.save(dataset, dst_file)
    print("Saved dataset to", os.path.abspath(dst_file))
      
compressed_file = "lichess_2017-08.pgn.zst" 
pgn_file = "data.pgn"
decompress_zstd(compressed_file, "data.pgn")
dst_file = compressed_file.split(".")[0] + "_dataset.pth"
generate_dataset(pgn_file, dst_file)
# remove the pgn file
os.remove(pgn_file)

Parsing pgn:  /home/azriv/chess-ai/colab/data.pgn
[..........]
[=.........]
[==........]
[===.......]
[====......]
[=====.....]
Extracted 527510 moves from 14124 games
Generated a dataset with 527510 samples
Saved dataset to /home/azriv/chess-ai/colab/lichess_2017-08_dataset.pth


### Model

In [15]:
class ChessModel(nn.Module):
    def __init__(self):
        super(ChessModel, self).__init__()
        self.conv_nn_stack = nn.Sequential(
            # Convolutional layers
            nn.Conv2d(15, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # Flatten the tensor for the fully connected layers
            nn.Flatten(),
            # Fully connected layers
            nn.Linear(64*8*8, 64*8*8),
            nn.ReLU(),
            # Output layer
            nn.Linear(64*8*8, 8*8*8*8)
        )

    def forward(self, x):
        logits = self.conv_nn_stack(x)
        return logits

### Training

In [17]:
def train():
    ds = torch.load(training_dataset, weights_only=False)
    print("Preparing to train on", len(ds), "samples")

    batch_size = 256
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True)
    print("DataLoader ready with", len(dl), "batches")
    
    vds = torch.load(validation_dataset, weights_only=False)
    vdl = DataLoader(vds, batch_size=batch_size, shuffle=True)

    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Performing training using {device}")

    model = ChessModel().to(device)
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, weights_only=True))
    
    print("Model will be saved to", os.path.abspath(model_path))

    learning_rate = 0.01
    epochs = 100;
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    best_loss = float("inf")
    patience = 3
    patience_counter = 0
    
    for epoch in range(epochs):
        print("----- STARTING EPOCH", epoch + 1, "-----")
        
        print("Training model...")
        model.train()
        batch_count = 0
        training_loss = 0
        for batch in dl:
            model_input = batch["board"].to(device)
            best_moves = batch["next_move"].to(device)

            optimizer.zero_grad() 
            logits = model(model_input)
            loss = loss_fn(logits, best_moves)
            training_loss += loss.item()
            loss.backward()
            optimizer.step()

            if (batch_count+1) % 500 == 0:
                print(f"[Batch #{batch_count + 1}] Running loss: {training_loss}")
            batch_count += 1
        
        training_loss /= len(dl)
        
        print("\nValidating model...")
        model.eval()
        validation_loss = 0
        with torch.no_grad():
            for batch in vdl:
                model_input = batch["board"].to(device)
                best_moves = batch["next_move"].to(device)
                logits = model(model_input)
                loss = loss_fn(logits, best_moves)
                validation_loss += loss.item()
                
        validation_loss /= len(vdl)
        print(f"Epoch {epoch + 1} complete. Training loss: {training_loss} - Validation loss: {validation_loss}")
        print("----- END OF EPOCH", epoch + 1, "-----\n")

        if validation_loss < best_loss:
            print("New best model found. Saving...")
            best_loss = validation_loss
            torch.save(model.state_dict(), model_path)
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping. Patience limit reached.")
                break

train()

Preparing to train on 7854493 samples
DataLoader ready with 30682 batches
Performing training using cuda
Model will be saved to /home/azriv/chess-ai/colab/chess_model_fast.pth
----- STARTING EPOCH 1 -----
Training model...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x4096 and 8192x4096)