In [2]:
import graphviz
import numpy
import matplotlib.pyplot as plt


class DecisionTree(object):
    def __init__(self,impurity,left,right,split,labels,data=[]):
        self.split = split  #划分值
        self.data = data[1:,:]   #数据集
        self.impurity = impurity  #纯度 有三种指标
        self.right = right  #左子树
        self.left = left   #右子树
        self.labels = data[0,:]
        
    def chooseBest(self,data):
        """
        选择最好的划分值
        data:数据集
        return：最好的划分值，特征所在的列
        """
    
        sorted_data = np.sort(data[:, :-1], axis=0)#不要最后一行 进行整理 axis=0表示按列排序
        sorted_data = sorted_data.T.tolist()#转置，且将数组转换为列表 相当于列表嵌套 写到一行了

        for row in range(len(sorted_data)):#row取了一行 因为是转置过的矩阵 所以其实是原来的一列
            sorted_ = sorted_data[row]#按照数字大小进行的整理 应该
            sorted_set = sorted(list(set(sorted_)))#对于数字去重 再整理  得到一个嵌套列表

            split = [(sorted_set[i] + sorted_set[i + 1]) / 2 for i in range(len(sorted_set)-2)]#依次计算了平均值  构成了一个比原来小的列表
            for value in split:#对于这个平均值列表里的每一个平均值 进行分裂左右节点  node_实际上是一个数字
                left = data[np.where(data[:, row] <= value)[0], :]
                right = data[np.where(data[:, row] > value)[0], :]
                impurity = calculate(left,right,data)#计算在该种分裂方式下 纯度（信息增益 增益率或者基尼系数）
                if impurity >= max_impurity:
                    best_split = split
                    axis = row
            # print(self.method + ' max value: ' + str(max_impurity))
            # print(combine_df[combine_df[index_] <= best_node].groupby(combine_df['label'])['label'].count())
            # print(combine_df[combine_df[index_] > best_node].groupby(combine_df['label'])['label'].count())
        return best_split,axis
        
    
    def entropy(dataset):
        """   
        计算信息熵
        """
        m = len(dataset)
        labelcount = {}
        for data in dataset:
            currentlabel = data[-1]
            if currentlabel not in labelcount.keys():
                labelcount[currentlabel] = 0
            labelcount[currentlabel]+=1
        entrop = 0
        for label in labelcount:
            p = float(labelcount[label])/m
            entropy-=log2(p)
        return entropy
    
    
    def Gini(dataset):
        """
        计算基尼属性值
        """
      
        m=dataset
        labelcount = {}
        for data in dataset:
            currentlabel = data[-1]
            if currentlabel not in labelcount.keys():
                labelcount[currentlabel] = 0
            labelcount[currentlabel]+=1
        a = len(labelcount[0])/m
        b = len(labelcount[1])/m
        Gini = 1-a*a-b*b
        return Gini
        
    
    def vote(classlist):
        """
        当特征值为1时，进行投票表决
        """
        classcount={}
        for class_ in classlist:
            if (class_ not in classcount.keys()):
                classcount[class_] = 0
            classcount[class_] += 1
        if classcount[0] >=classcount[1]:
            result=0
        else:
            result=1
        return result
            
            
        
    def split_data(dataset,split,axis):
        """
        划分左右子树
        
        Parameters:
        -----------
        dataset：数据集
        split：划分值
        axis：划分的特征所在的列
        
        Returns:
        -------
        left：左子树
        right：右子树
        """
        left=[]
        right=[]
        for data in dateset:
            if(data[axis] < split):
                data_x = data[:axis]
                data_x.extend(data[axis+1:])
                left.append(data_x)
            if(data[axis] >= split):
                data_x = data[:axis]
                data_x.extend(data[axis+1:])
                right.append(data_x)
        return left,right
    
    
    def buildTree(dataset,labels):
        """
        构建决策树
        Parameters：
        ------------
        dataset:数据集
        labels:标签
        
        Returns：
        ---------
        decision_tree:一个字典
        
        """
        labels = labels
        fealabels=[]
        classlist=[example[-1] for example in dataset]
        if (len(classlist)==classlist.count(classlist[0])):#当前的所有样本均属于同一类
            return classlist[0]
        if (len(dataset[0])==1):#属性剩余一个
            majorclass=vote(classlist)
            return majorclass
        decision_tree = {best_split:{}}
        value,axis = chooseBest(dataset)
        left,right = split_data(dataset,value,axis)
        return DecisionTree(left=left,right=right)
    
class ID3(DecisionTree):
    def calculate(left,right,dataset):
        """
        计算纯度--信息增益
        Parameters：
        -------------
        dataset：数据集
        left：左子树
        right：右子树
        Returns：
        --------
        impurity：纯度
    
        """
        value,axis = chooseBest(dataset)
        left,right = split_data(dataset,axis,value)
        origin_ent = entropy(dataset)
        p1 = float(len(left)/len(dataset))
        p2 = float(len(right)/len(dataset))
        info_gain = origin_ent-p1*entropy(left)-p2*entropy(right)
        impurity = info_gain
        return impurity
        
        
class C45(DecisionTree):
    def calculate(left,right,dataset):
        """
        计算纯度--信息增益率
        Parameters：
        -------------
        dataset：数据集
        left：左子树
        right：右子树
        Returns：
        --------
        impurity：纯度
    
        """
        value,axis = chooseBest(dataset)
        left,right = split_data(dataset,axis,value)
        origin_ent = entropy(dataset)
        p1 = float(len(left)/len(dataset))
        p2 = float(len(right)/len(dataset))
        info_gain = origin_ent-p1*entropy(left)-p2*entropy(right)
        impurity=float(info_gain/origin_ent)
        return impurity


    
    
class CART(DecisionTree):
    def calculate(left,right,dataset):
        """
        计算纯度--基尼增益
        Parameters：
        -------------
        dataset：数据集
        left：左子树
        right：右子树
        Returns：
        --------
        impurity：纯度
    
        """
        value,axis = chooseBest(dataset)
        left,right = split_data(dataset,axis,value)
        p1 = float(len(left)/len(dataset))
        p2 = float(len(right)/len(dataset))
        impurity=p1*Gini(left)+p2*Gini(right)
        return impurity


        
        
        
    
            
        
        
    



In [None]:
dot_data=tree.export_graphviz(model,
                              feature_names=   ,
                              class_names= ,)
graph=graphviz.Source(dot_data)
graph.render('computer')
graph