# AtaxxZero
## This algorithm tried to reimplement AlphaGo Zero for Ataxx
## however, the computation to train an AI from scratch can be too heavy, given my skills of code optimization and hardware limitation
## therefore, minor adjustments are made to the algorithms to make it plausible for this algorithm to give a rather satisfactory result in an acceptable period.

### modifications:
1. One major difference between AlphaGo Zero and Ataxx Zero is that Ataxx Zero relies __one engineered value function__. From the beginning of the training, the q value of each node is a combination of q from the hybrid network and a greedy function (output is monotone increasing with regard to difference of piece no. of each player).
2. Another major modification is Ataxx Zero apply MCTS to a __very shallow depth, currently being 3__. This change significantly reduce the searching time, thus accelerate training greatly.
3. The combination of 3 layer MCTS and an engineered value function guarantees a good performance of the algorithm in even before training, i.e. hybrid network output random probability and value. The behavior of Ataxx Zero before training should __resemble an impaired MinMax Searching with a depth of 3__. From a practical perspective, it wins 90% of the game with a greedy player(which attempts to maximize no.my_piece - no.opponent's_piece). With reinforcement learning, the algorithm is expected to behave better.
4. When actually applied in game, I plan to reduce the searching depth to 2 to further improve the speed, but expect the algorithm to work better than itself before training.

In [1]:
%load_ext Cython
from Cython.Compiler.Options import directive_scopes, directive_types
directive_types['linetrace'] = True
directive_types['binding'] = True

In [2]:
%load_ext line_profiler
import line_profiler

In [3]:
import os 
import importlib
import sys
import tensorflow as tf
import itertools
from multiprocessing import Queue, Pool, Process
import numpy as np
import numba as nb
from math import sqrt, log, exp
from numpy import unravel_index
from random import choice, random, sample
from operator import itemgetter
np.random.seed(1337)  # for reproducibility
from keras.models import Sequential, Model, load_model
from keras.layers import Input, BatchNormalization, Reshape, Lambda
from keras.layers import Dense, Dropout, Activation, Flatten, LocallyConnected2D
from keras.layers import Conv2D, MaxPooling2D, AlphaDropout, ConvLSTM2D, AvgPool2D, Conv2DTranspose, UpSampling2D
from keras.layers import add, concatenate, multiply, Multiply
from keras.initializers import VarianceScaling, RandomUniform
from keras.optimizers import Adam, SGD, rmsprop
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils, multi_gpu_model
from keras.utils.vis_utils import plot_model
from keras.engine.topology import Container
from keras.optimizers import SGD, Adadelta, Adagrad
from keras.regularizers import l1, l2
from keras.callbacks import EarlyStopping
import keras.backend as K
K.set_image_dim_ordering('th')
from keras.callbacks import Callback, ReduceLROnPlateau, LearningRateScheduler, TensorBoard, ModelCheckpoint
import matplotlib.pyplot as plt
import time
#%matplotlib notebook

Using TensorFlow backend.


In [4]:
def set_keras_backend(backend):
    os.environ['KERAS_BACKEND'] = backend
    importlib.reload(K)
    K.set_image_dim_ordering('th')
    assert K.backend() == backend

def set_omp_threads(n):
    n = str(n)
    os.environ['OMP_NUM_THREADS'] = n

In [5]:
set_keras_backend('tensorflow')

Using TensorFlow backend.


In [6]:
def get_rot_policy_dict():
    augment_dict = {}
    rot_m_0 = np.array([[0, -1], [1, 0]]) 
    center = np.array([3, 3])
    for is_flip in [False, True]:
        for rot_time in range(4):
            if not (is_flip == False and rot_time == 0):
                tmp_dict = {}
                # get rot matrix
                rot_m = np.eye(2)
                for i in range(rot_time):
                    rot_m = rot_m_0.dot(rot_m)

                for r in range(7):
                    for c in range(7):
                        for dr in range(-2, 3):
                            for dc in range(-2, 3):
                                start = np.array([r, c])
                                end = np.array([r+dr, c+dc])
                                if (dr != 0 or dc != 0) and \
                                    (start[0] < 7 and start[0] >= 0) and (start[1] < 7 and start[1] >= 0):
                                    new_start = (start - center)
                                    new_end = (end - center)
                                    if is_flip:
                                        new_start[1] = -new_start[1]
                                        new_end[1] = -new_end[1]
                                    new_start = rot_m.dot(new_start) + center
                                    new_end = rot_m.dot(new_end) + center
                                    tmp_dict[(tuple(start), tuple(end))] = (tuple(new_start), tuple(new_end))
                augment_dict[(is_flip, rot_time)] = tmp_dict
    return augment_dict
augment_dict = get_rot_policy_dict()

In [7]:
def augment_policy(data, is_flip, rot_time):
    if is_flip == False and rot_time == 0:
        return data
    
    global policy_dict, policy_list, augment_dict
    out = np.zeros_like(data)
    for i in range(792):
        tmp_data = data[i]
        if tmp_data > 0:
            move = policy_list[i]
            move_after_rot = augment_dict[(is_flip, rot_time)][move]
            i_after_rot = policy_dict[move_after_rot]
            out[i_after_rot] = tmp_data
    return out
            
def augment_data(train_data):
    out = []
    # input is a list for follows
    feature_map = train_data[0]
    action_mask = train_data[1]
    frequency_map = train_data[2]
    value = train_data[3]
    
    # do 7 times augmentation
    for is_flip in [False, True]:
        for rot_time in range(4):
            # do feature map augmentation
            if is_flip:
                tmp_feature_map = np.fliplr(feature_map)
            tmp_feature_map = np.rot90(feature_map, k=rot_time, axes=(1, 2))
            # augment two policy related data
            tmp_action_mask = augment_policy(action_mask, is_flip, rot_time)
            tmp_frequency_map = augment_policy(frequency_map, is_flip, rot_time)
            # append them to out
            out.append([tmp_feature_map, tmp_action_mask, tmp_frequency_map, value])
    return out

## might be a bit of ugly, but it's really efficient to isolate functions that can be accelerated by cython and jit
1. all the dictionary lookup and creation are all integrated in the following cell, which at least reduce the running time for 40%
2. memoryview in cython is a strong weapon in terms of algorithm speedup

In [8]:
%%cython 
# -a -f --compile-args=-DCYTHON_TRACE=1
import numpy as np
cimport numpy as np
import matplotlib.pyplot as plt
from random import choice

def get_policy_dict_list():
    index=0
    policy_dict = {}
    policy_list = []
    for r in range(7):
        for c in range(7):
            for dr in range(-2, 3):
                for dc in range(-2, 3):
                    new_r = r + dr
                    new_c = c + dc
                    if (dr != 0 or dc != 0) and (new_r < 7 and new_r >= 0) and (new_c < 7 and new_c >= 0):
                        policy_dict[((r, c), (new_r, new_c))] = index
                        policy_list.append(((r, c), (new_r, new_c)))
                        index += 1
    return policy_dict, policy_list

policy_dict, policy_list = get_policy_dict_list()

# this is for expand last two lines
def assign_children(children, np.float32_t[:] p_array):
    for move in children:
        children[move] = np.float32(p_array[policy_dict[move]])

cdef class Ataxx:
    cdef public np.int8_t[:, :] data

    def __init__(self, board=None):
        if board is None:                  # if there is no initialization given
            self.data = np.zeros((7, 7), dtype=np.int8)   # then generate a board with starting init, and black(-1) takes first turn
            self.data[0, 0] = -1           
            self.data[6, 6] = -1
            self.data[0, 6] = 1
            self.data[6, 0] = 1
        else:
            self.data = board.copy()
            
    def reset(self, board=None):
        if board is None:
            self.data = np.zeros((7, 7), dtype=np.int8)
            self.data[0, 0] = -1           
            self.data[6, 6] = -1
            self.data[0, 6] = 1
            self.data[6, 0] = 1
        else:
            self.data = board.copy()
        
    def get_feature_map(self, turn, move):
        cdef int j, k
        cdef np.int8_t[:, :, :] out = np.zeros((6, 9, 9), dtype=np.int8)
        # define 1 edge
        
        # edge
        for j in range(9):
            for k in range(9):
                if j == 0 or j == 8 or k == 0 or k == 8:
                    out[0, j, k] = 1
         
        # my pieces
        for j in range(9):
            for k in range(9):
                if j > 0 and j < 8 and k > 0 and k < 8:
                    if self.data[j-1, k-1] == turn:
                        out[1, j, k] = 1
        
        # op pieces
        for j in range(9):
            for k in range(9):
                if j > 0 and j < 8 and k > 0 and k < 8:
                    if self.data[j-1, k-1] == -turn:
                        out[2, j, k] = 1
         
        # last move
        if not move is None:               
            out[3, move[0][0]+1, move[0][1]+1] = 1
            out[4, move[1][0]+1, move[1][1]+1] = 1
            
        # whose first
        if turn == -1:
            for j in range(9):
                for k in range(9):
                    out[5, j, k] = 1
        return np.array(out)
    
    def plot(self, is_next_move=False, turn=None):                        # plot the board
        image = self.data.copy()
        if is_next_move:
            if turn not in [-1, 1]:
                raise ValueError("Turn must be -1 or 1, or Must input a turn for next moves")
            else:
                next_moves = self.get_moves(turn)
                if len(next_moves) == 0:
                    raise ValueError("Game is over already")
                next_pos = list(zip(*next_moves))[1]
                for pos in next_pos:
                    image[pos] = turn / 2
        plt.imshow(image, cmap='gray')
        plt.xticks(range(7), range(7))
        plt.yticks(range(7), range(7))
        plt.show()
                
    def is_valid(self, turn, pos):
        cdef int dr, dc, r = pos[0], c = pos[1], new_r, new_c
        if self.data[r, c] != 0:
            return False
        else:
            for dr in range(-2, 3):
                for dc in range(-2, 3):
                    new_r = r+dr
                    new_c = c+dc
                    if new_r >= 0 and new_c >= 0 and new_r < 7 and new_c < 7 and self.data[new_r, new_c] == turn:
                        return True
            return False 
        
    def get_moves(self, turn, return_node_info=False):
        cdef int r, c, dr, dc, new_r, new_c
        cdef np.int8_t[:] action_mask = np.zeros(792, dtype=np.int8)
        next_moves = []
        corr_dict = {}
        children_dict = {}
        for r in range(7):
            for c in range(7):
                has_duplicate_move = False      # move within the radius of one of another friendly piece is called
                if self.is_valid(turn, (r, c)): # duplicate move
                    for dr in range(-2, 3):
                        for dc in range(-2, 3):
                            new_r = r+dr
                            new_c = c+dc
                            if new_r >= 0 and new_c >= 0 and new_r < 7 and new_c < 7 and self.data[new_r, new_c] == turn:
                                if abs(dr) <= 1 and abs(dc) <=1:
                                    if has_duplicate_move: 
                                        cur_move = ((new_r, new_c), (r, c))
                                        corr_dict[cur_move] = dup_move
                                        # update action mask
                                        if return_node_info: 
                                            action_mask[policy_dict[cur_move]] = 1
                                    elif self.data[new_r, new_c] == turn:
                                        dup_move = ((new_r, new_c), (r, c))
                                        next_moves.append(dup_move) 
                                        has_duplicate_move = True
                                        # preparing children nodes and action mask
                                        if return_node_info: 
                                            children_dict[dup_move] = None
                                            action_mask[policy_dict[dup_move]] = 1
                                elif self.data[new_r, new_c] == turn:
                                    cur_move = ((new_r, new_c), (r, c))
                                    next_moves.append(cur_move) 
                                    # preparing children nodes and action mask
                                    if return_node_info:
                                        children_dict[cur_move] = None
                                        action_mask[policy_dict[cur_move]] = 1
                                else:
                                    continue
        if return_node_info:
            return next_moves, corr_dict, children_dict, np.array(action_mask)
        else:
            return next_moves
    
    def get_greedy_move(self, turn, moves=None):
        cdef int x0, y0, x1, y1, dr, dc, tmp_score, best_score = -50
        # get all possible moves if not provided
        if moves is None:
            moves, corr_dict, _, _ = self.get_moves(turn, return_node_info=True)
            for item in corr_dict:
                moves.append(item)
        
        if len(moves) == 0:
            raise ValueError('No Possible Moves')
        
        best_moves = []
        # calculate greedy move
        for (x0, y0), (x1, y1) in moves:
            tmp_score = 0
            if abs(x0-x1) <= 1 and abs(y0-y1) <= 1:
                tmp_score += 1
            for dr in range(-1, 2):
                for dc in range(-1, 2):
                    try:
                        if x1+dr >= 0 and y1+dc >= 0:
                            tmp_score += self.data[x1+dr, y1+dc] == -turn
                    except:
                        pass
            if tmp_score > best_score:
                best_moves = [((x0, y0), (x1, y1))]
                best_score = tmp_score
            elif tmp_score == best_score:
                best_moves.append(((x0, y0), (x1, y1)))
        return choice(best_moves)
        
    def move_to(self, turn, pos0, pos1):
        cdef int dr, dc, x0 = pos0[0], y0 = pos0[1], x1 = pos1[0], y1 = pos1[1]
        
        if not self.is_valid(turn, pos1):
            raise ValueError("This move: " + str((pos0, pos1)) + " of turn: " + str(turn) + " is invalid") 
        elif self.data[x0, y0] != turn:
            raise ValueError("The starting position is not your piece")
        else:
            self.data[x1, y1] = turn
            if abs(x0 - x1) > 1 or abs(y0 - y1) > 1:   # jump move
                self.data[x0, y0] = 0

            for dr in range(-1, 2):                  # infection mode!!!!
                for dc in range(-1, 2):
                    if x1+dr >= 0 and y1+dc >= 0 and x1+dr < 7 and y1+dc < 7:
                        if self.data[x1+dr, y1+dc] == -turn:  # convert any piece of the opponent to 'turn'
                            self.data[x1+dr, y1+dc] = turn
                            
    @staticmethod                       
    def get_manual_q(int turn, np.int8_t[:, :] board):
        '''consider linear growth of win prob with regard to n_diff
        when diff >= 10, the slope grow a bit
        when diff >= 35, consider win prob close to 1 or -1
        ''' 
        cdef int r, c, turn_no = 0, op_no = 0
        cdef float max1=0.9, max2=0.95, diff, sign
        # get no diff of turns
        for r in range(7):
            for c in range(7):
                if board[r, c] == turn:
                    turn_no += 1
                elif board[r, c] == -turn:
                    op_no += 1
        diff = turn_no - op_no
        if abs(diff) > 30:
            return diff / abs(diff)
        else:
            return diff / 30
        
        # ignore the rest for now
        sign = diff
        diff = abs(diff)
        if diff < 35:
            diff = (diff / 35) ** 2 * max1
        else:
            diff = max2

        if sign < 0:
            return -diff
        else:
            return diff
    
    def evaluate(self, turn, this_turn, max_score=1, min_score=0.001):
        cdef int r, c, turn_no=0, op_no=0
        for r in range(7):
            for c in range(7):
                if self.data[r, c] == turn:
                    turn_no += 1
                elif self.data[r, c] == -turn:
                    op_no += 1
        if len(self.get_moves(this_turn)) == 0:# if one of them can no longer move, count and end
            if turn_no > op_no:
                return max_score
            else:
                return -max_score
        else:
            value = turn_no - op_no
        return value * min_score

## In the following cell, a recursive alpha-beta pruning based on piece no. difference evalutation function is implemented.
* quite amazingly, the programming process was finished within 30 min, with only two cycle of debugging.
1. Very interestingly, seemingly if we apply count diff as evaluation, it plays not as good if we apply count pieces~

In [9]:
%%cython 
# -a -f --compile-args=-DCYTHON_TRACE=1
import numpy as np
cimport numpy as np
import matplotlib.pyplot as plt
from random import choice

'''These methods are for Min max'''
def evaluate(np.int8_t[:, :] board, int turn):
    cdef int r, c, turn_no = 0, op_no = 0
    # get no diff of turns
    for r in range(7):
        for c in range(7):
            if board[r, c] == turn:
                turn_no += 1
            elif board[r, c] == -turn:
                op_no += 1
    return (turn_no - op_no)
        
def is_valid(np.int8_t[:, :] board, turn, pos):
    cdef int dr, dc, r = pos[0], c = pos[1], new_r, new_c
    if board[r, c] != 0:
        return False
    else:
        for dr in range(-2, 3):
            for dc in range(-2, 3):
                new_r = r+dr
                new_c = c+dc
                if new_r >= 0 and new_c >= 0 and new_r < 7 and new_c < 7 and board[new_r, new_c] == turn:
                    return True
        return False 
    

def next_move(np.int8_t[:, :] board, turn):
    cdef int r, c, dr, dc, new_r, new_c
    next_moves = []
    for r in range(7):
        for c in range(7):
            has_duplicate_move = False      # move within the radius of one of another friendly piece is called
            if is_valid(board, turn, (r, c)): # duplicate move
                for dr in range(-2, 3):
                    for dc in range(-2, 3):
                        new_r = r+dr
                        new_c = c+dc
                        if new_r >= 0 and new_c >= 0 and new_r < 7 and new_c < 7 and board[new_r, new_c] == turn:
                            if abs(dr) <= 1 and abs(dc) <=1:
                                if board[new_r, new_c] == turn and not has_duplicate_move:
                                    dup_move = ((new_r, new_c), (r, c))
                                    has_duplicate_move = True
                                    yield dup_move
                            elif board[new_r, new_c] == turn:
                                cur_move = ((new_r, new_c), (r, c))
                                yield cur_move
                            else:
                                continue

def has_next_move(np.int8_t[:, :] board, turn):
    try:
        next(next_move(board, turn))
        return True
    except StopIteration:
        return False
                                
def move_to(np.int8_t[:, :] board, turn, pos0, pos1):
    cdef int dr, dc, x0 = pos0[0], y0 = pos0[1], x1 = pos1[0], y1 = pos1[1]

    if not is_valid(board, turn, pos1):
        raise ValueError("This move: " + str((pos0, pos1)) + " of turn: " + str(turn) + " is invalid") 
    elif board[x0, y0] != turn:
        raise ValueError("The starting position is not your piece")
    else:
        board = board.copy()
        board[x1, y1] = turn
        if abs(x0 - x1) > 1 or abs(y0 - y1) > 1:   # jump move
            board[x0, y0] = 0

        for dr in range(-1, 2):                  # infection mode!!!!
            for dc in range(-1, 2):
                if x1+dr >= 0 and y1+dc >= 0 and x1+dr < 7 and y1+dc < 7:
                    if board[x1+dr, y1+dc] == -turn:  # convert any piece of the opponent to 'turn'
                        board[x1+dr, y1+dc] = turn
        return board
    
def min_max(np.int8_t[:, :] board, int turn, int target_turn, int depth=3, int alpha=-100, int beta=100, is_max=True, is_root=True):
    '''A recursive alpha beta pruning min_max function
    return: board evaluation, chosen move
    NB. for board evaluation, if the searching was pruned, it will return 100 for a minimizer and -100 for a maximizer'''
    cdef int result
    if is_root:
        best_moves = []
    else:
        best_move = ((0, 0), (0, 0))
    
    if depth == 0 or not has_next_move(board, turn): # start to do pruning and selecting once the recursion reaches the end
        result = evaluate(board, target_turn)
        return result, None
    else:
        if is_max:
            alpha = -100
        else:
            beta = 100

        for move in next_move(board, turn):
            result, _ = min_max(move_to(board, turn, move[0], move[1]), \
                                -turn, target_turn, depth-1, alpha, beta, not is_max, False)
            # prun the searching tree or update alpha and beta respectively
            if is_max:
                if result >= beta:
                    return 100, None
                elif result > alpha:
                    alpha = result
                    if is_root:
                        best_moves = [move]
                    else:
                        best_move = move
                elif result == alpha and is_root:
                    best_moves.append(move)
            else:
                if result <= alpha:
                    return -100, None
                elif result < beta:
                    beta = result
                    if is_root:
                        best_moves = [move]
                    else:
                        best_move = move
                elif result == beta and is_root:
                    best_moves.append(move)
        if is_max:
            if is_root:
                return alpha, choice(best_moves)
            else:
                return alpha, best_move
        else:
            if is_root:
                return beta, choice(best_moves)
            else:
                return beta, best_move

In [10]:
N = 100
D = 1
greedy_win = 0
for _ in range(N):
    a = Ataxx()
    turn = -1
    greedy_turn = choice([-1, 1])
    while True:
        #a.plot()
        if turn == greedy_turn:
            best_move = a.get_greedy_move(turn)
            #_, best_move = min_max(a.data, turn, turn, depth=D)
        else:
            _, best_move = min_max(a.data, turn, turn, depth=D)
        a.move_to(turn, best_move[0], best_move[1])
        turn = -turn
        result = a.evaluate(greedy_turn, turn)
        if result == 1:
            greedy_win += 1
            break
        elif result == -1:
            break
print("In the previous ", N, " rounds, greedy win ratio is: ", float(greedy_win) / float(N))

In the previous  100  rounds, greedy win ratio is:  0.64


## Training ideas
1. The __life_span__, i.e. how many episode of games will be self_played until the algorithm update the target model, seems to be vital to the stability of training of the network. (considering performance critiria as the win ratio against simple greedy). As is predicted, __a larger life_span will result in more stabled training, but much slower performance improvement__.
2. in the get_q function, I implemented different ways of returning q for each board. In the beginning of the training, when most of the Q output are random numbers, __setting manual_q a larger contribution to the final q is very helpful in training__. However, __in the latter stages of training, we should give manual_q a large weight in the final q only when we are very confident of the output of manual_q__, in this case, the confidence is considered to be related to the magnitude of manual_q, as well as init_q(the q output of the network).
3. As is observed from the testing result, Q seems to perform worse than P most of the time. One possible explanation is that __Q only learns the most valuable board scenarios, while is bad at evaluating low-quality boards.__, while the during testing for Q, the output of __Q is directly used to evaluate all the next_steps, without the assistance of P__, which is responsible for filtering out useless moves. Under situations where bad board was never met and was given a rather large magnitude Q, the bad move would be selected instead. This can be backed up by the fact that, when BOTH p and q are used to do the testing, the accuracy is significantly raised__(from q 78, p 82 to both 87.2 percent)__. This result is a combination of the contribution of both p and q, but a higher q contribution is believed, as with the growth of rollout times(400 in this experiment), the contribution of q grows significantly.

In [11]:
@nb.jit(nopython=True)
def for_node(c, parent_n_visit, n_visit, p, q):
    return c * p * sqrt(parent_n_visit + 1) / (n_visit + 1) - q / (n_visit + 1) 

@nb.jit(nopython=True)
def get_q(init_q, manual_q, mode):
    '''Manual_q and init_q are both an estimation for the q value
    It seems that considering init_q to be a rectification will not lead to good result'''
    if mode == 0:
        return manual_q
    elif mode == 1:
        return init_q
    elif mode == 2:
        return 0.75 * manual_q + 0.25 * init_q
    elif mode == 3:
        if abs(manual_q) >= 0.5:
            return 0.5 * manual_q + 0.5 * init_q 
        elif abs(manual_q) >= 0.8:
            return manual_q
        elif abs(init_q) > 0.15:
            return 0.2 * manual_q + 0.8 * init_q
        elif abs(init_q) > 0.5:
            return init_q
        else:
            return 0.4 * manual_q + 0.6 * init_q
    elif mode == 4:
        if abs(manual_q) >= 0.8:
            return manual_q
        elif abs(manual_q) >= 0.5:
            return 0.5 * manual_q + 0.5 * init_q 
        else:
            return init_q
    else:
        raise ValueError("Mode is not specified")


@nb.jit(nopython=True)
def recover_q(q, manual_q):
    '''manual q is the initial guess for win ratio
    while init q is the rectification'''
    return q

class TreeNode():
    def __init__(self, parent, p=0.0):
        self._parent = parent
        self._children = {} # a dictionary of action:node
        self._corr_dict = {} # a dictionary for duplicated moves
        self._n_visit = 0
        # from the parent perspective
        self._q = 0.0
        self._manual_q = -5 # manually deviced q
        self._init_q = -5 # learnt q
        self._p = p
        self._action_mask = None
        self._feature_map = None
        self._board = None
        self._is_expanded = False
        self._prev_move = None
        
    def __str__(self):
        out = "_n_visit: {}, _q: {}, _p: {}, _children: \n{}".format(\
                self._n_visit, self._q, self._p, self._children)
        return out
    
    def get_start_q(self, mode=0):
        ''' Different mode means different q
        mode 0: pure manual Q
        mode 1: pure policy Q
        mode 2: hybrid Q
        '''
        assert self._init_q != -5 and self._manual_q != -5
        assert self._q == 0
        self._q = get_q(self._init_q, self._manual_q, mode)
    
    def access_children(self, move):
        try:
            return self._children[move]
        except:
            return self._children[self._corr_dict[move]]
    
    def children_generator(self):
        for move in self._children:
            yield (move, self._children[move])
        for move in self._corr_dict:
            yield (move, self._children[self._corr_dict[move]])
    
    def update_all(self, t_v):
        node = self
        while not node is None: 
            node._q += t_v
            node._n_visit += 1
            node = node._parent
            t_v = -t_v
            
    @staticmethod
    def get_search_value(parent, node, c):
        # return values 
        try:
            value = for_node(c, parent._n_visit, node._n_visit, node._p, node._q)
        except:
            print(parent)
            print(node)
            raise
        #print(value)
        return value
        
    @staticmethod
    def get_frequency_value(node):
        try:
            return node._n_visit
        except:
            return 0
    
    def select(self, c):
        best_node = [0, 0]
        best_node[0], best_node[1] = max(self._children.items(), key=lambda node: self.get_search_value(self, node[1], c))
        return best_node
        
    def get_action_mask(self):
        # only generate the action mask once
        if not self._action_mask is None:
            return self._action_mask
        else:
            raise ValueError("No action mask, request failure")
    
    def get_action_frequency_map(self, temp=1):
        global policy_dict
        out = np.zeros(len(policy_dict))
        # record all the n_visit of each node
        nodes = self.children_generator()
        for node in nodes:
            out[policy_dict[node[0]]] = (float(node[1]._n_visit) / 100) ** (1/temp)
        # normalize the array
        out /= out.sum()
        return out

In [12]:
class PolicyValueNetwork():
    def __init__(self, lr=None, is_load_model=False, is_load_pretrain_model=False, gpu=None, verbose=False):
        if gpu is None:
            self._sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=24))
        else:
            config = tf.ConfigProto(log_device_placement=True)
            config.gpu_options.allow_growth = True
            self._sess = tf.Session(config=config)
        K.set_session(self._sess)
        
        if gpu is None:
            self._device = "cpu"
        else:
            self._device = "gpu:" + str(gpu)
        
        if is_load_model:
            if not is_load_pretrain_model:
                self._model = load_model('MCTS_POLICY_MODEL/AtaxxZero_model.h5')
                self._target_model = load_model('MCTS_POLICY_MODEL/AtaxxZero_target_model.h5')
                print("successfully loaded two models")
            else:
                self._model = load_model('MCTS_POLICY_MODEL/pretrain_AtaxxZero_model.h5')
                self._target_model = load_model('MCTS_POLICY_MODEL/pretrain_AtaxxZero_target_model.h5')
                print("successfully loaded two pretrained models")
            if not lr is None:
                self.update_learning_rate(lr)
        else:
            assert not lr is None
            self._lr = lr
            self._model = self.create_model()
            self._target_model = self.create_model()
            init = tf.global_variables_initializer()
            self._sess.run(init)
            print("new models generated")
            
        # do not synchronize both models
        # self.update_target_model(T=1)
        # print the model structure
        if verbose:
            print(self._model.summary())
        
    def update_learning_rate(self, lr):
        try:
            print("learning rate updated from {} to {}".format(self._lr, lr))
        except:
            print("compile new learning rate {}".format(lr))
        self._lr = lr
        self._model.compile(loss=['categorical_crossentropy', 'mse'], optimizer=Adam(lr=self._lr, decay=1e-6),\
                     loss_weights=[1, 1])
        
    def create_model(self):
        assert K.backend() == 'tensorflow'
        
        def res_block(res_in):
            x = Conv2D(64, (3, 3), padding='same', kernel_regularizer=l2(1e-4))(res_in)
            x = BatchNormalization(axis=1)(x)
            x = Activation('relu')(x)
            x = Conv2D(64, (3, 3), padding='same', kernel_regularizer=l2(1e-4))(x)
            x = BatchNormalization(axis=1)(x)
            x = add(inputs=[x, res_in])
            x = Activation('relu')(x)
            return x

        with tf.device(self._device):
            board_input = Input((6, 9, 9))
            mask_input = Input((792, ))
            x = Conv2D(64, (3, 3), padding='valid', kernel_regularizer=l2(1e-4))(board_input)
            x = BatchNormalization(axis=1)(x)
            x = Activation('relu')(x)
            for i in range(1):
                x = res_block(x)
            y = x

            x = Conv2D(2, (1, 1), kernel_regularizer=l2(1e-4))(y) # as we have 792 policy compared to 360 policy in go
            x = BatchNormalization(axis=1)(x)
            x = Activation('relu')(x)
            x = Flatten()(x)
            x = Dense(792, activation='softplus', kernel_regularizer=l2(1e-4))(x)
            x = multiply(inputs=[x, mask_input])     # this mask will mask any illegal move
            action_output = Activation('softmax')(x)

            x = Conv2D(1, (1, 1), kernel_regularizer=l2(1e-4))(y)
            x = BatchNormalization(axis=1)(x)
            x = Activation('relu')(x)
            x = Flatten()(x)
            x = Dense(128, activation='relu', kernel_regularizer=l2(1e-4))(x)
            value_output = Dense(1, activation='tanh')(x)

            model = Model(input=[board_input, mask_input],output=[action_output, value_output])

        # compile the model
        model.compile(loss=['categorical_crossentropy', 'mse'], optimizer=Adam(lr=self._lr, decay=1e-6),\
                         loss_weights=[1, 1])
        
        return model

    def update_target_model(self, T=0.5):
        """Inheriting the DDPG concept of updating target weights,
        Each time when target weights are updated, only update the target_weights by a small amount
        This amount can be adjusted by parameter T"""

        model_weights = self._model.get_weights()
        target_weights = self._target_model.get_weights()
        for i in range(len(model_weights)):
            target_weights[i] = T * model_weights[i] + (1-T) * target_weights[i]
        self._target_model.set_weights(target_weights)
        print("\n\ntarget model updated by T:", T, "\n\n")

    def predict(self, feature_map, action_mask, is_target=True):
        if not is_target:
            model = self._model
        else:
            model = self._target_model
        
        return self._sess.run(model.outputs, feed_dict={model.inputs[0]: feature_map.reshape(-1, 6, 9, 9), \
                                    model.inputs[1]: action_mask.reshape(-1, 792), K.learning_phase(): 0})
    
    def save(self, is_pretrain=False):
        if not is_pretrain:
            self._model.save('MCTS_POLICY_MODEL/AtaxxZero_model.h5')
            self._target_model.save('MCTS_POLICY_MODEL/AtaxxZero_target_model.h5')
        else:
            self._model.save('MCTS_POLICY_MODEL/pretrain_AtaxxZero_model.h5')
            self._target_model.save('MCTS_POLICY_MODEL/pretrain_AtaxxZero_target_model.h5')

