<a href="https://colab.research.google.com/github/DashShantanu/chess-engine/blob/main/chess_engine.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
! pip install kaggle -q
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 /root/.kaggle/kaggle.json
! kaggle datasets download -d arevel/chess-games
! unzip -qq /content/chess-games.zip

# dataset url
# https://www.kaggle.com/datasets/arevel/chess-games

# !nvcc --version
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip3 install torchvision

mkdir: cannot create directory ‘/root/.kaggle’: File exists
Downloading chess-games.zip to /content
100% 1.45G/1.45G [00:15<00:00, 112MB/s] 
100% 1.45G/1.45G [00:15<00:00, 98.8MB/s]


In [3]:
! pip install chess -q
import chess

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/154.4 kB[0m [31m1.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.4/154.4 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
# Column indexes mapping from letter to num and vice-versa
letter_to_num = { 'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7 }
num_to_letter = { 0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h' }

In [5]:
import numpy as np
import pandas as pd
import re

# Chess-board to matrix representation
def board_to_rep(board):
  pieces = ['p', 'r', 'n', 'b', 'q', 'k']
  layers = []

  for piece in pieces:
    layers.append(create_rep_layer(board, piece))

  # combining into a single 3-D tensor
  board_rep = np.stack(layers)

  return board_rep

In [6]:
# Create a layer of the matrix representation, white value is positive, black is negative
def create_rep_layer(board, type):
    # convert board object to string
    s = str(board)

    # remove all characters except for the type we are looking for
    s = re.sub(f'[^{type}{type.upper()} \n]', '.', s)
    # replace all occurences of the type in uppercase with 1 (white pieces)
    s = re.sub(f'[{type.upper()}]', '1', s)
    # replace all occurences of the type in lowercase with -1 (black pieces)
    s = re.sub(f'{type}', '-1', s)
    # replace all occurences of empty spaces or other characters with 0
    s = re.sub(f'\.', '0', s)

    board_matrix = []
    for row in s.split('\n'):
        # split each row into a list of characters
        row = row.split(' ')
        # convert each character to an integer
        row = [int(x) for x in row]
        # append the row to the board matrix
        board_matrix.append(row)

    return np.array(board_matrix)

In [7]:
# chess-move to matrix representation
def move_to_rep(move, board):
    # make the move on the board and convert move to uci format
    board.push_san(move).uci()

    # Make a copy of the original board to preserve its state
    board_copy = board.copy()
    # convert board object to string
    move = str(board_copy.pop())

    from_output_layer = np.zeros((8,8))
    from_row = 8 - int(move[1])
    from_column = letter_to_num[move[0]]
    from_output_layer[from_row, from_column] = 1

    to_output_layer = np.zeros((8,8))
    to_row = 8 - int(move[3])
    to_column = letter_to_num[move[2]]
    to_output_layer[to_row, to_column] = 1

    return np.stack([from_output_layer, to_output_layer])

In [8]:
# Break down game into individual moves
def create_move_list(s):
    # remove the move numbers and periods
    # split the string into a list of moves
    # The last element is excluded because it often contains empty space due to the trailing space after the last move in the original string
    return re.sub('\d*\. ', '', s).split(' ')[ : -1]

Loading the Chess Dataset

In [9]:
chess_data_raw = pd.read_csv('/content/chess_games.csv', usecols=['AN', 'WhiteElo'])
chess_data = chess_data_raw[chess_data_raw['WhiteElo'] > 2000]

In [10]:
import gc
# remove rejected games from memory
del chess_data_raw
gc.collect()

0

In [11]:
chess_data = chess_data[['AN']]
chess_data = chess_data[~chess_data['AN'].str.contains('{')]
chess_data = chess_data[chess_data['AN'].str.len() > 20]

In [12]:
print(chess_data.shape[0])

883376
