# Import package

In [1]:
import copy
import gymnasium as gym
import numpy as np
from datetime import datetime, timedelta

# Config

In [6]:
class Config():
    def __init__(self):
        self.env_name = "CartPole-v1"
        self.total_episode = 10
        self.gamma = 0.997
        self.num_action = 2
        self.reuse_tree = False #reuse or build new tree. After first MCTS (for s0), we have tree for s1, ... We can reuse this tree or build new tree for s1, s2, ...
        #You need set config.timeout or config.search_step to None and another to integer
        self.timeout = None #each MCTS step will run until timeout (set config.timeout to None if you don't want limit by time)
        self.search_step = 10 #each MCTS step will run config.search_step (set config.search_step to None if you want to run with time limit)

config = Config()

# MCTS

## Node
The Node includes two main functions: init and get_score:
- init:
    - N: number of visits, initialized to 0
    - V: sum of returns, initialized to 0
    - children: None initially; children nodes will be added in the explore function
    - parent: the parent node, or None if this is the root
    - lead_action: the action that led to this node, or None if root
    - num_action: number of possible actions
    - env: a copy of the environment from the parent node’s env (or copy the main env if root)
    - If lead_action exists: perform this action to get reward and done status
    - Otherwise, set reward = 0 and done = False
    - available_actions: the set of legal actions (equal to num_action), from which actions will be removed when children nodes are added
- get_score:
    - Calculate the score of the node using the formula:
        - If the node is root, score = 0
        - Otherwise: $score = \frac{V}{N} + c \sqrt\frac{2\ln(N_p)}{N}$
    - Where $c = \frac{1}{\sqrt 2}$
    - $N_p$ is the visit count of the parent node if the node has a parent, or the node’s own visit count if it is the root
- is_explore: return True if node is explored:
    - Have childrent (!= None)
    - Is terminal state (done = True)

In [7]:
class Node():
    def __init__(self, num_action, base_env, parent = None, lead_action = None, state = None):
        self.N = 0
        self.V = 0.

        self.children = None
        self.parent = parent
        self.lead_action = lead_action
        self.num_action = num_action

        self.env = copy.deepcopy(base_env)
        if lead_action is not None:
            self.state, self.reward, terminated, truncated, info = self.env.step(lead_action)
            self.done = terminated or truncated
        else:
            self.done = False
            self.reward = 0.
            self.state = state
        self.available_actions = set(list(range(num_action))) if not self.done else set()

    def get_score(self):
        if self.N == 0:
            return 1e9

        top_node = self
        if top_node.parent is not None:
            top_node = top_node.parent

        c = 1. / np.sqrt(2)
        V = self.V / self.N
        return V + c * np.sqrt(2 * np.log(top_node.N) / self.N)

    def is_explore(self):
        return self.children and not self.done

## MCTS functions

**simulate**: Perform a random play (rollout) for one episode and return the total reward of that episode.
- Copy the environment state from the given node before simulation.

**create_child**: create all chilrent for current node.

**backpropagate**: backpropagate the return value up the tree.

**find_explore_node**: Select a leaf node starting from the root by repeating:
- Use the score function to select the best child node.
- Move to that child node.
- Stop when reaching a terminal node or unexplored node (**node.is_exlore**).

**explore**:
- Select a leaf node(**Select a leaf node**):
- If stop at non terminal node:
    - create all children for this node (**create_child**).
- Perform simulation from this node to get a return value (**simulate**).
- Backpropagate the return value up the tree (**backpropagate**).

**select_action**: Choose the best action as the child node with the highest visit count (N)

**mcts**:
- Perform multiple explore steps.
- Choose the best action as the child node with the highest visit count (N) (**select_action**).

In [8]:
def simulate(node):
    env = copy.deepcopy(node.env)
    done = node.done
    G = 0.
    rewards = []
    while not done:
        action = np.random.choice(config.num_action)
        _, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        rewards.append(reward)
    for reward in reversed(rewards):
        G = G * config.gamma + reward
    env.close()
    return G

def create_child(node):
    node.children = []
    for action in range(config.num_action):
        new_node = Node(config.num_action, node.env, node, action, None)
        node.children.append(new_node)
    return node

def backpropagate(node, G):
    while node:
        node.N += 1
        node.V += G
        G = config.gamma * G + node.reward
        node = node.parent

def find_explore_node(explore_node):
    current_node = explore_node
    while current_node.children and not current_node.done:
        children = current_node.children
        actions_score = [child.get_score() for child in children]
        max_score = max(actions_score)
        best_children = [child for child, score in zip(children, actions_score) if score == max_score]
        idx = np.random.choice(len(best_children))
        current_node = best_children[idx]
    return current_node

def explore(explore_node):
    current_node = find_explore_node(explore_node)

    if current_node.children is None and not current_node.done:
        current_node = create_child(current_node)

    G = simulate(current_node)

    backpropagate(current_node, G)

def check_mcts(timeout, start_time, search_step, current_step):
    if timeout is None:
        return current_step < search_step
    return datetime.now() - start_time < timedelta(seconds=timeout)

def select_action(root_node):
    children = root_node.children
    Ns= [child.N for child in children]
    max_N = max(Ns)
    best_children = [child for child, action, N in zip(children, [child.lead_action for child in children], Ns) if N == max_N]
    best_actions = [action for child, action, N in zip(children, [child.lead_action for child in children], Ns) if N == max_N]
    best_child = best_children[0]
    best_child.parent = None
    best_child.lead_action = None
    best_child.reward = 0.
    best_child.done = False
    return best_actions[0], best_child

def mcts(root_state, root_env, timeout = 1, search_step = None, tree=None):
    root_node = tree if (tree and config.reuse_tree) else Node(config.num_action, root_env, None, None)
    start_time = datetime.now()
    step = 0
    while check_mcts(timeout, start_time, search_step, step):
        explore(root_node)
        step += 1
    return select_action(root_node)

# play and test episodes

Play config.total_episode episodes and print the results. For each episode:
- Reset the environment.
- Repeat until the episode ends:
    - Use the mcts function to find the best action.
    - Perform this action in the environment.

In [None]:
episode_rewards = []
episode_steps = []
episode_runtimes = []

for episode in range(config.total_episode):
    start_time = datetime.now()
    env = gym.make(config.env_name, render_mode="rgb_array")
    state, info = env.reset()
    done = False
    episode_reward = 0
    episode_step = 0
    next_tree = None

    while not done:
        action, next_tree = mcts(state, env, config.timeout, config.search_step, next_tree)
        state, reward, terminated, truncated, info = env.step(action)
        episode_reward += reward
        episode_step += 1
        if episode_step % 100 == 0:
            print(episode_step, episode_reward)
        done = terminated or truncated

    episode_rewards.append(episode_reward)
    episode_steps.append(episode_step)
    episode_runtimes.append(datetime.now() - start_time)
    print(episode_reward, episode_runtimes[-1], "\n")

episode_rewards = np.array(episode_rewards)
print(episode_rewards)
print(episode_rewards.max(), episode_rewards.min(), episode_rewards.mean(), episode_rewards.std())
print(max(episode_runtimes), min(episode_runtimes), np.mean(episode_runtimes))