### MTCS tree search
1. Selection
2. Expansion
3. Simulation
4. Backpropagation
# TicTacToe
1. Play
2. CheckWin
3. Get possible move

In [1]:
import numpy as np
import math
import random
from collections import deque

In [173]:
PLAYER_X = 'X'
PLAYER_O = 'O'
EMPTY_SPOT = '_'
DRAW_MARKER = 'DRAW'

class TicTacToe: 
    def __init__(self, board_size): 
        """ 
        Initializes the Tic-Tac-Toe game. 
        """ 
        self.board_size = board_size
        self.current_player = PLAYER_X
        self.board = np.array([[EMPTY_SPOT]*3 for _ in range(3)])
        self.winner = None 

    # This modifies the self.board in-place but in MCST we need a method to create a new game state without altering the original.

    def copy(self): 
        """
        Creates a deep copy of the current game state. 

        Returns: 
            TicTacToe: A new TicTacToe object with the same state as the current one
        """ 
        new_game = TicTacToe(self.board_size)
        new_game.board = np.copy(self.board)
        new_game.current_player = self.current_player
        new_game.winner = self.winner
        return new_game

    def switch_player(self):
        if self.current_player == PLAYER_X: 
            self.current_player = PLAYER_O
        else: 
            self.current_player = PLAYER_X
            

    def _check_line_win(self, line): 
        """"
        A little helper function to check if all elements in a line are the same and not empty

        Args:
            line: A list or NumPy array representing a row, column, or diagonal.
    
        Returns:
            True if all elements in the line are the same and not equal to EMPTY_SPOT,
            False otherwise.
            
        Raises:
            TypeError: If the input 'line' is not a list or a NumPy array.
        """
        if not isinstance(line, (list, np.ndarray)):
            raise TypeError("Input 'line' must be a list or a NumPy array.")
            
        if len(line):
            return False
            
        first_element = line[0]
        if line[0] == EMPTY_SPOT: 
            return False
            
        try:
            return np.all(line == first_element)
        except TypeError as e:
            raise 

    def check_win_or_draw(self): 
        """
        Check if there is a winner of if the gaem is a draw. 
        Set self.winner if the game has concluded. 

        Returns: 
            str or None: The winner (PLAYER_X, PLAYER_O), DRAW_MARKER, or None if ongoing
        """ 

        # Check rows
        for i in range(self.board_size): 
            current_line = self.board[i,:]
            
                
            if self._check_line_win(self.board[i,:]): 
                self.winner = self.board[i,0]
                return self.winner

        # Check columns
        for i in range(self.board_size): 
            current_line = self.board[i,:]
                
            if self._check_line_win(self.board[:, i]): 
                self.winner = self.board[0, i]
                return self.winner

        # Check diagonals (top-left to bottom-right)
        if self._check_line_win(np.diag(self.board)):  
            self.winner = self.board[0][0]

        # Check antidiagonals (top-right to bottom-left)
        if self._check_line_win(np.diag(np.fliplr(self.board))): # np.fliplr reverses columns 
            self.winner = self.board[0][self.board_size]

        # Check draws: 
        if EMPTY_SPOT not in self.board: 
            self.winner = DRAW_MARKER
            return self.winner

        return None
        
    def is_game_over(self):
        """ 
        Checks if the game has ended 
        """ 
        return self.winner is not None
    
    def play(self, r, c): 
        '''Attempted to make a move. 
        Args: 
            r (int): The row index of the move
            c (int): The column index of the move

        Returns: 
            bool: True if the move was successful, False otherwise (e.g., the game has ended, spot taken, invalid coordinates)
        '''
        if self.is_game_over(): 
            return False

        if self.board[r][c] != EMPTY_SPOT: 
            return False

        if not (0 <= r <= self.board_size - 1 and 0 <= c <= self.board_size - 1): 
            return False

        self.board[r][c] = self.current_player

        self.check_win_or_draw() 

        if not self.is_game_over():
            self
        # Turn switching
        self.current_player = PLAYER_O if self.current_player == PLAYER_X else PLAYER_O

        return True
        
    def find_possible_moves(self): 
        """Find all empty spots on the board where a move can be made.

        Returns: 
            list of tuples: A list of (rows, cols) tuples representing possible moves. 
                            Return an empty list the game is over or no moves are possible. 
        """
        if self.is_game_over(): 
            return []
            
        possible_moves = []
        for r in range(self.board_size): 
            for c in range(self.board_size):
                if self.board[r][c] == EMPTY_SPOT:
                    possible_moves.append((r, c))
        return possible_moves
        
    def get_opponent(self):
        """
        Returns the opponent of the given player.

        Args:
            player (str): The player (PLAYER_X or PLAYER_O).

        Returns:
            str: The opponent player, or None if the input player is invalid.
        """
        if self.current_player == PLAYER_X:
            return PLAYER_O
        else: 
            return PLAYER_X
        