In [13]:
class Relay():
    def __init__(self, life_span=80):
        self._relay_dict = {}
        self._counter = 0
        self._life_span = int(life_span * 4)
        
    def append(self, train_data):
        train_data = list((zip(*train_data)))
        # stack the small batch
        for i in range(4):
            train_data[i] = np.stack(train_data[i], axis=0)
        # append new data
        self._relay_dict[self._counter] = train_data
        self._counter += 1
        
        # remove too old data
        remove_index = -1
        for index in self._relay_dict:
            if self._counter - index > self._life_span:
                remove_index = index
        if remove_index != -1:
            del self._relay_dict[remove_index]
                
    def get(self, n_data=None):
        # concatenate all data
        all_data = [[], [], [], []]
        for _, data in self._relay_dict.items():
            for i in range(4):
                all_data[i].append(data[i])
        for i in range(4):
            all_data[i] = np.concatenate(all_data[i], axis=0)
        # sample n_data from all data
        length = all_data[0].shape[0]
        
        # sampling n_data from all data
        out = [[], [], [], []]
        if n_data >= length:
            out = all_data
        else:
            indexs = range(length)
            out_indexs = np.random.choice(indexs, size=n_data, replace=False)
            for i in out_indexs:
                for j in range(4):
                    out[j].append(all_data[j][i])
            for i in range(4):
                out[i] = np.stack(out[i], axis=0)
        
        print(length, " data in database")
        print(n_data, " data expected")
        print(out[0].shape[0], " data grabbed")
        return out

