### GNN embedding

In [1]:
#!pip install python-chess


In [2]:
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
#!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cpu.html
#!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.6.0+cpu.html
#!pip install torch-geometric

In [3]:
import torch
import random
import numpy as np

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(3633)

In [4]:
import os
import zipfile
import chess.pgn
import torch
import numpy as np
from io import TextIOWrapper
from torch_geometric.data import Data
import pandas as pd



# =====================
# FUNCIÓN DE FEATURES
# =====================

def board_to_feature(board):
    """Convierte el estado del tablero a un vector de 773 features"""
    piece_map = board.piece_map()
    planes = np.zeros((12, 64), dtype=np.float32)
    piece_to_index = {
        'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
        'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
    }
    for square, piece in piece_map.items():
        planes[piece_to_index[piece.symbol()]][square] = 1.0
    flat_pieces = planes.reshape(-1)  # 768

    extras = [
        float(board.turn),
        float(board.has_kingside_castling_rights(chess.WHITE)),
        float(board.has_queenside_castling_rights(chess.WHITE)),
        float(board.has_kingside_castling_rights(chess.BLACK)),
        float(board.has_queenside_castling_rights(chess.BLACK)),
    ]
    return np.concatenate([flat_pieces, extras])  # 773

# =====================
# FUNCIÓN PARA GRAFO DE PARTIDA
# =====================

