In [None]:
import os
import random
import time
from math import log, sqrt
from typing import Callable, List

import chess
import chess.pgn
import chess.svg
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from collections import defaultdict
from __future__ import annotations
from abc import abstractmethod

In [None]:
def board_to_bitfields(board: chess.Board, turn: chess.Color) -> np.ndarray:

    pieces_array = []
    colors = [chess.WHITE, chess.BLACK]
    for c in colors if turn == chess.WHITE else colors[::-1]:
        for p in (chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING):
            pieces_array.append(board.pieces_mask(p, c))

    return np.array(pieces_array).astype(np.int64)


def bitfield_to_nums(bitfield: np.int64, white: bool) -> np.ndarray:

    board_array = np.zeros(64).astype(np.float32)

    for i in np.arange(64).astype(np.int64):
        if bitfield & (1 << i):
            board_array[i] = 1. if white else -1.

    return board_array


def bitfields_to_nums(bitfields: np.ndarray) -> np.ndarray:
    bitfields = bitfields.astype(np.int64)

    boards = []

    for i, bitfield in enumerate(bitfields):
        boards.append(bitfield_to_nums(bitfield, i < 6))

    return np.array(boards).astype(np.float32)


def board_to_nums(board: chess.Board, turn: chess.Color) -> np.ndarray:

    return bitfields_to_nums(board_to_bitfields(board, turn))


In [None]:
total_time = 0
total_evaluations = 0

In [None]:
piece = {chess.PAWN: 100, chess.KNIGHT: 280, chess.BISHOP: 320,
         chess.ROOK: 479, chess.QUEEN: 929, chess.KING: 60000}
pst = {
    chess.PAWN: (0,   0,   0,   0,   0,   0,   0,   0,
                 78,  83,  86,  73, 102,  82,  85,  90,
                 7,  29,  21,  44,  40,  31,  44,   7,
                 -17,  16,  -2,  15,  14,   0,  15, -13,
                 -26,   3,  10,   9,   6,   1,   0, -23,
                 -22,   9,   5, -11, -10,  -2,   3, -19,
                 -31,   8,  -7, -37, -36, -14,   3, -31,
                 0,   0,   0,   0,   0,   0,   0,   0),
    chess.KNIGHT: (-66, -53, -75, -75, -10, -55, -58, -70,
                   -3,  -6, 100, -36,   4,  62,  -4, -14,
                   10,  67,   1,  74,  73,  27,  62,  -2,
                   24,  24,  45,  37,  33,  41,  25,  17,
                   -1,   5,  31,  21,  22,  35,   2,   0,
                   -18,  10,  13,  22,  18,  15,  11, -14,
                   -23, -15,   2,   0,   2,   0, -23, -20,
                   -74, -23, -26, -24, -19, -35, -22, -69),
    chess.BISHOP: (-59, -78, -82, -76, -23, -107, -37, -50,
                   -11,  20,  35, -42, -39,  31,   2, -22,
                   -9,  39, -32,  41,  52, -10,  28, -14,
                   25,  17,  20,  34,  26,  25,  15,  10,
                   13,  10,  17,  23,  17,  16,   0,   7,
                   14,  25,  24,  15,   8,  25,  20,  15,
                   19,  20,  11,   6,   7,   6,  20,  16,
                   -7,   2, -15, -12, -14, -15, -10, -10),
    chess.ROOK: (35,  29,  33,   4,  37,  33,  56,  50,
                 55,  29,  56,  67,  55,  62,  34,  60,
                 19,  35,  28,  33,  45,  27,  25,  15,
                 0,   5,  16,  13,  18,  -4,  -9,  -6,
                 -28, -35, -16, -21, -13, -29, -46, -30,
                 -42, -28, -42, -25, -25, -35, -26, -46,
                 -53, -38, -31, -26, -29, -43, -44, -53,
                 -30, -24, -18,   5,  -2, -18, -31, -32),
    chess.QUEEN: (6,   1,  -8, -104,  69,  24,  88,  26,
                  14,  32,  60, -10,  20,  76,  57,  24,
                  -2,  43,  32,  60,  72,  63,  43,   2,
                  1, -16,  22,  17,  25,  20, -13,  -6,
                  -14, -15,  -2,  -5,  -1, -10, -20, -22,
                  -30,  -6, -13, -11, -16, -11, -16, -27,
                  -36, -18,   0, -19, -15, -15, -21, -38,
                  -39, -30, -31, -13, -31, -36, -34, -42),
    chess.KING: (4,  54,  47, -99, -99,  60,  83, -62,
                 -32,  10,  55,  56,  56,  55,  10,   3,
                 -62,  12, -57,  44, -67,  28,  37, -31,
                 -55,  50,  11,  -4, -19,  13,   0, -49,
                 -55, -43, -52, -28, -51, -47,  -8, -50,
                 -47, -42, -43, -79, -64, -32, -29, -32,
                 -4,   3, -14, -50, -57, -18,  13,   4,
                 17,  30,  -3, -14,   6,  -1,  40,  18),
}


