In [15]:
import numpy as np
from sklearn.datasets import load_iris

iris = load_iris()
iris
x, y = iris.data, iris.target

class Tree:
    
    def __init__(self, x, y, left=None, right=None, parent=None, min_samples_split=2, min_samples_leaf=1):
        self.__x = x
        self.__y = y
        self.__right = right
        self.__left = left
        self.__parent = parent
        self.__min_samples_split = min_samples_split
        self.__min_samples_leaf = min_samples_leaf
        self.__is_leaf = 0
        self.__best_split_feature = None
        self.__res = None
        self.__value = {}

    @property
    def x(self):
        return self.__x
    
    @x.setter
    def x(self,x):
        self.__x = x
        
    @property
    def y(self):
        return self.__y
    
    @y.setter
    def y(self,y):
        self.__y = y
    
    @property
    def left(self):
        return self.__left
    
    @left.setter
    def left(self,left):
        self.__left = left
        
    @property
    def right(self):
        return self.__right
    
    @right.setter
    def right(self, right):
        self.__right = right
        
    @property
    def parent(self):
        return self.__parent
    
    @parent.setter
    def parent(self, parent):
        self.__parent = parent
        
    @property    
    def res(self):
        return self.__res
    
    @res.setter
    def res(self, res):
        self.__res = res
    
    @property
    def best_split_feature(self):
        return self.__best_split_feature
    
    @best_split_feature.setter
    def best_split_feature(self, best_split_feature):
        self.__best_split_feature = best_split_feature
    
    @property
    def label(self):
        label = None
        _size = 0
        sa = np.unique(self.y)
        for i in sa:
            size = self.y[self.y==i].size
            if size>_size:
                label = i
                _size = size
        return label
    
    @property
    def is_leaf(self):
        return self.__is_leaf
    
    @is_leaf.setter
    def is_leaf(self, is_leaf):
        self.__is_leaf = is_leaf
        
    @property
    def value(self):
        return self.__value
    @value.setter
    def value(self, value):
        self.__value = value
    
    
    # 计算基尼指数
    def gini(self, y=None):
        
        if y is None:
            y = self.y
        
        la = y.size
        sa = np.unique(y)
        
        g = 0
        for i in sa:
            pi = y[y==i].size/la
            pi2 = pi*pi
            g = g + pi2
        
        return round(1-g,2)
        

    #搜索最佳特征和分裂点
    def search_best_split(self):
        
        best_split_feature = None
        best_split_dot = None
        best_split_gini = self.gini()
        
        for i in np.unique(self.y):
            self.value[i] = self.y[self.y==i].size
        
        if self.y.size<=self.__min_samples_split or np.unique(self.y).size==self.__min_samples_leaf or self.gini()==0:
            self.is_leaf = 1
            self.res = ['samples = %s' % self.y.size, 'label = %s' % self.label, 'value = %s'%self.value]
        
        split_dict = {}
        
        for i in range(self.x.shape[1]):
            split_dict[i] = []
            uni_x = np.unique(self.x[:,i])
            len_x = uni_x.size
            
            j=0
            
            while j<=len_x-1-1:
                
                split_dot = (uni_x[j]+uni_x[j+1])/2
        
                y1 = self.y[self.x[:,i]<=split_dot]
                y2 = self.y[self.x[:,i]>split_dot]
    
                gini_y1 = self.gini(y1)
                gini_y2 = self.gini(y2)
            
                gini_total = y1.size/self.y.size*gini_y1 + y2.size/self.y.size*gini_y2
                
                if gini_total<=best_split_gini:
                    best_split_feature = i
                    best_split_dot = split_dot
                    best_split_gini = gini_total
                
                _split = {'split_dot':split_dot, 'gini_total':gini_total}
                
                split_dict[i].append(_split)
                
                j = j+1
        
        self.best_split_feature = [best_split_feature, best_split_dot]
        self.res =  ['x[%s]<=%s'%(best_split_feature,best_split_dot),'gini = %s'%self.gini(), 'samples = %s'%self.y.size, 
               'value = %s'%self.value]
        return best_split_feature, best_split_dot, best_split_gini, split_dict
            
    def split(self):
        
        if self.is_leaf == 1:
            return self.res
        
        best_split_feature, best_split_dot, best_split_gini, split_dict = self.search_best_split()
        self.left = Tree(self.x[self.x[:,best_split_feature]<=best_split_dot], self.y[self.x[:,best_split_feature]<=best_split_dot])
        self.right = Tree(self.x[self.x[:,best_split_feature]>best_split_dot], self.y[self.x[:,best_split_feature]>best_split_dot])
        
        return self.res
    

