In [1]:
import numpy as np

class Move:
    def __init__(self, x, y, value):
        self.x = x
        self.y = y
        self.value = value
        
    def __repr__(self):
        return "x:{0} y:{1} v:{2}".format(self.x, self.y, self.value)

In [2]:
# Keeps track of a particular game state
class GameState:
    x = 1
    o = -1
    def __init__(self, state, next_to_move=1):
        self.board = state
        self.board_size = state.shape[0]
        self.next_to_move = next_to_move
        
    # Returns result of this game state
    def get_result(self):
        row_sum = np.sum(self.board, 0)
        col_sum = np.sum(self.board, 1)
        diag_sum1 = self.board.trace()
        diag_sum2 = self.board[::-1].trace()
        
        cond1 = any(row_sum == self.board_size)
        cond2 = any(col_sum == self.board_size)
        cond3 = (diag_sum1 == self.board_size)
        cond4 = (diag_sum2 == self.board_size)
        
        # Player 1 wins (x)
        if cond1 or cond2 or cond3 or cond4:
            return self.x
        
        cond1 = any(row_sum == -self.board_size)
        cond2 = any(col_sum == -self.board_size)
        cond3 = (diag_sum1 == -self.board_size)
        cond4 = (diag_sum2 == -self.board_size)
        
        # Player 2 wins (0)
        if cond1 or cond2 or cond3 or cond4:
            return self.o
        
        
        # Tie
        if np.all(self.board != 0):
            return 0
        
        # Game not over
        return None
    
    def is_game_over(self):
        return self.get_result() is not None
    
    def is_move_allowed(self, move):
        # If it is not the current players turn, move is not allowed
        if move.value != self.next_to_move:
            return False
        
        # If move is outside of board, move is not allowed
        if ((not (0 <= move.x < self.board_size)) or
            (not (0 <= move.y < self.board_size))):
            return False
        
        # If (x, y) is not occupied, it is allowed, otherwise not.
        return self.board[move.x, move.y] == 0
        
        
    def move(self, move):
        if not self.is_move_allowed(move):
            raise ValueError("Move {0} on board {1} is not allowed.".format(move, self.board))

        new_board = np.copy(self.board)
        new_board[move.x, move.y] = move.value
            
        return GameState(new_board, self.next_to_move*-1)
        
    # Returns a list of all allowed moves based on current state
    def get_allowed_moves(self):
        indices = np.where(self.board == 0)
        return [
            Move(c[0], c[1], self.next_to_move)
            for c in list(zip(indices[0], indices[1]))
        ]
    

In [3]:
from collections import defaultdict

# A node in the MCTS search tree. It corresponds to a particular game state. 
class MCTSNode:
    def __init__(self, state, parent=None):
        self.state = state # The game state it corresponds to
        self.parent = parent # Parent node in the MCTS
        self.children = [] # Children nodes in the MCTS
        self.results = defaultdict(int) # Statistics stored in this node
        self.n_visits = 0 # Number of visists to this node
        self._untried_moves = None
    
    # Return untried moves (i.e., which of the possible moves from the current state that have NOT been visisted yet)
    def untried_moves(self):
        if self._untried_moves is None:
            self._untried_moves = self.state.get_allowed_moves()
        return self._untried_moves
    
    def Q(self):
        n_wins = self.results[self.parent.state.next_to_move]
        n_losses = self.results[-1 * self.parent.state.next_to_move]
        return n_wins - n_losses
     
    def N(self):
        return self.n_visits
    
    # Expand by simply picking the first child node.
    def expand(self):
        action = self._untried_moves.pop() # Remove from list of untried moves
        next_state = self.state.move(action)
        child = MCTSNode(next_state, parent=self)
        self.children.append(child)
        return child
    
    def is_terminal_node(self):
        return self.state.is_game_over()
    
    def rollout(self):
        current_rollout_state = self.state
        while not current_rollout_state.is_game_over():
            possible_moves = current_rollout_state.get_allowed_moves()
            action = self.rollout_policy(possible_moves)
            current_rollout_state = current_rollout_state.move(action)
        return current_rollout_state.get_result()
            
    
    def backpropagate(self, result):
        self.n_visits += 1
        self.results[result] += 1
        if self.parent:
            self.parent.backpropagate(result)
    
    def is_fully_expanded(self):
        return len(self.untried_moves()) == 0
    
    def best_child(self, c_param=1.4):
        choices_weights = [
            (c.Q() / c.N()) + c_param * np.sqrt((2*np.log(self.N()) / c.N()))
            for c in self.children
        ]
        return self.children[np.argmax(choices_weights)]
        
    def rollout_policy(self, possible_moves):
        return possible_moves[np.random.randint(len(possible_moves))]

