In [1]:
import numpy as np
np.random.seed(2019)
from collections import namedtuple
from collections import deque
from random import shuffle
import tensorflow as tf

from pickle import Pickler, Unpickler
import os, sys

import matplotlib.pyplot as plt

In [2]:
WinState = namedtuple('WinState', ['is_ended', 'winner'])

In [3]:
class MCTS():
    def __init__(self, game, wrapped_nnet, args):
        self.game = game
        self.wrapped_nnet = wrapped_nnet
        self.args = args
        
        self.q_sa = {} # q(s,a) = Q value of taking action a at state s, avg win rate with penalty (values in the interval [-1,1])
                       # e.g.: at state s take action a results 1 game won over 3, q_sa(s,a) = (1-1-1)/3 = -1/3
        self.n_sa = {} # n(s,a) = number of games passing by state s and action a
        self.n_s = {} # n(s) = number of games played including state s
        
        self.policies_s = {} # initial policies returned by the nnet
        self.valid_moves_s = {} # all valid moves at state s
        self.end_game_s = {} # if a game is ended or not 
        

    '''
    uses history statistic counts to 
    return a probabilistic vector of length nb_rows*nb_cols (flattened board)
    '''
    
    
    def getActionSpaceProba(self, canonicalBoard, temperature=1, add_dirichlet_noise=True):
        '''
        ========
        pipeline
        ========
        
        # step 1: run simulations started from the given board
        # step 2: update statistics of the given board
        # step 3: calculate probability vector by the following formula
        #         Pr(a|s) = n_sa(s,a)^(1/temperature) / sum(n_sa(s,a')^(1/temperature), for all a' in legal moves)
        '''
        
        
        ''' step 1: '''
        for i in range(self.args['MCTS_nb_simulations']):
            self.search(canonicalBoard)
        '''end step 1'''
        
        ''' step 2:'''
        s = self.game.getStringRepresentation(canonicalBoard)
        
        # valid_pos: nb_rows * nb_cols binary array
        valid_pos = self.game.getValidMoves(canonicalBoard)
        
        # counts: nb_rows * nb_cols array whose elements are node counts of each position
        counts = [[self.n_sa[(s,(row, col))] if (s,(row, col)) in self.n_sa and valid_pos[row,col] else 0 
                  for col in range(canonicalBoard.shape[1])]
                 for row in range(canonicalBoard.shape[0])]
        counts = np.array(counts)
        ''' end step 2'''
        
        # diriclet noise: randomly pick a position among the valid ones and add a tiny value to its count
        if add_dirichlet_noise:
            nb_valid_pos = valid_pos.sum()
            
            # this allows to convert valid positions' matrix into an equi-probability matrix of valid moves
            valid_pos = valid_pos / nb_valid_pos
            idx = np.random.choice(canonicalBoard.shape[0]*canonicalBoard.shape[1], 
                                   p=np.ndarray.flatten(valid_pos))
            # add a noise
            counts[int(idx/canonicalBoard.shape[0]), np.mod(idx,canonicalBoard.shape[1])] += 1
            
        ''' step 3:'''
        probas = np.zeros((canonicalBoard.shape[0]*canonicalBoard.shape[1],))
        if temperature == 0:
            probas[np.argmax(counts)] = 1
            
        else:
            counts = counts ** (1/temperature)
            probas = np.ndarray.flatten(counts/counts.sum())
        ''' end step 3''' 
        
        
        return probas
    
    
    '''
    returns simulation results guided by neural network and updates statistics of the search tree
    '''
    def search(self, canonicalBoard):
        '''
        ========
        pipeline
        ========
        
        # step 1: check if the current state is seen, if not, store whether it is an ending state
        # step 2: check if the current state is an ending state, if is, return the winner from the opponent's point of view;
                                                                 if not, next step.
        # step 3: check if the current state is seen, if is, next step;
                                                      if not, use the current neural network to predict the probability vector and the winner,
                                                              store the information of the current state then return the predicted winner from  
                                                              the opponent's point of view.
        # step 4: develop the current state one step ahead by choosing the move that maximises the Q value (kind of reward function), move on 
                  to the next state, repeat step 1 (call search function recursively, recursion terminates when one of the stopping conditions
                  is satisfied).
        # step 5: increment the exploitation count of the current state
        # step 6: check if the taken move is seen, if is, update the move's statistics;
                                                   if not, initialise the move as child node of the current state
        '''
        
        
        s = self.game.getStringRepresentation(canonicalBoard)
        
        ''' step 1: '''
        # track the untracked game state
        if s not in self.end_game_s:
            self.end_game_s[s] = self.game.isEnded(canonicalBoard,1) # canonicalBoard is the board from player1's point of view
        ''' end step 1'''
        
        ''' step 2: '''
        # if the game is ended, return value
        if self.end_game_s[s] != 0:  
            return -self.end_game_s[s]
        ''' end step 2'''
        
        ''' step 3: '''
        # if s is a newly discovered leaf node, return value and store policies matrix
        if s not in self.policies_s:
            # policies and value are given by the neural network's prediction
            self.policies_s[s], v = self.wrapped_nnet.predict(canonicalBoard) # v to be returned
            self.policies_s[s] = self.policies_s[s].reshape(canonicalBoard.shape[0],canonicalBoard.shape[1])
            
            # a binomial array indicating if a position is valid or not
            valid_moves = self.game.getValidMoves(canonicalBoard)
            
            # mask out the unvalid moves from policies matrix and renormalise it
            self.policies_s[s] *= valid_moves
            
            # renormalise policies_s[s]
            tmp_sum = self.policies_s[s].sum()
            if tmp_sum > 0:
                self.policies_s[s] /= tmp_sum
            else:
                print("Warning: all valid moves are masked. Check the neural net. \n")
                self.policies_s[s] = valid_moves/valid_moves.sum()
            
            self.valid_moves_s[s] = valid_moves
            self.n_s[s] = 0 # newly created leaf node
            return -v
        ''' end step 3 '''
        
        ''' step 4: '''
        # if s is not a ongoing state, keep searching until return condition satisfied
        valid_moves = self.valid_moves_s[s]
        current_best = -np.inf
        best_action = (-1,-1)
        
        valid_pos = np.where(valid_moves)
        for row, col in zip(valid_pos[0], valid_pos[1]):
            # calculate q value accordingly
            if (s,(row, col)) in self.q_sa:
                u = self.q_sa[(s,(row, col))] + self.args['MCTS_C'] * self.policies_s[s][row][col] * np.sqrt(self.n_s[s]) / (1 + self.n_sa[(s,(row, col))])
            else:
                u = self.args['MCTS_C'] * self.policies_s[s][row][col] * np.sqrt(self.n_s[s] + 0.000001)

            if u > current_best:
                current_best = u
                best_action = (row, col)
                    
        next_s, next_player = self.game.getNextState(canonicalBoard, 1, best_action[0], best_action[1])
        next_s = self.game.getCanonicalForm(next_s, next_player)
        v = self.search(next_s)
        ''' end step 4 '''
        
        ''' step 5: '''
        # recursively update statistics of the search tree
        self.n_s[s] += 1
        ''' end step 5 '''
        
        ''' step 6: '''
        # if the best action has never been played
        if (s, best_action) not in self.q_sa:
            self.q_sa[(s, best_action)] = v
            self.n_sa[(s, best_action)] = 1
            
        # if the best action has already been played
        else:
            # take the average of the q values with its history
            self.q_sa[(s, best_action)] = (v + self.n_sa[(s, best_action)] * self.q_sa[(s, best_action)]) / (self.n_sa[(s, best_action)] + 1)
            # increment the count of the node
            self.n_sa[(s, best_action)] += 1
        ''' end step 6 '''  
        
        return -v
      
      
    
