In [29]:
import torch
import numpy as np
from numpy import sin,cos
import matplotlib.pyplot as plt
from scipy.optimize import minimize,root,bisect
import scipy
import os
import sys

In [30]:
n = 20

ele = ele = np.linspace(1e-9,1,n)
x_0 = np.kron(ele,np.ones(len(ele)))
x_1 = np.kron(np.ones(len(ele)),ele)

In [31]:
data_dict = {}

data_dict['idx_to_op'] = ['<NULL>','<START>','<END>','+','-','*','/','sin','cos','c','x_0','x_1']
data_dict['opt_to_idx'] = {'<NULL>':0,'<START>':1,'<END>':2,'+':3,'-':4,'*':5,'/':6,'sin':7,'cos':8,'c':9,'x_0':10,'x_1':11}

In [32]:
class binary_node():
    def __init__(self,op):
        self.op = op
        self.constant = 0
        self.left = None
        self.right = None

In [33]:
class unary_node():
    def __init__(self,op):
        self.op = op
        self.child = None

In [34]:
class Tree():
    def __init__(self):
        self.root = None

In [35]:
def _is(op):
    if op ==9:
        return 'constant'
    elif op>=10:
        return 'variable'
    elif op==7 or op==8:
        return 'unary'
    else:
        return 'binary'

In [36]:
def sampling(span = 1):
    if span ==1:
        op = np.random.choice([3,4,5,6,7,8,9,10,11],1)[0]
        if _is(op)=='unary':
            childNode = unary_node(op)
        else:
            childNode = binary_node(op)
    elif span==2:
        op = np.random.choice([3,4,5,6,7,8],1)[0]
        childNode = binary_node(op)
    elif span==3:
        op = np.random.choice([3,4,5,6,9,10,11],1)[0]
        childNode = binary_node(op)
    else:
        op = np.random.choice([9,10,11],1)[0]
        childNode = binary_node(op)
        
    return op,childNode

In [37]:
def GenOperator(OpTree,complexity,root,nodes,unary):
    
    #print('Node:{}'.format(root.op))
    if len(nodes) >= complexity:
        return
    
    if _is(root.op)== 'constant' or _is(root.op)== 'variable':
        return
    elif _is(root.op)=='unary':
        op,ChildNode = sampling(span=3)
        if _is(op)=="unary":
            unary=True
        nodes.append(op)
        root.child = ChildNode
        GenOperator(OpTree,complexity,ChildNode,nodes,unary)
    else: 
        if unary==True:
            op1,LeftNode = sampling(span=3)
        else:
            op1,LeftNode = sampling()
        
        if _is(op1)=="unary":
            unary=True
       
        nodes.append(op1)
        root.left = LeftNode

        
        if unary==True:
            op2,RightNode = sampling(span=3)
        else:
            op2,RightNode = sampling()
        
        if _is(op2)=="unary":
            unary=True
            
        nodes.append(op2)
        root.right = RightNode

        GenOperator(OpTree,complexity,LeftNode,nodes,unary)
        GenOperator(OpTree,complexity,RightNode,nodes,unary)
        
        
   

In [38]:
def organizeTree(OpTree):
    root = OpTree.root
    queue = [root]
    con_cnt = 1
    
    while len(queue)>0:
        
        cur_node = queue.pop(0)
        print(cur_node.op,data_dict['idx_to_op'][cur_node.op])
        if _is(cur_node.op)=='unary':
            child = cur_node.child
            if child == None:
                op,ChildNode = sampling(span = 4)
                cur_node.child = ChildNode
                child = ChildNode
            
            queue.append(child)
                
        else:
            left = cur_node.left
            right = cur_node.right
        
            if left!=None and right!=None:
                queue.append(left)
                queue.append(right)
            
            else: 
                if _is(cur_node.op)!='constant' and _is(cur_node.op)!='variable':
                    op,_ = sampling(span = 4)
                    cur_node.op = op
            
                if _is(cur_node.op) == 'constant':
                    cur_node.constant = con_cnt
                    con_cnt+=1

