# Monte Carlo Tree Search

In [25]:
import numpy as np
from random import uniform

In [512]:
class Node:
    def __init__(self, depth=20, parent=None, name='', B=10, tau=None):
        self.parent = parent
        self.name = name
        self.B = B
        self.tau = depth/5 if tau is None else tau
        if depth > 0:
            self.left = Node(depth=depth-1, parent=self, name=name+"L", B=B, tau=self.tau)
            self.right = Node(depth=depth-1, parent=self, name=name+"R", B=B, tau=self.tau)
            self.leaf = False
        else:
            self.left, self.right = None, None
            self.leaf = True
    
    def __repr__(self):
        return ("Leaf n" if self.leaf else "N") + f"ode {self.name}" if self.name != '' else "Root node"
    
    def select_target(self):
        if self.leaf:
            return self.name
        # Selecting a target by randomly and recursively choosing between children
        self.target = self.left.select_target() if uniform(0, 1) < 0.5 else self.right.select_target()
        return self.target
    
    # Assign values to every leaf node based on the distance from the target
    def assign_values(self, target=None):
        if target is None:
            target = self.target
        if self.leaf:
            self.value = self.B*np.exp(-self.distance(target)/self.tau) + np.random.randn()
        else:
            self.left.assign_values(target=target)
            self.right.assign_values(target=target)
    
    def distance(self, target):
        return (~(np.asarray(list(self.name)) == np.asarray(list(target)))).sum()
    
    # Useful for checking the values of each leaf node but should only be called for small trees
    def list_leaves(self):
        if self.leaf:
            print(f"{self}: {self.value:.04f}")
            return
        self.left.list_leaves()
        self.right.list_leaves()

In [517]:
class MonteCarloTreeSearch:
    def __init__(self, tree, c=1):
        self.tree = tree
        self.current_root = tree
        self.initialize_node(tree)
        self.c = c
        self.n = 0
    
    def search(self, n=50):
        while not self.current_root.leaf:
            self.iteration()
            if self.current_root.n >= n:
                self.current_root = self.better_child(self.current_root)
    
    def iteration(self):
        node = self.current_root
        while True:
            node = self.better_child(node)
            # Stop if a leaf or unvisited node is reached. In the latter case use rollout
            if node.leaf or node.n == 0:
                # Only one rollout per "snowcap" leaf node in this implementation
                value = node.value if node.leaf else self.rollout(node)
                self.backprop(node, value)
                self.n += 1
                break
    
    def better_child(self, node):
        for child in [node.left, node.right]:
            # Initialize child if it has not been been previously visited
            if not hasattr(child, 'value'):
                self.initialize_node(child)
            self.calculate_value(child)
        return max([node.left, node.right], key=lambda x: x.value)
    
    def initialize_node(self, node):
        node.total = 0
        node.n = 0
    
    def calculate_value(self, node):
        # If the node is a leaf node, its value is known
        if node.leaf: return
        # Because of division by 0 in the UCB, the initial value is set to infinity
        node.value = np.infty if node.n == 0 else node.total / node.n + self.c * np.sqrt(np.log(self.n) / node.n)
        
    def rollout(self, node):
        if node.leaf:
            return node.value
        # Choose left or child with equal probability for the rollout policy
        return self.rollout(node.left if uniform(0, 1) < 0.5 else node.right)
    
    def backprop(self, node, value):
        # Leaf node values do not need to be estimated
        if not node.leaf:
            node.total += value
            node.n += 1
        while node.parent is not None:
            node = node.parent
            node.total += value
            node.n += 1
    
    def target(self):
        node = self.tree
        while not node.leaf:
            # If the search has not converged, the output will be somewhat random
            node = self.better_child(node) 
        return node.name

In [530]:
tree = Node(depth=20)
tree.select_target()
tree.assign_values()
tree.target

'LRLRLRLLRLLLLLLRLLRL'

In [531]:
mcts = MonteCarloTreeSearch(tree)
mcts.search(n=50)
mcts.target()

'RLLLLRLLRRLLLLRRLLRR'

In [532]:
(~(np.asarray(list(tree.target)) == np.asarray(list(mcts.target())))).sum()

6