In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from datasets import load_from_disk
from lightning_training import *
from data_process.tokenizers import FullMoveTokenizerNoEOS

from nanoGPT.model import GPTConfig
from lightning_training import *


In [37]:
def load_model_from_checkpoint(checkpoint_path, config):
    pl_model = LightningGPT.load_from_checkpoint(checkpoint_path, config=config)
    model = pl_model.model.cpu()
    model.eval()
    return model

In [22]:
def cut_game(row):
    game = row['piece_uci'].split(' ')
    cut = row['ply_30s']
    if cut == -1:
        return game
    else:
        return game[:cut]

In [42]:
tokenizer = FullMoveTokenizerNoEOS()

model_config = GPTConfig(
    block_size=301,
    vocab_size=len(tokenizer.vocab),
    n_layer=4,
    n_head=4,
    n_embd=256,
    bias=False,
)

In [40]:
games_bins = {}

for elo in range(1100, 2000, 100):
    dataset = load_from_disk(f"./data/huggingface_datasets/elo_bins/split/elo_{elo}/test")
    df = dataset.to_pandas()
    games = df.apply(cut_game, axis=1)
    games_bins[elo] = list(games)

In [65]:
from playing.agents import Agent, GPTNocheckAgent
import chess
from data_process.vocabulary import PieceMove
from playing.testing import is_legal_move
from tqdm import tqdm


def legal_accuracy(agent: Agent, games):
    illegal_moves = []
    legal_moves = 0
    total_moves = 0
    for game in tqdm(games):
        board = chess.Board()
        for move in game:
            agent_move = agent.play(board)
            if is_legal_move(board, agent_move):
                legal_moves += 1
            else:
                illegal_moves.append((board.fen(), agent_move))
            total_moves += 1
            board.push_uci(move[1:])
    print(f"{legal_moves} / {total_moves} = {legal_moves / total_moves}")
    return legal_moves / total_moves, illegal_moves

In [35]:
pm = PieceMove.from_uci('Pe2e4')
pm.from_square == chess.E2, pm.to_square == chess.E4, pm.piece_type == chess.PAWN

(True, True, True)

In [62]:
checkpoint_elo_1100 = "./lightning_logs/rating_bins/elo_1100/version_3/checkpoints/epoch=9-step=937500.ckpt"
checkpoint_elo_1200 = "./lightning_logs/rating_bins/elo_1200/version_0/checkpoints/epoch=9-step=937500.ckpt"
checkpoint_elo_1300 = "./lightning_logs/rating_bins/elo_1300/version_3/checkpoints/epoch=9-step=937500.ckpt"
checkpoint_elo_1400 = "./lightning_logs/rating_bins/elo_1400/version_2/checkpoints/epoch=9-step=937500.ckpt"
checkpoint_elo_1500 = "./lightning_logs/rating_bins/elo_1500/version_1/checkpoints/epoch=9-step=937500.ckpt"
checkpoint_elo_1600 = "./lightning_logs/rating_bins/elo_1600/version_5/checkpoints/epoch=9-step=937500.ckpt"
checkpoint_elo_1700 = "./lightning_logs/rating_bins/elo_1700/version_1/checkpoints/epoch=9-step=937500.ckpt"
checkpoint_elo_1800 = "./lightning_logs/rating_bins/elo_1800/version_0/checkpoints/epoch=9-step=937500.ckpt"
checkpoint_elo_1900 = "./lightning_logs/rating_bins/elo_1900/version_2/checkpoints/epoch=9-step=937500.ckpt"

checkpoints = {
    1100: checkpoint_elo_1100,
    1200: checkpoint_elo_1200,
    1300: checkpoint_elo_1300,
    1400: checkpoint_elo_1400,
    1500: checkpoint_elo_1500,
    1600: checkpoint_elo_1600,
    1700: checkpoint_elo_1700,
    1800: checkpoint_elo_1800,
    1900: checkpoint_elo_1900,
}

In [44]:
model = load_model_from_checkpoint(checkpoint_elo_1100, config=model_config)
agent = GPTNocheckAgent(model)

number of parameters: 4.19M


In [56]:
pm = PieceMove.from_uci('Pe7e8q')

In [67]:
accuracies = []
errors = []

for elo in range(1100, 2000, 100):
    model = load_model_from_checkpoint(checkpoints[elo], config=model_config)
    agent = GPTNocheckAgent(model)
    print(f"elo: {elo}")
    acc, illegal_moves = legal_accuracy(agent, games_bins[elo][:1000])
    accuracies.append(acc)
    errors.append(illegal_moves)

number of parameters: 4.19M
elo: 1100


100%|██████████| 1000/1000 [04:17<00:00,  3.89it/s]


55823 / 56422 = 0.989383573783276
number of parameters: 4.19M
elo: 1200


100%|██████████| 1000/1000 [04:14<00:00,  3.94it/s]


56514 / 57187 = 0.9882315910958784
number of parameters: 4.19M
elo: 1300


100%|██████████| 1000/1000 [04:42<00:00,  3.54it/s]


59546 / 60212 = 0.9889390819105827
number of parameters: 4.19M
elo: 1400


100%|██████████| 1000/1000 [04:12<00:00,  3.97it/s]


58834 / 59412 = 0.9902713256581162
number of parameters: 4.19M
elo: 1500


100%|██████████| 1000/1000 [04:23<00:00,  3.80it/s]


61421 / 62224 = 0.9870950115710979
number of parameters: 4.19M
elo: 1600


100%|██████████| 1000/1000 [04:33<00:00,  3.65it/s]


61905 / 62657 = 0.9879981486505898
number of parameters: 4.19M
elo: 1700


100%|██████████| 1000/1000 [04:50<00:00,  3.44it/s]


61912 / 62563 = 0.9895944887553346
number of parameters: 4.19M
elo: 1800


100%|██████████| 1000/1000 [04:41<00:00,  3.55it/s]


62069 / 62821 = 0.9880294805877016
number of parameters: 4.19M
elo: 1900


100%|██████████| 1000/1000 [04:25<00:00,  3.77it/s]

63621 / 64263 = 0.9900098034638906