In [4]:
class MCTS:
    def __init__(self, node):
        self.root = node
        
    def best_action(self, n_simulations):
        for _ in range(0, n_simulations):
            v = self.tree_policy()
            reward = v.rollout()
            v.backpropagate(reward)
        
        return self.root.best_child(c_param=0.)
    
    def tree_policy(self):
        current_node = self.root
        while not current_node.is_terminal_node():
            if not current_node.is_fully_expanded():
                return current_node.expand()
            else:
                current_node = current_node.best_child()
        return current_node

In [78]:
import time

start = time.perf_counter()
initial_board = np.zeros((3,3))
next_to_move = 1
current_state = GameState(state = initial_board, next_to_move=next_to_move)
while not current_state.is_game_over():
    if next_to_move == -1:
        root = MCTSNode(state = current_state)
        mcts = MCTS(root)
        best_node = mcts.best_action(1000)
        current_state = best_node.state
        print("MCTS1 move:")
        print(current_state.board)
    else:
        root = MCTSNode(state = current_state)
        mcts = MCTS(root)
        best_node = mcts.best_action(1000)
        current_state = best_node.state
        print("MCTS2 move:")
        print(current_state.board)
    next_to_move *= -1
result = current_state.get_result()
print("Game finished")
if result == -1:
    print("MCTS1 won.")
elif result == 1:
    print("MCTS2 won.")
else:
    print("Tie.")
    
end = time.perf_counter()
print("Execution time: {:.2f} seconds.".format(end-start))

MCTS2 move:
[[0. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]
MCTS1 move:
[[ 0.  0.  0.]
 [ 0.  1.  0.]
 [-1.  0.  0.]]
MCTS2 move:
[[ 0.  0.  0.]
 [ 0.  1.  0.]
 [-1.  1.  0.]]
MCTS1 move:
[[ 0. -1.  0.]
 [ 0.  1.  0.]
 [-1.  1.  0.]]
MCTS2 move:
[[ 0. -1.  0.]
 [ 1.  1.  0.]
 [-1.  1.  0.]]
MCTS1 move:
[[ 0. -1.  0.]
 [ 1.  1. -1.]
 [-1.  1.  0.]]
MCTS2 move:
[[ 0. -1.  0.]
 [ 1.  1. -1.]
 [-1.  1.  1.]]
MCTS1 move:
[[-1. -1.  0.]
 [ 1.  1. -1.]
 [-1.  1.  1.]]
MCTS2 move:
[[-1. -1.  1.]
 [ 1.  1. -1.]
 [-1.  1.  1.]]
Game finished
Tie.
Execution time: 3.00 seconds.


In [63]:
initial_board = np.zeros((3,3))
next_to_move = -1
current_state = GameState(state = initial_board, next_to_move=next_to_move)
while not current_state.is_game_over():
    if next_to_move == -1:
        while True:
            user_input = input()
            coords = user_input.strip().split(sep=",")
            x = int(coords[0])
            y = int(coords[1])
            try:
                current_state = current_state.move(Move(x, y, next_to_move))
                break
            except ValueError:
                print("Invalid move, try again.")
            

        print("User move:")
        print(current_state.board)
            
    else:
        root = MCTSNode(state = current_state)
        mcts = MCTS(root)
        best_node = mcts.best_action(1000)
        current_state = best_node.state
        print("MCTS move:")
        print(current_state.board)
    next_to_move *= -1
result = current_state.get_result()
print("Game finished")
if result == -1:
    print("You won.")
elif result == 1:
    print("MCTS won.")
else:
    print("Tie.")

 0,0


User move:
[[-1.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]]
MCTS move:
[[-1.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  1.  0.  0.]]


KeyboardInterrupt: 