def evaluate_position_static(board: chess.Board, turn: chess.Color) -> float:
    global total_evaluations
    total_evaluations += 1
    # sum up the total weighted pieces for each player and return a score

    total = 0

    if board.is_game_over():
        return -1000 if board.turn == turn else 1000

    for p in (chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING):
        for c in board.pieces(p, turn):
            total += piece[p] + pst[p][c]

    return total


In [None]:

model = load_model('training/001model184.h5')

all_predictions = []

def evaluate_position_with_NN(board: chess.Board, turn: chess.Color) -> float:
    global total_evaluations, all_predictions, total_time

    start = time.time()

    board_array = board_to_nums(board, turn)
    prediction = model(np.asarray([board_array.flatten()]), training=False)

    total_time += time.time() - start
    total_evaluations += 1
    all_predictions.append(board_array.flatten())
    return prediction[0][0]

def time_predict_all():
    global all_predictions

    start = time.time()
    prediction = model(np.asarray(all_predictions * 100), training=False)

    all_predictions = []

    print("Time to predict all", time.time() - start)


In [None]:
class MCTSNode:
    def __init__(self, state: MCTSState, parent: MCTSNode = None, parent_action=None):
        self.state = state
        self.parent = parent
        self.parent_action = parent_action
        self.children = []
        self._number_of_visits = 0
        self._results = defaultdict(int)
        self._results[1] = 0
        self._results[-1] = 0
        self._untried_actions = None
        self._untried_actions = self.untried_actions
        
    @property
    def untried_actions(self):
        self._untried_actions = self.state.get_legal_actions()
        return self._untried_actions

    @property
    def q(self):
        wins = self._results[1]
        loses = self._results[-1]
        return wins - loses

    @property
    def n(self):
        return self._number_of_visits

    @property 
    def is_terminal_node(self):
        return self.state.is_game_over()

    @property 
    def is_fully_expanded(self):
        return len(self._untried_actions) == 0

    def expand(self):        
        action = self._untried_actions.pop()
        next_state = self.state.move(action)
        child_node = MCTSNode(next_state, parent=self, parent_action=action)

        self.children.append(child_node)
        return child_node

    def rollout(self):
        current_rollout_state = self.state
        
        while not current_rollout_state.is_game_over():
            
            possible_moves = current_rollout_state.get_legal_actions()
            
            action = self.state.rollout(possible_moves)
            current_rollout_state = current_rollout_state.move(action)
        return current_rollout_state.game_result()

    def backpropagate(self, result: int):
        self._number_of_visits += 1.
        self._results[result] += 1.
        if self.parent:
            self.parent.backpropagate(result)

    def best_child(self, c_param:float = 0.1):    
        choices_weights = [(c.q / c.n) + c_param * np.sqrt((2 * np.log(self.n) / c.n)) for c in self.children]
        return self.children[np.argmax(choices_weights)]

    def _tree_policy(self):
        current_node = self
        while not current_node.is_terminal_node:
            
            if not current_node.is_fully_expanded:
                return current_node.expand()
            else:
                current_node = current_node.best_child()
        return current_node

    def best_action(self, simulation_no: int = 100):
        for _ in range(simulation_no):            
            v = self._tree_policy()
            reward = v.rollout()
            v.backpropagate(reward)
        
        return self.best_child(c_param=0.)

class MCTSState:
    @abstractmethod
    def get_legal_actions(self) -> List: 
        '''
        Modify according to your game or
        needs. Constructs a list of all
        possible states from current state.
        Returns a list.
        '''
        return []

    @abstractmethod
    def is_game_over(self) -> bool:
        '''
        Modify according to your game or 
        needs. It is the game over condition
        and depends on your game. Returns
        true or false
        '''
        return False

    @abstractmethod
    def game_result(self) -> int:
        '''
        Modify according to your game or 
        needs. Returns 1 or 0 or -1 depending
        on your state corresponding to win,
        tie or a loss.
        '''
        return 0

    @abstractmethod
    def move(self, action: chess.Move) -> None:
        '''
        Modify according to your game or 
        needs. Changes the state of your 
        board with a new value. For a normal
        Tic Tac Toe game, it can be a 3 by 3
        array with all the elements of array
        being 0 initially. 0 means the board 
        position is empty. If you place x in
        row 2 column 3, then it would be some 
        thing like board[2][3] = 1, where 1
        represents that x is placed. Returns 
        the new state after making a move.
        '''
        pass

    @abstractmethod
    def rollout(self, possible_moves: List[chess.Board]):
        return possible_moves[np.random.randint(len(possible_moves))]

