# Monte-Carlo Tree Search

Monte Carlo Tree Search (MCTS) is a heuristic search algorithm that expands the search tree based on random sampling of the search space. The structure of the MCTS algorithm consists of iteratively running simulations with random choices from the presented state. At the end, each action is evaluated.

This algorithm has found wide application in the development of bots for various board games. In this lab, we will consider the MCTS algorithm using the example of the tic-tac-toe game.

In [310]:
from ipywidgets import widgets, HBox, VBox, Layout
from IPython.display import display
from functools import partial
import numpy as np
import numpy.typing as npt
import copy
import math
import time
import random

## Tic-Tac-Toe Domain Representation

In [311]:
# Game field constants
FIELD_SIZE = 5

IN_LINE_TO_WIN = 4

EMPTY_CELL = 0
X_CELL = 1
O_CELL = 2

# Game status constants
NOT_OVER = 0
X_WINS = 1
O_WINS = 2
DRAW = 3

In [312]:
class GameState:
    """
    Implements the state of a tic-tac-toe game.

    Attributes
    ----------
    cells : np.ndarray
        An array that contains information about the current game state.
    next_player : int
        The value to be written to a cell on the next game move.
    """

    def __init__(self):
        """Creates a new empty game field."""
        self.cells = np.full((FIELD_SIZE * FIELD_SIZE), EMPTY_CELL)
        self.next_player = X_CELL

    def __hash__(self) -> int:
        return hash(str(self.cells))

    def __eq__(self, other: "GameState") -> bool:
        return (self.cells == other.cells).all()

In [313]:
def check_line_result(line: npt.NDArray) -> int:
    """
    Checks whether a line contains a sequence of 'n' identical values in a row.

    Parameters
    ----------
    line : np.ndarray
        The line to be checked.

    Returns
    -------
    int
        A value that characterizes the current state of the line:
        - NOT_OVER = 0
        - X_WINS = 1
        - O_WINS = 2
        - DRAW = 3
    """
    x_count = 0
    o_count = 0
    not_over = False
    for cell in line:
        if cell == X_CELL:
            x_count += 1
            o_count = 0
            if x_count == IN_LINE_TO_WIN:
                return X_WINS
        elif cell == O_CELL:
            o_count += 1
            x_count = 0
            if o_count == IN_LINE_TO_WIN:
                return O_WINS
        else:
            x_count = 0
            o_count = 0
            not_over = True
    if not_over:
        return NOT_OVER
    return DRAW


def check_result(game_state: GameState) -> int:
    """
    Checks the state of the game for the existence of winning combinations.

    Parameters
    ----------
    game_state : GameState
        The state of the game to be checked.

    Returns
    -------
    int
        A value that characterizes the current state of the game:
        - NOT_OVER = 0
        - X_WINS = 1
        - O_WINS = 2
        - DRAW = 3
    """
    not_over = False
    cells_matrix = game_state.cells.reshape(FIELD_SIZE, FIELD_SIZE)
    lines_to_check = []
    for row in range(FIELD_SIZE):
        lines_to_check.append(cells_matrix[row, :])
    for column in range(FIELD_SIZE):
        lines_to_check.append(cells_matrix[:, column])
    
    for diag_num in range(FIELD_SIZE):
        diagonal1 = np.diag(cells_matrix, diag_num)
        diagonal2 = np.diag(np.fliplr(cells_matrix), diag_num)
        diagonal3 = np.diag(cells_matrix, -diag_num)
        diagonal4 = np.diag(np.fliplr(cells_matrix), -diag_num)
        if len(diagonal1) < IN_LINE_TO_WIN:
            break
        lines_to_check.append(diagonal1)
        lines_to_check.append(diagonal2)
        lines_to_check.append(diagonal3)
        lines_to_check.append(diagonal4)

    for line in lines_to_check:
        result = check_line_result(line)
        if result == NOT_OVER:
            not_over = True
        elif result != DRAW:
            return result

    if not_over:
        return NOT_OVER
    else:
        return DRAW