In [14]:
class MCTS():
    def __init__(self, c=1, dep_lim=10, lr=1e-4, life_span=10, is_load_model=False, is_load_pretrain_model=False, gpu=True):
        # slow_step means how many step we use to do typical mcts, after that we do fast play
        self._c = c
        self._dep_lim = dep_lim
        self._game = Ataxx()
        self._turn = -1
        # generate model
        self._network = PolicyValueNetwork(lr, is_load_model, is_load_pretrain_model, gpu=gpu)
        self._lr = lr
        # determine which mode to use, default is 0, to switch mode, must do manually
        self._mode = 0
        # generate root and expand initially its children
        self._root = TreeNode(None) # this one will move in self play mode
        self.further_init(self._root, self._game, self._turn, get_p_array=True)
        self._root_store = self._root # this is a backup for reset
        
        
    def reset_root(self, move=None):
        self._root = TreeNode(None) # this one will move in self play mode
        self.further_init(self._root, self._game, self._turn, move, get_p_array=True)
        self._root_store = self._root # this is a backup for reset
        
    def reset(self, left_space=45):
        self._game.reset()
        self._turn = -1
        self._root = TreeNode(None)
        
        if left_space < 45:
            steps = 0
            is_terminal = False
            result = 45
            while not is_terminal and result > left_space:
                if np.random.random() < 0.2: # 80 percent using greedy move
                    best_move = choice(self._game.get_moves(self._turn))
                    
                else:
                    best_move = self._game.get_greedy_move(self._turn)
                    
                self.make_a_move(best_move)
                is_terminal = abs(self._game.evaluate(1, self._turn)) == 1
                result = (np.array(self._game.data) == 0).sum()
                steps += 1
            if is_terminal:
                print("reset failure, do reset again")
                self.reset(left_space)
        self.reset_root()
        try: # tell the _root which move led it here
            self._root._move = best_move
        except:
            pass
        
    def plot_move_visit_freq(self):
        nodes = sorted(self._root._children.items(), key=lambda node: self._root.get_frequency_value(node[1]), reverse=True)
        p_sum = 0
        for node in nodes:
            try:
                print("{}: n_v:{:>6d} q_all:{:+06.6f} q:{:+06.6f} q_m:{:+06.6f} p:{:06.6f}"\
                      .format(node[0], node[1]._n_visit, -node[1]._q / (node[1]._n_visit + 1), \
                              -node[1]._init_q, -node[1]._manual_q, node[1]._p))
                p_sum += node[1]._p
            except:
                pass
        print("########################p_sum is: ", p_sum)
                      
    def get_next_move(self, is_best=False, is_dirichlet=True, rollout_times=100, t_lim=np.nan):
        global policy_list
        # do mcts
        self.rollout(rollout_times, t_lim)
        
        if is_best:
            # return the best move
            index = np.argmax(self._root.get_action_frequency_map())
        elif is_dirichlet:
            # return a choiced move
            prob = (0.75*self._root.get_action_frequency_map(temp=1e-2) \
                         + 0.25*np.random.dirichlet(0.3*np.ones(792))) * self._root._action_mask
            index = np.random.choice(range(792), p=prob / prob.sum())
        else:
            # return move with prob equal to visit frequency
            index = np.random.choice(range(792), p=self._root.get_action_frequency_map())
            
        if index != np.argmax(self._root.get_action_frequency_map()):
            print("\n\nThis is a random move\n\n")
            
        return policy_list[index]
    
    def make_a_move(self, next_move):
        # move the root to next_move's root
        if self._root._children == {}:
            self._root = TreeNode(None)
        elif type(self._root.access_children(next_move)) is np.float32: # the root may not be neccessarily expanded
            self._root = TreeNode(None)
        else:
            self._root = self._root.access_children(next_move)
            self._root._parent = None # necessary for updata_all
        
        # update the game board
        self._game.move_to(self._turn, next_move[0], next_move[1])
        self._turn = -self._turn
      
    def further_init(self, node, game, turn, prev_move=None, get_p_array=False):
        global policy_dict, policy_list
        node._prev_move = prev_move # tell which move led the node here
        # preparing all children
        new_moves, node._corr_dict, node._children, node._action_mask \
                                = game.get_moves(turn, return_node_info=True)
        # if meet end of the game, generate manual q
        if node._children == {}:
            if node._manual_q == -5:
                node._manual_q = game.get_manual_q(turn, game.data)
            # quite tricky here dude, remember to look from the parent perspective
            if node._manual_q > 0:
                node._q = 1
            else:
                node._q = -1 
            return
        
        # generate feature map
        node._feature_map = game.get_feature_map(turn, prev_move)
        node._board = game.data.copy()
        # if required, generate p array and q, only if there are children
        if get_p_array:
            # generate policy prob array
            out = self._network.predict(node._feature_map, node._action_mask)
            p_array = out[0]
            # give p to each child (float32)
            for move in new_moves:
                node._children[move] = p_array[0][policy_dict[move]]
            # init node._q
            node._init_q = out[1][0][0] 
            node._manual_q = game.get_manual_q(turn, game.data)
            node.get_start_q(self._mode)
        
    def expand(self, node, game, turn):
        global policy_dict, policy_list
        
        # if the node was not expanded, take that as a new root and further init it
        if node._children == {} and node._q == 0:
            self.further_init(node, game, turn, get_p_array=True)
        # if end of game, quit expanding
        if node._children == {}:
            assert node._q != 0
            return
        
        # update expanded state
        node._is_expanded = True
        
        # if there are children
        backup_board = game.data.copy() # warning, to backup a memoryview ndarray, use copy()
        index_list = []
        feature_map = []
        action_mask = []
        boards = []
        for move in node._children:
            tmp = node._children[move]
            try:
                assert type(tmp) is np.float32
            except:
                print(type(tmp))
                raise
            new_node = TreeNode(node, p=node._children[move])
            game.move_to(turn,  move[0], move[1])
            self.further_init(new_node, game, -turn, move, get_p_array=False)
            node._children[move] = new_node
            # prepare to calculate p for new_node only if it has children
            if new_node._children != {}:
                index_list.append(new_node)
                feature_map.append(new_node._feature_map)
                action_mask.append(new_node._action_mask)
                boards.append(new_node._board)
            # reset the gamer
            game.reset(board=backup_board)
        # if there are no more node that is expandable, quit
        if len(index_list) == 0:
            return
        
        # do batch prediction
        # print("batch size:", len(index_list))
        feature_map = np.stack(feature_map, axis=0)
        action_mask = np.stack(action_mask, axis=0)
        out = self._network.predict(feature_map, action_mask)
        # get batch manual q 
        boards = [game.get_manual_q(-turn, board) for board in boards]
        # update the result to each child node
        for i, child in enumerate(index_list):
            # assign q
            child._manual_q = boards[i] # neg for display use
            child._init_q = out[1][i][0] # same as above
            child.get_start_q(self._mode)
            # assign p
            assign_children(child._children, out[0][i])
            

    def rollout(self, rollout_times=100, t_lim=np.nan, t_min=2):
        start = time.time()
        for i in range(int(rollout_times*1.1)): 
            tmp_node = self._root
            tmp_game = Ataxx(self._game.data)
            tmp_turn = self._turn
            # start mcts
            step = 0
            while True:
                assert self._dep_lim > 0
                if step < self._dep_lim:
                    # expand the node only when it has never been expanded
                    if tmp_node._is_expanded == False:
                        self.expand(tmp_node, tmp_game, tmp_turn)

                    # check if is leaf node, if so, update the whole tree
                    if tmp_node._children == {}:
                        t_v = tmp_node._q / (tmp_node._n_visit + 1)
                        tmp_node.update_all(t_v)
                        break
                    else:
                        # select a child and continue exploration
                        next_move, next_node = tmp_node.select(self._c)
                            
                        # move to next move and next node
                        tmp_game.move_to(tmp_turn, next_move[0], next_move[1])    
                        tmp_node = next_node
                        tmp_turn = -tmp_turn
                else:
                    t_v = tmp_node._q / (tmp_node._n_visit + 1)
                    tmp_node.update_all(t_v)
                    break
                # update steps                                    
                step += 1
            cur_time = time.time() - start
            if cur_time > t_lim * 0.999:
                print("due to time lim, final rollout times: ", i, "time elapsed: ", cur_time)
                break
            
            if cur_time > t_min and i > rollout_times:
                print("due to rollout lim, final rollout times: ", i, "time elapsed: ", cur_time)
                break
                
        
    def testing_against_min_max(self, rounds=5, left_space=45, mm_dep=1, c=5, dep_lim=0, rollout_times=400, t_lim=6, verbose=True):
        print("####               ####")
        print("#### start testing ####")
        test_start = time.time()
        # record dep_lim and c for restoration
        store_dep_lim = self._dep_lim
        self._dep_lim = dep_lim
        store_c = self._c
        self._c = c
        # recorder of game result
        n_win = 0.0
        win_steps = 0.0
        lose_steps = 0.0
        for r in range(rounds):
            tmp_round_s = time.time()
            # randomly init the game board if no left_space specified
            self.reset(left_space)
            # set up start turns
            my_turn = choice([-1, 1])
            if verbose:
                print("round:", r+1)
                print("this game start with {} space left".format(left_space))
                print("self takes turn: ", my_turn)
            # start the game
            steps = 0
            while abs(self._game.evaluate(1, self._turn)) != 1:
                # plot the game board
                if verbose:
                    self._game.plot()
                    tmp_s = time.time()
                if self._turn == my_turn:
                    best_move = self.get_next_move(is_best=True, rollout_times=rollout_times, t_lim=t_lim)
                    if verbose:
                        print("self turn", my_turn)
                        print(self.plot_move_visit_freq())
                else:
                    _, best_move = min_max(self._game.data, self._turn, self._turn, mm_dep)
                    if verbose:
                        print("greedy turn", self._turn)
                if verbose:
                    print("this move takes time(s): ", time.time()-tmp_s)
                    print("chosen move is ", best_move)

                # synchronize steps and boards
                self.make_a_move(best_move)
                # update steps
                steps += 1
                if steps > 300:
                    print("steps over 250, game skip")
                    break
            if steps <= 300:
                is_self_win = self._game.evaluate(my_turn, self._turn) == 1
                if is_self_win:
                    n_win += 1
                    win_steps += steps
                else:
                    lose_steps += steps
                if verbose:
                    print("this round has steps: {}, time taken: {}, \n\n\nself wins? {}\n\n\n".format(steps, time.time()-tmp_round_s, is_self_win))
            else:
                n_win += 0.5
        # restore dep lim and c
        self._dep_lim = store_dep_lim
        self._c = store_c
        
        # output
        print("testing took time: ", time.time()-test_start)
        print("win steps: ", win_steps / (n_win + 1e-5), "lose steps: ", lose_steps / (rounds - n_win + 1e-5))
        print()
        if n_win == 0:
            return 0
        return n_win / rounds
    
    def tester(self, mm_dep=1, Q=False, P=False, BOTH=False, mode=0, times=200, dep_lim=1, rollout_times=400, verbose=False):
        out = {}
        mode_store = self._mode
        self._mode = mode
        print("mm_dep is: ", mm_dep)
        if Q:
            q_ratio = self.testing_against_min_max(\
                rounds=times, left_space=45, mm_dep=mm_dep, c=0, dep_lim=dep_lim, rollout_times=1, t_lim=6, verbose=verbose)
            print("\n\n\n                        win ratio of Q is {} \n\n\n\n\n".format(q_ratio))
            out['Q'] = q_ratio
        if P:
            p_ratio = self.testing_against_min_max(\
                rounds=times, left_space=45, mm_dep=mm_dep, c=1000000, dep_lim=dep_lim, rollout_times=1, t_lim=6, verbose=verbose)
            print("\n\n\n                        win ratio of P is {} \n\n\n\n\n".format(p_ratio))
            out['P'] = p_ratio
        if BOTH: # multiple customizations for this one
            both_ratio = self.testing_against_min_max(\
                rounds=int(times), left_space=45, mm_dep=mm_dep, c=self._c, dep_lim=dep_lim, rollout_times=rollout_times, t_lim=6, verbose=verbose)
            print("\n\n\n                        win ratio of both is {} \n\n\n\n\n".format(both_ratio))
            out['BOTH'] = both_ratio
        self._mode = mode_store
        return out
             
    def data_collector(self, node, visit_min, train_mode=0):
        out = []
        if node._n_visit >= visit_min:
            if train_mode == 0: # train both p and q
                tmp_data = [node._feature_map, \
                            node._action_mask, \
                            node.get_action_frequency_map(),\
                            recover_q(node._q / (node._n_visit + 1), node._manual_q)] # to recover what policy q should be
            elif train_mode == 1: # train only q
                print("Notice: train_mode is q train")
                tmp_data = [node._feature_map, \
                            np.zeros(792), \
                            np.zeros(792),\
                            recover_q(node._q / (node._n_visit + 1), node._manual_q)] # to recover what policy q should be
            else:
                raise ValueError("train_mode not understood")
            tmp_data = augment_data(tmp_data)
            out.extend(tmp_data)

        # This type of data augumentation reduce increase the entropy of P (low quality, less peaky P distribution)
        # however, it improved Q somehow, which is currently 20%(60%) better than P policy when playing against greedy
        # considering preserving only the data augmentation of Q (by setting action_mask to 0)
        for _, child in node._children.items():
            try: 
                assert child._n_visit >= visit_min * 0.625
                tmp_data = [child._feature_map, \
                            np.zeros(792), \
                            np.zeros(792), \
                            recover_q(child._q / (child._n_visit + 1), child._manual_q)]
                tmp_data = augment_data(tmp_data)
                out.extend(tmp_data)
            except:
                pass
        print("no. of data collected: ", len(out))
        return out
        
                
    def self_play(self, rollout_times=100, t_lim=np.nan, verbose=True, is_best=False, train_mode=0):
        train_data = []
        steps = 0
        print("start new self play")
        start = time.time()
        while self._root._children != {} and steps < 250:
            print(self._turn, "'s turn, step no. ", steps)
            tmp_s = time.time()
            if steps < 8: # to generate different game data
                best_move = self.get_next_move(is_best=is_best, is_dirichlet=False, \
                                               rollout_times=rollout_times, t_lim=t_lim)
            else:
                best_move = self.get_next_move(is_best=is_best, is_dirichlet=True, \
                                               rollout_times=rollout_times, t_lim=t_lim)
            print("one move takes time(s): ", time.time()-tmp_s)
            # record π data
            visit_min = rollout_times
            train_data.extend(self.data_collector(self._root, visit_min, train_mode=train_mode))
            # plot the game board if verbose
            if verbose:
                self._game.plot()
                self.plot_move_visit_freq()
            # make the move and move on
            self.make_a_move(best_move)
            steps += 1
        print("this self play has {} steps, time elapsed {}".format(steps, time.time()-start))
        print("winner is", np.sign(self._root._q * self._turn))
        
        return train_data
    
    def zero_out_pretraining(self): # zero out q output
        train_data = []
        # collect train data
        while len(train_data) < 500000:
            tmp_game = Ataxx()
            tmp_round_s = time.time()
            # randomly init the game board if no left_space specified
            # start the game
            steps = 0
            turn = -1
            while abs(tmp_game.evaluate(1, turn)) != 1:
                # plot the game board
                if steps < 10:
                    rand_thresh = 0.5
                else:
                    rand_thresh = 0.25
                if np.random.random() > rand_thresh:
                    best_move = tmp_game.get_greedy_move(turn)
                else:
                    best_move = choice(tmp_game.get_moves(turn))
                
                # make the move and grab data
                tmp_game.move_to(turn, best_move[0], best_move[1])
                turn = -turn
                tmp_data = [tmp_game.get_feature_map(turn, best_move), \
                            np.zeros(792), \
                            np.zeros(792),\
                            0]
                tmp_data = augment_data(tmp_data)
                train_data.extend(tmp_data)
                # update steps
                steps += 1
            print(len(train_data))
                
        # process train data
        train_data = list((zip(*train_data)))
        for i in range(4):
            train_data[i] = np.stack(train_data[i], axis=0)
        # do the training
        es = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=3, verbose=1, mode='auto')
        self._network._model.fit(x=[train_data[0], train_data[1]], y=[train_data[2], train_data[3]], verbose=1, \
                                    batch_size=1024, epochs=1024, shuffle=True, validation_split=0.1, callbacks=[es])
        # update the model every life_span
        self._network.update_target_model()
        self._network.save(is_pretrain=True)
    
    def reinforcement_pretraining(self, rounds=1000, verbose=True): # geting Q value close to manual Q
        train_data = []
        # collect train data
        for r in range(rounds):
            tmp_game = Ataxx()
            tmp_round_s = time.time()
            # randomly init the game board if no left_space specified
            if verbose:
                print("round:", r+1)
            # start the game
            steps = 0
            turn = -1
            while abs(tmp_game.evaluate(1, turn)) != 1:
                # plot the game board
                if verbose:
                    tmp_s = time.time()
                if steps < 8:
                    rand_thresh = 0.5
                else:
                    rand_thresh = 0.2
                if np.random.random() > rand_thresh:
                    best_move = tmp_game.get_greedy_move(turn)
                else:
                    best_move = choice(tmp_game.get_moves(turn))
                
                # make the move and grab data
                tmp_game.move_to(turn, best_move[0], best_move[1])
                turn = -turn
                tmp_data = [tmp_game.get_feature_map(turn, best_move), \
                            np.zeros(792), \
                            np.zeros(792),\
                            tmp_game.get_manual_q(turn, tmp_game.data)]
                tmp_data = augment_data(tmp_data)
                train_data.extend(tmp_data)
                # update steps
                steps += 1
            if verbose:
                print("this round has {} steps, takes time {}".format(steps, time.time()-tmp_round_s))
            
            if len(train_data) > 500000:
                # process train data
                train_data = list((zip(*train_data)))
                for i in range(4):
                    train_data[i] = np.stack(train_data[i], axis=0)
                # do the training
                es = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=5, verbose=1, mode='auto')
                self._network._model.fit(x=[train_data[0], train_data[1]], y=[train_data[2], train_data[3]], verbose=1, \
                                            batch_size=1024, epochs=512, shuffle=True, validation_split=0.1, callbacks=[es])
                # update the model every life_span
                self._network.update_target_model()
                self._network.save(is_pretrain=True)
                # clear train_data
                train_data = []
                
        if len(train_data) <= 500000:
            # process train data
            train_data = list((zip(*train_data)))
            for i in range(4):
                train_data[i] = np.stack(train_data[i], axis=0)
            # do the training
            es = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=5, verbose=1, mode='auto')
            self._network._model.fit(x=[train_data[0], train_data[1]], y=[train_data[2], train_data[3]], verbose=1, \
                                        batch_size=1024, epochs=1024, shuffle=True, validation_split=0.1, callbacks=[es])
            # update the model every life_span
            self._network.update_target_model()
            self._network.save(is_pretrain=True)
            
    def reinforcement_learning(self, episode=1000, rollout_times=100, life_span=25, train_interval=5, T=0.5, t_lim=np.nan, left_space_max=None, self_play_verbose=False, train_mode=0):        
        # setting testing parameters
        dep_q = 1
        dep_p = 3
        test_rounds = [250, 200, 100, 30]
        # setting up train relay
        train_relay = Relay(life_span)
        for epi in range(episode):
            print("episode {} now start".format(epi))
            self.reset_root()
            # randomly skip a few steps if left space not specified
            if not left_space_max is None:
                left_space = left_space_max * np.random.normal(loc=1, scale=0.2)
            else:
                left_space = 45
            self.reset(left_space=left_space)
            print("left space is {}".format(left_space))
            # start self_play and get train_data
            if np.random.random() < 0.05:
                train_data = self.self_play(rollout_times=rollout_times, t_lim=t_lim, verbose=True, train_mode=train_mode)
            else:
                train_data = self.self_play(rollout_times=rollout_times, t_lim=t_lim, verbose=self_play_verbose, train_mode=train_mode)
            # store it in relay
            train_relay.append(train_data)
            # do training every 5 epi and update target model and reset tree
            if epi >= life_span and epi%train_interval == 0:
                train_data = train_relay.get(1800*life_span)
                print("start training, training data no. {}".format(train_data[0].shape[0]))
                # do training
                es = EarlyStopping(monitor='val_loss', min_delta=0.000001, patience=64, verbose=1, mode='auto')
                self._network._model.fit(x=[train_data[0], train_data[1]], y=[train_data[2], train_data[3]], verbose=1, \
                                            batch_size=1024, epochs=512, shuffle=True, validation_split=0.15, callbacks=[es])
                print("saving files")
                self._network.save()
                # update the model every life_span, but dont update for the first life_span
                if epi%life_span == 0 and epi >= 2*life_span:
                    self._network.update_target_model(T=T)
            # test the performance of the AI
            if epi % (life_span) == 0:
                """Try to test so that a reasonable result is given"""
                print("\n\nstart testing against min max")
                if dep_q != 1:
                    self.tester(mm_dep=1, Q=True, mode=1, times=250)
                if dep_p != 1:
                    self.tester(mm_dep=1, P=True, mode=1, times=250)
                q_result = self.tester(mm_dep=dep_q, Q=True, mode=1, times=test_rounds[dep_q-1])['Q']
                p_result = self.tester(mm_dep=dep_p, P=True, mode=1, times=test_rounds[dep_p-1])['P']
                if q_result < 0.20:
                    if dep_q > 1:
                        dep_q -= 1
                elif q_result > 0.55:
                    if dep_q < 4:
                        self.tester(mm_dep=dep_q+1, Q=True, mode=1, times=test_rounds[dep_q])
                        if q_result > 0.80:
                            dep_q += 1
                if p_result < 0.20:
                    if dep_p > 1:
                        dep_p -= 1
                elif p_result > 0.55:
                    if dep_p < 4:
                        self.tester(mm_dep=dep_p+1, P=True, mode=1, times=test_rounds[dep_p])
                        if p_result > 0.80:
                            dep_p += 1
            # a flawed way of annealing lr
            if epi % (life_span) == life_span - 1 and self._lr >= 5e-6:
                self._lr = self._lr / 1.5
                self._network.update_learning_rate(self._lr)
            print("episode {} finished".format(epi))

