In [50]:
import chess
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn import functional as F
import random
import dataset
import model
import trainer
import utils
import dataset
import pickle
import IPython.display as vis
import chess.engine

In [51]:
vocab_size = 30
block_size = 512
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

In [52]:
games = open('data/datasets-cleaned/kingbase_cleaned.txt').read()
pretrain_dataset = dataset.PretrainDataset(games, block_size=block_size)

print(len(pretrain_dataset.stoi))
print(len(pretrain_dataset.itos))
print(pretrain_dataset.stoi)

Data has 592386175 characters, 30 unique.
30
30
{'□': 0, '\n': 1, ' ': 2, '#': 3, '+': 4, '-': 5, '1': 6, '2': 7, '3': 8, '4': 9, '5': 10, '6': 11, '7': 12, '8': 13, '=': 14, 'B': 15, 'K': 16, 'N': 17, 'O': 18, 'Q': 19, 'R': 20, 'a': 21, 'b': 22, 'c': 23, 'd': 24, 'e': 25, 'f': 26, 'g': 27, 'h': 28, 'x': 29}


In [53]:
# build model config
mconf = model.GPTConfig(
    vocab_size=vocab_size, 
    block_size=block_size, 
    n_layer=12, 
    n_head=32, 
    n_embd=128
)

# load model weights
model = model.GPT(mconf)
model.load_state_dict(torch.load('ckpt/model.iter.params', map_location=torch.device('cpu')))

Number of parameters: 2452736


<All keys matched successfully>

In [54]:
# load dataset
with open('cache/stoi.pkl', 'rb') as f: 
    stoi = pickle.load(f)
    print(len(stoi))
with open('cache/itos.pkl', 'rb') as f:
    itos = pickle.load(f)
    print(len(itos))

30
30


In [55]:
def get_prediction(game_str):

    x = game_str
    x = torch.tensor([stoi[s] for s in x], dtype=torch.long)
    x = x.view(1, -1)

    model.eval()
    with torch.no_grad():

        logits, _ = model(x)
        logits = torch.squeeze(logits)
        y_hat = torch.argmax(logits, dim=-1)
        y_hat = [itos[t.item()] for t in y_hat]

    pred = y_hat[len(game_str) - 1]
    return pred

In [64]:
def bot_vs_human(starting_pgn=''):
    first_bad_move = -1
    # run inference loop
    game_str = starting_pgn
    bot_move = ''
    board = chess.Board()

    for move in game_str.split():
        board.push_san(move)


    # declare a chess engine to support our bot if the bot will otherwise make an illegal move; we need to log illegal moves and prevent that from happening in the future
    engine = chess.engine.SimpleEngine.popen_uci("/usr/local/bin/stockfish")
    illegal_moves = []
    vis.display(board)
    bot_move_count = 0
    while True:
        #print(game_str)
        user_move = input('Enter move: ')
        
        if user_move == "resign":
            break
        else:
            assert board.parse_san(user_move) in board.legal_moves
            game_str += user_move + ' '
            board.push_san(user_move)
        
        if board.is_checkmate():
            display("CHECKMATE, PLAYER WINS")
            break

        vis.clear_output()
        vis.display(board)

        

        bot_move = ''
        while not bot_move.endswith(' '):
            pred = get_prediction(game_str + bot_move)
            bot_move += pred
        bot_move_count += 1
        try:
            board.push_san(bot_move[:-1])
            illegal_moves.append(0)
        except ValueError:
            illegal_moves.append(1)
            vis.display("ALERT ALERT ALERT: Bot move was illegal.  Computer move substituted.")
            bot_move = engine.play(board, chess.engine.Limit(time=5))
            bot_move_str = board.san(bot_move.move) + " "
            board.push(bot_move.move)
            bot_move = bot_move_str
            if first_bad_move == -1: first_bad_move = bot_move_count

        if board.is_checkmate():
            display("CHECKMATE, BOT WINS")
            break

        #print('Bot plays: {}'.format(bot_move))
        game_str += bot_move
        vis.clear_output()
        vis.display(board)

    return (game_str, illegal_moves, first_bad_move)

In [65]:
def bot_vs_stockfish(starting_pgn=''):
    first_bad_move = -1
    # run inference loop
    game_str = starting_pgn
    bot_move = ''
    board = chess.Board()

    for move in game_str.split():
        board.push_san(move)


    # declare a chess engine to support our bot if the bot will otherwise make an illegal move; we need to log illegal moves and prevent that from happening in the future
    engine = chess.engine.SimpleEngine.popen_uci("/usr/local/bin/stockfish")
    illegal_moves = []
    vis.display(board)
    bot_move_count = 0
    while True:
        #print(game_str)
        comp_move = engine.play(board, chess.engine.Limit(time=0.5))
        game_str += board.san(comp_move.move) + ' '
        board.push(comp_move.move)
        
        if board.is_checkmate():
            vis.clear_output()
            vis.display(board)
            display("CHECKMATE, STOCKFISH WINS")
            break

        vis.clear_output()
        vis.display(board)

        

        bot_move = ''
        while not bot_move.endswith(' '):
            pred = get_prediction(game_str + bot_move)
            bot_move += pred
        bot_move_count += 1
        try:
            board.push_san(bot_move[:-1])
            illegal_moves.append(0)
        except ValueError:
            illegal_moves.append(1)
            vis.display("ALERT ALERT ALERT: Bot move was illegal.  Computer move substituted.")
            bot_move = engine.play(board, chess.engine.Limit(time=0.1))
            bot_move_str = board.san(bot_move.move) + " "
            board.push(bot_move.move)
            bot_move = bot_move_str
            if first_bad_move == -1: first_bad_move = bot_move_count

        if board.is_checkmate():
            vis.clear_output()
            vis.display(board)
            display("CHECKMATE, BOT WINS")
            break

        #print('Bot plays: {}'.format(bot_move))
        game_str += bot_move
        vis.clear_output()
        vis.display(board)

    return (game_str, illegal_moves, first_bad_move)

game_str, illegal_moves, first_bad_move = bot_vs_stockfish()
vis.display(game_str)
vis.display(illegal_moves)
black_moves = int(len(game_str.split()) / 2)
white_moves = int((len(game_str.split()) + 1) / 2)
vis.display("Num Black Moves:", black_moves)
vis.display("Num White Moves:", white_moves)
vis.display("First Illegal Move:", first_bad_move)
vis.display("Total illegal Moves:", sum(illegal_moves))