class ChessState:

    def __init__(self, state: chess.Board, turn: chess.Color):
        self.state = state
        self.turn = turn

    def get_legal_actions(self) -> List: 
        return list(self.state.legal_moves)

    def is_game_over(self) -> bool:
        return self.state.is_game_over()

    def game_result(self) -> int:
        result = self.state.result()
        if result == '1-0':
            return 1
        elif result == '0-1':
            return -1
        else:
            return 0

    def move(self, action: chess.Move) -> MCTSState:
        new_state = self.state.copy()
        new_state.push(action)
        return ChessState(new_state, not self.state.turn)

    def rollout(self, possible_moves: List[chess.Move]) -> chess.Move:
        # TODO use NN here?
        self._evaluate_position_with_NN(possible_moves)
        return possible_moves[np.random.randint(len(possible_moves))]

    def _evaluate_position_with_NN(self, moves: List[chess.Move]) -> float:
        global total_evaluations, all_predictions, total_time

        start = time.time()

        boards = [self.state.copy() for _ in range(len(moves))]
        for i, move in enumerate(moves):
            boards[i].push(move)
            
        boards_array = [board_to_nums(board, self.turn).flatten() for board in boards]
        prediction = model(np.asarray(boards_array), training=False)

        total_time += time.time() - start
        total_evaluations += len(boards)
        all_predictions.append(boards_array)
        print(f'Evaluated {len(boards)} boards in {time.time() - start} seconds, predictions: {prediction}')
        return prediction[0][0]

In [None]:
class Node:
    def __init__(self, state: chess.Board, move: chess.Move = None, parent=None):
        self.move = move
        self.state = state
        self.parent = parent
        self.unexplored_moves = list(self.state.legal_moves)
        self.children = []
        self.visits = 0
        self.wins = 0

    def add_child(self, state, move):
        child_node = Node(state, move, self)
        self.children.append(child_node)
        self.unexplored_moves.remove(move)
        return child_node

    def UCT_select_child(self):
        s = sorted(
            self.children,
            key=lambda c:
                c.wins / c.visits + sqrt(2 * log(self.visits) / c.visits)
        )
        return s[-1]

    def Update(self, result: float):
        self.visits += 1
        self.wins += result


def UCT(rootstate: chess.Board, itermax: int, depthmax: int, evaluation: Callable) -> chess.Move:
    rootnode = Node(state=rootstate)
    for i in range(itermax):
        node = rootnode
        depth = 0
        state = rootstate.copy()

        # Select
        while node.unexplored_moves == [] and node.children != []:  # node is fully expanded and non-terminal
            node = node.UCT_select_child()
            state.push(node.move)

        # Expand
        # if we can expand (i.e. state/node is non-terminal)
        if node.unexplored_moves != []:
            m = random.choice(node.unexplored_moves)
            state.push(m)
            node = node.add_child(state, m)  # add child and descend tree
            depth += 1

        # Rollout - this can often be made orders of magnitude quicker using a state.GetRandomMove() function
        while list(state.legal_moves) != [] and depth < depthmax:  # while state is non-terminal
            state.push(random.choice(list(state.legal_moves)))
            depth += 1

        # Backpropagate
        result = evaluation(state, state.turn)
        while node != None:  # backpropagate from the expanded node and work back to the root node
            # state is terminal. Update node with result from POV of node.playerJustMoved
            node.Update(result)
            node = node.parent

    return sorted(rootnode.children, key=lambda c: c.visits)[-1].move

def mcts_player(board: chess.Board, evaluation: Callable, itermax: int = 500, depthmax: int = 30) -> chess.Move:
    for move_choice in board.legal_moves:
        copy = board.copy()
        copy.push(move_choice)
        if copy.is_game_over():
            board.push(move_choice)
            return

    root = MCTSNode(ChessState(board, board.turn), None)
    move = root.best_action()
    # move = UCT(board, itermax, depthmax, evaluation)
    board.push(move)
    return move

In [None]:
def mcts_player_with_stats(evaluation: Callable, itermax: int = 500, depthmax: int = 30):
    def inner(board: chess.Board):
        global total_evaluations, total_time
        total_evaluations = 0
        total_time = 0

        start = time.time()
        print("MCTS Player:", mcts_player(board, evaluation, itermax, depthmax))
        print("Total Evaluations:", total_evaluations)
        print("Time:", time.time() - start)
        print("Total Time:", total_time)
        # time_predict_all()

    return inner

def human_player(board: chess.Board):
    while True:
        move = input("Input Your Move:")
        if move == "q":
            raise KeyboardInterrupt
        try:
            board.push_san(move)
            break
        except Exception as e:
            print(e)

In [None]:
def play_game(player1, player2):
    board = chess.Board()

    while not board.is_game_over():
        if board.turn == chess.WHITE:
            player1(board)
        else:
            player2(board)

        with open("game.svg", "w") as f:
            f.write(chess.svg.board(board, size=650))
        os.startfile("game.svg")

        time.sleep(0.1)

    print(chess.pgn.Game.from_board(board))

In [None]:
# play_game(mcts_player_with_stats(evaluate_position_static, itermax=2000), human_player)
play_game(mcts_player_with_stats(evaluate_position_with_NN, itermax=2000), human_player)