In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils
import numpy as np
import matplotlib.pyplot as plt
import chess.pgn
import chess
import os, os.path

In [4]:
# ----- hyper parameteres -----
batch_size = 64
embed_dim = 128
head_size = 16
learning_rate = 3e-4
dropout = 0.1
n_layers = 12
n_classes = 5000
n_iters = 5000
eval_interval = 200
best_loss = 1e9
n_epochs = 100
weight_decay = 0.97
mode = "train"
# -----------------------------
n_tokens = 64 + 1
num_heads = embed_dim // head_size

In [None]:
def get_game_data(pgns_dir: str):
    dir = os.fsencode(pgns_dir)

    board_states = []
    moves = []

    for pgn in os.listdir(dir):
        pgn_path = os.path.join(pgns_dir, os.fsdecode(pgn))
        if os.path.getsize(pgn_path) < 10:
            continue

        with open(pgn_path) as f:
            while True:

                game = chess.pgn.read_game(f)
                if game is None:
                    break
                board = game.board()

                for move in game.mainline_moves():
                    board_states.append(board.fen())
                    moves.append(move.uci())
                    board.push(move)
        
    return board_states, moves
        
        

board_states, moves = get_game_data("pgns")

print(board_states[:10], moves[:10])

In [None]:
def encode_board_state(fen: str):
    piece_to_id = {
        'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6,
        'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12,
        '.': 0
    }

    board_part, active_color, castling_rights, en_passant, fifty_move, full_move = fen.split(' ')
    rows = board_part.split('/')
    board_array = np.zeros((8, 8), dtype=int)
    
    for i, row in enumerate(rows):
        col = 0
        for char in row:
            if char.isdigit():
                col += int(char)
            else:
                board_array[i, col] = piece_to_id[char]
                col += 1

    active_color_id = 1 if active_color == 'w' else 0

    castling_to_id = {
        'K': 1, 'Q': 2, 'k': 4, 'q': 8
    }

    if castling_rights == '-':
        castling_id = 0
    else:
        castling_id = sum(castling_to_id[c] for c in castling_rights)

    en_passant_id = {
        'a'=1, 'b'=2, 'c'=3, 'd'=4, 'e'=5, 'f'=6, 'g'=7, 'h'=8, '-':0
    }

    if en_passant == '-':
        en_passant_id = 0
    else:
        if en_passant[1] == '6':
            en_passant_id = en_passant_id[en_passant[0]] + 8
        else:
            en_passant_id = en_passant_id[en_passant[0]]

    board_state_vec = np.concatenate([board_array.flatten(), 
                                  [active_color_id], 
                                  [castling_id], 
                                  [en_passant_id], [fifty_move], [full_move]])
    
                
    return board_state_vec

def encode_move(uci: str):
    file_to_id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7}
    rank_to_id = {'1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5, '7': 6, '8': 7}

    from_file = file_to_id[uci[0]]
    from_rank = rank_to_id[uci[1]]
    to_file = file_to_id[uci[2]]
    to_rank = rank_to_id[uci[3]]

    move_id = from_rank * 8 + from_file
    move_id = move_id * 64 + (to_rank * 8 + to_file)

    move_vec = np.zeros(4096)
    move_vec[move_id] = 1

    promotion_id = 0
    promotion_vec = [0] * 4
    if len(uci) == 5:
        promotion_piece = uci[4]
        promotion_to_id = {'q': 1, 'r': 2, 'b': 3, 'n': 4}
        promotion_id = promotion_to_id[promotion_piece]
        promotion_vec[promotion_id - 1] = 1
    
    move_vec = np.concatenate([move_vec, promotion_vec])

    move_vec = torch.tensor(move_vec, dtype=torch.float16)

    return move_vec
# 64 squares * 64 squares + 4 promotion types = 4096 + 4 = 4100