In [3]:
import numpy as np
from sklearn import datasets
from collections import Counter


_DEBUG_MODE = 0


class NodeData:
    def __init__(self, node_type=None):
        self.node_type = node_type
        self.f = None
        self.sp = None
        self.pred_class = None

    def set_data(self, f, sp, pred_class):
        self.f = f
        self.sp = sp
        self.pred_class = pred_class


class Tree:
    def __init__(self):
        self.left = None
        self.right = None
        self.data = None


class Subtree:

    _rnd_attempts = 2
    _max_depth = 5

    def __init__(self, X, y):
        self._X = X
        self._y = y
        self.sequence = ''

    def make_split(self, curr_depth, tree_node, sequence=''):
        self.curr_depth = curr_depth
        self.sequence += sequence
        nd = NodeData()
        tree_node.data = nd
        tree_node.data.node_type = 'node'

        if _DEBUG_MODE > 0:
            print('\n\n === Current sequence: ', self.sequence)

        # init parameters
        self._feature_num = self._X.shape[1]

        # select a feature randomly
        features = np.random.randint(0, self._feature_num - 1, min(self._rnd_attempts, self._feature_num - 1))

        # split data
        ig_best = 0
        fi_best = 0
        spi_best = 0
        ent_children_best = 0

        # calc entropy
        classes = np.unique(self._y)
        classes_num = len(np.unique(self._y))
        mf_class = int(Counter(self._y).most_common(1)[0][0])
        tree_node.data.pred_class = mf_class

        if _DEBUG_MODE > 0:
            print('y', self._y.T)
            print('Most probable class is ', mf_class)

        if classes_num == 1:
            tree_node.data.node_type = 'leaf'
            return tree_node
        # TODO add classes checking - if number of classes is 1 then we should stop splitting right here!

        if _DEBUG_MODE > 0:
            print('Total classes: ', classes_num)

        entropy_parent = 0
        for c in classes:
            if np.sum(self._y == c) > 0:
                p_i = np.sum(self._y == c) / len(self._y)
                entropy_parent -= p_i * np.log2(p_i)

        if _DEBUG_MODE > 0:
            print(self.curr_depth, 'Parent entropy: ', entropy_parent)

        for f in features:
            unique_items = np.unique(self._X[:, f])

            if _DEBUG_MODE > 10:
                print('Unique values: ', unique_items)

            split_points = np.random.randint(0, len(unique_items), self._rnd_attempts)

            if _DEBUG_MODE > 0:
                print('Feature: ', f, 'Split ranges: ', [unique_items[s] for s in split_points])

            for sp in split_points:
                # split, evaluate information gain and make the best choice
                ly = self._y[self._X[:, f] > unique_items[sp]]
                ry = self._y[self._X[:, f] <= unique_items[sp]]

                entropy_children = 0
                for _y in [ly, ry]:
                    for c in classes:
                        if np.sum(_y == c) > 0:
                            p_i = np.sum(_y == c) / len(_y)
                            entropy_children -= p_i * np.log2(p_i) # TODO - add weighting

                entropy_children /= 2
                ig = entropy_parent - entropy_children  # parent entropy should also be higher that children

                if ig > ig_best:
                    ig_best = ig
                    fi_best = f
                    spi_best = sp
                    enth_children_best = entropy_children

            if _DEBUG_MODE > 20:
                print('Best IG, feature, range for this split: ', ig_best, fi_best, spi_best)

        if _DEBUG_MODE > 20:
            print('Best IG: ', ig_best)
            print('Best FI: ', fi_best)
            print('Best SPI: ', spi_best)

        if ig_best < 0:
            print('\n !!! Warning! Something go wrong. IG less then zero! !!! \n')

        # build tree recursively
        if self.curr_depth < self._max_depth:

            if _DEBUG_MODE > 0:
                print('Starting build subtrees. Current depth: ', curr_depth)

            lX = self._X[self._X[:, fi_best] > unique_items[spi_best]]
            ly = self._y[self._X[:, fi_best] > unique_items[spi_best]]
            rX = self._X[self._X[:, fi_best] <= unique_items[spi_best]]
            ry = self._y[self._X[:, fi_best] <= unique_items[spi_best]]

            # if this split doesn't bring additional information, finish splitting
            if min(len(ry), len(ly)) == 0:
                print('Empty split')
                print('Parent entropy ', entropy_parent, 'children entropy ', enth_children_best)

                return None

            self.left_subtree = Subtree(lX, ly)
            self.right_subtree = Subtree(rX, ry)

            if _DEBUG_MODE > 0:
                    print('Splitting', len(ly), len(ry))
                    print('Parent entropy ', entropy_parent, 'children entropy ', enth_children_best)
                    print('Left subtree', 'X', fi_best, '>', unique_items[spi_best], '\n', lX.T, '\n', ly.T)
                    print('Right subtree', 'X', fi_best, '<=', unique_items[spi_best], '\n', rX.T, '\n', ry.T)

            lst = Tree()
            rst = Tree()

            # tree_node.data = (fi_best, unique_items[spi_best], None)
            tree_node.data.set_data(fi_best, unique_items[spi_best], mf_class)

            lst = self.left_subtree.make_split(self.curr_depth + 1, lst, self.sequence + ' left ')
            rst = self.right_subtree.make_split(self.curr_depth + 1, rst, self.sequence + ' right ')

            # add subnodes to node
            tree_node.left = lst
            tree_node.right = rst
        else:
            pass
            # to add data filling in by most probable class name

        return tree_node