class Board():    
    def __init__(self, nb_rows=15, nb_cols=15, win_length=5, stone_matrix=None):
        self.nb_rows = nb_rows
        self.nb_cols = nb_cols
        self.win_length = win_length
        if stone_matrix is None:
            self.stone_matrix = np.zeros((self.nb_rows, self.nb_cols))
            #self.stone_matrix[int(self.nb_rows/2),int(self.nb_cols/2)] = 1
        else:
            self.stone_matrix = stone_matrix
    
    
    def __str__(self):
        return str(self.stone_matrix)
    
    def copyBoard(self, stone_matrix):
        if stone_matrix is None:
            stone_matrix = self.stone_matrix
        return Board(stone_matrix=stone_matrix)
    
        
    def addStone(self, row, col, player):
        if self.stone_matrix[row,col] != 0:
            raise ValueError("(%d, %d) taken." % (row, col))
        self.stone_matrix[row,col] = player
     
    
    def getValidMoves(self):
        return self.stone_matrix == 0
    
    
    def getWinState(self):
        # first, check if one of the players wins
        # second, check if tied
        # last, game is still ongoing
        
        # 1st
        for player in [-1,1]:
            current_player_stones = np.array(self.stone_matrix == -player, dtype=int)
            
            if (self._straight5(current_player_stones) 
                or self._straight5(current_player_stones.transpose()) 
                or self._diag5(current_player_stones)):
                
                # if the opponent has more than 5 stones aligned, the opponent loses
                if (self._straight5plus(current_player_stones) 
                    or self._straight5plus(current_player_stones.transpose()) 
                    or self._diag5plus(current_player_stones)):
                    return WinState(True, player)
                
                return WinState(True, -player)
        
        # 2nd
        if not self.getValidMoves().any():
            return WinState(True, None)
        
        # finally
        return WinState(False, None)
    
    
    def _straight5(self, player_stones):
        contiguous_stones_counts = [player_stones[i:(i+self.win_length), :].sum(axis=0) 
                                    for i in range(self.nb_rows-self.win_length+1)]
        contiguous_stones_counts = np.array(contiguous_stones_counts)
        aligned = (contiguous_stones_counts == self.win_length)
        if aligned.any():
            return True
        return False
    
    def _straight5plus(self, player_stones):
        contiguous_stones_counts = [player_stones[i:(i+self.win_length+1), :].sum(axis=0) 
                                    for i in range(self.nb_rows-self.win_length)]
        contiguous_stones_counts = np.array(contiguous_stones_counts)
        aligned = (contiguous_stones_counts > self.win_length)
        if aligned.any():
            return True
        return False
    
    def _diag5(self, player_stones):
        # slide a matrix of win_length over the board and check if there exists any diagonal or anti-diagonal 
        # sub-matrices. if it does, the game is won
        for i in range(self.nb_rows - self.win_length + 1):
            for j in range(self.nb_cols - self.win_length + 1):
                # if diagonal or anti-diagonal exists
                if (np.trace(player_stones[i:(i+self.win_length), j:(j+self.win_length)]) == self.win_length or
                    np.trace(np.fliplr(player_stones[i:(i+self.win_length), j:(j+self.win_length)])) == self.win_length): 
                    return True 
        return False
    
    def _diag5plus(self, player_stones):
        # slide a matrix of win_length+1 over the board and check if there exists any diagonal or anti-diagonal 
        # sub-matrices. if it does, the game is loss
        for i in range(self.nb_rows - self.win_length):
            for j in range(self.nb_cols - self.win_length):
                # if diagonal or anti-diagonal exists
                if (np.trace(player_stones[i:(i+self.win_length+1), j:(j+self.win_length+1)]) > self.win_length or
                    np.trace(np.fliplr(player_stones[i:(i+self.win_length+1), j:(j+self.win_length+1)])) > self.win_length): 
                    return True 
        return False
      
      
      
      
      