def get_possible_moves(game_state: GameState) -> npt.NDArray:
    """
    Returns a list of the possible moves.

    Parameters
    ----------
    game_state : GameState
        The state of the game to check.

    Returns
    -------
    np.ndarray
        An array of the possible moves, represented as positions of empty cells in the cell array
        (in other words, these positions can be numbers from 0 to 8).
    """
    return np.where(game_state.cells == EMPTY_CELL)[0]


def game_is_over(game_state: GameState) -> bool:
    """
    Checks whether the game is over (either a draw or a winning state has occurred).

    Parameters
    ----------
    game_state : GameState
        The state of the game to check.

    Returns
    -------
    bool
        True if the game is over, False otherwise.
    """
    return check_result(game_state) != NOT_OVER


def play_move(game_state: GameState, move: int):
    """
    Executes the move of the next player to the cell at the specified 'move' position.
    Note that information about the player (`X` or `O`) who is next to move is included within the game state.

    Parameters
    ----------
    game_state : GameState
        The current state of the game.
    move : int
        The position of the cell for the move.

    Returns
    -------
    GameState | None
        The new game state after the move, or None if the move is incorrect.
    """
    if game_state.cells[move] == EMPTY_CELL:
        next_state = copy.deepcopy(game_state)
        next_state.cells[move] = game_state.next_player
        next_state.next_player = O_CELL if game_state.next_player == X_CELL else X_CELL
        return next_state
    return None

## AI Policy (How AI will select a move) 

As previously mentioned, in this lab, you are tasked with developing a policy for playing the tic-tac-toe game. Your main task is to create a strategy that can effectively select the best move for a given game state using the `get_move` method.

In [314]:
class BasePolicy:
    def get_move(self, state: GameState) -> int:
        """
        Finds the best move for the AI player from the current state.

        Parameters
        ----------
        state : GameState
            The current state of the game.

        Returns
        -------
        int
            The position of the cell where the AI player should make its move. 
            This is represented as an integer corresponding to a cell position in the game grid.
        """
        raise NotImplementedError

A straightforward example of such a policy is the random selection of a move, as shown in the `RandomPolicy` class. This class implements a basic strategy where the move is chosen randomly from the available options in the current game state. 


In [315]:
class RandomPolicy:
    def get_move(self, state: GameState):
        """
        Returns a random move that is permissible from the current state.

        Parameters
        ----------
        state : GameState
            The current state of the game.

        Returns
        -------
        int
            A randomly selected permissible move, represented as an integer corresponding to a cell position in the game grid.
        """
        possible_moves = get_possible_moves(state)
        return random.choice(possible_moves)

## Tic-Tac-Toe GUI Implementation

To test the developed policy, we will utilize an implementation of the game that includes a graphical user interface. Given that the main focus of this lab is not on developing the GUI, we will not delve into its implementation in detail. It's important to note that a detailed understanding of the GUIGame class is not necessary for the successful completion of this lab.

