In [1]:
import pandas as pd
import numpy as np
import os
import re
import glob
import tensorflow as tf
import dask.dataframe as dd
import torch
import chess

2025-05-30 13:16:20.326044: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748632580.416250     452 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748632580.443109     452 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-30 13:16:20.655129: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
parquets = glob.glob('../data/processed/*.parquet')

In [3]:
df = pd.read_parquet(parquets[0])

In [4]:
# one_million_games = pd.concat([pd.read_parquet(parquet) for parquet in parquets])

In [14]:
def fen_to_tensor(fen:str) -> torch.Tensor:
    """
    Converts a FEN position into a torch tensor of shape (12,8,8),
    12 matrix of 8x8 positions, in which each type of piece eaither PNBRQK or pnbrqk,
    will ocupate a place in the matrix, each matrix for each set of piece representation.

    Parameters
    ----------
    fen : str
          The notation FEN to convert into numerical values
    Returns
    -------
    board_tensor : torch.Tensor
                   The representation of FEN notation in 12 matrix of 8x8

    """

    
    board = chess.Board(fen)
    
    piece_to_index = {piece:idx for idx,piece in enumerate('PNBRQKpnbrqk')} # represents the piece and index of each value of the str

    #TODO: add extra ccanals to indicate if there is castling available 4 canals, passant square, halfmove clock
    
    board_tensor = torch.zeros((12,8,8),dtype=torch.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            idx = piece_to_index[piece.symbol()]
            row = 7 - (square // 8)
            col = square %8
            board_tensor[idx,row,col] = 1.0
    return board_tensor
        

In [26]:
uci,index = generate_uci_move_vocabulary()

In [72]:
import chess

def get_legal_moves_vocab(fen:str) -> tuple[dict[str,int],dict[int,str]]:
    """
    Generates a set of legal posible moves for a given position

    Parameters
    ----------
        fen: FEN notation of the current position
    Returns
    -------
        uci_to_idx: Dict {uci_move : idx}
        idc_to_uci: Dict {idx : uci_move}
    """

    board = chess.Board(fen)
    legal_moves = list(board.legal_moves)
    
    legal_moves_sorted = sorted(legal_moves, key=lambda m: m.uci())

    uci_to_idx = {move.uci():  idx for idx, move in enumerate(legal_moves_sorted)}
    idx_to_uci = {idx: move.uci() for idx,move in enumerate(legal_moves_sorted)}
    return uci_to_idx, idx_to_uci


In [74]:
def get_legal_mask(board: chess.Board, uci_to_index: dict) -> torch.Tensor:
    mask = torch.zeros(len(uci_to_index), dtype=torch.float32)
    for move in board.legal_moves:
        uci = move.uci()
        if uci in uci_to_index:
            mask[uci_to_index[uci]] = 1.0
    return mask  # Shape: (n_moves,)


({'a1b1': 0,
  'a2a3': 1,
  'a2a4': 2,
  'b2b3': 3,
  'b2b4': 4,
  'c1d2': 5,
  'c1e3': 6,
  'c2c3': 7,
  'c2c4': 8,
  'e2c3': 9,
  'e2d4': 10,
  'e2g1': 11,
  'f1d1': 12,
  'f1e1': 13,
  'f1f2': 14,
  'f1f3': 15,
  'f1g1': 16,
  'f4f5': 17,
  'g2c6': 18,
  'g2d5': 19,
  'g2e4': 20,
  'g2f2': 21,
  'g2f3': 22,
  'g2g1': 23,
  'g2h2': 24,
  'g3g4': 25,
  'h1h2': 26,
  'h3h4': 27},
 {0: 'a1b1',
  1: 'a2a3',
  2: 'a2a4',
  3: 'b2b3',
  4: 'b2b4',
  5: 'c1d2',
  6: 'c1e3',
  7: 'c2c3',
  8: 'c2c4',
  9: 'e2c3',
  10: 'e2d4',
  11: 'e2g1',
  12: 'f1d1',
  13: 'f1e1',
  14: 'f1f2',
  15: 'f1f3',
  16: 'f1g1',
  17: 'f4f5',
  18: 'g2c6',
  19: 'g2d5',
  20: 'g2e4',
  21: 'g2f2',
  22: 'g2f3',
  23: 'g2g1',
  24: 'g2h2',
  25: 'g3g4',
  26: 'h1h2',
  27: 'h3h4'})

In [None]:
# Globales cargados una vez al inicio
uci_to_index, index_to_uci = generate_full_uci_move_vocabulary()

def move_to_index(uci_move: str) -> int:
    return uci_to_index.get(uci_move, -1)  # -1 si no está

def index_to_move(idx: int) -> str:
    return index_to_uci.get(idx, "0000")  # dummy por si acaso


In [None]:
def generate_full_uci_move_vocabulary():
    move_set = set()
    
    for from_sq in chess.SQUARES:
        for to_sq in chess.SQUARES:
            if from_sq == to_sq:
                continue

            move = chess.Move(from_sq, to_sq)
            move_set.add(move.uci())

            to_rank = chess.square_rank(to_sq)
            if to_rank in [0, 7]:  # posibles promociones
                for promo in [chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT]:
                    move_set.add(chess.Move(from_sq, to_sq, promotion=promo).uci())

    move_list = sorted(move_set)
    uci_to_index = {uci: idx for idx, uci in enumerate(move_list)}
    index_to_uci = {idx: uci for uci, idx in uci_to_index.items()}
    return uci_to_index, index_to_uci


In [None]:
uci_to_index, index_to_uci = generate_uci_move_vocabulary()

def move_to_index(uci):
    return uci_to_index.get(uci, -1)

def index_to_move(idx):
    return index_to_uci.get(idx, '0000')  # default dummy