In [16]:
class ClassifyTree:
    
    def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1):

        self.__min_samples_split = min_samples_split
        self.__min_samples_leaf = min_samples_leaf
        self.__max_depth = max_depth
        self.__children = []
        self.__parents = []
        self.__root = None
        
    def fit(self, x, y, min_samples_split=2, min_samples_leaf=1):
        
        depth = 1
        self.__root = Tree(x, y)
        self.__parents.append(self.__root)
        
        print('begin fit, root = %s'%hex(id(self.__root)))
        print('----------------------------------')
        
        #广度优先
        while depth<=self.__max_depth:
            
            print('--------------depth: %s ----------------'%depth)
            
            for i in self.__parents:
                
                i.search_best_split()
                print('当前分裂节点： %s, 节点内容： %s, 节点标签: %s, 是否叶子节点: %s'%(hex(id(i)), i.res, round(i.label,2), i.is_leaf))
                
                if i == None or i.is_leaf==1:
                    continue
                i.split()
                self.__children.append(i.left)
                self.__children.append(i.right)
                
                if depth == self.__max_depth:
                    i.left.is_leaf = 1
                    i.right.is_leaf = 1
                    
                print('当前节点生成左子节点：%s, 父节点为：%s'%(
                        hex(id(i.left)),hex(id(i))
                        ))
                print('当前节点生成右子节点：%s,  父节点为：%s'%(
                        hex(id(i.right)),hex(id(i))
                        ))
                print('------')
                
            self.__parents=[]
            self.__parents.extend(self.__children)
            self.__children = []
            
            depth = depth + 1
            
    def predict(self, x1):
        
        current_node = self.__root
        is_leaf = current_node.is_leaf

        #print('predict begin')
        
        while is_leaf == 0:

            best_feature = current_node._Tree__best_split_feature
            #print('current_node_id: %s'%hex(id(current_node)),'current_node_res : %s'%current_node.res,
            #      'is_leaf : %s'%current_node.is_leaf, 'label: %s'%current_node.label)


            if x1[best_feature[0]]<=best_feature[1]:
                #print('go left')      
                current_node = current_node.left
            else:
                #print('go right')
                current_node = current_node.right

            is_leaf = current_node.is_leaf
        return current_node.label


In [17]:
clf = ClassifyTree()
clf.fit(x, y)

begin fit, root = 0x2e1af6d85e0
----------------------------------
--------------depth: 1 ----------------
当前分裂节点： 0x2e1af6d85e0, 节点内容： ['x[3]<=0.8', 'gini = 0.67', 'samples = 150', 'value = {0: 50, 1: 50, 2: 50}'], 节点标签: 0, 是否叶子节点: 0
当前节点生成左子节点：0x2e1a83d3760, 父节点为：0x2e1af6d85e0
当前节点生成右子节点：0x2e1aee747f0,  父节点为：0x2e1af6d85e0
------
--------------depth: 2 ----------------
当前分裂节点： 0x2e1a83d3760, 节点内容： ['x[3]<=0.55', 'gini = 0.0', 'samples = 50', 'value = {0: 50}'], 节点标签: 0, 是否叶子节点: 1
当前分裂节点： 0x2e1aee747f0, 节点内容： ['x[3]<=1.75', 'gini = 0.5', 'samples = 100', 'value = {1: 50, 2: 50}'], 节点标签: 1, 是否叶子节点: 0
当前节点生成左子节点：0x2e1af6d8640, 父节点为：0x2e1aee747f0
当前节点生成右子节点：0x2e1af6d8460,  父节点为：0x2e1aee747f0
------
--------------depth: 3 ----------------
当前分裂节点： 0x2e1af6d8640, 节点内容： ['x[2]<=4.95', 'gini = 0.17', 'samples = 54', 'value = {1: 49, 2: 5}'], 节点标签: 1, 是否叶子节点: 0
当前节点生成左子节点：0x2e1af6d8e50, 父节点为：0x2e1af6d8640
当前节点生成右子节点：0x2e1af6d8610,  父节点为：0x2e1af6d8640
------
当前分裂节点： 0x2e1af6d8460, 节点内容： ['x[2]<=