In [None]:
import chess
import chess.pgn
import torch
from pathlib import Path
from tqdm import tqdm
import numpy as np
import json

# Paths and config
DATA_DIR = Path("data/")
SAVE_DIR = Path("prepared_data/batches/")
SAVE_DIR.mkdir(parents=True, exist_ok=True)
MOVE_DICT_PATH = Path("move_index_dict.json")
INDEX_LIST_PATH = Path("index_to_move.json")

# PGN files (Dec 2021 – May 2022)
# PGNs were sourced from https://database.nikonoel.fr/
PGN_PATHS = sorted(
    list(DATA_DIR.glob("lichess_elite_2021-12.pgn")) +
    list(DATA_DIR.glob("lichess_elite_2022-0[1-5].pgn"))
)

# Parameters
MAX_GAMES = 75_000
SAVE_EVERY = 10_000

In [2]:
def board_to_tensor(board: chess.Board) -> np.ndarray:
    tensor = np.zeros((21, 8, 8), dtype=np.float32)

    # Piece planes (12)
    piece_to_index = {
        (chess.PAWN, True): 0, (chess.KNIGHT, True): 1,
        (chess.BISHOP, True): 2, (chess.ROOK, True): 3,
        (chess.QUEEN, True): 4, (chess.KING, True): 5,
        (chess.PAWN, False): 6, (chess.KNIGHT, False): 7,
        (chess.BISHOP, False): 8, (chess.ROOK, False): 9,
        (chess.QUEEN, False): 10, (chess.KING, False): 11,
    }

    for square, piece in board.piece_map().items():
        row = 7 - (square // 8)
        col = square % 8
        idx = piece_to_index[(piece.piece_type, piece.color)]
        tensor[idx, row, col] = 1.0

    # Turn (1)
    tensor[12] = np.ones((8, 8), dtype=np.float32) if board.turn == chess.WHITE else np.zeros((8, 8), dtype=np.float32)

    # Castling rights (4)
    castling_planes = [
        (chess.BB_H1, board.has_kingside_castling_rights(chess.WHITE)),
        (chess.BB_A1, board.has_queenside_castling_rights(chess.WHITE)),
        (chess.BB_H8, board.has_kingside_castling_rights(chess.BLACK)),
        (chess.BB_A8, board.has_queenside_castling_rights(chess.BLACK)),
    ]
    for i, (square_bb, has_right) in enumerate(castling_planes):
        if has_right:
            square = chess.square_mirror(square_bb.bit_length() - 1)
            row = 7 - (square // 8)
            col = square % 8
            tensor[13 + i, row, col] = 1.0

    # En passant square (1)
    if board.ep_square is not None:
        row = 7 - (board.ep_square // 8)
        col = board.ep_square % 8
        tensor[17, row, col] = 1.0

    # Game state planes (3): checkmate, draw, active
    if board.is_checkmate():
        tensor[18] = np.ones((8, 8), dtype=np.float32)
    elif (
        board.is_stalemate()
        or board.is_insufficient_material()
        or board.is_seventyfive_moves()
        or board.is_fivefold_repetition()
    ):
        tensor[19] = np.ones((8, 8), dtype=np.float32)
    else:
        tensor[20] = np.ones((8, 8), dtype=np.float32)


    return tensor

In [3]:
# Final structures
index_to_move = []
move_index_dict = {}

def build_fixed_move_index():
    idx = 0
    seen_uci = set()
    board = chess.Board()

    piece_types = [
        chess.PAWN, chess.KNIGHT, chess.BISHOP,
        chess.ROOK, chess.QUEEN, chess.KING
    ]
    colors = [chess.WHITE, chess.BLACK]

    for color in colors:
        enemy_color = not color
        for piece_type in piece_types:
            for square in chess.SQUARES:
                board.clear()
                board.turn = color
                board.set_piece_at(square, chess.Piece(piece_type, color))
                board.castling_rights = chess.BB_ALL

                # Diagonal promo attack targets
                if piece_type == chess.PAWN:
                    rank = chess.square_rank(square)
                    if color == chess.WHITE and rank == 6:
                        for offset in [-1, 1]:
                            f = chess.square_file(square) + offset
                            if 0 <= f <= 7:
                                target = square + 8 + offset
                                board.set_piece_at(target, chess.Piece(chess.ROOK, enemy_color))
                    elif color == chess.BLACK and rank == 1:
                        for offset in [-1, 1]:
                            f = chess.square_file(square) + offset
                            if 0 <= f <= 7:
                                target = square - 8 + offset
                                board.set_piece_at(target, chess.Piece(chess.ROOK, enemy_color))

                for move in board.legal_moves:
                    uci = move.uci()
                    if uci not in seen_uci:
                        seen_uci.add(uci)
                        move_index_dict[uci] = idx
                        index_to_move.append(uci)
                        idx += 1

    print(f"Total move classes: {len(index_to_move)}")

    # Save JSON files
    with open(MOVE_DICT_PATH, "w") as f:
        json.dump(move_index_dict, f)
    with open(INDEX_LIST_PATH, "w") as f:
        json.dump(index_to_move, f)

build_fixed_move_index()

Total move classes: 1968


In [4]:
def result_to_value(result_str: str) -> float | None:
    """
    Converts PGN result string to scalar value for value head.
    Returns:
        1.0  for white win,
        -1.0 for black win,
        0.0  for draw,
        None if result is invalid or missing.
    """
    if result_str == "1-0":
        return 1.0
    elif result_str == "0-1":
        return -1.0
    elif result_str == "1/2-1/2":
        return 0.0
    return None

In [5]:
def parse_and_save_games(pgn_paths, max_games, save_every, save_dir):
    game_count = 0
    sample_count = 0
    batch = []

    with tqdm(total=max_games, desc="Parsing games", unit="game") as pbar:
        for pgn_path in pgn_paths:
            with open(pgn_path, 'r', encoding='utf-8') as f:
                while game_count < max_games:
                    game = chess.pgn.read_game(f)
                    if game is None:
                        break

                    result_str = game.headers.get("Result")
                    result_val = result_to_value(result_str)
                    if result_val is None:
                        tqdm.write("results_val is None")
                        continue  # skip corrupted or missing results

                    board = game.board()
                    for move in game.mainline_moves():
                        try:
                            if move.uci() not in move_index_dict:
                                board.push(move)
                                tqdm.write(f"Skipped move: {move.uci()}")
                                continue  # skip moves not in fixed move encoding

                            tensor = board_to_tensor(board)
                            assert tensor.shape == (21, 8, 8), f"Bad tensor shape: {tensor.shape}"
                            move_idx = move_index_dict[move.uci()]
                            batch.append((tensor, move_idx, result_val))
                            sample_count += 1
                            board.push(move)
                        except Exception as e:
                            tqdm.write(f"Skipped move due to error: {e}")
                            continue

                        if len(batch) >= save_every:
                            save_path = save_dir / f"batch_{sample_count // save_every:03d}.pt"
                            torch.save(batch, save_path)
                            batch = []

                    game_count += 1
                    pbar.update(1)

    if batch:
        save_path = save_dir / f"batch_{(sample_count // save_every):03d}_final.pt"
        torch.save(batch, save_path)

    tqdm.write(f"Done. Parsed {game_count} games, {sample_count} samples.")

In [6]:
parse_and_save_games(PGN_PATHS, MAX_GAMES, SAVE_EVERY, SAVE_DIR)

Parsing games: 100%|██████████| 75000/75000 [10:03<00:00, 124.21game/s]

Done. Parsed 75000 games, 6211912 samples.



