Monte Carlo Tree Search (MCTS) is best known for adversarial planning in two-player zero-sum games, but it is not limited to that and can be applied to any decision-making process.

In the CartPole environment, we face a continuous state space and a discrete action space. MCTS requires discrete states; otherwise, we would have an infinite number of leaves. To address this, we will bin the state space to discretize it.

For simulating rollouts, we will use the gym environment itself. CartPole is simple enough that we can take a Python deepcopy and simulate the rollout using random or chosen actions. The original environment remains unchanged until a real, non-simulated action is applied to it.

### Discretize to 10 bins, then perform MCTS

In [16]:
import math
import random
import numpy as np
import gymnasium as gym
import copy
import matplotlib.pyplot as plt
from IPython.display import clear_output
from matplotlib import animation

class MCTSNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.action = action
        self.visits = 0
        self.wins = 0

    def is_fully_expanded(self, action_space_size):
        return len(self.children) == action_space_size

    def best_child(self, exploration_param=1.414):
        choices_weights = [
            (child.wins / max(1, child.visits)) + exploration_param * math.sqrt(
                math.log(max(1, self.visits)) / max(1, child.visits)
            ) for child in self.children
        ]
        return self.children[np.argmax(choices_weights)]

    def most_visited_child(self):
        return max(self.children, key=lambda child: child.visits)

def discretize_state(state, bins):
    """
    Discretize the continuous state into a discrete bin representation.
    We create bins for each feature (position, velocity, angle, angular velocity).
    """
    binned_state = []
    for i in range(len(state)):
#         print('state[i]', state[i])
#         print('bin[i]', bins[i])
        binned_state.append(np.digitize(state[i], bins[i]))
    return tuple(binned_state)

def create_bins():
    """
    Create bins for each state variable.
    CartPole has 4 continuous state variables:
    [Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity].
    We define the ranges and bins for each.
    """
    bins = [
        np.linspace(-4.8, 4.8, 10),  # Cart Position
        np.linspace(-5.0, 5.0, 10),  # Cart Velocity
        np.linspace(-0.418, 0.418, 10),  # Pole Angle
        np.linspace(-5.0, 5.0, 10),  # Pole Angular Velocity
    ]
    return bins

def rollout(env, bins):
    """
    Simulate a random rollout from the current state.
    """
    total_reward = 0
    done = False
    while not done:
        # Randomly pick an action (exploration)
        action = env.action_space.sample()
        state, reward, done, _, _ = env.step(action)
        total_reward += reward
        if done:
            break
    return total_reward

def backpropagate(node, reward):
    while node is not None:
        node.visits += 1
        node.wins += reward
        node = node.parent

def expand(node, env, action_space_size, bins):
    """
    Expand the node by creating child nodes for each action.
    """
    for action in range(action_space_size):
        # Clone the environment and step with the current action
        env_copy = gym.make('CartPole-v1')
        env_copy.reset()
        env_copy.env.state = env.env.state  # Copy the environment state
        state, reward, done, _, _ = env_copy.step(action)

        # Discretize the state to get the discrete representation
        discrete_state = discretize_state(state, bins)

        child_node = MCTSNode(state=discrete_state, parent=node, action=action)
        node.children.append(child_node)

def select(node, action_space_size):
    """
    Traverse the tree by selecting the best child node based on the UCB1 algorithm.
    """
    current_node = node
    while current_node.is_fully_expanded(action_space_size):
        current_node = current_node.best_child()
    return current_node

def mcts(env, state, current_node, simulations=1000):
    env_copy = copy.deepcopy(env)
    action_space_size = env.action_space.n

    for _ in range(simulations):
        # Step 1: Selection
        selected_node = select(current_node, action_space_size)
#         print('sn',selected_node)
        
        # Step 2: Expansion
        expand(selected_node, env_copy, action_space_size, bins)
        
        # Step 3: Simulation
        reward = rollout(env_copy, bins)
        if reward !=0:
            print('r', reward)
        
        # Step 4: Backpropagation
        backpropagate(selected_node, reward)
        
    best_node = current_node.best_child(exploration_param=0)
    print('children', current_node.children)
    for c in current_node.children:
        print(c.wins)
    print('mv', current_node.most_visited_child().wins)
    print('bst', current_node.best_child().wins)
    print('curr', current_node.wins)
    return best_node

# Initialize the CartPole environment
env = gym.make("CartPole-v1")

# Run MCTS for 1000 simulations


# Apply the best move (select action based on the best child)
state = env.reset()[0]
print(state)
done = False
tot_r = 0

# Discretize the initial state
bins = create_bins()
discrete_state = discretize_state(state, bins)
current_node = MCTSNode(discrete_state) # first will be root node

while not done:
    current_node = mcts(env, state, current_node, simulations=1000)
    print('action', current_node.action)
    state, reward, done, _, _ = env.step(current_node.action)
    tot_r += reward
    #     print(reward)
    env.render()

print('done')
print(tot_r)
env.close()