In [39]:
def evaluateExpressionTree(root):

    if root is None:
        return 0
    
    
    if _is(root.op) =='unary':
        child_sum = evaluateExpressionTree(root.child)
    
    else:
        if root.left is None and root.right is None:
            if root.op == 9:
                ans = np.random.rand()*10
                print(root.constant,ans)
                return ans
            elif root.op ==10:
                return x_0
            elif root.op ==11:
                return x_1
        else:    
            left_sum = evaluateExpressionTree(root.left)
            right_sum = evaluateExpressionTree(root.right)

    # check which operation to apply
    if root.op == 3:
        return left_sum + right_sum
 
    elif root.op == 4:
        return left_sum - right_sum
 
    elif root.op == 5:
        return left_sum * right_sum
 
    elif root.op == 6:
        return left_sum / right_sum
    
    elif root.op == 7:
        return np.sin(child_sum)
    
    elif root.op == 8:
        return np.cos(child_sum)
    
    
    


In [91]:
def local_optimizer(OpTree,num_c,result):
    root = OpTree.root
    constants = np.ones(num_c)*5
    #objfn = lambda x: np.sqrt(np.mean((evaluateAgraph(root,x) - result)**2))
    objfn = lambda x: np.sqrt(np.mean((evaluateAgraph(root,x) - result)**2))/(np.max(result) - np.min(result))
    #objfn = lambda x: np.mean(abs(np.mean((evaluateAgraph(root,x) - result)/(abs(evaluateAgraph(root,x)) + abs(result)))
    
    bnds = [(0,10) for x  in range(num_c)]
    #res = minimize(objfn,constants,bounds = bnds,method="L-BFGS-B",callback=new_callback(),options={'gtol': 1e-10, 'disp': True})
    res = minimize(objfn,constants,bounds = bnds,method='Nelder-Mead', tol=1e-10,callback=new_callback())
    constants = res.x
    print('loss:',objfn(constants))
    return constants

def new_callback():
    step = 1

    def callback(xk):
        nonlocal step
        #print('Step #{}: xk = {}'.format(step, xk))
        step += 1

    return callback

def obj_fn(x,root,result):
    diff = evaluateAgraph(root,x) - result
    return np.mean(np.sqrt(diff**2))
    

def evaluateAgraph(root,constants):
    
    if root is None:
        return 0
    
    
    if _is(root.op) =='unary':
        child_sum = evaluateAgraph(root.child,constants)
    
    else:
        if root.left is None and root.right is None:
            if root.op == 9:
                ind = root.constant-1
                ans = constants[ind]
                return ans
            elif root.op ==10:
                return x_0
            elif root.op ==11:
                return x_1
        else:    
            left_sum = evaluateAgraph(root.left,constants)
            right_sum = evaluateAgraph(root.right,constants)

    # check which operation to apply
    if root.op == 3:
        return left_sum + right_sum
 
    elif root.op == 4:
        return left_sum - right_sum
 
    elif root.op == 5:
        return left_sum * right_sum
 
    elif root.op == 6:
        return left_sum / right_sum
    
    elif root.op == 7:
        return np.sin(child_sum)
    
    elif root.op == 8:
        return np.cos(child_sum)
    

In [92]:
def ExpressionString(OpTree):
    root = OpTree.root
    
    def recursive_exp(root):
        
        if root is None:
            return ''

        if _is(root.op) =='unary':
            child_sum = recursive_exp(root.child)

        else:
            if root.left is None and root.right is None:
                if root.op == 9:
                    return 'c_{}'.format(root.constant-1)
                elif root.op ==10:
                    return 'x_0'
                elif root.op ==11:
                    return 'x_1'
            else:    
                left_sum = recursive_exp(root.left)
                right_sum = recursive_exp(root.right)

        # check which operation to apply
        if root.op == 3:
            return '('+left_sum + '+' +right_sum +')'

        elif root.op == 4:
            return '('+left_sum + '-' + right_sum+')'

        elif root.op == 5:
            return '('+left_sum + '*' + right_sum+')'

        elif root.op == 6:
            return '('+left_sum + '/' + right_sum +')'

        elif root.op == 7:
            return 'sin({})'.format(child_sum)

        elif root.op == 8:
            return 'cos({})'.format(child_sum)
        
    Expression = recursive_exp(root)
    return Expression
    
    


