In [None]:
import random
import math

class MCTSNode:
    def __init__(self, game_state, parent=None, move=None):
        self.game_state = game_state
        self.parent = parent
        self.move = move
        self.children = []
        self.wins = 0
        self.visits = 0
        self.untried_moves = game_state.get_legal_moves()

    def select_child(self):
        # UCB1 formula를 사용하여 자식 노드 선택
        return sorted(self.children, key=lambda c: c.wins / c.visits + math.sqrt(2 * math.log(self.visits) / c.visits))[0]

    def add_child(self, move, game_state):
        child_node = MCTSNode(game_state, parent=self, move=move)
        self.untried_moves.remove(move)
        self.children.append(child_node)
        return child_node

    def update(self, result):
        self.visits += 1
        self.wins += result

    def is_fully_expanded(self):
        return len(self.untried_moves) == 0

def MCTS(root_state, iterations):
    root_node = MCTSNode(game_state=root_state)

    for _ in range(iterations):
        node = root_node
        state = root_state.clone()

        # Selection
        while node.is_fully_expanded() and node.children:
            node = node.select_child()
            state.do_move(node.move)

        # Expansion
        if node.untried_moves:
            move = random.choice(node.untried_moves)
            state.do_move(move)
            node = node.add_child(move, state)

        # Simulation
        while state.get_legal_moves():
            state.do_move(random.choice(state.get_legal_moves()))

        # Backpropagation
        while node:
            node.update(state.get_result(node.parent.game_state.next_player))
            node = node.parent

    return sorted(root_node.children, key=lambda c: c.visits)[-1].move
