In [None]:
class UCT_Node:
    def __init__(self, game_state, move=None, parent=None, C=1.4):
        self.game_state = game_state
        self.move = move
        self.is_expanded = False
        self.parent = parent
        self.children = {}
        self.child_priors = np.zeros(
            [self.game_state.n_actions+1], dtype=np.float32)
        self.child_total_value =  np.zeros(
            [self.game_state.n_actions+1], dtype=np.float32)
        self.child_number_visits =  np.zeros(
            [self.game_state.n_actions+1], dtype=np.float32)
        self.C = C
        self.penalty = 0

        
    def any_argmax(self, aa):
        """
        stochastic function:
        returns any of the indices that have the maximum.
        """
        import random
        ind = np.argmax(aa)
        m = aa[ind]
        choices = np.where(aa == m)
        return random.choice(np.ndarray.tolist(choices[0]))        

    
    def child_Q(self):
        return self.child_total_value / (1 + self.child_number_visits)

    def child_U(self):
        # TODO: This is not reflecting the theory
        if self.number_visits == 0: 
            return np.zeros([5], dtype=float)
        else:
            return np.sqrt(np.log(self.number_visits) / 
                           (1 + self.child_number_visits)) 

    def best_child(self):
        qus = self.child_Q() + self.C * self.child_U()
        move = self.any_argmax(qus[1:])+1 # because we're starting at index 1
        if self.children:
            return self.children[move]
        else:
            return None

    
    def select_leaf(self):
        current = self
        while current.is_expanded:
            current.number_visits += 1
            current.total_value -= self.penalty #make it less attractive until backup
            current = current.best_child()
        return current
    
    def expand(self, child_priors):
        self.is_expanded = True
        self.child_priors = child_priors
        
        for move in self.game_state.legal_actions:
            self.add_child(move)
            
    def add_child(self, move):
        self.children[move] = UCT_Node(
            self.game_state.play(move), move=move, parent=self)
    
    def backup(self, value_estimate):
        current = self
        while current.parent is not None:
            current.total_value += (value_estimate * 
                                    self.game_state.to_play) + self.penalty
            current = current.parent
            
    @property
    def number_visits(self):
        if self.parent:
            return self.parent.child_number_visits[self.move]
        else:
            return 0
    
    @number_visits.setter
    def number_visits(self, value):
        #print("Value: %s" % value)
        #raise ValueError()
        if self.parent:
            self.parent.child_number_visits[self.move] = value
        
    @property
    def total_value(self):
        if self.parent:
            return self.parent.child_total_value[self.move]
        else:
            return 0
    
    @total_value.setter
    def total_value(self, value):
        if self.parent:
            self.parent.child_total_value[self.move] = value

In [None]:
def uct_search(game_state, policy, num_reads, verbose=0):
    root = UCT_Node(game_state)
    for _ in range(num_reads):
        leaf = root.select_leaf()
        child_priors, _,  value_estimate = policy.evaluate(leaf.game_state)
        #print("Value: %s" % value_estimate)
        if value_estimate == 0:
            leaf.expand(child_priors)
        leaf.backup(value_estimate)
    return max(root.children.items(), 
               key = lambda item: item[1].number_visits)