# About using Cython
It's difficult for Cython to declear every var due to python writing style.  
I guess Cython is not effective for date type list or dict, beause Cython don't know what's in contrainers.  
For cython, it's only 30% faster which is way far from my target (2x fast by converting python into cpp)

I convert all code blocks in `/env` except `env_utils` into Cython. The reason why I don't convert others is that Cython doesn't support Class inherit and there's no improvement while doing so.

In [1]:
%load_ext Cython

In [2]:
%%cython

import copy
import numpy as np
cimport numpy as np

from deepdraughts.env.env_utils import *

cdef extern from "math.h":
    double sqrt(double theta)

cdef class Piece():
    cdef public int player
    cdef public int pos
    cdef public int isking
    cdef public int captured
    cdef public set KING_POS_WHITE
    cdef public set KING_POS_BLACK
        
    def __init__(self, int player, int pos, int isking, set KING_POS_WHITE, set KING_POS_BLACK):
        self.player = player # white 1 black -1
        self.pos = pos
        self.isking = isking
        self.captured = False
        self.KING_POS_WHITE = KING_POS_WHITE
        self.KING_POS_BLACK = KING_POS_BLACK

    cpdef void move_to(self, int pos):
        '''
        Args: 
            pos: pos to move. Make sure it's available.
        '''        
        self.pos = pos
        if not self.isking:
            if self.player == WHITE and pos in self.KING_POS_WHITE:
                self.king_promote()
            if self.player == BLACK and pos in self.KING_POS_BLACK:
                self.king_promote()


    cpdef void king_promote(self):
        self.isking = True
    
cdef class Move():
    cdef public tuple pos
    cdef public str direction
    cdef public int force
    cdef public int take_piece
    cdef public int taken_pos
    cdef public int move_type
    
    
    def __init__(self, int pos_from, int pos_to, str direction, int move_type = MEN_MOVE, 
                int take_piece = False, int taken_pos = -1, int force = False) -> None:
        self.pos = (pos_from, pos_to)
        self.direction = direction
        self.force = force
        self.take_piece = take_piece
        self.taken_pos = taken_pos
        self.move_type = move_type
    
    def __str__(self):
        if self.force:
            return "->".join([str(x) for x in self.pos])
        if self.take_piece:
            return "x".join([str(x) for x in self.pos])
        return "-".join([str(x) for x in self.pos])

