In [19]:
import math
import random
import import_ipynb
from operations import *
from helper import *

In [27]:
unary = unary_functions
binary = binary_functions 
unary_str = unary_functions_str
binary_str = binary_functions_str

M = len(unary)                  #Number of unary operations
N = len(binary)                 #Number of binary operations
N_sim = 5                       #Number of simulation
N_step = 5                      #Number of steps per simulation
node_levels = {}                #Dictionary maintaining a list of nodes - values - for each level - key - of tree

In [8]:
class Node:
    """
    Node class representing a node in a graph
    """
    def __init__(self, state, parent=None, level):
        #state parameter encodes the mathematical expression - binary or unary - that is represented by the node
        self.state = state        
        self.parent = parent
        self.level = level
        self.children = []
        self.visits = 0
        self.value = 0
        
    def add_child(self, child):
        """
        Add a child node to the current node.
        """
        self.children.append(child)
        
    def update(self, score):
        """
        Update the node's visit count and total score.
        """
        self.value += score
        self.visits += 1

In [25]:
class MCTS:
    """
    Monte Carlo Tree Search Class
    """
    
    def __init__(self, initial_state):
        self.root = Node(initial_state)
        
    ################################################################
    ########################### SELECTION ##########################
    ################################################################
        
    def select_best_child(self, node):
        """
        Select the best child node using the UCB1 rule and/or maximum probability rule.
        """
        # Exploration constant
        exploration_constant = 0.8                           
        # Sum of values of children of the node
        children_value_sum = sum(node.value for node in node.children)  
        
        def ucb_score(child):
                """
                Calculate the UCB1 score for a child node.
                """
                exploration_term = exploration_constant * np.sqrt(np.log(node.visits) / child.visits)
                exploitation_term = child.value / child.visits
                return exploitation_term + exploration_term

        def probability(child):
            """
            Calculate the probability associated to the child node based on the noramlized value of the 
            node.
            """
            child_value = child.value
            child_probability = child_value/children_value_sum
            return child_probability

        return max(node.children, key=probability)
        #return max(node.children, key=ucb_score)
        
    def select_node(self, node, level):
        """
        Select the best child node based until a leaf node is reached.
        """
        while node.children:
            node = self.select_best_child(node)
            node.visits += 1
            level += 1
        return node, level
    
    ################################################################
    ########################### EXPANSION ##########################
    ################################################################
    
    
    def expand_node(self, node, level):
        """
        Expand the given node by generating all possible child states and creating child nodes for each state.
        Return a list of child states.
        Input: node, node at which expansion takes place
               level, height of the node in the tree
        """
        
        #children = []
        
        if level%2 == 1:
        
            for index in range(M):
                
                unary_expression = unary[index]
                unary_expresion_str = unary_str[index]
                

                child = Node( unary_expression ) 
                child.parent = node                  # Set the parent of the child node
                child.visits = 0
                child.value = 0
                child.level = level + 1
                
                node.add_child(child)
                
        else:
        
            for index in range(N):
                
                binary_expression = binary[index]
                binary_expresion_str = binary_str[index]
                
                siblings = node_levels[level]
                
                for sibling in siblings:

                    child = Node( binary_expression ) 
                    child.parent = node                  # Set the parent of the child node
                    child.visits = 0
                    child.value = 0
                    child.level = level + 1

                    node.add_child(child)
            
                     
    ################################################################
    ########################## SIMULATION ##########################
    ################################################################
    
    def simulate_random_play(self, node):
        """
        Simulate a random playout from the given node.
        Return the score of the playout.
        """
        
        for dummy_steps in range(N_step):
            
            
        
    
    def simulate(self, node):
        """
        Simulate a random playout from the given node's state.
        """
        
        for dummy_sim in range(N_sim):
            
            
        
        
        return self._simulate_random_playout(node.state)
    
    
    ################################################################
    ################################################################  
    
    
    def backpropagate(self, node, score):
        """
        Update the node and its ancestors with the given score by backpropagation.
        """
        while node is not None:
            node.update(score)
            node = node.parent

In [None]:
x=Node(2)
y=Node(1)
x.add_child(y)
x.update(23)