In [241]:
EXPLORATION_CONSTANT = math.sqrt(2)
class Node: 
    # A root node is represented as having no parent or parent == None
    def __init__(self, gameState, parent = None, move_that_led_to_this_node = None): 
        """
        Initialize a node in the MCTS.   

        Args: 
            gameState (TicTacToe): The game state that this node represents. 
            parent (Node, optional): The parent node. None for the root node 
            move_that_led_to_this_node(tuple, optional): The move (row, col) that led from the parent to this node. 
                                                          None for the root node. 
        """
        self.gameState = gameState 
        self.parent = parent 

        # MCTS statistics
        self.winCount = 0.0 # number of wins for the node considered. 1 or a win, -1 for a loss and 0.5 for a draw. 
        self.visitCount = 0 # number of times going through this node
        self.children = [] # List of child Node objects

        self.move_that_led_to_this_node = move_that_led_to_this_node
        self.possible_moves_from_the_state = self.gameState.find_possible_moves()
    def is_terminal_node(self):
        """
        Returns whether this is a terminal node in the game
        """
        return self.gameState.is_game_over()
        
    def is_fully_expanded(self): 
        """
        Checks if all possible child nodes from this node has been created. 

        Returns: 
            bool: True if fully expanded, False otherwise
        """ 
        return len(self.children) == len(self.possible_moves_from_the_state)
    
    def get_UBC1(self): 
        """
        Calculate the UBC1 (Upper Confidence Bound) score 1 for this node. 
        This score is used by the parent node to select which child to traverse.
        UCB1 = (wins/visits) + C * sqrt(log(parent_visits)/visits).
        """
        
        if self.visitCount == 0: 
            return float('inf') # prioritize unvisited node
        if self.parent is None: 
            return float('-inf') # either theres's an issue or this is the root node of the MCST

        exploitation_term = self.winCount/self.visitCount
        if self.visitCount == 0: 
            exploration_term = float('inf') # If parent has no visits, should definitely explore this node
        else: 
            exploration_term = EXPLORATION_CONSTANT * math.sqrt(math.log(self.parent.visitCount) / self.visitCount)
        return exploitation_term + exploration_term
        
    def select(self):
        """
        Selects the childnode with the highest UBC1 score. 
        Assume this node is not a teminal node and has children.
        """
        if not self.is_fully_expanded(): # # Always prioritize expanding unvisited node 
            return self
            
        best_child = None
        best_score = -float('inf')
        
        for child in self.children: 
            score = child.get_UBC1()
            if score > best_score: 
                best_score = score
                best_child = child 
        return best_child.select()
        
    def expand(self): 
        """
        Creates a child node from the selected node above and choose one of them. 
        If fully expanded or game over at this node, return None. 

        Returns: 
            Node or None: The new child node if expansion was successful, otherwise None
        """

        if self.is_fully_expanded(): 
            return None

        already_expanded_moves = {child.move_that_led_to_this_node for child in self.children} 
        untried_moves = [move for move in self.possible_moves_from_the_state if move not in already_expanded_moves]

        if not untried_moves: 
            return None 

        move_to_expand = random.choice(untried_moves) 

        new_game_state = self.gameState.copy()
        success = new_game_state.play(move_to_expand[0], move_to_expand[1]) 

        if not success: 
            print(f"Can't play the move {move_to_expand}. Can't expand the current node {self}")
            return None

        new_child_node = Node(
            gameState  = new_game_state, 
            parent = self, 
            move_that_led_to_this_node=move_to_expand
        )
        
        self.children.append(new_child_node) 
        return new_child_node
        
    def simulate(self): 
        """
        Complete one random playout from the current expanded node. 
        
        Returns:
            str: The winner of the simulation (PLAYER_X, PLAYER_O, or DRAW_MARKER).
        """
        sim_game_state = self.gameState.copy()
        while not sim_game_state.is_game_over(): 
            possible_moves = sim_game_state.find_possible_moves()
            if not possible_moves: 
                if gameStateCopy.winner is None: 
                    return DRAW_MARKER
                break
            chosen_move = random.choice(possible_moves)
            sim_game_state.play(chosen_move[0], chosen_move[1])
        return sim_game_state.winner
        
    def backpropagate(self, simulation_winner): 
        """
        Updates the visit counts and win counts from this node up to the root.
        """
        current_node = self
        while current_node is not None: 
            current_node.visitCount += 1
            score_for_this_node = 0.0

            if current_node.parent is None:
                player_who_just_reached_current_node = current_node.gameState.get_opponent()
                if simulation_winner == player_who_just_reached_current_node: 
                    score_for_this_node += 1.0 
                elif simulation_winner == DRAW_MARKER: 
                    score_for_this_node += 0.5
                else: 
                    score_for_this_node -= 1.0
                    
            current_node.winCount += score_for_this_node
            current_node = current_node.parent
    def print_node(self): 
        """
        Prints the statistic of the node, this includes: 
            current_player: X or O, 
            move: (row, col) or None in case of root node,
            winCount,
            visitCount,
            UBC1,
            is_terminal_node: whether this is a terminal node or not
        """
        print(f"Current player: {self.gameState.current_player} | Visits: {self.visitCount} | Wins: {self.winCount:.2f} | UCB1: {self.get_UBC1():.2f} | Move: {self.move_that_led_to_this_node} | End: {self.is_terminal_node()}")