cdef class Board():
    cdef public dict pieces
    cdef public int ngrid
    cdef public int rule
    cdef public dict piece_moves
    cdef public list all_king_jumps
    cdef public list all_jump_moves
    cdef public list all_normal_moves
    cdef public int nsize
    
    
    def __init__(self, int ngrid = N_GRID_64, int rule = RUSSIAN_RULE) -> None:
        self.pieces = dict() # key - value: pos - piece
        self.ngrid = ngrid
        self.rule = rule
        self.piece_moves = dict()
        self.all_king_jumps = []
        self.all_jump_moves = []
        self.all_normal_moves = []

        cdef int n
        
        import math
        n = int(sqrt(ngrid))
        if n * n != ngrid:
            raise Exception("N_grid is not squre number.")
        self.nsize = n
        

    '''
    Load gloval vars
    '''    

    def get_king_promotion_pos(self):
        return globals()["KING_POS_WHITE_" + str(self.ngrid)], globals()["KING_POS_BLACK_" + str(self.ngrid)]


    def get_default_pos(self):
        return globals()["DEFAULT_POS_WHITE_" + str(self.ngrid)], globals()["DEFAULT_POS_BLACK_" + str(self.ngrid)]

    def get_edge_pos(self):
        return globals()["EDGE_POS_" + str(self.ngrid)]

    def get_valid_pos(self):
        return globals()["VALID_POS_" + str(self.ngrid)]

    '''
    Update board
    '''    
    cpdef void reset_available_moves(self):
        self.piece_moves.clear()
        self.all_king_jumps.clear()
        self.all_jump_moves.clear()
        self.all_normal_moves.clear()

    def set_board(self, whites_pos, blacks_pos, whites_isking = None, blacks_isking = None):
        whites_pos = norm_pos_list(whites_pos)
        blacks_pos = norm_pos_list(blacks_pos)
        if (self.check_pos_list(whites_pos) and self.check_pos_list(blacks_pos)) == False:
            raise Exception("Invaid pos list")
        if whites_isking is None:
            whites_isking = [False] * len(whites_pos)
        if blacks_isking is None:
            blacks_isking = [False] * len(blacks_pos)

        for pos, isking in zip(whites_pos, whites_isking):
            self.pieces[pos] = Piece(WHITE, pos, isking, *self.get_king_promotion_pos())
        for pos, isking in zip(blacks_pos, blacks_isking):
            self.pieces[pos] = Piece(BLACK, pos, isking, *self.get_king_promotion_pos())
        self.reset_available_moves()

    cpdef void init_empty_board(self):
        self.pieces.clear()
        self.reset_available_moves()

    cpdef void init_default_board(self):
        self.set_board(*self.get_default_pos())

    cpdef check_pos_list(self, list pos_list):
        is_ok = True
        for pos in pos_list:
            if pos not in globals()["VALID_POS_" + str(self.ngrid)]:
                x, y = pos2coord(pos, self.ngrid)
                print("Invalid pos:", pos, "row:", x, "col:", y)
                is_ok = False
        return is_ok

    
    cpdef do_move(self, Move move):
        self.move(move.pos[-2], move.pos[-1], move.taken_pos)

    cpdef move(self, int pos_from, int pos_to, int taken_pos = -1):
        '''
        Move the piece in pos_from to pos_to.
        '''
        if pos_from not in self.pieces or pos_to in self.pieces:
            raise Exception("Invalid move operation. From", pos_from, "to", pos_to)
        cdef Piece piece
        piece = self.pieces.pop(pos_from)
        piece.move_to(pos_to)
        self.pieces[pos_to] = piece
        if taken_pos != -1:
            self.pieces.pop(taken_pos)
        self.reset_available_moves()


    '''
    Querying states
    '''    
    def number_of_pieces(self):
        return len(self.pieces)

    def get_pieces(self):
        return [x for x in self.pieces.values()]

    def get_available_moves(self, int pos):
        if pos in self.piece_moves:
            return self.piece_moves[pos]
        
        cdef list king_jump_moves
        cdef list jump_moves
        cdef list normal_moves
        
        king_jump_moves = []
        jump_moves = []
        normal_moves = []
        
        cdef Piece piece
        
        piece = self.pieces[pos]
        cdef dict dict_pos
        
        # normal piece
        if piece.isking == False:
            dict_pos = KHOP_POS_64[pos]

            # jump moves
            for key in HOP_POS_ARGS:
                next_pos = dict_pos[key][0]
                jump_pos = dict_pos[key][1]

                # make sure move is valid
                if next_pos is None or jump_pos is None:
                    continue
                
                # 如果周围有子，其后方没有子，且异色
                if next_pos in self.pieces and jump_pos not in self.pieces and self.pieces[next_pos].player != piece.player:
                    jump_moves.append(Move(pos, jump_pos, key, MEN_MOVE, True, next_pos))
            
            # normal moves
            for key in HOP_POS_ARGS:
                next_pos = dict_pos[key][0]
                if next_pos is None:
                    continue
                if piece.player == WHITE and next_pos > pos: # 不能往回走
                    continue
                if piece.player == BLACK and next_pos < pos:
                    continue
                
                if next_pos not in self.pieces:
                    normal_moves.append(Move(pos, next_pos, key, MEN_MOVE))

        # king moves
        else:
            dict_pos = KHOP_POS_64[pos]
            
            for key in HOP_POS_ARGS:
                # for each direction, check:
                # 1. same color, if true, break
                # 2. diff color, if true, go step 3
                # 3. any space behind this diff piece, go futher util meet any piece.
                tmp_normal = []
                tmp_jump = []
                meet_diff_color = None # meet pos id
                for i in range(self.nsize):
                    next_pos = dict_pos[key][i]
                    if next_pos is None:
                        break
                    # meet piece and check color
                    if next_pos in self.pieces:
                        # same color
                        if self.pieces[next_pos].player == piece.player:
                            break
                        # diff color
                        else:
                            if meet_diff_color is not None:
                                break
                            else:
                                meet_diff_color = next_pos
                    # find a place where can jump to
                    else:
                        if meet_diff_color is not None:
                            tmp_jump.append(Move(pos, next_pos, key, KING_MOVE, True, meet_diff_color))
                        else:
                            tmp_normal.append(Move(pos, next_pos, key, KING_MOVE))
                
                # deal each direction
                king_jump_moves.extend(tmp_jump)
                normal_moves.extend(tmp_normal)


        self.piece_moves[pos] = (king_jump_moves, jump_moves, normal_moves)
        return king_jump_moves, jump_moves, normal_moves


    def get_all_available_moves_board(self, int current_player):
        '''
        Args: 
		
        Returns: 
            take_piece: Bool
            next_moves: List
        '''
        if len(self.all_king_jumps) + len(self.all_jump_moves) + len(self.all_normal_moves) >= 1:
            return self.all_king_jumps, self.all_jump_moves, self.all_normal_moves
        
        cdef list all_king_jumps
        cdef list all_jump_moves
        cdef list all_normal_moves
        cdef int pos
        
        all_king_jumps = []
        all_jump_moves = []
        all_normal_moves = []
        for pos in self.pieces:
            if self.pieces[pos].player != current_player:
                continue
            king_jump_moves, jump_moves, normal_moves = self.get_available_moves(pos)
            all_king_jumps.extend(king_jump_moves)
            all_jump_moves.extend(jump_moves)
            all_normal_moves.extend(normal_moves)

        self.all_king_jumps = all_king_jumps
        self.all_jump_moves = all_jump_moves
        self.all_normal_moves = all_normal_moves
        return all_king_jumps, all_jump_moves, all_normal_moves
        