class Game():
    def __init__(self, nb_rows=15, nb_cols=15, win_length=5, stone_matrix=None):
        self._base_board = Board(nb_rows, nb_cols, win_length, stone_matrix)
    
    def getInitBoard(self):
        return self._base_board.stone_matrix
    
    def getNextState(self, board, player, row, col):
        board_cp = self._base_board.copyBoard(stone_matrix=np.copy(board))
        board_cp.addStone(row, col, player)
        return board_cp.stone_matrix, -player
    
    def getValidMoves(self, board):
        return self._base_board.copyBoard(stone_matrix=board).getValidMoves()
    
    def getActionSpaceSize(self):
        return self._base_board.nb_rows * self._base_board.nb_cols
    
    def isEnded(self, board, player):
        board_new = self._base_board.copyBoard(stone_matrix=board)
        win_state = board_new.getWinState()
        
        # if game ended
        if win_state.is_ended:
            # tie
            if win_state.winner is None:
                return 0.00001
            # current player wins
            elif win_state.winner == player:
                return 1
            # current player loses
            elif win_state.winner == -player:
                return -1
            else:
                raise ValueError('Unexpected winstate found: ', win_state)
            
        # if game continues
        else:
            return 0
      
    # keep player1's point of view
    def getCanonicalForm(self, board, player):
        return board * player
    
    def getStringRepresentation(self, board):
        return str(board)
    
    # returns a list of symmetric states of the current board along with the policy matrix
    def getSymmetries(self, board, pi):
        # original board; left-right flipped board; upside-down board; 
        # left_right flipped upside down board;
        # transposed board; left_right flipped transposed board; upside-down transposed board;
        # left_right flipped upside down transposed board
        pi = np.array(pi).reshape(board.shape[0],board.shape[1])
        pi_T = np.transpose(pi)
        
        board_T = np.transpose(board)
        
        return [(board, np.ndarray.flatten(pi)), 
                (np.flip(board, axis=1), np.ndarray.flatten(np.flip(pi, axis=1))), (np.flip(board, axis=0), np.ndarray.flatten(np.flip(pi, axis=0))),
                (np.flip(board), np.ndarray.flatten(np.flip(pi))),
                (board_T, np.ndarray.flatten(pi_T)), 
                (np.flip(board_T, axis=1), np.ndarray.flatten(np.flip(pi_T, axis=1))), (np.flip(board_T, axis=0), np.ndarray.flatten(np.flip(pi_T, axis=0))),
                (np.flip(board_T), np.ndarray.flatten(np.flip(pi_T)))]
      
      
      
