In [1]:
import gym
from mcts_simple import Game
from copy import deepcopy

# For rendering
import imageio
import base64
import IPython

class CartPole(Game):
    def __init__(self, video = "cartpole.mp4", seed = 0):
        # mcts-simple at this point in time does not accept seedless value
        # However, you can always keep exporting and importing mcts for each CartPole with different seeds (only can be used for closed loop MCTS)
        self.prev_env = None
        self.env = gym.make("CartPole-v0")
        self.env.seed(0)
        self.states = [self.env.reset()]
        
        self.video = video
        self.frames = []
        
        self.done = False
        self.win = False

    def render(self):
        # Don't worry about the pop up windows, it's a feature not a bug.
        # Environment has to be deepcopied such that the following error will not be raised:
        # ValueError: ctypes objects containing pointers cannot be pickled
        # If you find a better solution, do contribute to the repository.
        env = deepcopy(self.env)
        self.frames.append(env.render(mode = "rgb_array"))
        env.close()
        del env
        
        if self.has_outcome():
            imageio.mimwrite(self.video, self.frames, "MP4", fps = 10, **{"macro_block_size": None})
            with open(self.video, "rb") as f: # base64 encoding to support linux systems
                video = f.read()
                b64_video = base64.b64encode(video)
                f.close()
            IPython.display.display(IPython.display.HTML(data = f"""
            <video controls src = "data:video/mp4;base64,{b64_video.decode()}"></video>
            """))

    def get_state(self):
        # I really want to do rounding to discretise such continuous state space,
        # but rounding leads to state space inconsistency which was a pain to deal with since I did not know how to fix it.
        # If you know how to, please contribute to the repository.
        # Note that rounding still works for MCTS but not for UCT.
        # It could also be due to my Jupyter notebook hallucinating, but we will never know unless someone else tries.
        return tuple(self.states[-1])

    def number_of_players(self):
        return 1

    def current_player(self):
        return 1

    def possible_actions(self):
        return [i for i in range(self.env.action_space.n)] # returns list of int (doesn't have to be string since human play is illegal by default)

    def take_action(self, action):
        if action not in self.possible_actions():
            raise RuntimeError("Action taken is invalid.")
        self.prev_env = deepcopy(self.env)
        observation, reward, done, info = self.env.step(action)
        self.states.append(observation)
        if done:
            self.done = True
            if self.env._elapsed_steps >= self.env._max_episode_steps:
                self.win = True

    def delete_last_action(self):
        if self.prev_env is None:
            raise RuntimeError("No last action to delete.")
        self.env = self.prev_env
        self.prev_env = None
        self.states.pop()
        self.done = False
        self.win = False

    def has_outcome(self):
        return self.done

    def winner(self):
        if not self.has_outcome():
            raise RuntimeError("winner() cannot be called when outcome is undefined.")
        if self.win:
            return self.current_player()
        else:
            return -1

In [2]:
from mcts_simple import OpenLoopUCT

## UCT is required for exploration ##
## (closed loop) UCT would take a long time to be trained due to large branching factor before reaching terminal node, so it will not be used in this example ##

# Export trained open loop UCT
print("Export trained open loop UCT")
OL_uct = OpenLoopUCT(CartPole(video = "cartpole_open_loop_UCT.mp4", seed = 0))
OL_uct.run(iterations = 500000)
OL_uct._export("CartPole_open_loop_UCT.json")

# Import trained open loop UCT
print("Import trained open loop UCT")
OL_uct = OpenLoopUCT(CartPole(video = "cartpole_open_loop_UCT.mp4", seed = 0))
OL_uct._import("CartPole_open_loop_UCT.json")
OL_uct.self_play(activation = "best")

Export trained open loop UCT


HBox(children=(FloatProgress(value=0.0, description='Simulating', max=500000.0, style=ProgressStyle(descriptio…


Import trained open loop UCT