In [None]:
player = MCTS(gpu=1, lr=1e-6, c=3, is_load_model=True, is_load_pretrain_model=False, dep_lim=3)

Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
keep_dims is deprecated, use keepdims instead
successfully loaded two models
compile new learning rate 1e-06


In [None]:
##### player._dep_lim = 4
player._mode = 4 # almost mode 1
player.reinforcement_learning(episode=320, left_space_max=None, rollout_times=2000, train_interval=5, \
                              life_span=40, t_lim=8, self_play_verbose=False, train_mode=0, \
                              T=0.1)

episode 0 now start
left space is 45
start new self play
-1 's turn, step no.  0


This is a random move


one move takes time(s):  0.3195915222167969
no. of data collected:  8
1 's turn, step no.  1


This is a random move


one move takes time(s):  0.3857433795928955
no. of data collected:  16
-1 's turn, step no.  2
one move takes time(s):  0.8061413764953613
no. of data collected:  16
1 's turn, step no.  3
one move takes time(s):  0.9494669437408447
no. of data collected:  16
-1 's turn, step no.  4
due to rollout lim, final rollout times:  2186 time elapsed:  2.0000181198120117
one move takes time(s):  2.001197099685669
no. of data collected:  16
1 's turn, step no.  5
one move takes time(s):  1.5986297130584717
no. of data collected:  16
-1 's turn, step no.  6
one move takes time(s):  0.40499448776245117
no. of data collected:  16
1 's turn, step no.  7
one move takes time(s):  0.7652602195739746
no. of data collected:  16
-1 's turn, step no.  8
one move takes time(s):  1.6035