class Coach():
    def __init__(self, game, args): 
        self.game = game
        
        self.wrapped_nnet = NNetWrapper(game)
        self.mcts = MCTS(game, self.wrapped_nnet, args)
        
        self.wrapped_prev_net = NNetWrapper(game)
        
        self.args = args
        
        self.trainExamplesHistory = []
        self.skip_first_self_play = False
        
        
    '''
    an episode is a from-start-to-end game
    '''
    def executeEpisode(self):
        
        trainEg = []
        board = self.game.getInitBoard()
        self.current_player = 1
        episode_count = 0
        
        while True:
            episode_count += 1
            canonical_board = self.game.getCanonicalForm(board, self.current_player)
            temperature = int(episode_count < self.args['COACH_temperature_threshold'])
            
            pi = self.mcts.getActionSpaceProba(canonical_board, temperature) # vector of length nb_rows * nb_cols
            sym = self.game.getSymmetries(canonical_board, pi)
            trainEg += [[b, self.current_player, p, None] for b,p in sym]   # p: vector of length nb_rows * nb_cols
            
            
            
            idx = np.random.choice(len(pi), p=pi)
            row = int(idx/canonical_board.shape[0])
            col = np.mod(idx, canonical_board.shape[0]) 
            
            board, self.current_player = self.game.getNextState(board, self.current_player, row, col)
            res = self.game.isEnded(board, self.current_player)
            
            if res != 0:
                print("in Coach.executeEpisode, total steps taken %d" % episode_count)
                #       board   pi    val = 1 if current player wins otherwise -1
                return [ (x[0], x[2], res) for x in trainEg ]
            
            
            
    def learn(self):
        for i in range(1, self.args['COACH_nb_iter'] + 1):   # number of times to update the neural network
            print("-------- iteration number %d --------" % i)
            if not self.skip_first_self_play or i > 1:
                iterationEg = deque([], maxlen=self.args['COACH_max_length_of_queue'])
                for eps in range(self.args['COACH_nb_episode']):   # number of self-played games, used as training data
                    print("in Coach.learn, episode %d" % (eps + 1))
                    
                    # for each game, initialise a new search tree
                    self.mcts = MCTS(self.game, self.wrapped_nnet, self.args)
                    iterationEg += self.executeEpisode()
                    
                self.trainExamplesHistory.append(iterationEg)
            
            while len(self.trainExamplesHistory) > self.args['COACH_nb_iters_for_training_history']:
                print("length of trainExamplesHistory: %d, remove the oldest training examples." % len(self.trainExamplesHistory))
                self.trainExamplesHistory.pop(0)
                
            self.saveTrainExamples(i-1)
            
            # merge generated examples in to a list (trainEg) and shuffle it
            trainEg = []
            for eg in self.trainExamplesHistory:
                trainEg.extend(eg)
            shuffle(trainEg)
            
            # make a copy of the model that is used to predict probabilities and value in MCTS before training it
            self.wrapped_nnet.saveCheckpoint(folder=self.args['NNET_checkpoint'], filename='temp.ckpt')
            self.wrapped_prev_net.loadCheckpoint(folder=self.args['NNET_checkpoint'], filename='temp.ckpt')
            
            
            # train the model
            print("training the neural network...")
            self.wrapped_nnet.train(trainEg)
            
            
            
            print("Pit against the old neural network")
            # new network as player 2
            prev_mcts = MCTS(self.game, self.wrapped_prev_net, self.args)
            new_mcts = MCTS(self.game, self.wrapped_nnet, self.args)
            arena1st = Arena(lambda x: np.argmax(prev_mcts.getActionSpaceProba(x, temperature=0, add_dirichlet_noise=True)),
                             lambda x: np.argmax(new_mcts.getActionSpaceProba(x, temperature=0, add_dirichlet_noise=True)),
                             game=self.game)
            prev_wins_1st, new_wins_1st, draws_1st = arena1st.playGames(self.args['ARENA_nb_games'], verbose=False)
            
            # new network as player 1
            prev_mcts = MCTS(self.game, self.wrapped_prev_net, self.args)
            new_mcts = MCTS(self.game, self.wrapped_nnet, self.args)
            arena2nd = Arena(lambda x: np.argmax(new_mcts.getActionSpaceProba(x, temperature=0, add_dirichlet_noise=True)),
                             lambda x: np.argmax(prev_mcts.getActionSpaceProba(x, temperature=0, add_dirichlet_noise=True)),
                             game=self.game)
            new_wins_2nd, prev_wins_2nd, draws_2nd = arena2nd.playGames(self.args['ARENA_nb_games'], verbose=False)
            
            prev_wins = prev_wins_1st + prev_wins_2nd
            new_wins = new_wins_1st + new_wins_2nd
            print("previous model won %d games, new model won %d games, total game played: %d" % 
                  (prev_wins, new_wins, 
                   prev_wins + new_wins + draws_1st + draws_2nd))
            
            # if the new player is not much stronger, do not update
            if (prev_wins + new_wins == 0) or (new_wins / (prev_wins + new_wins)  < self.args['COACH_update_threshold']):
                print("Reject new model")
                self.wrapped_nnet.loadCheckpoint(folder=self.args['NNET_checkpoint'], filename='temp.ckpt')
            else:
                print("Model updated")
                self.wrapped_nnet.saveCheckpoint(folder=self.args['NNET_checkpoint'], filename=self.getCheckpointFile(i))
                self.wrapped_nnet.saveCheckpoint(folder=self.args['NNET_checkpoint'], filename='best.ckpt')
                self.saveTrainExamples(99999) # 99999 for the best examples
    
   
    def saveTrainExamples(self, iteration):
        folder = self.args['COACH_load_folder_file']
        if not os.path.exists(folder):
            os.makedirs(folder)
        filename = os.path.join(folder, self.getCheckpointFile(iteration)+".examples")
        with open(filename, "wb+") as f:
            Pickler(f).dump(self.trainExamplesHistory)
        f.closed
                                   
                                   
    def getCheckpointFile(self, iteration):
        return "check_pt_%d.ckpt" % iteration
    
    def loadTrainExamples(self):
        
        modelFile = os.path.join(self.args['COACH_load_folder_file'], self.getCheckpointFile(99999)) 
        examplesFile = modelFile + ".examples"
        if not os.path.isfile(examplesFile):
            print(examplesFile)
            r = input("File with trainExamples not found. Continue? [y|n]")
            if r != "y":
                sys.exit()
        else:
            print("File with trainExamples found")
            with open(examplesFile, "rb") as f:
                self.trainExamplesHistory = Unpickler(f).load()
            f.closed
            # examples based on the model were already collected (loaded)
            self.skip_first_self_play = True
    
    
    
