In [None]:
import copy
import gymnasium as gym
import numpy as np
from datetime import datetime, timedelta
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Config():
    def __init__(self):
        #env
        self.env_name = "CartPole-v1"
        self.total_episode = 1000
        self.gamma = 0.997
        self.num_action = 2
        self.state_dim = 4

        #MCTS
        self.reuse_tree = False #reuse or build new tree. After first MCTS (for s0), we have tree for s1, ... We can reuse this tree or build new tree for s1, s2, ...
        #You need set config.timeout or config.search_step to None and another to integer
        self.timeout = None #each MCTS step will run until timeout (set config.timeout to None if you don't want limit by time)
        self.search_step = 50 #each MCTS step will run config.search_step (set config.search_step to None if you want to run with time limit)
        self.max_moves = 512
        self.num_sampling_moves = 501

        # Root prior exploration noise.
        self.root_dirichlet_alpha = 0.3
        self.root_exploration_fraction = 0.25

        # UCB formula
        self.pb_c_base = 19652
        self.pb_c_init = 1.25

        #training
        self.training_steps = 10
        self.testing_steps = 10
        self.total_test_episode = 10
        self.start_train_from_episode = 10
        self.buffer_size = 256
        self.save_step = 100
        self.epochs = 100
        self.V_loss_type = "CE" # or "MSE"

        #model
        self.learning_rate = 5e-3
        self.use_scheduler = False
        self.final_lr = 5e-4
        self.total_decay_step = 500
        self.weight_decay  = 1e-4
        self.batchsize = 256
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.max_support_value = 500
        self.support_step = self.max_support_value//50
        self.save_path = "best_CE.pth"

config = Config()

In [None]:
class MinMaxScale():
    def __init__(self):
        self.min_G = 500.
        self.max_G = 0.

    def update(self, G):
        self.min_G = min(self.min_G, G)
        self.max_G = max(self.max_G, G)

    def scale(self, G):
        if self.min_G < self.max_G:
            return (G - self.min_G) / (self.max_G - self.min_G)
        return G

In [None]:
def scalar_to_support(value, support_values):
    """
    value: [batch_size], float
    support_values: 1D tensor like [0, 50, ..., 500] of shape [N]
    return: [batch_size, N] soft distribution
    """
    batch_size = value.size(0)
    device = value.device
    N = support_values.size(0)

    # Clamp values to be within support range
    min_v, max_v = support_values[0], support_values[-1]
    value = torch.clamp(value, min_v.item(), max_v.item())

    # Find which two support points the value lies between
    diff = support_values.unsqueeze(0) - value.unsqueeze(1)  # [batch_size, N]
    mask = diff <= 0
    idx_low = mask.sum(dim=1) - 1  # Index of largest support_value <= value
    idx_high = idx_low + 1

    idx_low = idx_low.clamp(0, N - 1)
    idx_high = idx_high.clamp(0, N - 1)

    v_low = support_values[idx_low]
    v_high = support_values[idx_high]

    # Linear interpolation
    weight_high = (value - v_low) / (v_high - v_low + 1e-8)
    weight_low = 1.0 - weight_high

    support = torch.zeros((batch_size, N), device=device)
    support.scatter_(1, idx_low.unsqueeze(1), weight_low.unsqueeze(1))
    support.scatter_add_(1, idx_high.unsqueeze(1), weight_high.unsqueeze(1))

    return support

In [None]:
def support_to_scalar(probs, support_values):
    """
    probs: [batch_size, N], soft distribution
    support_values: [N]
    return: [batch_size]
    """
    return torch.sum(probs * support_values.unsqueeze(0), dim=1)

