In [1]:
import time
import chess
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn import functional as F
import random
import dataset
import model
import trainer
import utils
import dataset
import pickle
import IPython.display as vis
import chess.engine

In [2]:
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

In [3]:
# get ckpt

ckpt = torch.load('ckpts/finetune_default/iter_64000.pt', map_location=torch.device(device))
model_config = ckpt['model_config']
itos = ckpt['itos']
stoi = ckpt['stoi']


# build model config
mconf = model.GPTConfig(
    vocab_size=len(itos), 
    args_dict=model_config.__dict__
)

# load model weights
model = model.GPT(mconf)
model = model.to(device)

model.load_state_dict(ckpt['state_dict'])

Number of parameters: 2452736


<All keys matched successfully>

In [4]:
MASK_CHAR = u"\u2047"

In [5]:
def get_prediction(game_str, sample=False):
    x = game_str + MASK_CHAR
    x = torch.tensor([stoi[s] for s in x], dtype=torch.long)[None,...].to(device)
    
    pred = utils.sample(model, x, 10, sample=sample)[0]
    completion = ''.join([itos[int(i)] for i in pred])
    pred = completion.split(MASK_CHAR)[1].split(' ')[0]
    
    return pred

In [6]:
# declare a chess engine to support our bot if the bot will otherwise make an illegal move; we need to log illegal moves and prevent that from happening in the future
engine = chess.engine.SimpleEngine.popen_uci("/usr/local/bin/stockfish")

In [7]:
def bot_vs_human(starting_pgn='', display_board=True):
    winner = None
    final_illegal = 0
    first_bad_move = -1

    # run inference loop
    game_str = starting_pgn
    bot_move = ''
    board = chess.Board()

    for move in game_str.split():
        board.push_san(move)

    illegal_moves = []
    if display_board:
        vis.display(board)
    bot_move_count = 0
    while True:
        user_move = input('Enter move: ')
        
        if user_move == "resign":
            break
        else:
            assert board.parse_san(user_move) in board.legal_moves
            game_str += user_move + ' '
            board.push_san(user_move)
        
        if board.is_checkmate():
            display("CHECKMATE, PLAYER WINS")
            break

        vis.clear_output()
        vis.display(board)

        # bot turn
        # handle cases where game str is larger than block size
        if len(game_str) >= 504:
            if display_board:
                vis.display("ALERT: Game string too long.  ChEPT resigns.")
                vis.clear_output()
                vis.display(board)
                display("CHECKMATE, CHePT RESIGNS")
            break
        bot_move = get_prediction(game_str)
        bot_move_count += 1

        try:
            board.push_san(bot_move)
            illegal_moves.append(0)
        except ValueError:
            if first_bad_move == -1: first_bad_move = bot_move_count
            illegal_moves.append(1)
            if display_board:
                vis.display("ALERT ALERT ALERT: Bot move was illegal.  Computer move substituted.")
                
            # try re-sampling 5 times
            success = False
            for i in range(5):
                bot_move = get_prediction(game_str, sample=True)
                
                try:
                    board.push_san(bot_move)
                    success = True
                    break
                except:
                    pass

            if not success:
                final_illegal += 1
                bot_move = engine.play(board, chess.engine.Limit(time=0.05))
                bot_move_str = board.san(bot_move.move)
                board.push(bot_move.move)
                bot_move = bot_move_str
            

        if board.is_checkmate():
            winner = "BOT"
            if display_board:
                vis.clear_output()
                vis.display(board)
                display("CHECKMATE, BOT WINS")
            break

        game_str = game_str + bot_move + ' '
        if display_board:
            vis.clear_output()
            vis.display(board)

    return (game_str, illegal_moves, first_bad_move, final_illegal, winner)