cdef class Game():
    
    cdef public list move_path
    cdef public str player1_name
    cdef public str player2_name
    cdef public Board current_board
    cdef public int current_player
    cdef public list available_moves
    cdef public int n_king_move

    cdef public int is_chain_taking
    cdef public list chain_taking_moves
    
    def __init__(self, str player1_name = "player1", str player2_name = "player2", int ngrid = N_GRID_64, int rule = RUSSIAN_RULE) -> None:
        self.move_path = []
        self.player1_name = player1_name
        self.player2_name = player2_name
        self.current_board = Board(ngrid, rule)
        self.current_player = WHITE

        self.current_board.init_default_board()
        self.available_moves = []
        self.n_king_move = 0

        self.is_chain_taking = False
        self.chain_taking_moves = []

    cpdef void reset_available_moves(self):
        self.available_moves = []

    cpdef void reset_chain_taking_states(self):
        self.is_chain_taking = False
        self.chain_taking_moves = []

    cpdef int do_move(self, Move move):
        self.current_board.do_move(move)
        self.move_path.append(move)
        self.reset_available_moves()
        
        if move.move_type == MEN_MOVE:
            self.n_king_move = 0
        else:
            self.n_king_move += 1

        # check whether the player can take another piece after this move.
        king_jumps, jumps, _ = self.current_board.get_available_moves(move.pos[-1])

        # The folling code block is the same with get_all_available_moves()
        # for checking whether go over the same piece
        if king_jumps:
            list_remove = []
            for king_jump in king_jumps:
                if is_opposite_direcion(king_jump.direction, move.direction):
                    pos_a = move.pos[-2]
                    pos_b = move.pos[-1]
                    pos_c = king_jump.pos[-1]
                    if not ((pos_a > pos_b and pos_b > pos_c) or (pos_a < pos_b and pos_b < pos_c)):
                        list_remove.append(king_jump)
            for tmp_move in list_remove:
                king_jumps.remove(tmp_move)

        can_take_piece = (len(king_jumps) + len(jumps)) >= 1
        if move.take_piece and can_take_piece:
            # 连吃 chain-taking
            self.is_chain_taking = True
            self.chain_taking_moves = [move]
        else:
            self.change_player()
            self.reset_chain_taking_states()

        is_over, winner = self.is_over()
        if is_over:
            if winner == None:
                return GAME_DRAW
            return GAME_WHITE_WIN if winner == WHITE else GAME_BLACK_WIN
        return GAME_CONTINUE

    def is_over(self):
        cdef list available_moves
        
        available_moves = self.get_all_available_moves()
        if len(available_moves) == 0:
            return True, WHITE if self.current_player == BLACK else BLACK
        elif self.is_drawn():
            return True, None
        else:
            return False, None

    cpdef int is_drawn(self):
        '''
        Here I just implement the only one basic rules about drawn:
        - If both players play 15 kingmoves (any king) without captures or moving men, the game is drawn.
        '''        
        if self.n_king_move >= 30:
            return True
        return False

    cpdef list get_all_available_moves(self):
        cdef Move last_move, king_jump
        cdef list king_jumps, jump_moves, normal_moves, list_remove, king_chain_takings, king_normal_jumps, tmp_king_jumps, tmp_jump_moves
        cdef int pos_a, pos_b, pos_c, can_take_piece
        cdef Board board_tmp
        
        # TODO Brazilian rule 有多吃多
        if len(self.available_moves) >= 1:
            return self.available_moves

        if self.is_chain_taking:
            # last move's pos_to
            last_move = self.chain_taking_moves[-1]
            king_jumps, jump_moves, normal_moves = self.current_board.get_available_moves(last_move.pos[-1])
            
            # check whether go over the same piece
            # 1. whether the opposite direction. if false, it's ok
            # 2. if true, whether the pos is mono. if true, it's ok
            # 3. if false, remove this move
            list_remove = []
            for king_jump in king_jumps:
                if is_opposite_direcion(king_jump.direction, last_move.direction):
                    pos_a = last_move.pos[-2]
                    pos_b = last_move.pos[-1]
                    pos_c = king_jump.pos[-1]
                    if not ((pos_a > pos_b and pos_b > pos_c) or (pos_a < pos_b and pos_b < pos_c)):
                        list_remove.append(king_jump)
            for move in list_remove:
                king_jumps.remove(move)

        else:
            king_jumps, jump_moves, normal_moves = self.current_board.get_all_available_moves_board(self.current_player)

        if len(king_jumps)  == 0:
            self.available_moves = jump_moves if len(jump_moves) >= 1 else normal_moves
            return self.available_moves

        # king jump must be carefully dealt when chain-taking:
        # if king can take a piece, and after this move another piece can be taken,
        # only continueing chain-taking is available.
        king_chain_takings = []
        king_normal_jumps = []
        for king_jump in king_jumps:
            board_tmp = copy.deepcopy(self.current_board)
            board_tmp.do_move(king_jump)
            tmp_king_jumps, tmp_jump_moves, _ = board_tmp.get_available_moves(king_jump.pos[-1])
            can_take_piece = (len(tmp_king_jumps) + len(tmp_jump_moves)) >= 1
            if can_take_piece:
                king_chain_takings.append(king_jump)
            else:
                king_normal_jumps.append(king_jump)

        self.available_moves = king_chain_takings if len(king_chain_takings) >= 1 else king_normal_jumps
        self.available_moves.extend(jump_moves)
        return self.available_moves

    cpdef change_player(self):
        self.current_player = WHITE if self.current_player == BLACK else BLACK
        self.current_board.reset_available_moves()
        self.reset_available_moves()
        self.is_chain_taking = False

    def to_vector(self):
        return state2vec(self)

    def __str__(self):
        return ", ".join(str(x) for x in self.move_path).strip(", ")