def pgn_to_graph_one_player(game, color="white"):
    board = game.board()
    x = []
    edge_index = [[], []]
    node_idx = 0
    move_idx = 0

    for move in game.mainline_moves():
        board.push(move)
        if (color == "white" and move_idx % 2 == 0) or (color == "black" and move_idx % 2 == 1):
            x.append(board_to_feature(board.copy()))
            if node_idx > 0:
                edge_index[0].append(node_idx - 1)
                edge_index[1].append(node_idx)
            node_idx += 1
        move_idx += 1

    if len(x) < 2:
        return None

    x = torch.tensor(np.stack(x), dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    return Data(x=x, edge_index=edge_index)

def process_zip_pgns(zip_folder):
    data_by_player = {}
    metadata = []
    game_id = 0

    for filename in os.listdir(zip_folder):
        if not filename.endswith(".zip"):
            continue

        player_name = filename.replace(".zip", "")
        zip_path = os.path.join(zip_folder, filename)
        graphs = []

        with zipfile.ZipFile(zip_path, 'r') as zf:
            for pgn_filename in zf.namelist():
                with zf.open(pgn_filename) as file:
                    pgn = TextIOWrapper(file, encoding='utf-8', errors='ignore')

                    while True:
                        game = chess.pgn.read_game(pgn)
                        if game is None:
                            break

                        white = game.headers.get("White", "").lower()
                        black = game.headers.get("Black", "").lower()
                        target = player_name.lower()

                        if target in white:
                            color = "white"
                        elif target in black:
                            color = "black"
                        else:
                            continue

                        graph = pgn_to_graph_one_player(game, color=color)
                        if graph:
                            graph.player = player_name
                            graph.game_id = game_id
                            graphs.append(graph)

                            metadata.append({
    "game_id": game_id,
    "player": player_name,
    "Event": game.headers.get("Event", ""),
    "Site": game.headers.get("Site", ""),
    "Date": game.headers.get("Date", ""),
    "Round": game.headers.get("Round", ""),
    "White": game.headers.get("White", ""),
    "Black": game.headers.get("Black", ""),
    "Result": game.headers.get("Result", ""),
    "WhiteElo": game.headers.get("WhiteElo", ""),
    "BlackElo": game.headers.get("BlackElo", ""),
    "ECO": game.headers.get("ECO", ""),
    "pgn": str(game)
})

                            game_id += 1

        data_by_player[player_name] = graphs
        print(f"{player_name}: {len(graphs)} grafos generados")

    # Guardar metadata
    df_meta = pd.DataFrame(metadata)
    df_meta.to_csv("games_metadata.csv", index=False)
    print("✅ Metadata guardada en games_metadata.csv")

    return data_by_player



In [5]:
graphs = process_zip_pgns("pgns")

Alekhine: 1660 grafos generados
Anand: 4201 grafos generados
Anderssen: 680 grafos generados
Aronian: 5106 grafos generados
Bogoljubow: 972 grafos generados
Botvinnik: 891 grafos generados
Bronstein: 1928 grafos generados
Capablanca: 597 grafos generados
Carlsen: 6613 grafos generados
Caruana: 5339 grafos generados
Chigorin: 687 grafos generados
DeLaBourdonnais: 0 grafos generados
Euwe: 1121 grafos generados
Fine: 304 grafos generados
Fischer: 825 grafos generados


illegal san: 'Qxe1' in r2k3r/2pPp3/p4n2/3b2B1/1p5P/2qP4/3RQ1P1/4K2R w - - 2 31 while parsing <Game at 0x25cac539a60 ('Gelfand,B' vs. 'Gareev,T', '2019.12.29' at 'Moscow RUS')>


Gelfand: 3969 grafos generados
Geller: 2195 grafos generados
Ivanchuk: 4949 grafos generados
Kamsky: 7035 grafos generados
Karjakin: 3534 grafos generados
Karpov: 3528 grafos generados
Kasparov: 2127 grafos generados
Keres: 1570 grafos generados
Korchnoi: 1038 grafos generados
Kramnik: 4323 grafos generados
Larsen: 2377 grafos generados
Lasker: 899 grafos generados
Leko: 2679 grafos generados
Maroczy: 754 grafos generados
Morphy: 211 grafos generados
Najdorf: 1603 grafos generados
Nimzowitsch: 511 grafos generados
Petrosian: 1892 grafos generados
Philidor: 6 grafos generados
Pillsbury: 387 grafos generados
Polugaevsky: 1889 grafos generados
Portisch: 3029 grafos generados
Reshevsky: 1265 grafos generados
Rubinstein: 796 grafos generados
Schlechter: 738 grafos generados
Smyslov: 2624 grafos generados
Spassky: 2229 grafos generados
Staunton: 283 grafos generados
Steinitz: 589 grafos generados
Tal: 2430 grafos generados
Tarrasch: 703 grafos generados
Timman: 3618 grafos generados
Topalov:

In [6]:
graphs_list = []
game_id_counter = 0

for player, data_list in graphs.items():
    for data in data_list:
        data.player = player
        data.game_id = game_id_counter
        graphs_list.append(data)
        game_id_counter += 1
    

In [7]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader
import pandas as pd
from tqdm import tqdm

# ======================
# Hiperparámetros
# ======================
INPUT_DIM = 773
HIDDEN_DIM = 256
EMBED_DIM = 128
BATCH_SIZE = 32

# ======================
# Modelo GCN Encoder
# ======================
class GNNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return global_mean_pool(x, batch)

# ======================
# Cargar grafos
# ======================
#graphs = torch.load("chess_graphs.pt")  # cada Data debe tener .player y .game_id
loader = DataLoader(graphs_list, batch_size=BATCH_SIZE, shuffle=False)

model = GNNEncoder(INPUT_DIM, HIDDEN_DIM, EMBED_DIM)
model.eval()

# ======================
# Embedding loop
# ======================
embeddings = []
players = []
game_ids = []

with torch.no_grad():
    for batch in tqdm(loader, desc="Generando embeddings"):
        z = model(batch.x, batch.edge_index, batch.batch)
        embeddings.append(z)
        graphs_in_batch = batch.to_data_list()
        players += [g.player for g in graphs_in_batch]
        game_ids += [g.game_id for g in graphs_in_batch]

embeddings = torch.cat(embeddings).cpu().numpy()

# ======================
# Guardar resultados
# ======================
df = pd.DataFrame(embeddings, columns=[f"dim_{i}" for i in range(EMBED_DIM)])
df["player"] = players
df["game_id"] = game_ids
df.to_csv("gnn_chess_embeddings.csv", index=False)

print("✅ Embeddings GNN guardados como CSV.")



Generando embeddings: 100%|██████████| 3120/3120 [01:28<00:00, 35.34it/s]


✅ Embeddings GNN guardados como CSV.


In [None]:
games_metadata = pd.read_csv("games_metadata.csv")
def detectar_color(row):
    player = row["player"].lower()
    white = row["White"].lower()
    black = row["Black"].lower()

    if player in white:
        return "white"
    elif player in black:
        return "black"
    else:
        return "unknown"

games_metadata["color"] = games_metadata.apply(detectar_color, axis=1)
games_metadata

  games_metadata = pd.read_csv("games_metadata.csv")


In [14]:
df["game_id"] = df["game_id"].apply(lambda x: int(x))
df

Unnamed: 0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,...,dim_120,dim_121,dim_122,dim_123,dim_124,dim_125,dim_126,dim_127,player,game_id
0,0.186119,0.224894,0.144705,-0.175782,0.095944,0.030560,0.151180,0.337592,-0.090849,-0.153542,...,-0.008651,-0.130952,0.114924,-0.046799,0.103017,0.171092,-0.003021,-0.125612,Alekhine,0
1,0.315813,0.277531,0.013814,0.044793,0.051820,0.215278,0.178388,0.506271,-0.308271,-0.155647,...,-0.041976,-0.350548,0.174448,-0.249423,0.019336,0.129544,0.313523,0.061583,Alekhine,1
2,0.208794,0.196072,0.024174,-0.053569,0.039854,0.341885,0.033939,0.310090,-0.162769,-0.138082,...,-0.007500,-0.204563,0.083665,-0.118874,0.021532,0.126749,0.176464,-0.055112,Alekhine,2
3,0.238679,0.223980,0.012921,-0.091717,-0.013169,0.264325,0.075385,0.463902,-0.412598,-0.262492,...,0.126914,-0.193142,0.042185,-0.243187,0.095005,0.163843,0.211810,-0.032563,Alekhine,3
4,0.254185,0.170033,0.070007,-0.025069,0.056162,0.023767,0.215903,0.461962,-0.098904,-0.190934,...,-0.106467,-0.235553,0.205192,-0.154784,0.162187,0.235318,0.193537,-0.035006,Alekhine,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99813,0.133941,0.234257,0.026412,-0.088559,0.031019,0.178800,0.121495,0.490427,-0.232153,-0.244242,...,-0.083335,-0.263199,0.040623,-0.190044,-0.056501,0.094648,0.180433,0.106629,Zukertort,99813
99814,0.311340,0.512791,0.033773,-0.228010,0.082287,0.085471,0.079640,0.385600,-0.176134,-0.268078,...,0.105749,-0.166975,0.173110,-0.244497,0.125968,0.142024,0.011218,-0.030287,Zukertort,99814
99815,0.205848,0.152802,0.211567,-0.089802,-0.037830,0.159888,0.028132,0.169020,-0.089636,-0.144241,...,-0.073481,-0.167524,0.102670,-0.117493,0.133395,0.176833,0.027140,-0.022743,Zukertort,99815
99816,0.283790,0.110991,-0.018982,-0.231278,0.170148,0.010545,0.192621,0.276377,-0.081099,-0.084389,...,-0.044635,-0.199848,0.057434,-0.032101,0.050490,0.140358,0.139633,-0.077303,Zukertort,99816


In [15]:
df_final = df.merge(games_metadata[["game_id", "color"]], on="game_id", how="left")

In [16]:
df_final

Unnamed: 0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,...,dim_121,dim_122,dim_123,dim_124,dim_125,dim_126,dim_127,player,game_id,color
0,0.186119,0.224894,0.144705,-0.175782,0.095944,0.030560,0.151180,0.337592,-0.090849,-0.153542,...,-0.130952,0.114924,-0.046799,0.103017,0.171092,-0.003021,-0.125612,Alekhine,0,black
1,0.315813,0.277531,0.013814,0.044793,0.051820,0.215278,0.178388,0.506271,-0.308271,-0.155647,...,-0.350548,0.174448,-0.249423,0.019336,0.129544,0.313523,0.061583,Alekhine,1,black
2,0.208794,0.196072,0.024174,-0.053569,0.039854,0.341885,0.033939,0.310090,-0.162769,-0.138082,...,-0.204563,0.083665,-0.118874,0.021532,0.126749,0.176464,-0.055112,Alekhine,2,white
3,0.238679,0.223980,0.012921,-0.091717,-0.013169,0.264325,0.075385,0.463902,-0.412598,-0.262492,...,-0.193142,0.042185,-0.243187,0.095005,0.163843,0.211810,-0.032563,Alekhine,3,white
4,0.254185,0.170033,0.070007,-0.025069,0.056162,0.023767,0.215903,0.461962,-0.098904,-0.190934,...,-0.235553,0.205192,-0.154784,0.162187,0.235318,0.193537,-0.035006,Alekhine,4,black
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99813,0.133941,0.234257,0.026412,-0.088559,0.031019,0.178800,0.121495,0.490427,-0.232153,-0.244242,...,-0.263199,0.040623,-0.190044,-0.056501,0.094648,0.180433,0.106629,Zukertort,99813,white
99814,0.311340,0.512791,0.033773,-0.228010,0.082287,0.085471,0.079640,0.385600,-0.176134,-0.268078,...,-0.166975,0.173110,-0.244497,0.125968,0.142024,0.011218,-0.030287,Zukertort,99814,black
99815,0.205848,0.152802,0.211567,-0.089802,-0.037830,0.159888,0.028132,0.169020,-0.089636,-0.144241,...,-0.167524,0.102670,-0.117493,0.133395,0.176833,0.027140,-0.022743,Zukertort,99815,white
99816,0.283790,0.110991,-0.018982,-0.231278,0.170148,0.010545,0.192621,0.276377,-0.081099,-0.084389,...,-0.199848,0.057434,-0.032101,0.050490,0.140358,0.139633,-0.077303,Zukertort,99816,black


In [17]:
df_final.to_csv("csvs\embeddings_gnn.csv", index=False)