one move takes time(s):  0.8544659614562988
no. of data collected:  24
1 's turn, step no.  65


This is a random move


one move takes time(s):  0.5916268825531006
no. of data collected:  16
-1 's turn, step no.  66


This is a random move


one move takes time(s):  0.8661670684814453
no. of data collected:  24
1 's turn, step no.  67
one move takes time(s):  1.4016268253326416
no. of data collected:  16
-1 's turn, step no.  68
one move takes time(s):  0.5149927139282227
no. of data collected:  16
1 's turn, step no.  69


This is a random move


one move takes time(s):  1.013920783996582
no. of data collected:  16
-1 's turn, step no.  70
one move takes time(s):  0.5887372493743896
no. of data collected:  16
1 's turn, step no.  71


This is a random move


one move takes time(s):  0.7829866409301758
no. of data collected:  16
-1 's turn, step no.  72
one move takes time(s):  0.5552680492401123
no. of data collected:  24
1 's turn, step no.  73


This is a random move


one move tak

due to rollout lim, final rollout times:  2001 time elapsed:  2.050743341445923
one move takes time(s):  2.0524210929870605
no. of data collected:  16
-1 's turn, step no.  34


This is a random move


one move takes time(s):  1.461106538772583
no. of data collected:  16
1 's turn, step no.  35
one move takes time(s):  1.3778965473175049
no. of data collected:  16
-1 's turn, step no.  36
one move takes time(s):  1.1112861633300781
no. of data collected:  16
1 's turn, step no.  37
one move takes time(s):  1.5525250434875488
no. of data collected:  24
-1 's turn, step no.  38
one move takes time(s):  1.8264341354370117
no. of data collected:  24
1 's turn, step no.  39