[-0.02977553  0.02695098 -0.03981876  0.03226086]
r 38.0
children [<__main__.MCTSNode object at 0x7f031fe1fdf0>, <__main__.MCTSNode object at 0x7f031fe1f9a0>]
0.0
0.0
mv 0.0
bst 0.0
curr 38.0
action 0
r 33.0
children [<__main__.MCTSNode object at 0x7f031fe1f5b0>, <__main__.MCTSNode object at 0x7f031fe1f130>]
0.0
33.0
mv 33.0
bst 33.0
curr 33.0
action 1
r 19.0
children [<__main__.MCTSNode object at 0x7f031fe1faf0>, <__main__.MCTSNode object at 0x7f031fe1f160>]
52.0
0.0
mv 52.0
bst 0.0
curr 52.0
action 0
r 26.0
children [<__main__.MCTSNode object at 0x7f03327cc550>, <__main__.MCTSNode object at 0x7f031fe1f9d0>]
26.0
52.0
mv 52.0
bst 26.0
curr 78.0
action 1
r 18.0
r 4.0
children [<__main__.MCTSNode object at 0x7f031fb36b50>, <__main__.MCTSNode object at 0x7f031fb36d00>]
55.0
19.0
mv 55.0
bst 19.0
curr 74.0
action 0
r 21.0
children [<__main__.MCTSNode object at 0x7f031fb36f10>, <__main__.MCTSNode object at 0x7f031fb36ca0>]
0.0
76.0
mv 76.0
bst 0.0
curr 76.0
action 1
r 33.0
r 8.0
children [

### This code block we try 100 bins

In [17]:
#100 bins

import math
import random
import numpy as np
import gymnasium as gym
import copy
import matplotlib.pyplot as plt
from IPython.display import clear_output
from matplotlib import animation

class MCTSNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.action = action
        self.visits = 0
        self.wins = 0

    def is_fully_expanded(self, action_space_size):
        return len(self.children) == action_space_size

    def best_child(self, exploration_param=1.414):
        choices_weights = [
            (child.wins / max(1, child.visits)) + exploration_param * math.sqrt(
                math.log(max(1, self.visits)) / max(1, child.visits)
            ) for child in self.children
        ]
        return self.children[np.argmax(choices_weights)]

    def most_visited_child(self):
        return max(self.children, key=lambda child: child.visits)

def discretize_state(state, bins):
    """
    Discretize the continuous state into a discrete bin representation.
    We create bins for each feature (position, velocity, angle, angular velocity).
    """
    binned_state = []
    for i in range(len(state)):
#         print('state[i]', state[i])
#         print('bin[i]', bins[i])
        binned_state.append(np.digitize(state[i], bins[i]))
    return tuple(binned_state)

def create_bins():
    """
    Create bins for each state variable.
    CartPole has 4 continuous state variables:
    [Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity].
    We define the ranges and bins for each.
    """
    bins = [
        np.linspace(-4.8, 4.8, 100),  # Cart Position
        np.linspace(-5.0, 5.0, 100),  # Cart Velocity
        np.linspace(-0.418, 0.418, 100),  # Pole Angle
        np.linspace(-5.0, 5.0, 100),  # Pole Angular Velocity
    ]
    return bins

def rollout(env, bins):
    """
    Simulate a random rollout from the current state.
    """
    total_reward = 0
    done = False
    while not done:
        # Randomly pick an action (exploration)
        action = env.action_space.sample()
        state, reward, done, _, _ = env.step(action)
        total_reward += reward
        if done:
            break
    return total_reward

def backpropagate(node, reward):
    while node is not None:
        node.visits += 1
        node.wins += reward
        node = node.parent

def expand(node, env, action_space_size, bins):
    """
    Expand the node by creating child nodes for each action.
    """
    for action in range(action_space_size):
        # Clone the environment and step with the current action
        env_copy = gym.make('CartPole-v1')
        env_copy.reset()
        env_copy.env.state = env.env.state  # Copy the environment state
        state, reward, done, _, _ = env_copy.step(action)

        # Discretize the state to get the discrete representation
        discrete_state = discretize_state(state, bins)

        child_node = MCTSNode(state=discrete_state, parent=node, action=action)
        node.children.append(child_node)

def select(node, action_space_size):
    """
    Traverse the tree by selecting the best child node based on the UCB1 algorithm.
    """
    current_node = node
    while current_node.is_fully_expanded(action_space_size):
        current_node = current_node.best_child()
    return current_node

def mcts(env, state, current_node, simulations=1000):
    env_copy = copy.deepcopy(env)
    action_space_size = env.action_space.n

    for _ in range(simulations):
        # Step 1: Selection
        selected_node = select(current_node, action_space_size)
#         print('sn',selected_node)
        
        # Step 2: Expansion
        expand(selected_node, env_copy, action_space_size, bins)
        
        # Step 3: Simulation
        reward = rollout(env_copy, bins)
        
        # Step 4: Backpropagation
        backpropagate(selected_node, reward)
        
    best_node = current_node.best_child(exploration_param=0)

    return best_node

# Initialize the CartPole environment
env = gym.make("CartPole-v1")

# Run MCTS for 1000 simulations


# Apply the best move (select action based on the best child)
state = env.reset()[0]
done = False
tot_r = 0

# Discretize the initial state
bins = create_bins()
discrete_state = discretize_state(state, bins)
current_node = MCTSNode(discrete_state) # first will be root node

while not done:
    current_node = mcts(env, state, current_node, simulations=1000)
    print('action', current_node.action)
    state, reward, done, _, _ = env.step(current_node.action)
    tot_r += reward
    env.render()

print('done')
print(tot_r)
env.close()


action 0
action 1
action 1
action 0
action 1
action 0
action 1
action 1
action 1
action 1
action 1
action 1
action 1
action 1
action 0
done
15.0


we learn that 10 and 100 bins are somewhat similar, so it doens't really matter.