# Monte-Carlo Tree Search

In [1]:
import gym

import numpy as np
import pandas as pd

from abc import ABC, abstractmethod

from copy import deepcopy

In [2]:
# Set up environment and import first observation
env = gym.make("MountainCar-v0")
obs = env.reset()
print('Starting obs:', obs)

Starting obs: [-0.48451472  0.        ]


In [3]:
# Take action 1 (push cart to right and get new observation)
print ('\nTake action (push cart to right):')
action = 1
new_state, reward, done, info = env.step(action)
print('\nNew state:', obs)
print('Reward:', reward)
print('Done:', done)
print('Extra info:', info)


Take action (push cart to right):

New state: [-0.48451472  0.        ]
Reward: -1.0
Done: False
Extra info: {}


In [4]:
DISCRETE_OBS_SIZE = [20] * len(env.observation_space.high)
DISCRETE_OBS_SIZE

[20, 20]

In [5]:
def get_discrete_state(state, env):
    discrete_state = (state - env.observation_space.low) / discrete_obs_win_size
    return tuple(discrete_state.astype(np.int))

In [6]:
class TreeNode(object):
    '''
    Tree data structure to use with MCTS
    '''
    def __init__(self, n_legal_actions, parent=None, state=None):
        self.n_legal_actions = n_legal_actions
        self.children = [None for i in range(n_legal_actions)]
        self.visits = 0
        self.q_value = 0.0
    
        self.state = state
        self.parent = parent
        
        #array representation of children visits
        self.bandit_means = np.empty(n_legal_actions, np.float)
        self.actions = np.zeros(n_legal_actions, np.int32)
            
    def __iter__(self):
        self.n = 0
        return self

    def __next__(self):
        if self.n < self.n_legal_actions:
            self.n += 1
            return self.children[self.n-1]
        else:
            raise StopIteration
            
    def expand(self):
        for idx in range(self.n_legal_actions):
            self.children[idx] = TreeNode(self.n_legal_actions, parent=self)
            

In [7]:
x = TreeNode(3)

for child in x:
    print(child)

x.expand()

for child in x:
    print(child)

None
None
None
<__main__.TreeNode object at 0x7fe2ddae1c10>
<__main__.TreeNode object at 0x7fe2ddae1f70>
<__main__.TreeNode object at 0x7fe2ddae1df0>


In [8]:
class BanditAgent(ABC):
    '''
    Abstract base class for bandit agents
    '''
    @abstractmethod
    def select_action(self, node):
        pass
    
class EpsilonGreedy(BanditAgent):
    '''
    Epsilon-greedy agent for Tree Search.
    
    Explore epsilon of the time and exploit 1 - epsilon.
    '''
    def __init__(self, epsilon, random_state=None):
        self.epsilon = epsilon
        self._rand = np.random.RandomState(seed=random_state)
        
    def select_action(self, node):
        u = self._rand.random()
        if u > self.epsilon:
            child_index = np.argmax(node.bandit_means)
        else:
            child_index = self._rand.choice(node.bandit_means)
            
        return node.children[child_index], action

class UCB1(BanditAgent):
    '''
    UCB1 Agent for tree search
    '''
    def __init__(self):
        pass

    def select_action(self, node):
        '''
        Select the action to take from a tree node based on
        UCB1 score (highest confidence bound)
        
        UCB1 = mean_child + 2. Sqrt(ln(N_parent)/ n_child)
        
        where:
        
            mean_child = mean Q value of child
            
            N_parent = no. time parent visited.
            
            n_child = no. times child node has been visited.
            
        When n_child = 0 then UCB1 = inf.  In this case UCB1 chooses the first
        child node in the list with inf upper bound.
        
        Returns:
        -------
        TreeNode
            The selected child node.
        '''
        half_widths = (2 * np.sqrt((np.log(node.visits) / node.actions)))
        upper_bounds = node.bandit_means + half_widths
        action = np.argmax(upper_bounds)
        return node.children[action], action

In [9]:
class MCTS(object):
    def __init__(self, env, n_legal_actions, bandit_agent):
        self.env = env
        self.state = [] 
        self.root = TreeNode(n_legal_actions)
        self.bandit_agent = bandit_agent
        
    def solve(self, iterations=5):
        for i in range(iterations):
            leaf = self.traverse_tree(self.root)
            leaf.expand()
            self.rollout(leaf)
            self.backpropogate()
            
    def traverse_tree(self):
        '''
        Traverse the search tree.  Treats each tree node
        as a bandit problem.  
        
        Returns:
        ----------
        TreeNode
            Leaf node on the tree (i.e has not been visited)
        '''
        current_node = root
        while current_node.visits > 0:
            current_node, action = self.bandit_agent.select_action(current_node)
            new_state, reward, done, info = self.env.step(action)
                        
        return current_node
            
    def rollout(self, node):
        #copy states
        visited = {}
        states_copy = self.states[:]
        current_state = states_copy[-1]
        
    def backpropogate(self):
        pass
        

