In [4]:
import numpy as np
import time

In [2]:
class Node:
    def __init__(self, board, parent=None, action=None):
        """
            self.board: the matrix of board (state)
            self.action: (i, j) the location of stone put in this step
            self.Q: the number of victory
            self.N: the number of simulation
        """
        self.board = board
        self.action = action
        self.Q = 0
        self.N = 0
        self.parent = parent
        self.unvisited_nodes = self.get__unvisited_nodes()
        self.children = []
    
    def get_unvisited_nodes(self):
        """
            return the places (i, j) that can put a new stone
        """
        n = len(self.board)
        unvisited_nodes = []
        for i in range(n):
            for j in range(n):
                if self.board[i][j] == 0:
                    unvisited_nodes.append((i, j))
        return unvisited_nodes
    
    def UST(self, c=1):
        return self.Q / self.N + c * np.sqrt(np.log(self.parent.N) / self.N)

In [5]:
class MCTS:
    def __init__(self):
        pass
    
    def get_best_next_step(self, board, player):
        root = Node(board)
        best_node = self.selection(root, player)
        print(best_node.action)
        print(best_node.board)
        return best_node
    
    def get_children_score(self, node):
        print(node.board)
        print(f"Q: {node.Q}, N: {node.N}")
        for child in node.children:
            print(f"action: {child.action}")
            print(f"Q: {child.Q}, N: {child.N}")
            
    def selection(self, node, player):
        """
            play: 1 or -1 denotes black stone or white stone
        """
        winner, if_finish = check(node.board)
        # 1. judge whether finish
        if if_finish:
            # back propagation
            if winner == player:
                result_game = 1
            else:
                result_game = 0
            self.backpropagation(node, result_game)
            return None
        
        # 2. judge whether existing any unvisited nodes
        while node.unvisited_nodes:
            # randomly pick a place to put a new stone
            np.random.shuffle(self.unvisited_nodes)
            # action = (next_i, next_j)
            action = self.unvisited_nodes.pop()
            
            # enter expansion
            self.expansion(node, action)
        start_time = time.time()
        while time.time() - start_time < 5:
            next_node = max(node.children, key=lambda x: x.UST())
            # recursion
            self.selection(next_node, player)
            
        best_next_node = max(node.children, key=lambda x: x.Q/x.N)
        return best_next_node
    
    def expasion(self, node, action, player):
        # create a new node
        next_i, next_j = action
        board = node.board[:]
        
        # update the board with action
        board[next_i][next_j] = player
        
        # link the node and next_node
        next_node = Node(board, node, action)
        node.children.append(next_node)
        
        # enter simulation for giving new node a initial score -> Q and N
        self.simulation(next_node, player)
        
    def simulation(self, node, player):
        board = node.board[:]
        n = len(board)
        cur_player = player
        
        winner, if_finish = check(board)
        while if_finish == False:
            # player has put its stone in the board
            # so the opponent takes the first action
            cur_player *= -1
            i, j = np.random.randint(0, n, 2)
            while board[i][j] != 0:
                i, j = np.random.randint(0, n, 2)
            board[i][j] = player
            winner, if_finish = check(board)

        if winner == 0:
            result_game = 1 if player == -1 else 0
        else:
            result_game = 1 if winner == player else 0
        # after simulation
        self.back_propagation(node, result_game)
    
    def back_propagation(self, node, result_game):
        """
            node: which node need to be updated
            result_game: the result of game 1 or 0
        """
        while node:
            node.Q += result_game
            node.N += 1
            node = node.parent