class YADecisionTree:
    """
        My own decision tree classifier
    """

    _rnd_attempts = 2

    def __init__(self, max_depth=5):
        self._max_depth = max_depth

    def fit(self, X, y):
        self._X = X.copy()
        self._y = y.copy()
        self._feature_num = len(np.unique(self._y))

        self._root = Subtree(self._X, self._y)
        self.tree_node = Tree()
        self.tree_node = self._root.make_split(0, self.tree_node, 'root')

    def predict(self, X):
        n = self.tree_node
        y = None

        while y == None:
            if n.data.node_type == 'leaf':
                y = n.data.pred_class
            else: # it is a decision node
                if X[n.data.f] > n.data.sp:
                    n = n.left
                else:
                    n = n.right
        return y

    def __print_node(self, n):
        if n.data.node_type != 'leaf':
            print(n.data.pred_class)

    def print_tree(self):
        stack = []
        if self.tree_node is not None:
            print('\n\n===== Tree =====')

            stack.append(('', self.tree_node))
            while len(stack) > 0:

                n = stack.pop()
                lrdir = n[0]
                node = n[1]
                level = len(stack) + 1

                if node.data.node_type == 'leaf':
                    # do smth
                    print('-'*level*3, 'leaf at', level, ', prediction class:', node.data.pred_class)
                    pass
                else: # it is a node
                    if lrdir == '':
                        print('-'*level*3, 'Go down left from', level, 'X' + str(node.data.f), '>', node.data.sp)
                        stack.append(('l', node))
                        stack.append(('', node.left))
                    elif lrdir == 'l':
                        print('-'*level*3, 'Go down right from', level, level, 'X' + str(node.data.f), '<=', node.data.sp)
                        stack.append(('r', node))
                        stack.append(('', node.right))
                    else: # 'r'
                        print('-'*level*3, 'Go up from', level)
                        pass # go up





In [6]:


# testing
sample_ds = datasets.make_classification(n_samples=1500,
                                         n_features=5,
                                         n_informative=2,
                                         n_classes=4,
                                         n_redundant=0,
                                         n_clusters_per_class=1,
                                         random_state=6)

ins = YADecisionTree()

Xtest = np.arange(0, 300)
Xtest = Xtest.reshape(3, 100).T
ytest = np.concatenate([np.ones((25, 1)), np.ones((25, 1)) * 2, np.ones((25, 1)) * 3, np.ones((25, 1)) * 4])


Xtest = np.arange(0, 60)
Xtest = Xtest.reshape(3, 20).T
ytest = np.concatenate([np.ones(5), np.ones(5) * 2, np.ones(5) * 3, np.ones(5) * 4]) #np.concatenate([np.ones((5, 1)), np.ones((5, 1)) * 2, np.ones((5, 1)) * 3, np.ones((5, 1)) * 4])


# X, y = sample_ds

print(Xtest.T, '\n', ytest.T)

np.random.seed(18)  # 11
Xxx, yyy = Xtest, ytest
ins.fit(Xxx, yyy)
ins.print_tree()

print('\n\n')

for i in range(len(yyy)):
    print(Xxx[i], ins.predict(Xxx[i]), yyy[i])

print('\n', 'New test examples')
print(Xxx[0]/2, ins.predict(Xxx[0]/2))
print(Xxx[-1]*2, ins.predict(Xxx[-1]*2))
print(Xxx[-1]/1.5, ins.predict(Xxx[-1]/1.5))

[[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
 [20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39]
 [40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59]] 
 [ 1.  1.  1.  1.  1.  2.  2.  2.  2.  2.  3.  3.  3.  3.  3.  4.  4.  4.
  4.  4.]


===== Tree =====
--- Go down left from 1 X1 > 34
------ leaf at 2 , prediction class: 4
--- Go down right from 1 1 X1 <= 34
------ Go down left from 2 X0 > 1
--------- Go down left from 3 X0 > 9
------------ leaf at 4 , prediction class: 3
--------- Go down right from 3 3 X0 <= 9
------------ Go down left from 4 X1 > 23
--------------- Go down left from 5 X1 > 24
------------------ leaf at 6 , prediction class: 2
--------------- Go down right from 5 5 X1 <= 24
------------------ leaf at 6 , prediction class: 1
--------------- Go up from 5
------------ Go down right from 4 4 X1 <= 23
--------------- leaf at 5 , prediction class: 1
------------ Go up from 4
--------- Go up from 3
------ Go down right from 2 2 X0 <= 