In [8]:
def bot_vs_stockfish(starting_pgn='', display_board=True):
    winner = None
    final_illegal = 0
    first_bad_move = -1

    # run inference loop
    game_str = starting_pgn
    bot_move = ''
    board = chess.Board()

    for move in game_str.split():
        board.push_san(move)

    illegal_moves = []
    if display_board:
        vis.display(board)
    bot_move_count = 0
    while True:
        comp_move = engine.play(board, chess.engine.Limit(time=0.05))
        game_str += board.san(comp_move.move) + ' '
        board.push(comp_move.move)
        
        if board.is_checkmate():
            winner = "STOCKFISH"
            if display_board:
                vis.clear_output()
                vis.display(board)
                display("CHECKMATE, STOCKFISH WINS")
            break
        if display_board:
            vis.clear_output()
            vis.display(board)

        # bot turn
        # handle cases where game str is larger than block size
        if len(game_str) >= 504:
            if display_board:
                vis.display("ALERT: Game string too long.  ChEPT resigns.")
                vis.clear_output()
                vis.display(board)
                display("CHECKMATE, CHePT RESIGNS")
            break
        bot_move = get_prediction(game_str)
        bot_move_count += 1

        try:
            board.push_san(bot_move)
            illegal_moves.append(0)
        except ValueError:
            if first_bad_move == -1: first_bad_move = bot_move_count
            illegal_moves.append(1)
            if display_board:
                vis.display("ALERT ALERT ALERT: Bot move was illegal.  Computer move substituted.")
                
            # try re-sampling 5 times
            success = False
            for i in range(5):
                bot_move = get_prediction(game_str, sample=True)
                
                try:
                    board.push_san(bot_move)
                    success = True
                    break
                except:
                    pass

            if not success:
                final_illegal += 1
                bot_move = engine.play(board, chess.engine.Limit(time=0.05))
                bot_move_str = board.san(bot_move.move)
                board.push(bot_move.move)
                bot_move = bot_move_str
            

        if board.is_checkmate():
            winner = "BOT"
            if display_board:
                vis.clear_output()
                vis.display(board)
                display("CHECKMATE, BOT WINS")
            break

        game_str = game_str + bot_move + ' '
        if display_board:
            vis.clear_output()
            vis.display(board)

    return (game_str, illegal_moves, first_bad_move, final_illegal, winner)

In [9]:
num_illegal_moves = []
first_illegal_move = []
total_black_moves = []
final_illegal_moves = []
winners = []

# TODO: Evaluate ratio stockfish / (our move) per game & compare models
# TODO: At some point cut the data, reload params, only tune on mid-late game?
for i in tqdm(range(5)):
    game_str, illegal_moves, first_bad_move, final_illegal, winner = bot_vs_stockfish(display_board=False)
    winners.append(winner)
    final_illegal_moves.append(final_illegal)
    black_moves = int(len(game_str.split()) / 2)
    total_black_moves.append(black_moves)
    first_illegal_move.append(first_bad_move)
    num_illegal_moves.append(sum(illegal_moves))

100%|██████████| 5/5 [01:36<00:00, 19.24s/it]


In [12]:
z = np.array(first_illegal_move)
curated_first_illegal = z[z != -1]


print(f'Analyzed {i + 1} games...')
print('On average, ChePT made:')
print(f'\t\t\t{int(np.mean(total_black_moves))} moves per game.')
print(f'\t\t\tFirst illegal move on move {int(np.mean(curated_first_illegal))}.')
print(f'\t\t\t{int(np.mean(num_illegal_moves))} illegal moves per game.')
print(f'\t\t\t{int(np.mean(final_illegal_moves))} final illegal moves per game.')

print('')
percent = np.round(np.mean(np.array(num_illegal_moves) / np.array(total_black_moves)) * 100, 3)
print(f'ChePT makes an illegal move {percent}% of the time')

n_bot_wins = np.sum(np.array(winners) == 'BOT')

print(f'\nChePT managed to win {n_bot_wins} games.')

Analyzed 5 games...
On average, ChePT made:
			24 moves per game.
			First illegal move on move 13.
			3 illegal moves per game.
			0 final illegal moves per game.

ChePT makes an illegal move 15.449% of the time

ChePT managed to win 0 games.


In [None]:
game_str, illegal_moves, first_bad_move, final_illegal, winner = bot_vs_human()
vis.display(game_str)
vis.display(illegal_moves)
black_moves = int(len(game_str.split()) / 2)
white_moves = int((len(game_str.split()) + 1) / 2)
vis.display("Num Black Moves:", black_moves)
vis.display("Num White Moves:", white_moves)
vis.display("First Illegal Move:", first_bad_move)
vis.display("Total Illlegal Moves:", sum(illegal_moves))

In [None]:
print(game_str)