# CHess-NN

In [1]:
import os
import io
import chess
import chess.pgn
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
import numpy as np

In [3]:
from tqdm import tqdm

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
def count_games(pgn_file):
    count = 0
    
    with open(pgn_file) as pgn:
        while True:
            game = chess.pgn.read_game(pgn)
            if game is None:
                break  # Fin del archivo
            count += 1
    
    return count

In [6]:
# Funciones para convertir el el tablero de ajedrez a tensor y visceversa
def board_to_tensor(board):
    pieces = ['P', 'N', 'B', 'R', 'Q', 'K', 'p', 'n', 'b', 'r', 'q', 'k']
    tensor = torch.zeros(12, 8, 8)
    for i, piece in enumerate(pieces):
        for pos in board.pieces(chess.Piece.from_symbol(piece).piece_type, chess.WHITE if piece.isupper() else chess.BLACK):
            tensor[i, pos // 8, pos % 8] = 1
    return tensor

def tensor_to_move(tensor):
    move_index = tensor.argmax().item()
    from_square = move_index // 64
    to_square = move_index % 64
    return chess.Move(from_square, to_square)

In [7]:
# pgn_file = "C:/Users/mated/Documents/GitHub/CHESS_DATA/lichess_db_standard_rated_2017-03.pgn"
pgn_file = "C:/Users/mated/Documents/GitHub/CHESS_DATA/lichess_db_standard_rated_2013-01.pgn"

In [8]:
class ChessDataset(Dataset):
    def __init__(self, pgn_file):
        self.games = []
        with open(pgn_file) as f:
            game = []
            for line in f:
                if line.startswith('[Event'):
                    if game:
                        self.games.append(game)
                    game = []
                if line.startswith('1.'):
                    moves = line.strip().split()[1:]
                    game.append(moves)
            if game:
                self.games.append(game)

        self.data = []
        for game in tqdm(self.games):
            board = chess.Board()
            fen_moves = []
            for moves in game:
                for move in moves:
                    try:
                        board.push_uci(move)
                        fen_moves.append((board.fen(), move))
                    except ValueError:
                        # Ignore invalid moves
                        pass
            self.data.append(fen_moves)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [9]:
class ChessDataset(Dataset):
    def __init__(self, pgn_file):
        self.games = []
        with open(pgn_file) as f:
            game = []
            for line in f:
                if line.startswith('[Event'):
                    if game:
                        self.games.append(game)
                    game = []
                if line.startswith('1.'):
                    moves = line.strip().split()[1:]
                    game.append(moves)
            if game:
                self.games.append(game)

        self.data = []
        for game in tqdm(self.games):
            board = chess.Board()
            fen_moves = []
            uci_moves = []
            for moves in game:
                for move in moves:
                    try:
                        uci_moves.append(move)
                        fen_moves.append((board.fen(), move))
                        board.push_uci(move)
                    except ValueError:
                        # Ignore invalid moves
                        pass
            self.data.append({'fen_moves': fen_moves, 'uci_moves': uci_moves})

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [40]:
class ChessDataset(Dataset):
    def __init__(self, pgn_file):
        self.games = []
        with open(pgn_file) as f:
            game = []
            for line in f:
                if line.startswith('[Event'):
                    if game:
                        self.games.append(game)
                    game = []
                if line.startswith('1.'):
                    moves = line.strip().split()[1:]
                    game.append(moves)
            if game:
                self.games.append(game)

        self.max_moves = max([len(game) for game in self.games])

        self.data = []
        for game in tqdm(self.games):
            board = chess.Board()
            fen_moves = []
            uci_moves = []
            for moves in game:
                for move in moves:
                    try:
                        uci_moves.append(move)
                        fen_moves.append((board.fen(), move))
                        board.push_uci(move)
                    except ValueError:
                        # Ignore invalid moves
                        pass
            # while len(fen_moves) < self.max_moves:
            #     fen_moves.append((' ', ' '))
            #     uci_moves.append('')
            self.data.append({'fen_moves': fen_moves, 'uci_moves': uci_moves})

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [41]:
%%time
dataset = ChessDataset(pgn_file)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

100%|██████████| 121332/121332 [17:11<00:00, 117.64it/s]


CPU times: total: 10min 21s
Wall time: 17min 16s


In [42]:
from sklearn.model_selection import train_test_split
import numpy as np

In [43]:
dataset.data[0]

{'fen_moves': [('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1',
   'e4'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'e6'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', '2.'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'd4'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'b6'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', '3.'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'a3'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'Bb7'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', '4.'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'Nc3'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'Nh6'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', '5.'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'Bxh6'),
  ('rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1', 'gxh6'

In [44]:
train_data, val_data = train_test_split(dataset.data, test_size=0.2, random_state=42)

In [45]:
len_dicts = {}
for key in list(train_data[0].keys()):
    len_dicts[f'len_{key}'] = [len(d[key]) for d in train_data]

In [46]:
max(len_dicts['len_fen_moves'])

1104

In [47]:
max(len_dicts['len_uci_moves'])

1104

In [20]:
batch_size = 64

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)


In [22]:
# Iterar a través del dataloader
for batch_idx, (data, target) in enumerate(train_loader):
    # Imprimir los primeros elementos del primer batch
    if batch_idx == 0:
        print("Data shape:", data.shape)
        print("Target shape:", target.shape)
        print("Data examples:")
        print(data[:5])
        print("Target examples:")
        print(target[:5])
    # Salir después del primer batch
    break

RuntimeError: each element in list of batch should be of equal size