class NNet():
    def __init__(self, game, args):
        self.board_x = game._base_board.nb_rows
        self.board_y = game._base_board.nb_cols
        
        self.action_size = game.getActionSpaceSize()
        self.args = args
        
        Relu = tf.nn.relu
        Tanh = tf.nn.tanh
        Batchnormalisation = tf.layers.batch_normalization
        Dropout = tf.layers.dropout
        Dense = tf.layers.dense
        
        self.graph = tf.Graph()
        
        with self.graph.as_default():
            self.input_board = tf.placeholder(tf.float32, shape=[None, self.board_x, self.board_y])
            self.dropout = tf.placeholder(tf.float32)
            self.is_training = tf.placeholder(tf.bool, name="is_training")
            
            x_image = tf.reshape(self.input_board, [-1, self.board_x, self.board_y, 1])
            
            h_conv1 = Relu(Batchnormalisation(self.conv2d(x_image, args['NNET_nb_channels'], [4,4], 'same'), axis=3, training=self.is_training))
            h_conv2 = Relu(Batchnormalisation(self.conv2d(h_conv1, args['NNET_nb_channels'], [4,4], 'same'), axis=3, training=self.is_training))
            
            h_conv3 = Relu(Batchnormalisation(self.conv2d(h_conv2, args['NNET_nb_channels'], [3,3], 'valid'), axis=3, training=self.is_training))
            h_conv4 = Relu(Batchnormalisation(self.conv2d(h_conv3, int(args['NNET_nb_channels']/2), [3,3], 'valid'), axis=3, training=self.is_training))
            h_conv5 = Relu(Batchnormalisation(self.conv2d(h_conv4, int(args['NNET_nb_channels']/4), [2,2], 'valid'), axis=3, training=self.is_training))
            
            h_conv5_flat = tf.reshape(h_conv5, [-1, int(args['NNET_nb_channels']/4) * (self.board_x - 5) * (self.board_y - 5)])
            
            s_fc1 = Dropout(Relu(Batchnormalisation(Dense(h_conv5_flat, 900), axis=1, training=self.is_training)), rate=self.dropout)
            s_fc2 = Dropout(Relu(Batchnormalisation(Dense(s_fc1, 450), axis=1, training=self.is_training)), rate=self.dropout)
            
            self.pi = Dense(s_fc2, self.action_size)
            self.proba = tf.nn.softmax(self.pi)
            self.val = Tanh(Dense(s_fc2, 1))
            
            self.calculateLoss()
            
    def conv2d(self, x, out_channels, kernel_size, padding):
        return tf.layers.conv2d(x, out_channels, kernel_size=kernel_size, padding=padding,
                                kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False))

    def calculateLoss(self):
        self.target_pi = tf.placeholder(tf.float32, shape=[None, self.action_size])
        self.target_val = tf.placeholder(tf.float32, shape=[None])
        
        self.loss_pi = tf.losses.softmax_cross_entropy(onehot_labels=self.target_pi, logits=self.pi) 
        self.loss_val = tf.losses.mean_squared_error(labels=self.target_val, predictions=tf.reshape(self.val, [-1,]))
        
        self.total_loss = self.loss_pi + self.loss_val
        
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.train_step = tf.train.AdamOptimizer(self.args['NNET_lr']).minimize(self.total_loss)
            
            
            
            
            
