# Monte Carlo Tree Search (MCTS)

In [2]:
import math
import random

class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.wins = 0

    def add_child(self, child_state):
        child = Node(child_state, self)
        self.children.append(child)
        return child

    def update(self, result):
        self.visits += 1
        self.wins += result

    def fully_expanded(self):
        return len(self.children) == len(self.state.possible_moves())

    def best_child(self, c_param=1.4):
        choices_weights = [
            (child.wins / child.visits) + c_param * math.sqrt((2 * math.log(self.visits) / child.visits))
            for child in self.children
        ]
        return self.children[choices_weights.index(max(choices_weights))]

class MCTS:
    def __init__(self, iterations=1000):
        self.iterations = iterations

    def search(self, initial_state):
        root = Node(initial_state)

        for _ in range(self.iterations):
            node = self._select(root)
            if not node.fully_expanded():
                node = self._expand(node)
            result = self._simulate(node)
            self._backpropagate(node, result)

        return root.best_child(c_param=0)

    def _select(self, node):
        while node.fully_expanded() and node.children:
            node = node.best_child()
        return node

    def _expand(self, node):
        tried_states = [child.state for child in node.children]
        new_state = random.choice([move for move in node.state.possible_moves() if move not in tried_states])
        return node.add_child(new_state)

    def _simulate(self, node):
        current_state = node.state
        while not current_state.is_terminal():
            current_state = random.choice(current_state.possible_moves())
        return current_state.result()

    def _backpropagate(self, node, result):
        while node is not None:
            node.update(result)
            node = node.parent

class SimpleGameState:
    def __init__(self, moves_left):
        self.moves_left = moves_left

    def possible_moves(self):
        return [SimpleGameState(self.moves_left - 1) for _ in range(self.moves_left)]

    def is_terminal(self):
        return self.moves_left == 0

    def result(self):
        return 1 if self.moves_left % 2 == 0 else 0

# Example usage
initial_state = SimpleGameState(10)
mcts = MCTS(iterations=1000)
best_move = mcts.search(initial_state)
print(f"Best move has {best_move.state.moves_left} moves left.")

Best move has 9 moves left.