# ctypedef double npy_double
# def rollout_policy_fn(Game game):
#     """a coarse, fast version of policy_fn used in the rollout phase."""
#     # rollout randomly
#     cdef list availables
#     cdef np.ndarray[npy_double, ndim=1] action_probs
#     availables = game.get_all_available_moves()
#     action_probs = np.random.rand(len(availables))
#     return zip(availables, action_probs)


# def policy_value_fn(Game game):
#     """a function that takes in a state and outputs a list of (action, probability)
#     tuples and a score for the state"""
#     # return uniform probabilities and 0 score for pure MCTS
#     cdef list availables
#     cdef np.ndarray[npy_double, ndim=1] action_probs
#     availables = game.get_all_available_moves()
#     action_probs = np.ones(len(availables))/len(availables)
#     return zip(availables, action_probs), 0


# cdef class TreeNode(object):
#     """A node in the MCTS tree. Each node keeps track of its own value Q,
#     prior probability P, and its visit-count-adjusted prior score u.
#     """
#     cdef public TreeNode _parent
#     cdef public dict _children
#     cdef public double _n_visits
#     cdef public double _Q
#     cdef public double _u
#     cdef public double _P
    
#     def __init__(self, TreeNode parent, double prior_p):
#         self._parent = parent
#         self._children = {}  # a map from action to TreeNode
#         self._n_visits = 0
#         self._Q = 0
#         self._u = 0
#         self._P = prior_p

        
#     def expand(self, action_priors):
#         """Expand tree by creating new children.
#         action_priors: a list of tuples of actions and their prior probability
#             according to the policy function.
#         """
#         cdef Move action
#         cdef double prob
#         for action, prob in action_priors:
#             if action not in self._children:
#                 self._children[action] = TreeNode(self, prob)

#     def select(self, int c_puct):
#         """Select action among children that gives maximum action value Q
#         plus bonus u(P).
#         Return: A tuple of (action, next_node)
#         """
#         return max(self._children.items(),
#                    key=lambda act_node: act_node[1].get_value(c_puct))

#     cpdef void update(self, double leaf_value):
#         """Update node values from leaf evaluation.
#         leaf_value: the value of subtree evaluation from the current player's
#             perspective.
#         """
#         # Count visit.
#         self._n_visits += 1
#         # Update Q, a running average of values for all visits.
#         self._Q += 1.0*(leaf_value - self._Q) / self._n_visits

#     cpdef double get_value(self, double c_puct):
#         """Calculate and return the value for this node.
#         It is a combination of leaf evaluations Q, and this node's prior
#         adjusted for its visit count, u.
#         c_puct: a number in (0, inf) controlling the relative impact of
#             value Q, and prior probability P, on this node's score.
#         """
#         self._u = (c_puct * self._P *
#                    sqrt(self._parent._n_visits) / (1 + self._n_visits))
#         return self._Q + self._u

#     cpdef int is_leaf(self):
#         """Check if leaf node (i.e. no nodes below this have been expanded).
#         """
#         return self._children == {}

#     cpdef int is_root(self):
#         return self._parent is None


# cdef class MCTS(object):
#     """A simple implementation of Monte Carlo Tree Search."""
    
#     cdef public TreeNode _root
#     cdef public object _policy
#     cdef public double _c_puct
#     cdef public int _n_playout
    
#     def __init__(self, object policy_value_fn, double c_puct=5, int n_playout=10000):
#         """
#         policy_value_fn: a function that takes in a board state and outputs
#             a list of (action, probability) tuples and also a score in [-1, 1]
#             (i.e. the expected value of the end game score from the current
#             player's perspective) for the current player.
#         c_puct: a number in (0, inf) that controls how quickly exploration
#             converges to the maximum-value policy. A higher value means
#             relying on the prior more.
#         """
#         self._root = TreeNode(None, 1.0)
#         self._policy = policy_value_fn
#         self._c_puct = c_puct
#         self._n_playout = n_playout

#     cpdef void _playout(self, Game state):
#         """Run a single playout from the root to the leaf, getting a value at
#         the leaf and propagating it back through its parents.
#         State is modified in-place, so a copy must be provided.
#         """
#         cdef TreeNode node
#         cdef list list_node, list_player
#         cdef Move action
#         cdef int is_over
#         cdef int current_player
#         cdef int leaf_value
#         cdef int player
        
#         node = self._root
#         list_node = [node]
#         list_player = [state.current_player]
#         while True:
#             if node.is_leaf():
#                 break

#             # Greedily select next move.
#             action, node = node.select(self._c_puct)
#             state.do_move(action)
#             list_node.append(node)
#             list_player.append(state.current_player)

#         action_probs, _ = self._policy(state)
#         # Check for end of game
#         is_over, _ = state.is_over()
#         if not is_over:
#             node.expand(action_probs)
#         # Evaluate the leaf node by random rollout
#         current_player = state.current_player
#         leaf_value = self._evaluate_rollout(state)
        