class NNetWrapper():
    def __init__(self, game):
        self.nnet = NNet(game, args)
        self.game = game
        self.board_x = game._base_board.nb_rows
        self.board_y = game._base_board.nb_cols
        self.nb_valid_moves = game.getActionSpaceSize()
        
        self.sess = tf.Session(graph=self.nnet.graph)
        self.saver = None
        
        with tf.Session() as temp_sess:
            temp_sess.run(tf.global_variables_initializer())
        self.sess.run(tf.variables_initializer(self.nnet.graph.get_collection('variables')))
        
    def train(self, examples):
        for epoch in range(self.nnet.args['NNET_epochs']):
            batch_idx = 0
            pi_losses = []
            val_losses = []
            
            
            while batch_idx < len(examples)/self.nnet.args['NNET_batch_size']:
                sample_idx = np.random.randint(len(examples), size=self.nnet.args['NNET_batch_size'])               
                board_samples, pi_samples, val_samples = list(zip(*[examples[i] for i in sample_idx]))
                
                input_dict = { 
                               self.nnet.input_board : board_samples,
                               self.nnet.target_pi   : pi_samples,
                               self.nnet.target_val  : val_samples,
                               self.nnet.dropout     : self.nnet.args['NNET_dropout'],
                               self.nnet.is_training : True
                             }
                
                self.sess.run(self.nnet.train_step, feed_dict=input_dict)
            
                pi_loss = self.sess.run(self.nnet.loss_pi, feed_dict=input_dict) 
                val_loss = self.sess.run(self.nnet.loss_val, feed_dict=input_dict)
                
                pi_losses.append(pi_loss)
                val_losses.append(val_loss)
                
                batch_idx += 1
            
            #'''
            tmp_pi_losses = []
            tmp_val_losses = []
            tmp_pi_losses += pi_losses
            tmp_val_losses += val_losses
            if np.mod(epoch+1,5) == 0:
                plt.subplots(figsize=(20,8))
                plt.subplot(1,2,1)
                plt.plot(np.arange(len(tmp_pi_losses)), tmp_pi_losses, '-o')
                plt.xlabel("iterations")
                plt.ylabel("soft max cross entropy loss")
                
                plt.subplot(1,2,2)
                plt.plot(np.arange(len(tmp_val_losses)), tmp_val_losses, '-*')
                plt.xlabel("iterations")
                plt.ylabel("MSE value loss")
                plt.show()
                
                tmp_pi_losses = []
                tmp_val_losses = []
            #'''

    def predict(self, board):
        board = board[np.newaxis, :, :]
        input_dict = {
                        self.nnet.input_board : board,
                        self.nnet.dropout     : 0,
                        self.nnet.is_training : False
                     }
        proba = self.sess.run(self.nnet.proba, feed_dict=input_dict)
        val = self.sess.run(self.nnet.val, feed_dict=input_dict)
        
        return proba[0], val[0]
    
    
    def saveCheckpoint(self, folder='checkpoint', filename='checkpoint.ckpt'):
        filepath = os.path.join(folder, filename)
        if not os.path.exists(folder):
            print("Checkpoint Directory does not exist! Making directory {}".format(folder))
            os.mkdir(folder)
        
        if self.saver == None:
            self.saver = tf.train.Saver(self.nnet.graph.get_collection('variables'))
        with self.nnet.graph.as_default():
            self.saver.save(self.sess, filepath)

    def loadCheckpoint(self, folder='checkpoint', filename='checkpoint.ckpt'):
        filepath = os.path.join(folder, filename)
        
        if not os.path.exists(filepath + '.meta'):
            raise ValueError("No model in path {}".format(filepath))
        with self.nnet.graph.as_default():
            self.saver = tf.train.Saver()
            self.saver.restore(self.sess, filepath)
            
            
            
            