In [243]:
class MCST: 
    def __init__(self, initial_game = None):
        if initial_game is None: 
            self.rootNode = Node(parent = None, gameState = TicTacToe(board_size = 3))
        else: 
            self.rootNode = Node(parent = None, gameState = initial_game)

    def update(self, iterations): 
        """
        Each time calling this, 
        """ 
        for _ in range(iterations): 
            leaf_node = self.rootNode.select() 

            if leaf_node.gameState.is_game_over():
                 simulation_winner = leaf_node.gameState.winner
                 leaf_node.backpropagate(simulation_winner) 
                 continue

            new_child_node = leaf_node.expand()

            if new_child_node is None:
                simulation_node = leaf_node
            else:
                simulation_node = new_child_node

            simulation_winner = simulation_node.simulate()
            simulation_node.backpropagate(simulation_winner) 
    
            
    def print_tree(self):
        """
        Performs a breadth-first traversal of the MCTS tree and prints each layer.
        For each node, prints its state and children.
        """
        if not self.rootNode:
            print("Tree is empty")
            return
            
        layer = 1
        queue = deque([self.rootNode])
        
        while queue:
            level_size = len(queue)
            print(f"\nLayer {layer}:")
            print("-" * 20)
            
            for _ in range(level_size):
                node = queue.popleft()
                print(f"\nNode:")
                node.print_node()
                print(f"Children: {len(node.children)}")
                for child in node.children:
                    queue.append(child)
            
            layer += 1            

In [245]:
tree = MCST() 
# We need to write a function that build all the node of a MCST and update the winCount and visitCount

In [247]:
tree.update(100)

In [249]:
tree.print_tree()


Layer 1:
--------------------

Node:
Current player: X | Visits: 100 | Wins: 50.00 | UCB1: -inf | Move: None | End: False
Children: 9

Layer 2:
--------------------

Node:
Current player: O | Visits: 12 | Wins: 0.00 | UCB1: 0.88 | Move: (0, 1) | End: False
Children: 8

Node:
Current player: O | Visits: 11 | Wins: 0.00 | UCB1: 0.92 | Move: (0, 0) | End: False
Children: 8

Node:
Current player: O | Visits: 11 | Wins: 0.00 | UCB1: 0.92 | Move: (1, 0) | End: False
Children: 8

Node:
Current player: O | Visits: 11 | Wins: 0.00 | UCB1: 0.92 | Move: (2, 1) | End: False
Children: 8

Node:
Current player: O | Visits: 11 | Wins: 0.00 | UCB1: 0.92 | Move: (0, 2) | End: False
Children: 8

Node:
Current player: O | Visits: 11 | Wins: 0.00 | UCB1: 0.92 | Move: (1, 2) | End: False
Children: 8

Node:
Current player: O | Visits: 11 | Wins: 0.00 | UCB1: 0.92 | Move: (1, 1) | End: False
Children: 8

Node:
Current player: O | Visits: 11 | Wins: 0.00 | UCB1: 0.92 | Move: (2, 0) | End: False
Children: 8

N

In [219]:
len(tree.rootNode.children)

1