In [None]:
import jax
import numpy as np
import jax.numpy as jnp

import sys

import chess
import chess.engine
from chess import Board, Move

from hijax.setup import setup_worker

from copy import deepcopy
from pprint import pprint

from IPython.display import display, clear_output
from time import sleep

In [None]:
sys.path.append("../")

In [None]:
from neural_chess import MODULE_NAME
from neural_chess.utils.data import one_hot_to_move, flat_repr_to_board, board_to_flat_repr, get_legal_move_mask

In [None]:
# load stockfish
engine = chess.engine.SimpleEngine.popen_uci("/usr/local/bin/stockfish")

In [None]:
# load the worker
worker, cfg = setup_worker(
    name="test7",
    module=MODULE_NAME,
    with_wandb=False,
    checkpoint_id="best",
    exp_dir="/Users/angusturner/experiments"
)

In [None]:
# grab a random batch of items
batch = next(worker.loaders.test.__iter__())
board_state = batch['board_state']
next_move = batch['next_move']
turn = batch['turn']
castling_rights = batch['castling_rights']
elo = batch['elo']
legal_moves = batch['legal_moves']
en_passant = batch['en_passant']

In [None]:
# view a random batch item
idx = 2
print(elo[idx] * 2500)
# elo[idx] = 0.9 
board = flat_repr_to_board(board_state[idx], turn[idx])

In [None]:
# view the actual move
board_next = deepcopy(board)
board_next.push(one_hot_to_move(next_move[idx]))
board_next

In [None]:
# get the model-predicted move
params = worker.params
rng = worker.rng_key

@jax.jit
def infer(board_state, turn, castling_rights, en_passant, elo, legal_moves, **_kwargs):
    logits = worker.forward.apply(
        params,
        rng, 
        board_state,
        turn,
        castling_rights,
        en_passant,
        elo,
        is_training=False
    )
    logits = jnp.where(legal_moves, logits, jnp.full_like(logits, -1e9))
    return jax.nn.softmax(logits, axis=-1)

In [None]:
move_probs = np.array(infer(**batch))[idx]

In [None]:
def sample_move(move_probs, greedy=False, topk=5):
    best_move_idxs = np.argsort(-move_probs)
    best_move_probs = move_probs[best_move_idxs]
    best_move_idxs = np.argsort(-move_probs)[:topk]
    best_move_probs = move_probs[best_move_idxs][:topk]
    move_preds = []
    for i, move_idx in enumerate(best_move_idxs):
        # convert to one-hot
        one_hot = np.zeros_like(move_probs)
        one_hot[move_idx] = 1
        move_pred = one_hot_to_move(one_hot)
        move_preds.append((move_pred, best_move_probs[i]))
        
    # renormalise and sample multinomial
    if greedy:
        move = move_preds[0][0]
    else:
        probs = np.array([x[1] for x in move_preds])
        probs = probs / np.sum(probs)  # re-normalise
        outcomes = np.random.multinomial(n=1, pvals=probs)
        idx, = np.where(outcomes)[0]
        move = move_preds[idx][0]
    
    return move, move_preds

In [None]:
# look at the top move prediction
move, move_preds = sample_move(move_probs)
print(f"Sampled {move}")
pprint(move_preds)

In [None]:
# view the predicted move
board_pred = deepcopy(board)
board_pred.push(move)
board_pred

In [None]:
info = engine.analyse(board, chess.engine.Limit(time=0.1))
print("Pre Move Score:", info["score"])
info = engine.analyse(board_next, chess.engine.Limit(time=0.1))
print("Score after next:", info["score"])
info = engine.analyse(board_pred, chess.engine.Limit(time=0.1))
print("Score after predicted:", info["score"])

In [None]:
# get the model to play itself!
def infer_from_board(board: Board, elos = (2000, 2000)):
    # encode the board
    board_state = board_to_flat_repr(board).astype(np.int32)
    
    # get the turn, castling rights, etc...
    turn = board.turn
    castling_rights = board.has_castling_rights(turn)
    elo = elos[int(turn)]
    elo = elo / 2500  # approx. in [0, 1]

    # is there an en-passant square?
    # - [0, 63] indicating the position that can be moved to with en-passant
    # - 64 indicating no en-passant rights
    en_passant = board.ep_square if board.ep_square else 64
    
    # legal moves mask!
    legal_moves = get_legal_move_mask(board).astype(bool)
    
    # convert stuff to arrays with batch dimension
    batch = {
        'board_state': board_state.reshape([1, -1]),
        'turn': np.array([board.turn]).astype(np.int32),
        'elo': np.array([elo]).astype(np.float32),
        'en_passant': np.array([en_passant]).astype(np.int32),
        'castling_rights': np.array([castling_rights]).astype(np.int32),
        'legal_moves': legal_moves.reshape([1, -1])
    }
    
    move_probs = np.array(infer(**batch))[0]
    next_move = sample_move(move_probs, greedy=False, topk=10)
    return next_move

In [None]:
board = Board()

In [None]:
display(board)
i = 0
while True:
    if i >= 50 or board.is_checkmate():
        break
    next_move, _ = infer_from_board(board)
    board.push(next_move)
    sleep(1.0)
    clear_output(wait=True)
    display(board)
    i += 1