In [81]:
import os
import pickle
import shutil
import chess
import chess.engine
import pandas as pd
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.optim import Adam
from stockfish import Stockfish
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [86]:
KAGGLE = False

if KAGGLE:
    input_stockfish_path = "/kaggle/input/stockfish_ubuntu/other/default/1/stockfish-ubuntu-x86-64-avx2"
    stockfish_path = "/kaggle/working/stockfish"
    if not os.path.exists(stockfish_path):
        shutil.copy(input_stockfish_path, stockfish_path)
    os.chmod(stockfish_path, 0o755)
else:
    stockfish_path = "./stockfish-ubuntu-x86-64-avx2"
    os.chmod(stockfish_path, 0o755)
    
stockfish = Stockfish(path=stockfish_path, depth=15, parameters={"Skill Level": 20, "Threads": 4})


In [87]:
def convert_board_to_bitarray(board):
    bitboards = np.zeros((16, 8, 8), dtype=np.uint8)
    
    for i, (pt, color) in enumerate([(pt, color) for pt in chess.PIECE_TYPES for color in chess.COLORS]):
        bitboard = board.pieces_mask(pt, color)
        for square in chess.SQUARES:
            if bitboard & (1 << square):
                bitboards[i, square // 8, square % 8] = 1

    # Player to move
    bitboards[12, :, :] = int(board.turn)

    # Castling rights
    castling_kingside_white = bool(board.castling_rights & chess.BB_H1)
    castling_queenside_white = bool(board.castling_rights & chess.BB_A1)
    castling_kingside_black = bool(board.castling_rights & chess.BB_H8)
    castling_queenside_black = bool(board.castling_rights & chess.BB_A8)
    
    bitboards[13, 0, 0] = castling_kingside_white
    bitboards[13, 0, 1] = castling_queenside_white
    bitboards[13, 1, 0] = castling_kingside_black
    bitboards[13, 1, 1] = castling_queenside_black

    # En passant square
    if board.ep_square:
        ep_rank, ep_file = divmod(board.ep_square, 8)
        bitboards[14, ep_rank, ep_file] = 1

    return bitboards


In [88]:
def generate_stockfish_training_data(stockfish, num_games=10000, max_moves_per_game=50):
    training_data = []
    
    for _ in tqdm(range(num_games), desc="Generating Training Data"):
        board = chess.Board()
        for _ in range(max_moves_per_game):
            if board.is_game_over():
                break

            stockfish.set_fen_position(board.fen())
            best_move = chess.Move.from_uci(stockfish.get_best_move())
            board_state = convert_board_to_bitarray(board)
            
            training_data.append((board_state, best_move))
            
            board.push(best_move)
            if board.is_game_over():
                break
            
            if random.uniform(0, 1) < 0.5:
                opponent_move = random.choice(list(board.legal_moves))
            else:
                stockfish.set_fen_position(board.fen())
                opponent_move = chess.Move.from_uci(stockfish.get_best_move())
                
            board.push(opponent_move)
    
    return training_data

In [89]:
training_data = generate_stockfish_training_data(stockfish, max_moves_per_game=50)

Generating Training Data: 100%|██████████| 10000/10000 [3:13:07<00:00,  1.16s/it] 


In [90]:
def save_training_data_pickle(training_data, file_path="training_data.pkl"):
    with open(file_path, "wb") as f:
        pickle.dump(training_data, f)
    print(f"Training data saved to {file_path}")

def load_training_data_pickle(file_path="training_data.pkl"):
    with open(file_path, "rb") as f:
        return pickle.load(f)

In [91]:
save_training_data_pickle(training_data)

Training data saved to training_data.pkl


In [92]:
loaded_data = load_training_data_pickle()

In [93]:
loaded_data[0]

(array([[[0, 0, 0, ..., 0, 0, 0],
         [1, 1, 1, ..., 1, 1, 1],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [1, 1, 1, ..., 1, 1, 1],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 1, 0, ..., 0, 1, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        ...,
 
        [[1, 1, 0, ..., 0, 0, 0],
         [1, 1, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],