In [316]:
class GUIGame:
    """
    This class implements a tic-tac-toe game with a graphical user interface (GUI).
    It has a bot as the second player, allowing a human player to interact with the game through the GUI.
    """

    def __init__(self, ai_player: BasePolicy):
        """
        Creates a virtual game instance with an AI-controlled second player.

        Parameters
        ----------
        ai_player : BasePolicy
            An object that implements the behavior of the AI player. 
            This object should define how the AI player makes decisions and responds to the game state.
        """
        self._ai_player = ai_player
        self._button_list = []
        self._tic_tac_toe_field = VBox()
        self._text_box = widgets.Text(
            value="Move: X", layout=Layout(width="129px", height="40px")
        )
        self._size = FIELD_SIZE
        self._state = None

        for i in range(self._size**2):
            button = widgets.Button(
                description="",
                disabled=False,
                button_style="",
                tooltip="",
                icon="",
                layout=Layout(width="40px", height="40px"),
            )
            button.on_click(partial(self._on_button_clicked, i))
            self._button_list.append(button)

        self._tic_tac_toe_field = VBox(
            [
                HBox(self._button_list[i * self._size : (i + 1) * self._size])
                for i in range(self._size)
            ]
        )
        display(VBox([self._text_box, self._tic_tac_toe_field]))

    def start_game(self, computer_starts=True):
        """
        Launches the game.

        Parameters
        ----------
        computer_starts : bool
            Indicates whether the AI player makes the first move. 
            If set to True, the AI player starts the game; if False, the human player starts.
        """

        self._state = GameState()
        if computer_starts:
            self._state = play_move(self._state, self._ai_player.get_move(self._state))
            self._update_field()

    def _update_field(self):
        self._text_box.value = "Move: " + (
            "X" if self._state.next_player == X_CELL else "O"
        )
        for i in range(self._size):
            for j in range(self._size):
                if self._state.cells[i * self._size + j] == O_CELL:
                    self._button_list[i * self._size + j].description = "o"
                if self._state.cells[i * self._size + j] == X_CELL:
                    self._button_list[i * self._size + j].description = "x"

    def _update_status(self):
        result = check_result(self._state)
        if result == X_WINS:
            self._text_box.value = "X wins"
        elif result == O_WINS:
            self._text_box.value = "O wins"
        else:
            self._text_box.value = "Draw"

    def _on_button_clicked(self, index, button):
        if game_is_over(self._state):
            self._update_status()
            return
        self._size
        i = int(index // self._size)
        j = index % self._size

        if self._state.cells[i * self._size + j] != EMPTY_CELL:
            self._text_box.value = "Incorrect move!"
            return

        self._state = play_move(self._state, index)
        self._update_field()
        time.sleep(0.1)
        if game_is_over(self._state):
            self._update_status()
            return

        ai_move = self._ai_player.get_move(self._state)
        self._state = play_move(self._state, ai_move)
        self._update_field()

        if game_is_over(self._state):
            self._update_status()
            return

In [317]:
g = GUIGame(RandomPolicy())
g.start_game(True)

VBox(children=(Text(value='Move: X', layout=Layout(height='40px', width='129px')), VBox(children=(HBox(childre…

## Monte-Carlo Tree Search Policy

In this section, you will be tasked with implementing the Monte-Carlo Tree Search (MCTS) algorithm. 

In [318]:
class Node:
    """
    Node class implements a node of a search tree.

    Attributes
    ----------
    state : GameState
        The corresponding game state of this node.
    visits : int
        The number of times this node has been visited.
    wins : int
        The number of wins recorded in the subtree rooted at this node.
    defeats : int
        The number of defeats recorded in the subtree rooted at this node.
    draws : int
        The number of draws recorded in the subtree rooted at this node.
    parent : Node
        The parent node of this node.
    successors : dict[GameState : tuple[Node, int]]
        A dictionary containing information about successors, including game states, corresponding nodes, and moves.
    """

    def __init__(self, state=None, size=3):
        if state is None:
            self.state = GameState(size)
        else:
            self.state = state

        self.visits = 0
        self.wins = 0
        self.defeats = 0
        self.draws = 0
        self.parent = None
        self.successors = dict()

    def add_successor(self, successor, move):
        self.successors[successor.state] = (successor, move)

Note that the final game action executed by the MCTS algorithm can be selected using several methods. In this work, you will be asked to implement two distinct options for this selection process:

- **Max Reward**: This method involves selecting the successor node at the root level that has accumulated the highest reward. 
- **Robust**: In this approach, you will select the successor node at the root that has been visited the most during the MCTS simulations. 


In [319]:
MAX_REWARD = 0
ROBUST = 1

CP = math.sqrt(2)

In [320]:
class MCTSPolicy:
    """
    Implements an AI player for the tic-tac-toe game using the Monte Carlo Tree Search (MCTS) algorithm.

    Attributes
    ----------
    max_time : float
        The time limit for making a move using the MCTS algorithm, specified in seconds. 
        This attribute defines how long the AI player will spend on calculating the best move.
    final_move_criterion : int
        The criterion used for choosing the final move in the MCTS algorithm. It can take one of the predefined values:
        - MAX_REWARD = 0: Selects the move with the maximum accumulated reward.
        - ROBUST = 1: Selects the most visited successor.
    """

    def __init__(self, max_time: float = 10, final_move_criterion: int = ROBUST):
        self.max_time = max_time
        self.final_move_criterion = final_move_criterion

    def get_move(self, state: GameState):
        """
        Finds the best move for the MCTS player from the current state.

        Parameters
        ----------
        state : GameState
            The current state of the game.

        Returns
        -------
        int
            The position of the cell where the MCTS player should make its move, represented as an integer. 
            This integer corresponds to a specific cell position in the game grid.
        """
        root_node = self._create_node(state)
        start_time = time.time()
        end_time = start_time

        while end_time - start_time < self.max_time:
            curr_node = self._tree_policy(root_node)
            playout_result = self._default_policy(curr_node.state)
            self._backpropagation(curr_node, root_node, playout_result)
            end_time = time.time()

        return self._best_successor_and_move(root_node, self.final_move_criterion, 0)[1]
    
    def _tree_policy(self, node):
        while not game_is_over(node.state):
            expanded, untried_states = self._is_full_expanded(node)
            if expanded:
                node, _ = self._best_successor_and_move(node, MAX_REWARD, CP)
            else:
                return self._expand(node, untried_states)

        return node

    def _default_policy(self, state):
        curr = state
        while not game_is_over(curr):
            moves = get_possible_moves(curr)
            move = random.choice(moves)
            curr = play_move(curr, move)

        return check_result(curr)

    def _backpropagation(self, node, root_node, playout_result):
        while True:
            node.visits += 1
            if playout_result == DRAW:
                node.draws += 1
            elif node.state.next_player == playout_result:
                node.defeats += 1
            else:
                node.wins += 1
            if node.state == root_node.state:
                break
            node = node.parent

    def _best_successor_and_move(
        self, parent_node, criterion=MAX_REWARD, exploration_term=0
    ):
        max_value = -1
        best_move = -1
        best_node = None

        for successor_node, move in parent_node.successors.values():
            if criterion == MAX_REWARD:
                successor_value = self._calculate_value(
                    successor_node, exploration_term, parent_node
                )
            elif criterion == ROBUST:
                successor_value = successor_node.visits
            else:
                successor_value = self._calculate_value(
                    successor_node, exploration_term, parent_node
                )

            if successor_value > max_value:
                max_value = successor_value
                best_move = move
                best_node = successor_node

        return best_node, best_move

    def _calculate_value(self, node, exploration_term, parent_node):
        if (node is None) or (node.visits == 0):
            return 0

        return (
            node.wins + 0.5 * node.draws
        ) / node.visits + exploration_term * math.sqrt(
            2 * math.log(parent_node.visits) / node.visits
        )

    def _create_node(self, state):
        new_node = Node(state)
        return new_node

    def _is_full_expanded(self, node):
        expanded = True
        untried = []
        for move in get_possible_moves(node.state):
            new_state = play_move(node.state, move)
            if new_state not in node.successors:
                expanded = False
                untried.append((new_state, move))
        return expanded, untried

    def _expand(self, node, untried):
        state, move = random.choice(untried)
        new_node = self._create_node(state)
        new_node.parent = node
        node.add_successor(new_node, move)
        return new_node

## Lets Play the Game!

In [321]:
mcts_player = MCTSPolicy(4, ROBUST)
g = GUIGame(mcts_player)
g.start_game(False)

VBox(children=(Text(value='Move: X', layout=Layout(height='40px', width='129px')), VBox(children=(HBox(childre…