In [93]:
def bfs(OpTree):
    root = OpTree.root
    queue = [root]
    
    while len(queue)>0:
        
        cur_node = queue.pop(0)
        #print(cur_node.op,data_dict['idx_to_op'][cur_node.op])
        if cur_node.op==7 or cur_node.op==8:
            child = cur_node.child
            queue.append(child)
        else:
            left = cur_node.left
            right = cur_node.right

            if left!=None and right!=None:
                queue.append(left)
                queue.append(right)
            
            

In [94]:
def GenCaption(ETree):
    preorder_array = []
    root = ETree.root
    
    def preorder(root):

        if root:
            
            #print(root.op, data_dict['idx_to_op'][root.op])
            preorder_array.append(root.op)
            
            if _is(root.op)=='unary':
                preorder(root.child)
            else:
                preorder(root.left)
                preorder(root.right)
    preorder(root)  
    return preorder_array 

In [95]:
def main(complexity):
    
    """
    1. Generate Tree
    """
    OpTree = Tree()
    op,childNode = sampling(span=2)
    if _is(op)=='unary':
        unary=True
    else:
        unary=False
    
    OpTree.root = childNode
    nodes = [op]
    """
    2. Generate Opertor
    """
    GenOperator(OpTree,complexity,OpTree.root,nodes,unary)
    organizeTree(OpTree)
    print()
    """
    3. Generte caption
    """
    preorder = GenCaption(OpTree)
    """
    4. Generate numerical data
    """
    root = OpTree.root
    result = evaluateExpressionTree(root)
    
    exp = ExpressionString(OpTree)
    num_c = exp.count('c_')
    print(exp)
    #print(eval(exp)==evaluateExpressionTree(root))
    if num_c>0:
        constants = local_optimizer(OpTree,num_c,result)
        print(constants)
    return preorder

In [106]:
complexity = 15

preorder = main(complexity)

7 sin
10 x_0

sin(x_0)


In [59]:
print(cos(3.495938812876779))
print(cos(6.2831853))

-0.9378735571098693
1.0


In [46]:
v =[1]
obj = lambda x: abs(np.cos(0.75*np.pi) - np.cos(x))
res = minimize(obj,v,options={'gtol': 1e-10, 'disp': True},method="L-BFGS-B")
v = res.x
print(v)
print(obj(v))
#print(np.cos(0.8051443878887125))

[2.35619449]
[1.58314528e-09]


In [47]:
def pre_to_tree(preorder):
    op = preorder[0]
    Node = binary_node(op)
    OpTree = Tree()
    OpTree.root = Node
    con_cnt = 1
    
    if op>=9:
        return OpTree
    
    stack = [Node]
    for i in range(1,len(preorder)):
        
        op = preorder[i]
        Node = binary_node(op)
        prev_node = stack[-1]
        if _is(prev_node.op)=='binary':
            if prev_node.left==None:
                prev_node.left = Node
            else:
                prev_node.right = Node
                stack.pop()
        elif _is(prev_node.op)=='unary':
            prev_node.child = Node
            stack.pop()
        
        if op<9:
            stack.append(Node)
        else:
            if _is(op)=='constant':
                Node.constant = con_cnt
                con_cnt+=1
    return OpTree

In [48]:
TestTree = pre_to_tree(preorder)
exp = ExpressionString(TestTree)
print(exp)

(x_1*((((x_1*(c_0+x_0))-(c_1+x_1))/c_2)/cos(x_0)))