class HumanPlayer():
    def __init__(self, game):
        self.game = game
        
    def play(self, board):
        valid_moves = self.game.getValidMoves(board)
        row, col = -1, -1
        while True:
            resign = input("resign (y|n)? ")
            if resign == 'y':
                return row, col, True
            
            row = int(input("row number? "))
            col = int(input("column number? "))
            
            if valid_moves[row-1, col-1]:
                break
            else:
                print("Invalid move")
        return row-1, col-1, False
      
      
      
      
      
class MachinePlayer():
    def __init__(self, game, args):
        self.game = game
        self.wrapped_nnet = NNetWrapper(game)
        self.args = args
        self.wrapped_nnet.loadCheckpoint(folder=self.args['NNET_checkpoint'], filename='best.ckpt')
        
    def play(self, board):
        """ return the best neural network's prediction """
        row, col = -1, -1
        mcts = MCTS(self.game, self.wrapped_nnet, self.args)
        pi = mcts.getActionSpaceProba(board, temperature=0)
        if pi.max() > 0:
            pos = np.argmax(pi)
            row = int(pos / board.shape[0])
            col = np.mod(pos, board.shape[1])
        
        return row, col

      

class Arena():
    def __init__(self, player1, player2, game):
        self.player1 = player1 # function that takes board as input and outputs an action
        self.player2 = player2 # function that takes board as input and outputs an action
        self.game = game       # game obj
        
    
    def humanVsMachine(self):
        # players[0] = player2, computer; players[1] = player1, human
        # player1 is numerated as 1, player2 is numerated as -1
        players = [self.player2, self.player1]
        current_player = 1
        board = self.game.getInitBoard()
        
        while self.game.isEnded(board, current_player) == 0:
            # human player's turn
            if current_player == 1:
                print("your turn: ")
                displayBoard(board)
                row, col, resign = players[1](board)
                
                if resign:
                    print("\n\n\n\nGame over")
                    print("Human resigns")
                    break
                    
            # computer's turn
            else:
                print("computer's turn:")
                row, col = players[0](self.game.getCanonicalForm(board, current_player))
                
                if row < 0 or col < 0:
                    print("Computer resigns")
                    break
                
            valid_moves = self.game.getValidMoves(board)        
            if valid_moves[row, col]:
                print("action taken by player %d: (%d, %d)" % (current_player, row+1, col+1))
                board, current_player = self.game.getNextState(board, current_player, row, col)
            
            else:
                print("\n\n\ninvalid move: (:d, :d), \ncheck the neural net or game settings.\n\n\n" % (row, col))
                break
        print("\n\n\n\nGame over, result:")
        displayBoard(board)
        print("Winner: %d" % (2 * int(self.game.isEnded(board, 1) == 1) - 1))

        
    def playGame(self, verbose=False):
        # players[0] = player2, players[1] = player1
        # player1 is numerated as 1, player2 is numerated as -1
        players = [self.player2, self.player1]
        current_player = 1
        nb_iter = 0
        board = self.game.getInitBoard()
        
        while self.game.isEnded(board, current_player) == 0:
            nb_iter += 1
            
            if verbose:
                print("\nTurn #%d, current player: %d" % (nb_iter, current_player))
                displayBoard(board)
                

            #           max is here to make sure when player = -1, we take the 0th element of the list "players"
            player_fct = players[max(current_player, 0)]
            pos = player_fct(self.game.getCanonicalForm(board, current_player))
            row, col = int(pos/board.shape[0]), np.mod(pos, board.shape[1])
            valid_moves = self.game.getValidMoves(board)
            if valid_moves[row, col]:
                print("action taken by player %d: (%d, %d)" % (current_player, row+1, col+1))
            board, current_player = self.game.getNextState(board, current_player, row, col)
            
        
        print("Game over, result:")
        print("Winner: %d" % self.game.isEnded(board, 1))
        displayBoard(board)
       
        return self.game.isEnded(board, 1)
    
    def playGames(self, nb_games, verbose=False):
        p1won = 0
        p2won = 0
        draws = 0
        eps = 0
        
        for game in range(nb_games):
            print("in Arena.playGames, playing... round %d" % (game + 1))
            res = self.playGame(verbose=verbose)
            if res == 1:
                p1won += 1
            elif res == -1:
                p2won += 1
            else:
                draws += 1
        
        return p1won, p2won, draws
         
      
      
    
