In [5]:
import numpy as np
from sklearn.datasets import load_diabetes



class Tree:

    def __init__(self, x, y, left=None, right=None, parent=None, min_samples_split=2, min_samples_leaf=1):
        self.__x = x
        self.__y = y
        self.__right = right
        self.__left = left
        self.__parent = parent
        self.__min_samples_split = min_samples_split
        self.__min_samples_leaf = min_samples_leaf
        self.__is_leaf = 0
        self.__best_split_feature = None
        self.__res = None

    @property
    def x(self):
        return self.__x

    @x.setter
    def x(self, x):
        self.__x = x

    @property
    def y(self):
        return self.__y

    @y.setter
    def y(self, y):
        self.__y = y

    @property
    def left(self):
        return self.__left

    @left.setter
    def left(self, left):
        self.__left = left

    @property
    def right(self):
        return self.__right

    @right.setter
    def right(self, right):
        self.__right = right

    @property
    def parent(self):
        return self.__parent

    @parent.setter
    def parent(self, parent):
        self.__parent = parent

    @property
    def res(self):
        return self.__res
    
    @res.setter
    def res(self, res):
        self.__res = res

    @property
    def best_split_feature(self):
        return self.__best_split_feature

    @best_split_feature.setter
    def best_split_feature(self, best_split_feature):
        self.__best_split_feature = best_split_feature

    @property
    def label(self):
        return np.mean(self.y)

    @property
    def is_leaf(self):
        return self.__is_leaf

    @is_leaf.setter
    def is_leaf(self, is_leaf):
        self.__is_leaf = is_leaf

    # 计算残差平方和
    def mse(self, y=None):

        if y is None:
            y = self.y

        return np.sum(np.square(y - np.mean(y))) / y.size

    # 搜索最佳特征和分裂点
    def search_best_split(self):

        _best_split_feature = None
        _best_split_dot = None
        _best_split_mse = self.mse()
        
        if self.y.size <= self.__min_samples_split or np.unique(self.y).size <= self.__min_samples_leaf
            or _best_split_mse == 0:
            self.is_leaf = 1
            self.res = ['samples = %s' % self.y.size, 'label = %s' % round(np.mean(self.y),2),
            'mse = %s'%round(self.mse(),2)]
            return ['samples = %s' % self.y.size, 'label = %s' % round(np.mean(self.y),2),
            'mse = %s'%round(self.mse(),2)]

        split_dict = {}
        
        for i in range(self.x.shape[1]):
            
            split_dict[i] = []
            uni_x = np.unique(self.x[:, i])
            len_x = uni_x.size

            j = 0

            while j <= len_x - 1 - 1:

                split_dot = (uni_x[j] + uni_x[j + 1]) / 2

                y1 = self.y[self.x[:, i] <= split_dot]
                y2 = self.y[self.x[:, i] > split_dot]

                mse_y1 = self.mse(y1)
                mse_y2 = self.mse(y2)

                mse_total =  y1.size/self.y.size*mse_y1 +  y2.size/self.y.size*mse_y2

                if mse_total <= _best_split_mse:
                    _best_split_feature = i
                    _best_split_dot = split_dot
                    _best_split_mse = mse_total

                _split = {'split_dot': _best_split_dot, 'mse_total': _best_split_mse}

                split_dict[i].append(_split)

                j = j + 1
            
            
        self.__best_split_feature = [_best_split_feature, _best_split_dot]
        self.res = ['x[%s]<=%s' % (_best_split_feature, round(_best_split_dot,4)), 'mse = %s'%round(self.mse(),2)
                    , 'samples = %s' % self.y.size]
        return _best_split_feature, _best_split_dot, _best_split_mse, split_dict

    def split(self):


        if self.is_leaf == 1:
            return self.res

        _best_split_feature, _best_split_dot, _best_split_mse, split_dict = self.search_best_split()
        print('best_split_feature:%s, best_split_dot:%s'%(_best_split_feature, round(_best_split_dot,4)))
        self.left = Tree(self.x[self.x[:, _best_split_feature] <= _best_split_dot],
                         self.y[self.x[:, _best_split_feature] <= _best_split_dot])
        self.right = Tree(self.x[self.x[:, _best_split_feature] > _best_split_dot],
                          self.y[self.x[:, _best_split_feature] > _best_split_dot])
        samples = self.y.size

        return self.res


