# Import package

In [12]:
import copy
import gymnasium as gym
import numpy as np
from datetime import datetime, timedelta

# Config
Just use a global variable for simplicity.

In [33]:
ENV_NAME = "CartPole-v1"
TOTAL_EPISODE = 10
GAMMA = 0.99
NUM_ACTION = 2
REUSE_TREE = False #reuse or build new tree. After first MCTS (for s0), we have tree for s1, ... We can reuse this tree or build new tree for s1, s2, ...
#You need set TIMEOUT or SEARCH_STEP to None and another to integer
TIMEOUT = None #each MCTS step will run until timeout (set TIMEOUT to None if you don't want limit by time)
SEARCH_STEP = 20 #each MCTS step will run SEARCH_STEP (set SEARCH_STEP to None if you want to run with time limit)

# MCTS

## Node
The Node includes two main functions: init and get_score:
- init:
    - N: number of visits, initialized to 0
    - V: sum of returns, initialized to 0
    - children: None initially; children nodes will be added in the explore function
    - parent: the parent node, or None if this is the root
    - lead_action: the action that led to this node, or None if root
    - num_action: number of possible actions
    - env: a copy of the environment from the parent node’s env (or copy the main env if root)
    - If lead_action exists: perform this action to get reward and done status
    - Otherwise, set reward = 0 and done = False
    - available_actions: the set of legal actions (equal to num_action), from which actions will be removed when children nodes are added
- get_score: 
    - Calculate the score of the node using the formula: 
        - If the node is root, score = 0
        - Otherwise: $score = \frac{V}{N} + c \sqrt\frac{2\ln(N_p)}{N}$
    - Where $c = \frac{1}{\sqrt 2}$
    - $N_p$ is the visit count of the parent node if the node has a parent, or the node’s own visit count if it is the root

In [34]:
class Node():
    def __init__(self, num_action, base_env, parent = None, lead_action = None, state = None):
        self.N = 0
        self.V = 0.

        self.children = None
        self.parent = parent
        self.lead_action = lead_action
        self.num_action = num_action

        self.env = copy.deepcopy(base_env)
        if lead_action is not None:
            self.state, self.reward, terminated, truncated, info = self.env.step(lead_action)
            self.done = terminated or truncated
        else:
            self.done = False
            self.reward = 0.
            self.state = state
        self.available_actions = set(list(range(num_action))) if not self.done else set()

    def get_score(self):
        if self.N == 0:
            return 1e9

        top_node = self
        if top_node.parent is not None:
            top_node = top_node.parent

        c = 1. / np.sqrt(2)
        V = self.V / self.N
        return V + c * np.sqrt(2 * np.log(top_node.N) / self.N)

## MCTS functions

**simulate**: Perform a random play (rollout) for one episode and return the total reward of that episode.
- Copy the environment state from the given node before simulation.

**explore**:
- Select a leaf node starting from the root by repeating:
    - Use the score function to select the best child node.
    - Move to that child node.
    - Stop when reaching a terminal node or a node that is not fully expanded.
- If the current node is not fully expanded:
    - Randomly select an action from the available actions.
    - Create a child node for this action.
    - Move to this child node.
- Perform simulation from this node to get a return value.
- Backpropagate the return value up the tree.

**mcts**: 
- Perform multiple explore steps.
- Choose the best action as the child node with the highest visit count (N).

In [35]:
def simulate(node):
    env = copy.deepcopy(node.env)
    done = node.done
    G = 0.
    rewards = []
    while not done:
        action = np.random.choice(NUM_ACTION)
        _, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        rewards.append(reward)
    for reward in reversed(rewards):
        G = G * GAMMA + reward
    env.close()
    return G

def explore(explore_node):
    current_node = explore_node
    while current_node.children and not current_node.done:
        children = current_node.children
        if len(children) < NUM_ACTION:
            break
        actions_score = [child.get_score() for child in children]
        max_score = max(actions_score)
        best_children = [child for child, score in zip(children, actions_score) if score == max_score]
        idx = np.random.choice(len(best_children))
        current_node = best_children[idx]
    available_actions = current_node.available_actions
    if len(available_actions) > 0:
        action = np.random.choice(list(available_actions))
        current_node.available_actions.remove(action)
        new_node = Node(NUM_ACTION, current_node.env, current_node, action, None)
        if current_node.children is not None:
            current_node.children.append(new_node)
        else:
            current_node.children = [new_node]
        current_node = new_node
    G = simulate(current_node)
    while current_node:
        current_node.N += 1
        current_node.V += G
        G = GAMMA * G + current_node.reward
        current_node = current_node.parent

def check_mcts(timeout, start_time, search_step, current_step):
    if timeout is None:
        return current_step < search_step
    return datetime.now() - start_time < timedelta(seconds=timeout)

def mcts(root_state, root_env, timeout = 1, search_step = None, tree=None):
    root_node = tree if (tree and REUSE_TREE) else Node(NUM_ACTION, root_env, None, None)
    start_time = datetime.now()
    step = 0
    while check_mcts(timeout, start_time, search_step, step):
        explore(root_node)
        step += 1
    children = root_node.children
    Ns= [child.N for child in children]
    max_N = max(Ns)
    best_children = [child for child, action, N in zip(children, [child.lead_action for child in children], Ns) if N == max_N]
    best_actions = [action for child, action, N in zip(children, [child.lead_action for child in children], Ns) if N == max_N]
    best_child = best_children[0]
    best_child.parent = None
    best_child.lead_action = None
    best_child.reward = 0.
    best_child.done = False
    return best_actions[0], best_child

# play and test episodes

Play TOTAL_EPISODE episodes and print the results. For each episode:
- Reset the environment.
- Repeat until the episode ends:
    - Use the mcts function to find the best action.
    - Perform this action in the environment.

In [None]:
episode_rewards = []
episode_steps = []
episode_runtimes = []

for episode in range(TOTAL_EPISODE):
    start_time = datetime.now()
    env = gym.make(ENV_NAME, render_mode="rgb_array")
    state, info = env.reset()
    done = False
    episode_reward = 0
    episode_step = 0
    next_tree = None

    while not done:
        action, next_tree = mcts(state, env, TIMEOUT, SEARCH_STEP, next_tree)
        state, reward, terminated, truncated, info = env.step(action)
        episode_reward += reward
        episode_step += 1
        if episode_step % 100 == 0:
            print(episode_step, episode_reward)
        done = terminated or truncated

    episode_rewards.append(episode_reward)
    episode_steps.append(episode_step)
    episode_runtimes.append(datetime.now() - start_time)
    print(episode_reward, episode_runtimes[-1], "\n")

episode_rewards = np.array(episode_rewards)
print(episode_rewards)
print(episode_rewards.max(), episode_rewards.min(), episode_rewards.mean(), episode_rewards.std())
print(max(episode_runtimes), min(episode_runtimes), np.mean(episode_runtimes))