#         # Update value and visit count of nodes in this traversal.
#         # Applied recursively for all ancestors
#         # Note: 
#         # 1. node.update() is an indepentent operation. So the order is unneccesary.
#         # 2. Why update -leaf_value? Easily to check by playing a game. Or see: https://github.com/junxiaosong/AlphaZero_Gomoku/issues/25

#         for node, player in zip(list_node, list_player):
#             node.update(-leaf_value if player == current_player else leaf_value)

    
#     cpdef int _evaluate_rollout(self, Game state, int limit=1000):
#         """Use the rollout policy to play until the end of the game,
#         returning +1 if the current player wins, -1 if the opponent wins,
#         and 0 if it is a tie.

#         Returns:
#             value: Value of current state.
#         """
#         cdef int start_player
#         cdef int value
#         cdef int i, is_over
#         cdef Move max_action

#         start_player = state.current_player
#         value = 0
#         for i in range(limit):
#             is_over, winner = state.is_over()
#             if is_over: 
#                 if winner == None:
#                     value = 0
#                 else:
#                     value = 1 if winner == start_player else -1
#                 break
#             action_probs = rollout_policy_fn(state)
#             max_action = max(action_probs, key=itemgetter(1))[0]
#             state.do_move(max_action)

#         return value

#      # It's impossible for python to parallel here. Theading is not real multi-theading for python!
#     def get_move(self, Game state):
#         """Runs all playouts sequentially and returns the most visited action.
#         state: the current game state

#         Return: the selected action
#         """
#         cdef int n
#         cdef Game state_copy
        
#         for n in range(self._n_playout):
#             state_copy = copy.deepcopy(state)
#             self._playout(state_copy)
#         return max(self._root._children.items(),
#                    key=lambda act_node: act_node[1]._n_visits)[0]

#     cpdef void update_with_move(self, object last_move):
#         """Step forward in the tree, keeping everything we already know
#         about the subtree.
#         """
#         if last_move in self._root._children:
#             self._root = self._root._children[last_move]
#             self._root._parent = None
#         else:
#             self._root = TreeNode(None, 1.0)

#     def __str__(self):
#         return "MCTS"


# cdef class MCTSPlayer(object):
#     """AI player based on MCTS"""
    
#     cdef public MCTS mcts
    
#     def __init__(self, double c_puct=5, int n_playout=2000):
#         self.mcts = MCTS(policy_value_fn, c_puct, n_playout)

#     cpdef void reset(self):
#         self.mcts.update_with_move(-1)

#     def get_action(self, Game game, object temp=None):
#         '''
#         Args: 
#             game: Current game states.

#         Returns: 
#             move: Move selected by AI player.
#             prob: Action prob to be selected.
		
#         '''        
#         cdef list sensible_moves
#         cdef Move move
        
#         sensible_moves = game.get_all_available_moves()
#         if len(sensible_moves) >= 2:
#             move = self.mcts.get_move(game)
#             self.mcts.update_with_move(-1)
#             return move, 1
#         elif len(sensible_moves) == 1:
#             return sensible_moves[0], 1
#         else:
#             print("WARNING: the game is full")
#             return [], -1

#     def __str__(self):
#         return "MCTS {}".format(self.player)




In [3]:
# -*- coding: utf-8 -*-
"""
A pure implementation of the Monte Carlo Tree Search (MCTS)

@author: Junxiao Song
"""

'''
    # The code is mainly contributed by Junxiao Song.
    # The github link is: https://github.com/junxiaosong/AlphaZero_Gomoku/blob/master/mcts_pure.py
    #
    # Code is modified by EndlessLethe for further use.
'''

'''
    For readers who want to use pure MCTS in their own project, 

    To reuse this code about pure mcts, following apis are required:
1. Class Game / Board to store current game state
2. methed to get availabel moves "get_all_available_moves()"
3. methed to make a move "do_move()"
4. methed to check game state "is_over()"

    Code needs modifying:
1. policy_fn
2. MCTS Class code blocks with listed metheds above

'''

import numpy as np
import copy
from operator import itemgetter

def rollout_policy_fn(game):
    """a coarse, fast version of policy_fn used in the rollout phase."""
    # rollout randomly
    availables = game.get_all_available_moves()
    action_probs = np.random.rand(len(availables))
    return zip(availables, action_probs)


def policy_value_fn(game):
    """a function that takes in a state and outputs a list of (action, probability)
    tuples and a score for the state"""
    # return uniform probabilities and 0 score for pure MCTS
    availables = game.get_all_available_moves()
    action_probs = np.ones(len(availables))/len(availables)
    return zip(availables, action_probs), 0