This is a random move


one move takes time(s):  0.9365200996398926
no. of data collected:  16
-1 's turn, step no.  40
one move takes time(s):  1.1772568225860596
no. of data collected:  16
1 's turn, step no.  41


This is a random move


one move takes time(s):  0.9280910491943359
no. of data collected:  16
-1 's tu

due to rollout lim, final rollout times:  2001 time elapsed:  3.9831395149230957


This is a random move


one move takes time(s):  3.9850704669952393
no. of data collected:  8
1 's turn, step no.  19
due to rollout lim, final rollout times:  2001 time elapsed:  3.103501796722412
one move takes time(s):  3.105172634124756
no. of data collected:  8
-1 's turn, step no.  20
due to rollout lim, final rollout times:  2001 time elapsed:  2.880441188812256
one move takes time(s):  2.8826980590820312
no. of data collected:  16
1 's turn, step no.  21
due to rollout lim, final rollout times:  2001 time elapsed:  2.0198912620544434
one move takes time(s):  2.021362781524658
no. of data collected:  16
-1 's turn, step no.  22
due to rollout lim, final rollout times:  2001 time elapsed:  2.7700538635253906
one move takes time(s):  2.7716078758239746
no. of data collected:  16
1 's turn, step no.  23
due to rollout lim, final rollout times:  2001 time elapsed:  2.308389186859131
one move takes tim

