In [1]:
import pandas as pd
import numpy as np
import os
import re
import glob
import tensorflow as tf
import dask.dataframe as dd
import torch
import chess

2025-06-03 19:51:22.936286: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749001883.031898     462 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749001883.059940     462 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-03 19:51:23.289395: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
parquets = glob.glob('../data/processed/*.parquet')

In [3]:
df = pd.read_parquet(parquets[0])

In [4]:
# one_million_games = pd.concat([pd.read_parquet(parquet) for parquet in parquets])

In [7]:
def fen_to_tensor(fen:str) -> torch.Tensor:
    """
    Converts a FEN position into a torch tensor of shape (12,8,8),
    12 matrix of 8x8 positions, in which each type of piece eaither PNBRQK or pnbrqk,
    will ocupate a place in the matrix, each matrix for each set of piece representation.

    Parameters
    ----------
    fen : str
          The notation FEN to convert into numerical values
    Returns
    -------
    board_tensor : torch.Tensor
                   The representation of FEN notation in 12 matrix of 8x8

    """

    
    board = chess.Board(fen)
    
    piece_to_index = {piece:idx for idx,piece in enumerate('PNBRQKpnbrqk')} # represents the piece and index of each value of the str

    #TODO: add extra ccanals to indicate if there is castling available 4 canals, passant square, halfmove clock
    
    board_tensor = torch.zeros((12,8,8),dtype=torch.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            idx = piece_to_index[piece.symbol()]
            row = 7 - (square // 8)
            col = square %8
            board_tensor[idx,row,col] = 1.0
    return board_tensor
        

In [9]:
import chess

def get_legal_moves_vocab(fen:str) -> tuple[dict[str,int],dict[int,str]]:
    """
    Generates a set of legal posible moves for a given position

    Parameters
    ----------
        fen: FEN notation of the current position
    Returns
    -------
        uci_to_idx: Dict {uci_move : idx}
        idc_to_uci: Dict {idx : uci_move}
    """

    board = chess.Board(fen)
    legal_moves = list(board.legal_moves)
    
    legal_moves_sorted = sorted(legal_moves, key=lambda m: m.uci())

    uci_to_idx = {move.uci():  idx for idx, move in enumerate(legal_moves_sorted)}
    idx_to_uci = {idx: move.uci() for idx,move in enumerate(legal_moves_sorted)}
    return uci_to_idx, idx_to_uci


In [20]:
fen = df['fen'].iloc[7]
print(fen)
board_tensor = fen_to_tensor(fen)
legal_moves = get_legal_moves_vocab(fen)


rnbqkbnr/ppp2ppp/4p3/8/4N3/3P4/PPP2PPP/R1BQKBNR b KQkq - 0 4


In [23]:
board_tensor

tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 1., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
       

In [22]:
legal_moves

({'a7a5': 0,
  'a7a6': 1,
  'b7b5': 2,
  'b7b6': 3,
  'b8a6': 4,
  'b8c6': 5,
  'b8d7': 6,
  'c7c5': 7,
  'c7c6': 8,
  'c8d7': 9,
  'd8d3': 10,
  'd8d4': 11,
  'd8d5': 12,
  'd8d6': 13,
  'd8d7': 14,
  'd8e7': 15,
  'd8f6': 16,
  'd8g5': 17,
  'd8h4': 18,
  'e6e5': 19,
  'e8d7': 20,
  'e8e7': 21,
  'f7f5': 22,
  'f7f6': 23,
  'f8a3': 24,
  'f8b4': 25,
  'f8c5': 26,
  'f8d6': 27,
  'f8e7': 28,
  'g7g5': 29,
  'g7g6': 30,
  'g8e7': 31,
  'g8f6': 32,
  'g8h6': 33,
  'h7h5': 34,
  'h7h6': 35},
 {0: 'a7a5',
  1: 'a7a6',
  2: 'b7b5',
  3: 'b7b6',
  4: 'b8a6',
  5: 'b8c6',
  6: 'b8d7',
  7: 'c7c5',
  8: 'c7c6',
  9: 'c8d7',
  10: 'd8d3',
  11: 'd8d4',
  12: 'd8d5',
  13: 'd8d6',
  14: 'd8d7',
  15: 'd8e7',
  16: 'd8f6',
  17: 'd8g5',
  18: 'd8h4',
  19: 'e6e5',
  20: 'e8d7',
  21: 'e8e7',
  22: 'f7f5',
  23: 'f7f6',
  24: 'f8a3',
  25: 'f8b4',
  26: 'f8c5',
  27: 'f8d6',
  28: 'f8e7',
  29: 'g7g5',
  30: 'g7g6',
  31: 'g8e7',
  32: 'g8f6',
  33: 'g8h6',
  34: 'h7h5',
  35: 'h7h6'})

In [10]:
def get_legal_mask(board: chess.Board, uci_to_index: dict) -> torch.Tensor:
    mask = torch.zeros(len(uci_to_index), dtype=torch.float32)
    for move in board.legal_moves:
        uci = move.uci()
        if uci in uci_to_index:
            mask[uci_to_index[uci]] = 1.0
    return mask  # Shape: (n_moves,)


In [12]:
def generate_full_uci_move_vocabulary():
    move_set = set()
    
    for from_sq in chess.SQUARES:
        for to_sq in chess.SQUARES:
            if from_sq == to_sq:
                continue

            move = chess.Move(from_sq, to_sq)
            move_set.add(move.uci())

            to_rank = chess.square_rank(to_sq)
            if to_rank in [0, 7]:  # posibles promociones
                for promo in [chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT]:
                    move_set.add(chess.Move(from_sq, to_sq, promotion=promo).uci())

    move_list = sorted(move_set)
    uci_to_index = {uci: idx for idx, uci in enumerate(move_list)}
    index_to_uci = {idx: uci for uci, idx in uci_to_index.items()}
    return uci_to_index, index_to_uci


In [13]:
# Globales cargados una vez al inicio
uci_to_index, index_to_uci = generate_full_uci_move_vocabulary()

def move_to_index(uci_move: str) -> int:
    return uci_to_index.get(uci_move, -1)  # -1 si no está

def index_to_move(idx: int) -> str:
    return index_to_uci.get(idx, "0000")  # dummy por si acaso
