In [1]:
%pip install numpy

Note: you may need to restart the kernel to use updated packages.


In [2]:
import numpy as np
np.__version__
import math

In [3]:
class TicTacToe:
    def __init__(self):
        self.row_count: int = 3
        self.column_count: int = 3
        self.action_size: int = self.row_count * self.column_count
        
    def get_initial_state(self) -> np.ndarray:
        return np.zeros((self.row_count, self.column_count))
    
    def get_row_column(self, action: int) -> tuple[int, int]:
        row = action // self.column_count
        column = action % self.column_count
        return row, column
    
    def get_next_state(self, state: np.ndarray, action: int, player: int) -> np.ndarray:
        '''state after action is played'''
        row, column = self.get_row_column(action)
        state[row, column] = player
        return state
    
    def get_valid_moves(self, state: np.ndarray) -> (np.uint8):
        # state as 1 dimension vector, then convert boolean to 0(false) or 1(true)
        return (state.reshape(-1) == 0).astype(np.uint8)
    
    def check_win(self, state: np.ndarray, action: int|None) -> bool:
        if action == None:
            return False
        
        row, column = self.get_row_column(action)
        player = state[row, column]
        
        return (
            np.sum(state[row, :]) == player * self.column_count
            or np.sum(state[:, column]) == player * self.row_count
            or np.sum(np.diag(state)) == player * self.row_count
            or np.sum(np.diag(np.flip(state, axis=0))) == player * self.row_count
        )
    
    
    def get_value_and_terminated(self, state: np.ndarray, action: int|None) -> tuple[int, bool]:
        '''value is 1 if win, 0 otherwise. Node is terminal if a player won or not valid moves'''
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player: int) -> int:
        return -player
    
    def get_opponent_value(self, value: int) -> int:
        return -value
    
    def change_perspective(self, state: np.ndarray, player: int) -> np.ndarray:
        '''state is positive if player 1, negative otherwise'''
        return state * player

In [6]:
from __future__ import Node

class Node:
    def __init__(self, game: TicTacToe, args, state: np.ndarray, parent=None, action_taken=None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        
        self.children = []
        self.expandable_moves = game.get_valid_moves(state)
        
        self.visit_count = 0
        self.value_sum = 0
        
    def is_fully_expanded(self):
        '''In the state, there's no more moves to explore: there are children but no unexplored moves'''
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0
    
    def select(self):
        '''Select node with best ucb'''
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child: Node):
        '''Score interest of a node'''
        # exploitation:
        # q_value increase if child has high value compared to visits
        # the next state is opponent, so we do 1-qvalue
        # exploration:
        # C increase exploration
        # we explore if we have low exploration on child compared to current
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)
    
    def expand(self) -> Node:
        '''Define next node choosen randomly from current'''
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0
        
        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, 1)
        child_state = self.game.change_perspective(child_state, player=-1)
        
        child = Node(self.game, self.args, child_state, self, action)
        self.children.append(child)
        return child
    
    def simulate(self) -> int:
        '''return value of a terminal node, otherwise move forward randomly and check again if terminal node'''
        value, is_terminal = self.game.get_value_and_terminated(self.state, self.action_taken)
        value = self.game.get_opponent_value(value)
        
        if is_terminal:
            return value
        
        #we go for a random move until we find a terminal node
        rollout_state = self.state.copy()
        rollout_player = 1
        while True:
            valid_moves = self.game.get_valid_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves == 1)[0])
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
            value, is_terminal = self.game.get_value_and_terminated(rollout_state, action)
            if is_terminal:
                if rollout_player == -1:
                    value = self.game.get_opponent_value(value)
                return value    
            
            rollout_player = self.game.get_opponent(rollout_player)
            
    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1
        
        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)  


class MCTS:
    def __init__(self, game: TicTacToe, args):
        self.game = game
        self.args = args
        
    def search(self, state):
        root = Node(self.game, self.args, state)
        
        for _ in range(self.args['num_searches']):
            node = root
            
            # we go forward on best ucb until a node is expandable
            while node.is_fully_expanded():
                node = node.select()
                
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)
            
            if not is_terminal:
                node = node.expand()
                value = node.simulate()
                
            node.backpropagate(value)    
            
            
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs
        
        
        

In [5]:
tictactoe = TicTacToe()
player = 1

args = {
    'C': 1.41,
    'num_searches': 1000
}

mcts = MCTS(tictactoe, args)

state = tictactoe.get_initial_state()


while True:
    print(state)
    
    if player == 1:
        valid_moves = tictactoe.get_valid_moves(state)
        print("valid_moves", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}:"))

        if valid_moves[action] == 0:
            print("action not valid")
            continue
            
    else:
        neutral_state = tictactoe.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)
        
    state = tictactoe.get_next_state(state, action, player)
    
    value, is_terminal = tictactoe.get_value_and_terminated(state, action)
    
    if is_terminal:
        print(state)
        if value == 1:
            print(player, "won")
        else:
            print("draw")
        break
        
    player = tictactoe.get_opponent(player)



[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid_moves [0, 1, 2, 3, 4, 5, 6, 7, 8]
[[0. 1. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[ 0.  1.  0.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]
valid_moves [0, 2, 3, 5, 6, 7, 8]
[[ 0.  1.  0.]
 [ 1. -1.  0.]
 [ 0.  0.  0.]]
[[-1.  1.  0.]
 [ 1. -1.  0.]
 [ 0.  0.  0.]]
valid_moves [2, 5, 6, 7, 8]
action not valid
[[-1.  1.  0.]
 [ 1. -1.  0.]
 [ 0.  0.  0.]]
valid_moves [2, 5, 6, 7, 8]
[[-1.  1.  0.]
 [ 1. -1.  0.]
 [ 0.  0.  1.]]
[[-1.  1. -1.]
 [ 1. -1.  0.]
 [ 0.  0.  1.]]
valid_moves [5, 6, 7]
[[-1.  1. -1.]
 [ 1. -1.  0.]
 [ 1.  0.  1.]]
[[-1.  1. -1.]
 [ 1. -1.  0.]
 [ 1. -1.  1.]]
valid_moves [5]
action not valid
[[-1.  1. -1.]
 [ 1. -1.  0.]
 [ 1. -1.  1.]]
valid_moves [5]
[[-1.  1. -1.]
 [ 1. -1.  1.]
 [ 1. -1.  1.]]
draw
