In [1]:
from gomoku.gomoku_env import GomokuEnv
import random
import math

class TreeNode:
    def __init__(self, env: GomokuEnv):
        self.env = env
        self.visits = 0
        self.wins = 0
        self.children = {}
        self.parent = None

    def add_child(self, action, child_node):
        child_node.parent = self
        self.children[action] = child_node

    def update(self, result):
        self.visits += 1
        if result == self.env.current_player:
            self.wins += 1
    
    def is_fully_expanded(self):
        """
        Check if the node is fully expanded, meaning all possible actions from this state have been explored.
        
        :return: True if fully expanded, False otherwise.
        """
        # This method should implement the logic to check if all children
        # of this node have been created.
        return len(self.children) == len(self.env.get_valid_actions())

In [2]:
from numpy import choose


class MCTS:
    def __init__(self, strategy, c=1.41):
        self.stratey = strategy
        self.c = c  # Exploration constant for UCT

        self.root = None  # The root node of the MCTS tree

    def run(self, env, iterations):
        """action = 
        Run the Monte Carlo Tree Search algorithm for a given number of iterations.

        :param state: The initial state of the game.
        :param iterations: The number of iterations to run the MCTS.
        :return: The best action to take from the current state.
        """
        self.root = TreeNode(env.clone())  # Initialize the root node with the environment state
        for _ in range(iterations):
            # Select a node in the tree
            node = self._select(self.root)
            
            # If the node is not fully expanded, expand it
            if not node.env._is_terminal() and not node.is_fully_expanded():
                node = self._expansion(node)
            
            # Simulate a random game from the current state
            result = self._simulate(node)
            
            # Backpropagate the result up the tree
            self._backpropagation(node, result)
        
        return self._best_action(self.root)
    
    def _select(self, root: TreeNode):
        """
        Select a node in the tree based on the UCT (Upper Confidence Bound for Trees) algorithm.
        
        :param state: The current state of the game.
        :return: The selected node.
        """
        # This method should implement the logic to select a node
        # based on the UCT algorithm.
        while not root.env._is_terminal() and root.is_fully_expanded():
            # Sel._is_terminal(): the child with the highest UCT value
            root = max(root.children.values(), key=lambda child: self._uct_value(child))
        return root

    def _expansion(self, node):
        # 1. 纯随机
        # 2. 基于策略网络
        valid_actions = node.env.get_valid_actions()  # Get valid actions from the environment
        expanded_actions = set(list(node.children.keys()))  # Get already expanded actions
        valid_actions = [action for action in valid_actions if action not in expanded_actions]
        choose_action = random.choice(valid_actions)  # Randomly select an action from valid actions

        child_env = node.env.clone()
        child_env.step(choose_action)  # Apply the action to the environment

        child_node = TreeNode(child_env)
        node.add_child(choose_action, child_node)  # Add the new child node to the current node
        return child_node

    
    def _simulate(self, node):
        """
        Simulate a random game from the current state to a terminal state.
        
        :param state: The current state of the game.
        """
        # This method should implement the logic to simulate a game
        # from the current state until a terminal state is reached.
        current_env = node.env.clone()
        while not current_env._is_terminal():
            valid_actions = current_env.get_valid_actions()
            action = random.choice(valid_actions)
            current_env.step(action)  # Apply the action to the environment
        return current_env.winner  # Return the result of the game (win/loss/d
        

    def _backpropagation(self, node, result):
        """
        Backpropagate the result of the simulation up the tree.
        
        :param state: The current state of the game.
        :param result: The result of the simulation (win/loss/draw).
        """
        # This method should implement the logic to update the nodes
        # in the tree based on the result of the simulation.
        current_node = node
        while current_node is not None:
            current_node.update(result)
            current_node = current_node.parent  # Move up the tree to the parent node  
    
    def _best_action(self, root: TreeNode):
        """
        Get the best action to take from the root node based on the number of visits.
        
        :param root: The root node of the MCTS tree.
        :return: The best action to t    env.render()  # Render the current state of the environment
        """
        # This method should implement the logic to select the action
        # with the highest number of visits from the root node.
        # return max(root.children, key=lambda child: child.visits).env.last_action
        return max(root.children.items(), key=lambda item: item[1].visits)[0]  # Return the action with the most visits
    
    def _uct_value(self, child: TreeNode):
        """
        Calculate the UCT value for a child node.
        
        :param child: The child node to evaluate.
        :return: The UCT value of the child node.
        """
        return -child.wins / child.visits + self.c * (math.log(child.parent.visits) / child.visits) ** 0.5

In [3]:
from gomoku.gomoku_env import GomokuEnvSimple
# from gomoku.mcts import MCTS


env = GomokuEnvSimple()

In [4]:
mcts = MCTS(strategy="random", c=1.41)  # Initialize MCTS with a random strategy and exploration constant