def displayBoard(stone_matrix):
    # player 1: black stone, player -1: white stone
    piece = {0: " ", 1: u"\u25cf", -1: u"\u25cb"}

    # define board's header: "  1 2 3 ... 15"
    header = "    {0}".format("   ".join(str(i) for i in range(1, 10)))
    header += "  {0}".format("  ".join(str(i) for i in range(10, stone_matrix.shape[1] + 1)))
    
    # define board's bar: "  +-----...----+"
    bar = "  +{0}+".format("-"*(4*stone_matrix.shape[1]-1))
    
    # define board's rows
    row_file = [str(i+1) + u" | {0} |".format(u" | ".join(piece[x] for x in row)) 
                for i,row in enumerate(stone_matrix[:9,:])]
    row_file += [str(i+10) + u"| {0} |".format(u" | ".join(piece[x] for x in row)) 
                 for i,row in enumerate(stone_matrix[9:,:])]
    
    # assemble up
    board = ("\n" + bar + "\n").join(row_file)
    board = u"\n".join((header, bar, board, bar, header))

    print(board)
    

In [9]:
args = {'MCTS_C': 2,                  
        'MCTS_nb_simulations': 500, 
        
        'COACH_temperature_threshold': 30, 
        'COACH_update_threshold': 0.55,
        'COACH_max_length_of_queue': 2000000,
        'COACH_load_folder_file': './gomoku0_v1/examples/',
        'COACH_nb_iters_for_training_history': 12, 
        'COACH_nb_episode': 8,                   
        'COACH_nb_iter': 300,  
        
        'NNET_lr': 0.003,
        'NNET_dropout': 0.3,
        'NNET_batch_size': 256,                   
        'NNET_nb_channels': 512,
        'NNET_checkpoint': './gomoku0_v1/',
        'NNET_epochs': 15, 
        
        'ARENA_nb_games': 5,     
        'load_model': True
       }

In [None]:
g = Game(nb_rows=10, nb_cols=10)
nnet = NNetWrapper(g)

if args['load_model']:
    nnet.loadCheckpoint(folder=args['NNET_checkpoint'], filename='best.ckpt')
    
c = Coach(g, args) 
if args['load_model']:
    print("Load trainExamples from file")
    c.loadTrainExamples()
c.learn()

INFO:tensorflow:Restoring parameters from ./gomoku0_v1/best.ckpt
Load trainExamples from file
File with trainExamples found
-------- iteration number 1 --------
INFO:tensorflow:Restoring parameters from ./gomoku0_v1/temp.ckpt
training the neural network...


In [11]:
board = Board(nb_rows=10, nb_cols=10)
renju = Game(nb_rows=10, nb_cols=10)
me = HumanPlayer(renju)
computer = MachinePlayer(renju,args)
fct_me = (lambda x: me.play(x))
fct_pc = (lambda x: computer.play(x))
pit = Arena(fct_me, fct_pc, renju)
pit.humanVsMachine()

INFO:tensorflow:Restoring parameters from ./gomoku0_v1/best.ckpt
your turn: 
    1   2   3   4   5   6   7   8   9  10
  +---------------------------------------+
1 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
2 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
3 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
4 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
5 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
6 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
7 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
8 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
9 |   |   |   |   |   |   |   |   |   |   |
  +---------------------------------------+
10|   |   |   |   |   |   |   |   |   |   |
 