# Monte-Carlo Tree Search

In [1]:
import gym

import numpy as np
import pandas as pd

from abc import ABC, abstractmethod

from copy import deepcopy

In [2]:
# Set up environment and import first observation
env = gym.make("MountainCar-v0")
obs = env.reset()
print('Starting obs:', obs)

Starting obs: [-0.55337788  0.        ]


In [3]:
# Take action 1 (push cart to right and get new observation)
print ('\nTake action (push cart to right):')
action = 1
new_state, reward, done, info = env.step(action)
print('\nNew observation:', obs)
print('Reward:', reward)
print('Done:', done)
print('Extra info:', info)


Take action (push cart to right):

New observation: [-0.55337788  0.        ]
Reward: -1.0
Done: False
Extra info: {}


In [4]:
DISCRETE_OBS_SIZE = [20] * len(env.observation_space.high)
DISCRETE_OBS_SIZE

[20, 20]

In [5]:
def get_discrete_state(state, env):
    discrete_state = (state - env.observation_space.low) / discrete_obs_win_size
    return tuple(discrete_state.astype(np.int))

In [23]:
class TreeNode(object):
    '''
    Tree data structure to use with MCTS
    '''
    def __init__(self, n_legal_actions, parent=None, state=None):
        self.n_legal_actions = n_legal_actions
        self.child_nodes = {i:None for i in range(n_legal_actions)}
        self.visits = 0
        self.q_value = 0.0
    
        self.state = state
        self.parent = parent
        
        #array representation of children visits
        self.bandit_means = np.empty(n_legal_actions, np.float)
        self.actions = np.zeros(n_legal_actions, np.int32)
            
    def __iter__(self):
        self.n = 0
        return self

    def __next__(self):
        if self.n < self.n_legal_actions:
            self.n += 1
            return self.child_nodes[self.n-1]
        else:
            raise StopIteration
            
    def expand(self):
        for key in range(self.n_legal_actions):
            self.child_nodes[key] = TreeNode(self.n_legal_actions, parent=self)
            

In [24]:
x = TreeNode(3)

for child in x:
    print(child)

x.expand()

for child in x:
    print(child)

None
None
None
<__main__.TreeNode object at 0x7f7e471f01f0>
<__main__.TreeNode object at 0x7f7e471f0940>
<__main__.TreeNode object at 0x7f7e4726c3d0>


In [22]:
class BanditAgent(ABC):
    '''
    Abstract base class for bandit agents
    '''
    @abstractmethod
    def select_action(self, node):
        pass
    
class EpsilonGreedy(BanditAgent):
    '''
    Epsilon-greedy agent for Tree Search
    '''
    def __init__(self, epsilon, random_state=None):
        self.epsilon = epsilon
        self._rand = np.random.RandomState(seed=random_state)
        
    def select_action(self, node):
        u = self._rand.random()
        if u > self.epsilon:
            child_key = np.argmax(node.bandit_means)
        else:
            child_key = self._rand.choice(node.bandit_means)
            
        return parent.child_nodes[child_key]

class UCB1(BanditAgent):
    '''
    UCB1 Agent for tree search
    '''
    def __init__(self):
        pass

    def select_action(self, node):
        '''
        Select the action to take from a tree node based on
        UCB1 score (highest confidence bound)
        '''
        half_widths = (2 * np.sqrt((np.log(node.visits) / node.actions)))
        upper_bounds = node.bandit_means + half_widths
        child_key = np.argmax(upper_bounds)
        return node.child_nodes[child_key]

In [16]:
class MCTS(object):
    def __init__(self, env, n_legal_actions, bandit_agent):
        self.env = env
        self.states = [] 
        self.root = TreeNode(n_legal_actions)
        self.bandit_agent = bandit_agent
        
    def solve(self, iterations=5):
        for i in range(iterations):
            leaf = self.traverse_tree(self.root)
            leaf.expand()
            self.rollout(leaf)
            self.backpropogate()
            
    def traverse_tree(self):
        '''
        Traverse the search tree.  Treats each tree node
        as a bandit problem.  
        
        Returns:
        ----------
        TreeNode
            Leaf node on the tree (i.e has not been visited)
        '''
        current_node = root
        while current_node.visits > 0:
            current_node = self.bandit_agent.select_action(current_node)
        
        return current_node
            
    def rollout(self, node):
        #copy states
        visited = {}
        states_copy = self.states[:]
        current_state = states_copy[-1]
        
    def backpropogate(self):
        pass
        

In [19]:
agent = MCTS(env=env, n_legal_actions=3, bandit_agent=UCB1())