In [10]:
agent = MCTS(env=env, n_legal_actions=3, bandit_agent=UCB1())

In [18]:
import numpy as np


class RLTicTacToeAdaptor(object): 
    '''
    adapts TicTacToe so that it can be played by an agent
    '''
    def __init__(self, board_size=3):
        '''
        Constructor
        
        Parameters:
        --------
        board_size: int optional (default=3)
            board will be board_size X board_size
        '''
        self.board_size = board_size
        self.reset()
        self.REWARD_DRAW = 0.5
        self.REWARD_LOSE = 0.0
        self.REWARD_WIN = 1.0
        self.REWARD_STEP = -0.1
    
    def step(self, action):
        '''
        Returns:
        --------
        Tuple
            new_state, reward, done, info
        '''
        coord = self.actions[action]
        done = self.game.place_piece(coord)
        
        if done and self.draw:
            reward = self.REWARD_DRAW
        elif done:
            reward = self.REWARD_WIN
        else:
            reward = self.REWARD_STEP
          
        return self.get_hashable_board(), reward, done, _
    
    
    def reset(self):
        '''
        Reset game
        
        Returns:
        -------
            np.array
            initial state
            vector len = self.board_size x self_boardsize representing board
        '''
        self.game = TicTacToe(board_size=self.board_size)
        self.actions = {}
        
        #init action dict
        action = 0
        for row in range(self.board_size):
            for col in range(self.board_size):
                self.actions[action] = (row, col)
                action += 1
        
        return self.get_hashable_board()
        
    
    def get_hashable_board(self):
        '''
        Reshape board from matrix to vector
        
        Returns:
        ------
            np.array
            board as vector
        '''
        return self.game.board.reshape(self.board_size * self.board_size)

In [19]:
class TicTacToe(object):
    '''
    Basic TicTacToe game
    '''
    def __init__(self, board_size=3):
        self.rows = board_size
        self.cols = board_size

        self.NAUGHTS = 0
        self.CROSSES = 1
        self.EMPTY = -1
        
        #naughts go first.
        self.player = self.NAUGHTS
        self.board = np.full((self.rows, self.cols), self.EMPTY, dtype=np.int)
        self.draw = False

        
    def place_piece(self, coord):
        '''
        Player places piece on board
        
        Parameters:
        ------
        coord - array-like,
            x, y board coordinates
        '''
        self.board[coord[0]][coord[1]] = self.player
        done = self.terminal_state(coord)
        self.end_player_turn()
        return done

    def get_legal_moves(self):
        '''
        Return array of legal coordinates in the grid
        '''
        return np.transpose(np.nonzero(self.board == self.EMPTY))
    
    def get_plays_remaining(self):
        '''
        Get plays remaining (empty places on board)
        
        Returns:
        -------
        int
        '''
        return len(self.get_legal_moves())
    
    def legal_action(self, action): 
        '''
        Is action legal?
        
        Parameters:
        ---------
        action - int
            action represented as index
            
        Returns:
        ---------
        bool
            True/False action is legal
        '''
        if not action in self.actions:
            return False
        
        coords = self.actions[action]
        
        if coords not in self.get_legal_moves():
            return False
        return True

    def terminal_state(self, last_play_coord):
        '''
        terminal state (end of game) if
        
        all pieces in row are the same
        all pieces in a col are the same
        all pieces diag left to right are equal 
        all pieces diag right to left are equal
        there are no places left to play
        
        Parameters:
        ---------
        last_play_coord: array-like
            coordinates of last piece played
        
        Returns:
        -------
        bool
            True/False indicated the game has been won/drawn
        '''
        coord = last_play_coord
        
        #row of last play
        row = self.board[coord[0],:]
        
        #col of last play
        col = self.board[:,coord[1]]
        
        lr_diag = np.diag(self.board)
        
        #quick hack only works for 3 x 3 game
        rl_diag = np.array([self.board[0][2], self.board[1][1], 
                            self.board[2][0]])
        
        plays_remaining = self.get_plays_remaining()
            
        #draw
        if plays_remaining == 0:
            return True
        #row complete
        elif (row == self.player).astype(int).sum() == self.rows:
            return True
        #col complete
        elif (col == self.player).astype(int).sum() == self.cols:
            return True
        #left right diagonal complete
        elif (lr_diag == self.player).astype(int).sum() == self.rows:
            return True
        #right to left diagonal
        elif (rl_diag == self.player).astype(int).sum() == self.rows:
            return True
        else:
            return False 
        
    def end_player_turn(self):
        '''
        End players turn
        '''
        if self.player == self.CROSSES:
            self.player = self.NAUGHTS
        else:
            self.player = self.CROSSES

In [20]:
game = RLTicTacToeAdaptor()
state = game.reset()
state

array([-1, -1, -1, -1, -1, -1, -1, -1, -1])

In [22]:
new_state, reward, done, info = game.step(1)

In [23]:
new_state

array([-1,  1, -1, -1, -1, -1, -1, -1, -1])

In [24]:
reward

-0.1

In [25]:
done

False