In [5]:
env.reset()
while not env._is_terminal():
    mcts = MCTS(strategy="random", c=1.41)  # Reinitialize MCTS for each iteration
    action = mcts.run(env, iterations=8000)  # Run MCTS for the next iteration
    env.step(action)  # Apply the action to the environment
    env.render()  # Render the current state of the environment

   0  1  2  3  4  5  6  7  8
 0 .  .  .  .  .  .  .  .  . 
 1 .  .  .  .  .  .  .  .  . 
 2 .  .  .  .  X  .  .  .  . 
 3 .  .  .  .  .  .  .  .  . 
 4 .  .  .  .  .  .  .  .  . 
 5 .  .  .  .  .  .  .  .  . 
 6 .  .  .  .  .  .  .  .  . 
 7 .  .  .  .  .  .  .  .  . 
 8 .  .  .  .  .  .  .  .  . 

   0  1  2  3  4  5  6  7  8
 0 .  .  .  .  .  .  .  .  . 
 1 .  .  .  .  .  .  .  .  . 
 2 .  .  .  .  X  .  .  .  . 
 3 .  .  .  .  .  .  .  .  . 
 4 .  .  O  .  .  .  .  .  . 
 5 .  .  .  .  .  .  .  .  . 
 6 .  .  .  .  .  .  .  .  . 
 7 .  .  .  .  .  .  .  .  . 
 8 .  .  .  .  .  .  .  .  . 

   0  1  2  3  4  5  6  7  8
 0 .  .  .  .  .  .  .  .  . 
 1 .  .  .  .  .  .  .  .  . 
 2 .  .  .  .  X  .  .  .  . 
 3 .  .  .  .  .  .  .  .  . 
 4 .  .  O  .  .  .  .  .  . 
 5 .  .  .  .  .  .  .  .  . 
 6 .  X  .  .  .  .  .  .  . 
 7 .  .  .  .  .  .  .  .  . 
 8 .  .  .  .  .  .  .  .  . 

   0  1  2  3  4  5  6  7  8
 0 .  .  .  .  .  .  .  .  . 
 1 .  .  .  .  .  .  .  .  . 
 2 .  .  O 

In [6]:
env = GomokuEnvSimple()
env.board[0, :4] = 1
env.board[1, :4] = 2
env.render()

   0  1  2  3  4  5  6  7  8
 0 X  X  X  X  .  .  .  .  . 
 1 O  O  O  O  .  .  .  .  . 
 2 .  .  .  .  .  .  .  .  . 
 3 .  .  .  .  .  .  .  .  . 
 4 .  .  .  .  .  .  .  .  . 
 5 .  .  .  .  .  .  .  .  . 
 6 .  .  .  .  .  .  .  .  . 
 7 .  .  .  .  .  .  .  .  . 
 8 .  .  .  .  .  .  .  .  . 



In [7]:
mcts = MCTS(strategy="random", c=1.41)  # Initialize MCTS with a random strategy and exploration constant
action = mcts.run(env, iterations=3000)  # Run MCTS for

In [8]:
action

4

In [9]:
judges = [[(item[0] // 9, item[0] % 9),(item[1].wins / item[1].visits)] for item in mcts.root.children.items()]

In [10]:
judges.sort(key=lambda x: x[1], reverse=True)  # Sort by win rate in descending order

In [11]:
judges

[[(6, 2), 0.8666666666666667],
 [(2, 0), 0.8235294117647058],
 [(0, 6), 0.8235294117647058],
 [(5, 8), 0.7777777777777778],
 [(3, 2), 0.7142857142857143],
 [(3, 8), 0.7142857142857143],
 [(2, 6), 0.7142857142857143],
 [(7, 1), 0.6818181818181818],
 [(5, 0), 0.6818181818181818],
 [(4, 2), 0.6818181818181818],
 [(5, 2), 0.6818181818181818],
 [(7, 2), 0.6666666666666666],
 [(8, 6), 0.64],
 [(4, 4), 0.6153846153846154],
 [(1, 6), 0.6153846153846154],
 [(7, 3), 0.6153846153846154],
 [(7, 4), 0.6153846153846154],
 [(6, 7), 0.6153846153846154],
 [(3, 5), 0.5925925925925926],
 [(5, 7), 0.5925925925925926],
 [(6, 0), 0.5925925925925926],
 [(7, 7), 0.5925925925925926],
 [(3, 3), 0.5925925925925926],
 [(7, 6), 0.5666666666666667],
 [(4, 5), 0.5666666666666667],
 [(3, 6), 0.5666666666666667],
 [(6, 3), 0.5666666666666667],
 [(1, 5), 0.5483870967741935],
 [(0, 8), 0.5483870967741935],
 [(5, 6), 0.5483870967741935],
 [(5, 3), 0.53125],
 [(2, 2), 0.53125],
 [(8, 1), 0.53125],
 [(8, 3), 0.53125],
 [(4