class TreeNode(object):
    """A node in the MCTS tree. Each node keeps track of its own value Q,
    prior probability P, and its visit-count-adjusted prior score u.
    """

    def __init__(self, parent, prior_p):
        self._parent = parent
        self._children = {}  # a map from action to TreeNode
        self._n_visits = 0
        self._Q = 0
        self._u = 0
        self._P = prior_p

    def expand(self, action_priors):
        """Expand tree by creating new children.
        action_priors: a list of tuples of actions and their prior probability
            according to the policy function.
        """
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] = TreeNode(self, prob)

    def select(self, c_puct):
        """Select action among children that gives maximum action value Q
        plus bonus u(P).
        Return: A tuple of (action, next_node)
        """
        return max(self._children.items(),
                   key=lambda act_node: act_node[1].get_value(c_puct))

    def update(self, leaf_value):
        """Update node values from leaf evaluation.
        leaf_value: the value of subtree evaluation from the current player's
            perspective.
        """
        # Count visit.
        self._n_visits += 1
        # Update Q, a running average of values for all visits.
        self._Q += 1.0*(leaf_value - self._Q) / self._n_visits

    def get_value(self, c_puct):
        """Calculate and return the value for this node.
        It is a combination of leaf evaluations Q, and this node's prior
        adjusted for its visit count, u.
        c_puct: a number in (0, inf) controlling the relative impact of
            value Q, and prior probability P, on this node's score.
        """
        self._u = (c_puct * self._P *
                   np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
        return self._Q + self._u

    def is_leaf(self):
        """Check if leaf node (i.e. no nodes below this have been expanded).
        """
        return self._children == {}

    def is_root(self):
        return self._parent is None


class MCTS(object):
    """A simple implementation of Monte Carlo Tree Search."""

    def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
        """
        policy_value_fn: a function that takes in a board state and outputs
            a list of (action, probability) tuples and also a score in [-1, 1]
            (i.e. the expected value of the end game score from the current
            player's perspective) for the current player.
        c_puct: a number in (0, inf) that controls how quickly exploration
            converges to the maximum-value policy. A higher value means
            relying on the prior more.
        """
        self._root = TreeNode(None, 1.0)
        self._policy = policy_value_fn
        self._c_puct = c_puct
        self._n_playout = n_playout

    def _playout(self, state):
        """Run a single playout from the root to the leaf, getting a value at
        the leaf and propagating it back through its parents.
        State is modified in-place, so a copy must be provided.
        """
        node = self._root
        list_node = [node]
        list_player = [state.current_player]
        while True:
            if node.is_leaf():
                break

            # Greedily select next move.
            action, node = node.select(self._c_puct)
            state.do_move(action)
            list_node.append(node)
            list_player.append(state.current_player)

        action_probs, _ = self._policy(state)
        # Check for end of game
        is_over, _ = state.is_over()
        if not is_over:
            node.expand(action_probs)
        # Evaluate the leaf node by random rollout
        current_player = state.current_player
        leaf_value = self._evaluate_rollout(state)
        
        # Update value and visit count of nodes in this traversal.
        # Applied recursively for all ancestors
        # Note: 
        # 1. node.update() is an indepentent operation. So the order is unneccesary.
        # 2. Why update -leaf_value? Easily to check by playing a game. Or see: https://github.com/junxiaosong/AlphaZero_Gomoku/issues/25

        for node, player in zip(list_node, list_player):
            node.update(-leaf_value if player == current_player else leaf_value)

    def _evaluate_rollout(self, state, limit=1000):
        """Use the rollout policy to play until the end of the game,
        returning +1 if the current player wins, -1 if the opponent wins,
        and 0 if it is a tie.

        Returns:
            value: Value of current state.
        """
        start_player = state.current_player
        value = None
        for i in range(limit):
            is_over, winner = state.is_over()
            if is_over: 
                if winner == None:
                    value = 0
                else:
                    value = 1 if winner == start_player else -1
                break
            action_probs = rollout_policy_fn(state)
            max_action = max(action_probs, key=itemgetter(1))[0]
            state.do_move(max_action)
        if value is None:
            # If no break from the loop, issue a warning.
            print("WARNING: rollout reached move limit")
            value = 0
        return value

     # It's impossible for python to parallel here. Theading is not real multi-theading for python!
    def get_move(self, state):
        """Runs all playouts sequentially and returns the most visited action.
        state: the current game state

        Return: the selected action
        """
        for n in range(self._n_playout):
            state_copy = copy.deepcopy(state)
            self._playout(state_copy)
        return max(self._root._children.items(),
                   key=lambda act_node: act_node[1]._n_visits)[0]

    def update_with_move(self, last_move):
        """Step forward in the tree, keeping everything we already know
        about the subtree.
        """
        if last_move in self._root._children:
            self._root = self._root._children[last_move]
            self._root._parent = None
        else:
            self._root = TreeNode(None, 1.0)

    def __str__(self):
        return "MCTS"


class MCTSPlayer(object):
    """AI player based on MCTS"""
    def __init__(self, c_puct=5, n_playout=2000):
        self.mcts = MCTS(policy_value_fn, c_puct, n_playout)

    def reset(self):
        self.mcts.update_with_move(-1)

    def get_action(self, game, temp=None):
        '''
        Args: 
            game: Current game states.

        Returns: 
            move: Move selected by AI player.
            prob: Action prob to be selected.
		
        '''        
        sensible_moves = game.get_all_available_moves()
        if len(sensible_moves) >= 2:
            move = self.mcts.get_move(game)
            self.mcts.update_with_move(-1)
            return move, 1
        elif len(sensible_moves) == 1:
            return sensible_moves[0], 1
        else:
            print("WARNING: the game is full")

    def __str__(self):
        return "MCTS {}".format(self.player)


In [4]:
'''
Author: Zeng Siwei
Date: 2021-09-11 15:56:20
LastEditors: Zeng Siwei
LastEditTime: 2021-09-18 17:07:52
Description: 
'''

import pygame as pg
from deepdraughts.env.env_utils import *
import time

class GUI():
    COLOR_WHITE = pg.Color('white')
    COLOR_BLACK = pg.Color('black')
    COLOR_YELLOW = pg.Color('yellow')
    COLOR_RED = pg.Color('red')

    SQUARE_COLORS = [pg.Color('gray'), pg.Color('brown')]
    BACKGROUND_COLOR = pg.Color('#8EA2F3')
    
    def __init__(self, game = None, screen_size = 800, text_size = 0):
        self.screen_size = screen_size
        self.text_size = text_size
        self.square_size = int(screen_size / 8)
        self.piece_radius = int(self.square_size / 2)
        self.king_radius1 = int(self.square_size / 3)
        self.king_radius2 = int(self.square_size / 4)
        self.move_radius = int(self.piece_radius / 6)

        if game is None:
            game = Game()
        self.game = game

        # pygame states
        self.screen = None
        self.surface = None

        # interaction states
        self.game = game

        self.selected_pos = None
        self.take_piece = False
        self.next_moves = []

    def draw_background(self):
        color_idx = 0
        for i in range(0, self.screen_size, self.square_size):
            for j in range(0, self.screen_size, self.square_size):
                square = (i, j, self.square_size, self.square_size)
                pg.draw.rect(self.surface, self.SQUARE_COLORS[color_idx], square)
                color_idx ^= 1 
            color_idx ^= 1

        pg.draw.rect(self.surface, self.BACKGROUND_COLOR, (self.screen_size, 0, self.text_size, self.screen_size))

    def draw_pieces(self, pos_list, player_list, isking_list, nsize):
        for pos, player, isking in zip(pos_list, player_list, isking_list):
            row, col = pos2coord(pos, nsize, origin = "left_upper")
            x = (col-1) * self.square_size + self.piece_radius
            y = (row-1) * self.square_size + self.piece_radius
            pg.draw.circle(self.surface, self.COLOR_WHITE if player == WHITE else self.COLOR_BLACK, 
                            (x, y), self.piece_radius)
            if isking:
                pg.draw.circle(self.surface, self.COLOR_BLACK if player == WHITE else self.COLOR_WHITE, 
                                (x, y), self.king_radius1)
                pg.draw.circle(self.surface, self.COLOR_WHITE if player == WHITE else self.COLOR_BLACK, 
                                (x, y), self.king_radius2)


    def draw_select(self, pos, nsize):
        row, col = pos2coord(pos, nsize, origin = "left_upper")
        x = (col-1) * self.square_size + self.piece_radius
        y = (row-1) * self.square_size + self.piece_radius
        pg.draw.circle(self.surface, self.COLOR_RED, (x, y), self.move_radius)

    def draw_move(self, moves):
        pass

    def reset_drawing(self):
        self.selected_pos = None
        self.take_piece = False
        self.next_moves = []

    def listen_human_action(self, event):
        if event.type == pg.QUIT:
            return GUI_EXIT, ()
        elif event.type == pg.MOUSEBUTTONDOWN:
            if event.button == 1: # left click of mouse
                # compute click info
                mouse_y, mouse_x = event.pos
                return GUI_LEFTCLICK, (mouse_y, mouse_x)
            elif event.button == 3: # right click of mouse
                return GUI_RIGHTCLICK, ()
        return GUI_WAIT, ()

    def update_by_human_action(self, action, info, pos_list, available_moves):
        '''
        Args: 
            action, info: Human interactions.

        Returns: 
            Game_status: Whether game is over.
        '''        
        if action == GUI_RIGHTCLICK:
            self.reset_drawing()

        elif action == GUI_LEFTCLICK:
            mouse_y, mouse_x = info
            row = int(mouse_x / self.square_size) + 1
            col = int(mouse_y / self.square_size) + 1
            
            pos = coord2pos(row, col, self.game.current_board.nsize, "left_upper")
            print(mouse_y, mouse_x, row, col, pos)

            # if move piece
            for move in self.next_moves:
                if pos == move.pos[-1]:
                    game_status = self.game.do_move(move)
                    print(str(self.game))
                    
                    # reset last action
                    self.reset_drawing()

                    return game_status

            # reset last action
            self.reset_drawing()

            # select piece
            if pos in pos_list:
                # can only interact with player's pieces
                if self.game.current_board.pieces[pos].player == self.game.current_player:
                    # show available moves.
                    for move in available_moves:
                        if pos == move.pos[-2]:
                            self.next_moves.append(move)
                    if len(self.next_moves) >= 1:
                        self.selected_pos = pos
        return GAME_CONTINUE
    
    def read_game_status(self, status):
        if status == GAME_CONTINUE:
            return GUI_WAIT
        else:
            if status == GAME_WHITE_WIN:
                print("Game Over.", "WHITE", "wins.")
            elif status == GAME_BLACK_WIN:
                print("Game Over.", "BLACK", "wins.")
            elif status == GAME_DRAW:
                print("Game Draw.")
            else:
                raise Exception("Game Status Type Error.")
            return GUI_EXIT

    def is_human_playing(self, player_white, player_black):
        return (self.game.current_player == WHITE and player_white == HUMAN_PLAYER) or \
                (self.game.current_player == BLACK and player_black == HUMAN_PLAYER)

    def run(self, player_white = HUMAN_PLAYER, player_black = HUMAN_PLAYER, 
            policy_white = None, policy_black = None):
        pg.init()
        self.screen = pg.display.set_mode((self.screen_size + self.text_size, self.screen_size))
        self.surface = pg.Surface((self.screen_size + self.text_size, self.screen_size))
        pg.display.set_caption('DeepDraughts')
        clock = pg.time.Clock()

        pg.font.init()
        font = pg.font.Font(None, 36)

        
        
        running = True
        while running:
            # quering current states for each frame
            pieces = self.game.current_board.get_pieces()
            pos_list = [x.pos for x in pieces]
            player_list = [x.player for x in pieces]
            isking_list = [x.isking for x in pieces]
            available_moves = self.game.get_all_available_moves()

            if self.is_human_playing(player_white, player_black):
                for event in pg.event.get():
                    human_action, info = self.listen_human_action(event)
                    if human_action == GUI_EXIT:
                        running = False
                    elif human_action == GUI_WAIT:
                        continue
                    game_status = self.update_by_human_action(human_action, info, pos_list, available_moves)
                    gui_status = self.read_game_status(game_status)
                    if gui_status == GUI_EXIT:
                        running = False
            else:
                start_time = time.time()
                policy = policy_white if self.game.current_player == WHITE else policy_black
                move, _ = policy.get_action(self.game)
                game_status = self.game.do_move(move)
                gui_status = self.read_game_status(game_status)
                end_time = time.time()
                print("Step Time for AI:", end_time-start_time, "s")
                if gui_status == GUI_EXIT:
                    running = False

            self.draw_background()
            self.draw_pieces(pos_list, player_list, isking_list, self.game.current_board.nsize)
            
            if self.selected_pos is not None and self.next_moves:
                for move in self.next_moves:
                    pos = move.pos[-1]
                    self.draw_select(pos, self.game.current_board.nsize)
            else:
                for move in available_moves:
                    pos = move.pos[-1]
                    self.draw_select(pos, self.game.current_board.nsize)
            
            self.screen.blit(self.surface, (0, 0))
            pg.display.flip()
            clock.tick(500)

        pg.quit()

    def replay(self, replay_game):
        pg.init()
        self.screen = pg.display.set_mode((self.screen_size + self.text_size, self.screen_size))
        self.surface = pg.Surface((self.screen_size + self.text_size, self.screen_size))
        pg.display.set_caption('DeepDraughts')
        clock = pg.time.Clock()

        pg.font.init()
        font = pg.font.Font(None, 36)
        
        replay_ptr = 0

        running = True
        while running:
            # quering current states for each frame
            pieces = self.game.current_board.get_pieces()
            pos_list = [x.pos for x in pieces]
            player_list = [x.player for x in pieces]
            isking_list = [x.isking for x in pieces]
            available_moves = self.game.get_all_available_moves()

            for event in pg.event.get():
                action, info = self.listen_human_action(event)
                if action == GUI_EXIT:
                    running = False
                elif action == GUI_WAIT:
                    continue
                elif action == GUI_RIGHTCLICK:
                    # TODO withdraw last play
                    self.game = Game()
                    replay_ptr = 0
                elif action == GUI_LEFTCLICK:
                    if replay_ptr >= len(replay_game.move_path):
                        running = False
                        continue
                    game_status = self.game.do_move(replay_game.move_path[replay_ptr])
                    replay_ptr += 1
                    gui_status = self.read_game_status(game_status)
                    if gui_status == GUI_EXIT:
                        running = False

            self.draw_background()
            self.draw_pieces(pos_list, player_list, isking_list, self.game.current_board.nsize)

            for move in available_moves:
                pos = move.pos[-1]
                self.draw_select(pos, self.game.current_board.nsize)
            
            self.screen.blit(self.surface, (0, 0))
            pg.display.flip()
            clock.tick(500)

        pg.quit()

pygame 2.0.1 (SDL 2.0.14, Python 3.6.10)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [5]:



gui = GUI()
# gui.run()

mcts_player = MCTSPlayer(c_puct=5, n_playout=1000)

gui.run(player_black=AI_PLAYER, policy_black=mcts_player)



289 534 6 3 43
347 462 5 4 36
43-36
Step Time for AI: 2.3261337280273438 s
200 623 7 3 51
150 655 7 2 50
251 532 6 3 43
43-36, 22-31, 50-43
Step Time for AI: 2.523010015487671 s
632 548 6 7 47
742 458 5 8 40
43-36, 22-31, 50-43, 15-22, 47-40
Step Time for AI: 2.5234289169311523 s
64 556 6 1 41
167 466 5 2 34
43-36, 22-31, 50-43, 15-22, 47-40, 20-29, 41-34
Step Time for AI: 2.4870598316192627 s
