In [1]:
import numpy as np

In [2]:
class TreeNode2:

    def __init__(self, name, dims=3):
        self.name = name
        self.center = np.random.uniform(-1,1,size=(dims))
        self.radius =  1
        self.dims = dims
        self.left = None  #TreeNode
        self.right = None  #TreeNode
        self.label = None
    
    def run(self,inputs):
        self.indexes=np.arange(len(inputs))
        outs = self._run(inputs,self.indexes)
        return [x for _,x in sorted(outs)]

    def _run(self, inputs, indexes):
        self.indexes = indexes
        if self.right is None or self.left is None:
              return [(i,self.label) for i in indexes]
        else:
              norm = np.linalg.norm(inputs-self.center,axis=1)
              l_indexes = np.argwhere(norm < self.radius).flatten()
              r_indexes = list(set(range(len(inputs))) - set(l_indexes))
              r_ret = self.right._run(inputs[r_indexes,:],indexes[r_indexes])
              l_ret = self.left._run(inputs[l_indexes,:],indexes[l_indexes])
              return l_ret + r_ret
    
    def clone(self) :
        tree = TreeNode2(self.name)
        if self.right is not None : tree.right = self.right.clone()
        if self.left is not None : tree.left = self.left.clone()
        return tree
    
    def cut(self):
        self.right=None
        self.left=None
        
    def grow(self,depth):
        if depth>0:
          self.right = TreeNode2(str(self.name)+"R")
          self.right.grow(depth-1)
          self.left = TreeNode2(str(self.name)+"L")
          self.left.grow(depth-1)
    
    def show(self, space=0, tab_size=6):
        space += tab_size
        if self.right is not None: self.right.show(space, tab_size)
        s = str(self.name) + "(" + str(self.label) + ")"
        print(" "*(space-tab_size), s)
        if self.left is not None: self.left.show(space, tab_size)  
    
    def to_list(self):
        if self.right is None or self.left is None:
            return [self]
        else:
            return [self] + self.left.to_list() + self.right.to_list()
    
    def depth(self):
        if self.right is None or self.left is None:
            return 1
        else:
            return 1 + max(self.right.depth(),self.left.depth())
    

In [3]:
# Testing Class methods

In [4]:
tree1 = TreeNode2("-") #init
tree1.grow(2) #grow
tree1.show() #show

             -RR(None)
       -R(None)
             -RL(None)
 -(None)
             -LR(None)
       -L(None)
             -LL(None)


In [5]:
branch = tree1.right.clone() #clone
branch.show()

       -RR(None)
 -R(None)
       -RL(None)


In [6]:
tree1.right.cut() #cut
tree1.show()

       -R(None)
 -(None)
             -LR(None)
       -L(None)
             -LL(None)


In [7]:
nodes = tree1.to_list() #to_list
print(nodes)

[<__main__.TreeNode2 object at 0x000001E3F9B48780>, <__main__.TreeNode2 object at 0x000001E3F9B48908>, <__main__.TreeNode2 object at 0x000001E3F9B480B8>, <__main__.TreeNode2 object at 0x000001E3F9B48278>, <__main__.TreeNode2 object at 0x000001E3F9B48080>]


In [8]:
inputs = np.random.uniform(-1,1,size=(20,3))
out = tree1.run(inputs) #run
print (out,'\n')

for node in tree1.to_list():
    print(node.name, node.indexes)

[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None] 

- [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
-L [ 1  6 11 13 16]
-LL [1 6]
-LR [11 13 16]
-R [ 0  2  3  4  5  7  8  9 10 12 14 15 17 18 19]


In [9]:
#outside methods

In [10]:
def crossover(tree1,tree2):
    new_tree = tree1.clone()
    rnd_node_tree1 = np.random.choice(new_tree.to_list(),1)[0]
    subs_tree2 = np.random.choice(tree2.to_list(),2)
    rnd_node_tree1.right = subs_tree2[0].clone()
    rnd_node_tree1.left = subs_tree2[1].clone()    
    return new_tree     

In [11]:
tree1 = TreeNode2("1")
tree1.grow(2)
tree2 = TreeNode2("2")
tree2.grow(2)

tree3 = crossover(tree1,tree2)
tree3.show()

                         2RR(None)
                   2R(None)
                         2RL(None)
             1RR(None)
                   2LL(None)
       1R(None)
             1RL(None)
 1(None)
             1LR(None)
       1L(None)
             1LL(None)


In [12]:
def mutate(tree, m_rate=0.1):
    new_tree = tree.clone()
    for node in new_tree.to_list():
        if np.random.rand(1) < m_rate :
            node.center = np.random.uniform(-1,1,size=(node.dims))
            node.radius =  np.random.rand(1)
    return new_tree    

In [13]:
outs = tree3.run(inputs) #run

for node in tree3.to_list():
    print(node.name, node.indexes)
    
tree4 = mutate(tree3)

outs = tree4.run(inputs) #run

for node in tree4.to_list():
    print(node.name, node.indexes)

1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
1L [4]
1LL []
1LR [4]
1R [ 0  1  2  3  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
1RL [ 0  9 10 13 16 17 18]
1RR [ 1  2  3  5  6  7  8 11 12 14 15 19]
2LL [ 1  6 11]
2R [ 2  3  5  7  8 12 14 15 19]
2RL [15]
2RR [ 2  3  5  7  8 12 14 19]
1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
1L [ 0  2 11 12 13 15 16 19]
1LL []
1LR [ 0  2 11 12 13 15 16 19]
1R [ 1  3  4  5  6  7  8  9 10 14 17 18]
1RL [ 3  5  8  9 14 18]
1RR [ 1  4  6  7 10 17]
2LL []
2R [ 1  4  6  7 10 17]
2RL [ 1  6 10]
2RR [ 4  7 17]


In [14]:
def labelize(tree, inputs, targets):
    tree.run(inputs)
    for node in tree.to_list():
        if len(node.indexes):
            values, counts = np.unique(targets[node.indexes],return_counts=True)
            node.label = values[np.argmax(counts)]

In [15]:
inputs = np.random.uniform(-1,1,size=(20,3))
targets = np.random.randint(4,size=(20))

tree3.run(inputs)
labelize(tree3,inputs,targets)
tree3.show()

for node in tree3.to_list():
    print(node.name,node.label, node.indexes)

                         2RR(0)
                   2R(0)
                         2RL(None)
             1RR(3)
                   2LL(3)
       1R(0)
             1RL(0)
 1(0)
             1LR(1)
       1L(1)
             1LL(2)
1 0 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
1L 1 [ 2  9 10 11 13]
1LL 2 [ 9 10]
1LR 1 [ 2 11 13]
1R 0 [ 0  1  3  4  5  6  7  8 12 14 15 16 17 18 19]
1RL 0 [ 1 19]
1RR 3 [ 0  3  4  5  6  7  8 12 14 15 16 17 18]
2LL 3 [ 3  5  6 12 16]
2R 0 [ 0  4  7  8 14 15 17 18]
2RL None []
2RR 0 [ 0  4  7  8 14 15 17 18]


In [16]:
print("score=",sum(tree3.run(inputs)==targets)/len(targets))

score= 0.5
