In [1]:
from Chess import ChessEngine
import math
import numpy as np

In [2]:
class Node:
    def __init__(self, gamestate, args, parent=None, action_taken=None):
        
        self.gamestate = gamestate
        self.args = args
        # self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.expandable_moves = self.gamestate.getValidMoves().copy()

        self.visit_counts = 0
        self.value_sum = 0
    
    def is_fully_expanded(self):
        return len(self.expandable_moves) == 0 and len(self.children) > 0
    
    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
                best_ucb = ucb
        return best_child
    
    def get_ucb(self, child):
        q_value = 1 - ((child.value_sum / child.visit_counts) + 1) / 2
    
        return  q_value + self.args['C'] * math.sqrt(math.log(self.visit_counts) / child.visit_counts)
    
    def expand(self):
        idx = np.random.randint(len(self.expandable_moves))
        action = self.expandable_moves.pop(idx)

        child_gamestate = self.gamestate.copy()
        child_gamestate.makeMove(action)
        child = Node(child_gamestate, self.args, parent=self, action_taken=action)
        self.children.append(child)
        return child
    
    def simulate(self):
        value, is_terminal = self.gamestate.getValueAndTerminated()
        value = self.gamestate.getOpponentValue(value)
        if is_terminal:
            return value
        
        rollout_gamestate = self.gamestate.copy()
        rollout_player = rollout_gamestate.whiteToMove
        while True:
            valid_moves = rollout_gamestate.getValidMoves()
            action = valid_moves[np.random.randint(len(valid_moves))]
            rollout_gamestate.makeMove(action)
            value, is_terminal = rollout_gamestate.getValueAndTerminated()
            if is_terminal:
                if rollout_gamestate.whiteToMove == rollout_player:
                    return value
                return rollout_gamestate.getOpponentValue(value)
    
    def backpropagate(self, value):
        self.visit_counts += 1
        self.value_sum += value

        value = self.gamestate.getOpponentValue(value)
        if self.parent is not None:
            self.parent.backpropagate(value)
            

In [4]:
class MCTS:
    def __init__(self, gamestate, args):
        self.gamestate = gamestate
        self.args = args

    def search(self, gamestate):
        # define root
        root = Node(gamestate, self.args)
        # selection
        for search in range(self.args['num_searches']):
            node = root
            while node.is_fully_expanded():
                node = node.select()
            
            value, is_terminal = self.game.getValueAndTerminated()
            value = self.game.getOpponentValue(value)
            
            if not is_terminal:

                # expansion
                node = node.expand()
                
                # simulation
                value = node.simulate()

            # backpropagation
            node.backpropagate(value)
        
        possible_actions = []
        action_probs = []
        # return visit counts
        for child in root.children:
            action_probs.append(child.visit_counts)
            possible_actions.append(child.action_taken)
        
        action_probs = action_probs / np.sum(action_probs)
        return possible_actions, action_probs