one move takes time(s):  0.06219601631164551
no. of data collected:  24
1 's turn, step no.  81


This is a random move


one move takes time(s):  0.054244041442871094
no. of data collected:  8
this self play has 82 steps, time elapsed 166.88370156288147
winner is 1.0
episode 2 finished
episode 3 now start
left space is 45
start new self play
-1 's turn, step no.  0


This is a random move


one move takes time(s):  0.2719700336456299
no. of data collected:  8
1 's turn, step no.  1


This is a random move


one move takes time(s):  0.3847939968109131
no. of data collected:  16
-1 's turn, step no.  2
one move takes time(s):  0.5361490249633789
no. of data collected:  16
1 's turn, step no.  3
one move takes time(s):  1.1309075355529785
no. of data collected:  16
-1 's turn, step no.  4
due to rollout lim, final rollout times:  2005 time elapsed:  2.00004243850708


This is a random move


one move takes time(s):  2.0013773441314697
no. of data collected:  16
1 's turn, step no.  5


T

In [17]:
player.tester(mm_dep=1, Q=True, P=True, BOTH=False, mode=1, times=250, dep_lim=1, rollout_times=400, verbose=False)
player.tester(mm_dep=2, Q=True, P=True, BOTH=False, mode=1, times=200, dep_lim=1, rollout_times=400, verbose=False)
player.tester(mm_dep=3, Q=True, P=True, BOTH=False, mode=1, times=100, dep_lim=1, rollout_times=400, verbose=False)
player.tester(mm_dep=4, Q=True, P=True, BOTH=False, mode=1, times=50, dep_lim=1, rollout_times=400, verbose=False)

mm_dep is:  1
####               ####
#### start testing ####
testing took time:  92.18110156059265
win steps:  55.88499720575014 lose steps:  35.31999293600141




                        win ratio of Q is 0.8 





####               ####
#### start testing ####
testing took time:  113.77895212173462
win steps:  58.052171389036026 lose steps:  74.19996290001855




                        win ratio of P is 0.92 





mm_dep is:  2
####               ####
#### start testing ####
testing took time:  82.00473427772522
win steps:  80.99999138297964 lose steps:  16.999998396226566




                        win ratio of Q is 0.47 





####               ####
#### start testing ####
testing took time:  144.61132621765137
win steps:  71.99999294117715 lose steps:  95.99999020408264




                        win ratio of P is 0.51 





mm_dep is:  3
####               ####
#### start testing ####
testing took time:  28.193013429641724
win steps:  0.0 lose steps:  14.199998580000141






{'P': 0, 'Q': 0}

In [None]:
w = player._network._model.get_weights()
fig = plt.gcf()
print(w[0].shape)
show_w = w[0]
for in_layer in range(6):
    for out_layer in range(64):
        print('in', in_layer, 'out', out_layer)
        plt.imshow(show_w[..., in_layer, out_layer], cmap='gray')
        plt.pause(0.1)