<a href="https://colab.research.google.com/github/AlexandreFleutelot/EWT_ESN/blob/main/TreeNode_RBF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np

In [None]:
from collections import deque
def print_tree(root, name=False):
    res = []
    q = deque([root])
    while q:
        row = []
        for _ in range(len(q)):
            node = q.popleft()
            if not node:
                row.append(".")
                continue
            if name:
                row.append(node.name)  
            else:
                row.append(node.label)
            q.append(node.left)
            q.append(node.right)
        res.append(row)
    rows = len(res)
    base = 2**(rows-1)
    for r in range(rows):
        for v in res[r]:
            print("." * (base), end = "",sep="")
            print(v, end = "",sep="")
            print("." * (base - 1), end = "",sep="")
        print("|")
        base //= 2

In [None]:
class TreeNode2:

    def __init__(self, name, dims=3):
        self.name = name
        self.center = np.random.uniform(-1,1,size=(dims))
        self.radius =  1
        self.dims = dims
        self.left = None  #TreeNode
        self.right = None  #TreeNode
        self.label = None

    def eval(self, inputs, indexes=None):

      if indexes is None:
        indexes = np.arange(len(inputs))
      self.indexes=indexes

      if self.right is None or self.left is None:
          return [self]
      else:
          norm = np.linalg.norm(inputs-self.center,axis=1)
          l_indexes = np.argwhere(norm < self.radius).flatten()
          r_indexes = list(set(range(len(inputs))) - set(l_indexes))
          r_ret = self.right.eval(inputs[r_indexes,:],indexes[r_indexes])
          l_ret = self.left.eval(inputs[l_indexes,:],indexes[l_indexes])
          return [self] + l_ret + r_ret

In [None]:
tree2 = TreeNode2(1)
tree2.left = TreeNode2(2)
tree2.left.left = TreeNode2(3)
tree2.left.right = TreeNode2(4)
tree2.right = TreeNode2(5)

In [None]:
inputs = np.random.uniform(-1,1,size=(20,3))
nodes = tree2.eval(inputs)

for node in nodes:
    print(node.name, node.indexes)

1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
2 [ 1  4  5  9 12 13 15 18]
3 []
4 [ 1  4  5  9 12 13 15 18]
5 [ 0  2  3  6  7  8 10 11 14 16 17 19]


In [None]:
def labelize(nodes, targets):
    for node in nodes:
        if node.right is None:
            if len(node.indexes):
                values, counts = np.unique(targets[node.indexes], return_counts=True)
                node.label = values[np.argmax(counts)]
            else:
                node.label = "#"
        else:
            node.label = "*"

In [None]:
targets = np.random.randint(0,3,(20))

labelize(nodes,targets)
print_tree(tree2)

........*.......|
....*.......0...|
..#...2.........|
........|


In [None]:
def score_MSE(nodes,targets):
    tot = 0
    for node in nodes:
        if node.right is None:
            if len(node.indexes):
                tot += np.count_nonzero(targets[node.indexes] == node.label)
    return tot/len(targets)

In [None]:
print(score_MSE(nodes,targets))

0.5


In [None]:
def generate_tree(parent,depth):
    if depth>0:
      parent.right = TreeNode2(parent.name+"r")
      generate_tree(parent.right,depth-1)
      parent.left = TreeNode2(parent.name+"l")
      generate_tree(parent.left,depth-1)


In [None]:
tree1 = TreeNode2("0")
generate_tree(tree1,3)
nodes1=tree1.eval(inputs)
labelize(nodes1,targets)
print_tree(tree1)

................*...............|
........*...............*.......|
....*.......*.......*.......*...|
..#...#...2...0...2...0...1...2.|
................................|


In [None]:
#cuting branch test
from copy import deepcopy

print("full tree:")
tree1 = TreeNode2("0")
generate_tree(tree1,3)
nodes1=tree1.eval(inputs)
labelize(nodes1,targets)
print_tree(tree1)

print("\nremaining:")
branch = deepcopy(tree1.right)
tree1.right.left = None
tree1.right.right = None
print_tree(tree1)

print("\nbranch:")
print_tree(branch)

full tree:
................*...............|
........*...............*.......|
....*.......*.......*.......*...|
..1...1...#...1...#...2...2...2.|
................................|

remaining:
................*...............|
........*...............*.......|
....*.......*...................|
..1...1...#...1.|
................|

branch:
........*.......|
....*.......*...|
..#...2...2...2.|
................|


In [None]:
#crossover

def crossover(tree1,tree2):
    input = np.zeros((1,tree1.dims))
    nodes1=tree1.eval(input)
    nodes2=tree2.eval(input)

    branch1 = np.random.choice(nodes1,1)[0]
    print(branch1)
    branch2 = np.random.choice(nodes2,1)[0]
    
    if np.random.randint(2):
        if np.random.randint(2):
            branch2.left,branch1.left = branch1.left,branch2.left
        else:
            branch2.right,branch1.right = branch1.right,branch2.right
    else:
        if np.random.randint(2):
            branch2.left,branch1.right = branch1.right,branch2.left
        else:
            branch2.right,branch1.left = branch1.left,branch2.right
   

In [None]:
tree1 = TreeNode2("0")
generate_tree(tree1,3)

tree2 = TreeNode2("1")
generate_tree(tree2,3)

print_tree(tree1, name=True)
print_tree(tree2, name=True)

crossover(tree1,tree2)
print_tree(tree1, name=True)
print_tree(tree2, name=True)

................0...............|
........0l...............0r.......|
....0ll.......0lr.......0rl.......0rr...|
..0lll...0llr...0lrl...0lrr...0rll...0rlr...0rrl...0rrr.|
................................|
................1...............|
........1l...............1r.......|
....1ll.......1lr.......1rl.......1rr...|
..1lll...1llr...1lrl...1lrr...1rll...1rlr...1rrl...1rrr.|
................................|
<__main__.TreeNode2 object at 0x7f156d508910>
................0...............|
........0l...............0r.......|
....0ll.......0lr.......1ll.......0rr...|
..0lll...0llr...0lrl...0lrr...1lll...1llr...0rrl...0rrr.|
................................|
................1...............|
........1l...............1r.......|
....0rl.......1lr.......1rl.......1rr...|
..0rll...0rlr...1lrl...1lrr...1rll...1rlr...1rrl...1rrr.|
................................|
