In [2]:
import numpy as np

In [3]:
X = np.random.random(size=(80, 100)) > 0.3

In [4]:
XX = [np.nonzero(row)[0] for row in X]
XX

[array([ 1,  2,  4,  5,  7,  8,  9, 10, 12, 13, 16, 17, 18, 19, 20, 21, 22,
        26, 27, 28, 31, 32, 33, 36, 37, 38, 39, 40, 41, 42, 43, 47, 49, 50,
        51, 52, 53, 56, 57, 58, 59, 60, 61, 63, 65, 66, 68, 69, 70, 71, 72,
        73, 74, 76, 77, 78, 81, 82, 84, 85, 86, 88, 89, 90, 91, 92, 93, 94,
        95, 96, 97, 98, 99]),
 array([ 0,  1,  2,  4,  5,  6,  8,  9, 11, 12, 13, 14, 16, 17, 18, 19, 20,
        23, 24, 25, 26, 27, 28, 29, 30, 33, 34, 35, 36, 38, 39, 40, 41, 43,
        45, 46, 47, 48, 49, 52, 53, 55, 56, 57, 58, 59, 60, 61, 63, 65, 66,
        68, 69, 70, 71, 73, 74, 76, 79, 80, 82, 83, 84, 85, 87, 89, 90, 92,
        94, 96, 97, 98, 99]),
 array([ 0,  2,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 17, 19, 20, 21,
        23, 24, 25, 27, 29, 30, 31, 33, 35, 36, 37, 38, 39, 41, 43, 44, 45,
        46, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
        65, 67, 68, 69, 72, 73, 75, 76, 78, 79, 80, 81, 82, 83, 84, 85, 86,
        89, 90, 91, 93, 94, 

In [5]:
#class select_random_path(object):    
#    def __call__(self):
#        while True:
#            self.cnt += 1
#            yield XX[np.random.randint(low=0, high=len(XX))]

def select_random_path():    
    while True:
        yield XX[np.random.randint(low=0, high=len(XX))]

In [6]:
class RITNode(object):
    def __init__(self, val):
        self._val = val
        self._children = []

    def is_leaf(self):
        return len(self._children) == 0

    @property
    def children(self):
        return self._children
    
    def add_child(self, val):
        val_intersect = np.intersect1d(self._val, val)
        self._children.append(RITNode(val_intersect))
        
    def is_empty(self):
        return len(self._val) == 0
        
    def is_leaf(self):
        return len(self._children) == 0
    
    @property
    def nr_children(self):
        return len(self._children) + sum(child.nr_children for child in self._children)
    
    def _traverse_depth_first(self, _idx):
        yield _idx[0], self
        for child in self.children:
            _idx[0] += 1
            yield from RITNode._traverse_depth_first(child, _idx=_idx)    

class RITTree(RITNode):
    def __len__(self):
        return self.nr_children + 1
    
    def traverse_depth_first(self):
        yield from RITNode._traverse_depth_first(self, _idx=[0])

In [7]:
from functools import partial

def build_tree(feature_paths, max_depth=3, num_splits=5, noisy_split=False, _parent=None, _depth=0):
    """
    Parameters
    ----------
    feature_paths : generator of list of ints
        ...
    max_depth : int
        The built tree will never be deeper than `max_depth`.
    num_splits : int
        At each node, the maximum number of children to be added.
    """
    expand_tree = partial(build_tree, feature_paths, max_depth=max_depth,
                          num_splits=num_splits, noisy_split=noisy_split)
    
    if _parent is None:
        tree = RITTree(next(feature_paths))
        expand_tree(_parent=tree, _depth=0)
        return tree
    else:
        _depth += 1
        if _depth >= max_depth:
            return
        if noisy_split:
            num_splits += np.random.randint(low=0, high=2)
        for i in range(num_splits):
            _parent.add_child(next(feature_paths))
            added_node = _parent.children[-1]
            if not added_node.is_empty():
                expand_tree(_parent=added_node, _depth=_depth)

In [8]:
np.random.seed(12)
tree = build_tree(feature_paths=select_random_path(), max_depth=3, noisy_split=False, num_splits=5)
#path_gen = select_random_path()
#tree = build_tree(feature_paths=path_gen(), max_depth=3, num_splits=5)

In [9]:
#%timeit build_tree(feature_paths=select_random_path())

In [10]:
print("Root:\n", tree._val)
print("Some child:\n", tree.children[0].children[1]._val)

Root:
 [ 1  3  4  5  6  7  8  9 10 11 12 14 15 16 17 18 19 20 21 22 23 24 25 26 27
 29 30 31 32 35 36 37 39 40 41 42 43 44 46 47 48 49 50 51 52 53 54 56 57 60
 61 64 67 68 69 70 71 72 73 74 75 77 78 80 81 82 83 84 85 86 87 88 89 90 91
 96 98 99]
Some child:
 [ 5  8  9 10 12 14 15 17 19 20 21 23 27 29 30 43 44 46 47 49 50 51 52 56 57
 60 61 67 68 69 72 75 80 82 83 84 86 89 90 91]


In [11]:
# If noisy split is False, this should pass
assert(len(tree) == 1 + 5 + 5**2)

In [15]:
# If noisy split is True, this should pass
print(len(tree))
assert(len(tree) <= 1 + 6 + 6**2)

31


In [16]:
list(tree.traverse_depth_first())

[(0, <__main__.RITTree at 0x10b3c2898>),
 (1, <__main__.RITNode at 0x10b3c2e80>),
 (2, <__main__.RITNode at 0x10b3f58d0>),
 (3, <__main__.RITNode at 0x10b3c2d30>),
 (4, <__main__.RITNode at 0x10b3f59b0>),
 (5, <__main__.RITNode at 0x10b3c2588>),
 (6, <__main__.RITNode at 0x10b3c2908>),
 (7, <__main__.RITNode at 0x10b3c2550>),
 (8, <__main__.RITNode at 0x10b3c2518>),
 (9, <__main__.RITNode at 0x10b3c2a20>),
 (10, <__main__.RITNode at 0x10b3c2470>),
 (11, <__main__.RITNode at 0x10b3c2ef0>),
 (12, <__main__.RITNode at 0x10b3c2390>),
 (13, <__main__.RITNode at 0x10b3c28d0>),
 (14, <__main__.RITNode at 0x10b3c2b00>),
 (15, <__main__.RITNode at 0x10b3c2cc0>),
 (16, <__main__.RITNode at 0x10b3c25f8>),
 (17, <__main__.RITNode at 0x10b3c2b70>),
 (18, <__main__.RITNode at 0x10b3c2b38>),
 (19, <__main__.RITNode at 0x10b3c2ac8>),
 (20, <__main__.RITNode at 0x10b3c2e10>),
 (21, <__main__.RITNode at 0x10b3c2dd8>),
 (22, <__main__.RITNode at 0x10b404048>),
 (23, <__main__.RITNode at 0x10b404080>),
 (