In [None]:
class Policy_Model(nn.Module):
    def __init__(self, state_dim=4, action_dim=2):
        super(Policy_Model, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)

        # Policy head
        self.policy_head = nn.Linear(64, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        policy_logits = self.policy_head(x)
        return policy_logits

class Value_Model(nn.Module):
    def __init__(self, state_dim=4, support_values=2, loss_type = "CE"):
        super(Value_Model, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)

        self.loss_type = loss_type
        # Value head
        if self.loss_type == "CE":
            self.value_head = nn.Linear(64, len(support_values))
            self.support_values = support_values
        else:
            self.value_head = nn.Linear(64, 1)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        value_logit = self.value_head(x)
        if self.loss_type == "CE":
            with torch.no_grad():
                value = support_to_scalar(F.softmax(value_logit, -1), self.support_values.to(value_logit.device)).unsqueeze(1)
        else:
            value = value_logit
        return value, value_logit

class Model(nn.Module):
    def __init__(self, state_dim=4, action_dim=2, support_values = None, loss_type = "CE"):
        super(Model, self).__init__()

        # Policy Model
        self.policy_model = Policy_Model(state_dim, action_dim)

        # Value model
        self.value_model = Value_Model(state_dim, support_values, loss_type)

    def forward(self, state):
        return self.policy_model(state), self.value_model(state)

In [None]:
class Episode():
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.total_rewards = []
        self.values = []
        self.probs = []

    def save_step_data(self, state, action, reward, done, value, probs):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.values.append(value)
        self.probs.append(probs)

    def save_total_rewards(self):
        total_rewards = 0.
        for reward in self.rewards[::-1]:
            total_rewards = total_rewards * config.gamma + reward
            self.total_rewards.append(total_rewards)
        self.total_rewards.reverse()

In [None]:
class Replay_Buffer():
    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.episodes = []
        self.index = 0
        self.real_size = 0

    def add(self, episode):
        if self.real_size < self.buffer_size:
            self.episodes.append(episode)
        else:
            self.episodes[self.index] = episode
        self.index = (self.index + 1) % self.buffer_size
        self.real_size = min(self.real_size + 1, self.buffer_size)

    def sample(self, batchsize):
        episode_index = np.random.choice(self.real_size, batchsize, replace = True)
        states = []
        target_values = []
        target_probs = []

        for i in range(batchsize):
            episode = self.episodes[episode_index[i]]
            len_episode = len(episode.states)
            pos_idx = np.random.choice(len_episode, 1, replace = False)[0]
            states.append(episode.states[pos_idx])
            target_values.append(episode.values[pos_idx])
            target_probs.append(episode.probs[pos_idx])

        return states, target_probs, target_values

In [None]:
class Node():
    def __init__(self, prior, num_action, base_env, parent = None, lead_action = None, state = None):
        self.N = 0
        self.V = 0.

        self.prior = prior

        self.predict_V = 0.
        self.predict_P = None

        self.children = None
        self.parent = parent
        self.lead_action = lead_action
        self.num_action = num_action

        self.env = copy.deepcopy(base_env)
        if lead_action is not None:
            self.state, self.reward, terminated, truncated, info = self.env.step(lead_action)
            self.done = terminated or truncated
        else:
            self.done = False
            self.reward = 0.
            self.state = state
        self.available_actions = set(list(range(num_action))) if not self.done else set()

    def ucb_score(self, min_max_scale):
        if self.N == 0:
            return 1e9

        pb_c = np.log((self.parent.N + config.pb_c_base + 1) /
                        config.pb_c_base) + config.pb_c_init
        pb_c *= np.sqrt(self.parent.N) / (self.N + 1)

        prior_score = pb_c * self.prior
        value_score = self.V / self.N
        value_score = min_max_scale.scale(value_score)
        return prior_score + value_score

    def is_explore(self):
        return self.children and not self.done

In [None]:
def add_exploration_noise(node):
  actions = np.arange(config.num_action)
  noise = np.random.gamma(config.root_dirichlet_alpha, 1, len(actions))
  frac = config.root_exploration_fraction
  for a, n in zip(actions, noise):
    node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac

def simulate(model, node):
    # env = copy.deepcopy(node.env)
    # done = node.done
    # G = 0.
    # rewards = []
    # while not done:
    #     action = np.random.choice(config.num_action)
    #     _, reward, terminated, truncated, info = env.step(action)
    #     done = terminated or truncated
    #     rewards.append(reward)
    # for reward in reversed(rewards):
    #     G = G * config.gamma + reward
    # env.close()
    if node.done is True:
        return torch.tensor([1./config.num_action for _ in range(config.num_action)]), torch.tensor([0.])

    with torch.no_grad():
        p, (V, _) = model(torch.tensor(np.array(node.state, dtype = np.float32), device = config.device).unsqueeze(0))
        p = F.softmax(p, -1)
    return p, V#G

def create_child(model, node):
    node.children = []
    with torch.no_grad():
        p, _ = model(torch.tensor(np.array(node.state, dtype = np.float32), device = config.device).unsqueeze(0))
        p = F.softmax(p, -1)
    for action in range(config.num_action):
        new_node = Node(p[0][action].item(), config.num_action, node.env, node, action, None)
        node.children.append(new_node)
    return node

def backpropagate(node, p, V, min_max_scale):
    while node:
        node.N += 1
        node.V += V
        V = config.gamma * V + node.reward
        min_max_scale.update(node.V / node.N)
        node = node.parent

def find_explore_node(explore_node, min_max_scale):
    current_node = explore_node
    while current_node.children and not current_node.done:
        children = current_node.children
        actions_score = [child.ucb_score(min_max_scale) for child in children]
        max_score = max(actions_score)
        best_children = [child for child, score in zip(children, actions_score) if score == max_score]
        idx = np.random.choice(len(best_children))
        current_node = best_children[idx]
    return current_node

def explore(model, explore_node, min_max_scale, is_test):
    current_node = find_explore_node(explore_node, min_max_scale)

    if current_node.children is None and not current_node.done:
        current_node = create_child(model, current_node)
        if current_node.lead_action is None and not is_test:
            add_exploration_noise(current_node)

    _, V = simulate(model, current_node)

    backpropagate(current_node, _, V.item(), min_max_scale)

def check_mcts(timeout, start_time, search_step, current_step):
    if timeout is None:
        return current_step < search_step
    return datetime.now() - start_time < timedelta(seconds=timeout)

def select_action(root_node, episode_data, is_test):
    if len(episode_data.actions) < config.num_sampling_moves and not is_test:
        #random with temp
        children = root_node.children
        Ns= np.array([child.N for child in children])
        probs = Ns / Ns.sum()
        action = np.random.choice(len(probs), 1, p=probs)[0]
        next_children = children[action]
        return action, probs, root_node, next_children
    else:
        children = root_node.children
        Ns= np.array([child.N for child in children])
        max_N = max(Ns)
        probs = Ns / Ns.sum()
        best_children = [child for child, action, N in zip(children, [child.lead_action for child in children], Ns) if N == max_N]
        best_actions = [action for child, action, N in zip(children, [child.lead_action for child in children], Ns) if N == max_N]
        best_child = best_children[0]
        best_child.parent = None
        best_child.lead_action = None
        best_child.reward = 0.
        best_child.done = False
        return best_actions[0], probs, root_node, best_child

def mcts(model, episode_data, root_state, root_env, timeout = 1, search_step = None, tree=None, is_test = False):
    min_max_scale = MinMaxScale()
    root_node = tree if (tree and config.reuse_tree) else Node(0, config.num_action, root_env, None, None, state = root_state)
    start_time = datetime.now()
    step = 0
    while check_mcts(timeout, start_time, search_step, step):
        explore(model, root_node, min_max_scale, is_test)
        step += 1
    return select_action(root_node, episode_data, is_test)

In [None]:
def play_episode(model, replay_buffer, is_test):
    env = gym.make(config.env_name, render_mode="rgb_array")
    state, info = env.reset()

    done = False
    episode_reward = 0
    episode_step = 0
    next_tree = None
    tree = None

    episode_data = Episode()
    action = None
    done = False

    while not done:
        next_action, probs, tree, next_tree = mcts(model, episode_data, state, env, config.timeout, config.search_step, next_tree, is_test)
        next_state, reward, terminated, truncated, info = env.step(next_action)
        episode_reward += reward
        episode_step += 1

        episode_data.save_step_data(state, action, reward, done, tree.V / tree.N, probs)

        action = next_action
        state = next_state

        done = terminated or truncated

    if not is_test:
        episode_data.save_total_rewards()
        replay_buffer.add(episode_data)

    return episode_reward, episode_step

In [None]:
def train(model, optimizer, replay_buffer, scheduler, support_values):
    optimizer.zero_grad()
    states, target_probs, target_values = replay_buffer.sample(config.batchsize)
    target_probs = torch.tensor(target_probs, device = config.device)
    target_values = torch.tensor(target_values, device = config.device).reshape(-1) #.unsqueeze(1)

    probs, (value, value_logit) = model(torch.tensor(np.array(states, dtype = np.float32), device = config.device))
    loss_P = F.cross_entropy(probs, target_probs)

    if config.V_loss_type == "CE":
        target_values_prob = scalar_to_support(target_values, support_values.to(target_values.device))
        loss_V = F.cross_entropy(value_logit, target_values_prob)
        with torch.no_grad():
            mse_V = F.mse_loss(value, target_values.unsqueeze(1)).item()
    else:
        loss_V = F.mse_loss(value, target_values.unsqueeze(1))
        mse_V = loss_V.item()

    loss = loss_V + loss_P
    loss.backward()
    optimizer.step()
    if config.use_scheduler:
        scheduler.step()
    return loss_P.item(), loss_V.item(), mse_V

In [None]:
episode_rewards = []
episode_steps = []
episode_runtimes = []

test_episode_rewards = []
test_episode_steps = []
test_episode_runtimes = []

support_values = torch.arange(0, config.max_support_value + 1, config.support_step).float()

model = Model(config.state_dim, config.num_action, support_values, config.V_loss_type)
model.to(config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
lr_gamma = (config.final_lr / config.learning_rate) ** (1 / config.total_decay_step)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: lr_gamma ** step)

replay_buffer = Replay_Buffer(config.buffer_size)

max_test_rewards = 0.

for episode in range(config.total_episode):
    start_time = datetime.now()
    episode_reward, episode_step = play_episode(model, replay_buffer, False)

    episode_rewards.append(episode_reward)
    episode_steps.append(episode_step)
    episode_runtimes.append(datetime.now() - start_time)
    print(f"episode: {episode}, rewards: {episode_reward}, episode_time: {episode_runtimes[-1]}, mean_rewards: {np.array(episode_rewards)[-min(100, len(episode_rewards)):].mean():.4f}")

    if episode % config.training_steps == 0 and len(episode_rewards) >= config.start_train_from_episode:
        losses_P = []
        losses_V = []
        mse_Vs = []
        for _ in range(config.epochs):
            loss_P, loss_V, mse_V = train(model, optimizer, replay_buffer, scheduler, support_values)
            losses_P.append(loss_P)
            losses_V.append(loss_V)
            mse_Vs.append(mse_V)
        print(f"Loss P: {np.array(losses_P).mean():.4f}, Loss V: {np.array(losses_V).mean():.4f}, MSE V: {np.array(mse_Vs).mean():.4f} lr: {optimizer.param_groups[0]['lr']:.6f}\n")

    if episode % config.save_step == 0:
        torch.save(model.state_dict(), f"{episode}.pth")

    if episode % config.testing_steps == 0 and len(episode_rewards) >= config.start_train_from_episode:
        for test_episode in range(config.total_test_episode):
            test_time = datetime.now()
            test_episode_reward, test_episode_step = play_episode(model, replay_buffer, True)

            test_episode_rewards.append(test_episode_reward)
            test_episode_steps.append(test_episode_step)
            test_episode_runtimes.append(datetime.now() - test_time)
            print(f"test episode: {test_episode}, rewards: {test_episode_reward}, episode_time: {test_episode_runtimes[-1]}, mean_rewards: {np.array(test_episode_rewards[-(test_episode+1):]).mean():.4f}")
        if np.array(test_episode_rewards[-config.total_test_episode:]).mean() > max_test_rewards:
            torch.save(model.state_dict(), config.save_path)
            max_test_rewards = np.array(test_episode_rewards[-config.total_test_episode:]).mean()
        print("\n")

episode_rewards = np.array(episode_rewards)
print(episode_rewards)
print(episode_rewards.max(), episode_rewards.min(), episode_rewards.mean(), episode_rewards.std())
print(max(episode_runtimes), min(episode_runtimes), np.mean(episode_runtimes))

In [None]:
import matplotlib.pyplot as plt

mean_rewards = []
for i in range(len(episode_rewards)):
    mean_rewards.append(np.array(episode_rewards)[max(0, i-100):i+1].mean())

# Plotting
plt.figure(figsize=(10, 5))
plt.plot(episode_rewards, label='Rewards')
plt.plot(mean_rewards, label='Mean Rewards', linestyle='--')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('Rewards vs Mean Rewards')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()