class RegressorTree:

    def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1):

        self.__min_samples_split = min_samples_split
        self.__min_samples_leaf = min_samples_leaf
        self.__max_depth = max_depth
        self.__children = []
        self.__parents = []
        self.__root = None

    def fit(self, x, y, min_samples_split=2, min_samples_leaf=1):

        depth = 1
        self.__root = Tree(x, y)
        self.__parents.append(self.__root)
        
        print('begin fit, root = %s'%hex(id(self.__root)))
        print('----------------------------------')
        
        #广度优先
        while depth <= self.__max_depth:
            
            print('--------------depth: %s ----------------'%depth)
            
            for i in self.__parents:
                
                i.search_best_split()
                print('当前分裂节点： %s, 节点内容： %s, 节点标签: %s, 是否叶子节点: %s'%(hex(id(i)), i.res, round(i.label,2), i.is_leaf))
                
                if i == None or i.is_leaf==1:
                    continue
                i.split()
                self.__children.append(i.left)
                self.__children.append(i.right)
                
                if depth == self.__max_depth:
                    i.left.is_leaf = 1
                    i.right.is_leaf = 1
                    
                
                
                print('当前节点生成左子节点：%s, 父节点为：%s'%(
                        hex(id(i.left)),hex(id(i))
                        ))
                print('当前节点生成右子节点：%s,  父节点为：%s'%(
                        hex(id(i.right)),hex(id(i))
                        ))
                print('------')
            self.__parents = []
            self.__parents.extend(self.__children)
            self.__children = []

            depth = depth + 1

    def predict(self, x1):

        current_node = self.__root
        is_leaf = current_node.is_leaf
        
        #print('predict begin')
        
        while is_leaf == 0:
                      
            best_feature = current_node.best_split_feature
            #print('current_node_id: %s'%hex(id(current_node)),'current_node_res : %s'%current_node.res,
            #      'is_leaf : %s'%current_node.is_leaf, 'label: %s'%current_node.label)
            
            if x1[best_feature[0]] <= best_feature[1]:
                #print('go left')
                current_node = current_node.left

            else:
                #print('go right')
                current_node = current_node.right
        
            is_leaf = current_node.is_leaf
            
        return current_node.label


In [6]:
x, y = load_diabetes(return_X_y=True)
clf =  RegressorTree(max_depth=5)
clf.fit(x, y)

begin fit, root = 0x289838b93d0
----------------------------------
--------------depth: 1 ----------------
当前分裂节点： 0x289838b93d0, 节点内容： ['x[8]<=-0.0038', 'mse = 5929.88', 'samples = 442'], 节点标签: 152.13, 是否叶子节点: 0
best_split_feature:8, best_split_dot:-0.0038
当前节点生成左子节点：0x28985c61280, 父节点为：0x289838b93d0
当前节点生成右子节点：0x289885e2e20,  父节点为：0x289838b93d0
------
--------------depth: 2 ----------------
当前分裂节点： 0x28985c61280, 节点内容： ['x[2]<=0.0062', 'mse = 3240.82', 'samples = 218'], 节点标签: 109.99, 是否叶子节点: 0
best_split_feature:2, best_split_dot:0.0062
当前节点生成左子节点：0x289885e2fa0, 父节点为：0x28985c61280
当前节点生成右子节点：0x289838b97f0,  父节点为：0x28985c61280
------
当前分裂节点： 0x289885e2e20, 节点内容： ['x[2]<=0.0148', 'mse = 5135.61', 'samples = 224'], 节点标签: 193.15, 是否叶子节点: 0
best_split_feature:2, best_split_dot:0.0148
当前节点生成左子节点：0x289838b9220, 父节点为：0x289885e2e20
当前节点生成右子节点：0x28983944bb0,  父节点为：0x289885e2e20
------
--------------depth: 3 ----------------
当前分裂节点： 0x289885e2fa0, 节点内容： ['x[6]<=0.021', 'mse = 2143.97', 'samples