# 决策树CART算法demo

## 1、引入对应的库并查看数据集内容

In [283]:
import numpy as np
from sklearn import datasets
# 引入数据集
import matplotlib.pyplot as plt
from pandas import Series
# 引入对应库
irisData = datasets.load_iris()
irisData
print("targetName:%s" % irisData["target_names"])
print("featureName:%s" % irisData["feature_names"])
print("target:%s" % irisData["target"])
dataSet = irisData["data"]
dataSet[1:5]

targetName:['setosa' 'versicolor' 'virginica']
featureName:['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
target:[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]


array([[4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2]])

## 2、建立算法模型

In [631]:
from collections import  Counter
import copy
# 建立cart算法的二叉树结构
class CartBtree:
    def __init__(self, data, left=None, right=None):
        self.data = data # 记录样本数据的列表
        self.var = None # 记录该节点的变量取值
        self.left = left # 左子树
        self.right = right # 右子树
        self.flage = None # 记录该节点代表的样本类型
        self.alpha = None
        
class Cart:
    def __init__(self, obj, target, feature, targetName):
        self.feature = feature
        self.dataSet = obj
        self.target = target
        self.targetName = targetName
        self.listVar = []
        for i in feature:
            # 整个字典的第一维表示变量类型第二维表示物种类型即标签类型最后一维表示每一个变量可能出现的数值
            minNum = (np.min(obj[:,feature.index(i)]) + 0.5) // 1
            maxNum = np.ceil(np.max(obj[:,feature.index(i)]))
            tmpArray = np.arange(minNum,maxNum,0.5)
            for j in tmpArray:
                self.listVar.append([feature.index(i), j])
                # 整个列表中每一个元素是一个元组，每一个元组代表一个变量，
                # 元组中第一个元素表示变量类型，第二个元素表示该类型对应的取值，可以看成是一个范围
        self.listVar = np.array(self.listVar)

    def funShuffle(self):
        index = np.random.choice(np.arange(self.dataSet.shape[0]), size = self.dataSet.shape[0], replace = False)
        # 产生一个随机不重复的序列
        self.listTrain = []
        self.trainTarget = []
        self.listTest = []
        self.testTarget = []
        # 同之前的两个算法相比这里新初始化了两个列表
        # 之前把变量和状态写在一个列表中做向量处理的时候比较乱
        for i in range(self.dataSet.shape[0]):
            if (i < 0.8 * self.dataSet.shape[0]):
                self.listTrain.append(self.dataSet[index[i]])
                self.trainTarget.append(self.target[index[i]])
            else:
                self.listTest.append(self.dataSet[index[i]])
                self.testTarget.append(self.target[index[i]])
        # 返回标准化后的矩阵
        return (np.floor(np.mat(self.listTrain)) + np.around(np.mat(self.listTrain) - np.floor(np.mat(self.listTrain))) / 2)
    def funBtree(self, data, target, var):
        '''
        递归的第一个传入的var变量是实例化后的listVar,类型为np.array
        相似的，data、target为实例化后的listTrain和trainTarget
        '''
        if (len(data)==0 or not(var.shape[0])):
            # 如果变量取值集合为空或所有状态已空则返回None
            return None
        listGini = []
        for i in set(var[:,0]):
            # 遍历变量内容如果同一类变量取值只有两种则可以删去一种，因为这两个的基尼指数必然相等
            if var[var[:,0]==i].shape[0]==2:
                var = np.delete(var, np.where(var[:,0]==i)[0][1], 0)
        for i in var:
            # 计算所有变量的基尼指数
            index = data[:,int(i[0])]==i[1]
            index2 = data[:,int(i[0])]!=i[1]
            count = data[index].shape[0] / data.shape[0]
            sum1, sum2 = 0, 0
            for j in range(3):
                if count!=0:
                    sum1 += pow(len(target[index][target[index]==j])/data[index].shape[0],2)
                if count!=1:
                    sum2 += pow(len(target[index2][target[index2]==j])/data[index2].shape[0],2)
            Gini1 = count * (1 - sum1)
            Gini2 = (1 - count) * (1 - sum2)
            listGini.append(Gini1+Gini2)
        tmpTree = CartBtree((data, target))
        tmpTree.flage = max(Counter(target),key=Counter(target).get)
        if (len(set(target))==1):
            tmpTree.left = None
            tmpTree.right = None
            return tmpTree
        tmpTree.var = var[listGini.index(min(listGini))]
        dataYes = data[data[:,int(tmpTree.var[0])]==tmpTree.var[1]]
        targetYes = target[data[:,int(tmpTree.var[0])]==tmpTree.var[1]]
        dataNo = data[data[:,int(tmpTree.var[0])]!=tmpTree.var[1]]
        targetNo = target[data[:,int(tmpTree.var[0])]!=tmpTree.var[1]]
        tmpTree.left = self.funBtree(dataYes,targetYes,np.delete(var,listGini.index(min(listGini)),0))
        tmpTree.right = self.funBtree(dataNo,targetNo,np.delete(var,listGini.index(min(listGini)),0))
        return tmpTree
    def funTarget(self, tree, item):
        tmp = (np.floor(np.array(item)) + np.around(np.array(item) - np.floor(np.array(item))) / 2)
        if type(tree.left)==type(None):
            return tree.flage
        if tmp[int(tree.var[0])]==tree.var[1]:
            return self.funTarget(tree.left, item)
        else:
            return self.funTarget(tree.right, item)
    def funTest(self, tree):
        errorCount = 0
        for i in range(len(self.listTest)):
            target = self.funTarget(tree, self.listTest[i])
            if int(target) != int(self.testTarget[i]):
                errorCount += 1
        errorRate = (errorCount/len(self.testTarget))*100
        print("错误率为：%.2f" % errorRate, end="%\n")
        return errorRate
    def funCount(self, count, tree):
        # 计算叶节点的个数的返回值需减一
        if (type(tree)==type(None)):
            return 0
        count1 = self.funCount(count, tree.left)
        count2 = self.funCount(count, tree.right)
        return count1+count2+1
    def _trim(tree, minAlpha):
        if (type(tree)==type(None)):
            return
        if tree.alpha==minAlpha:
            tree.left=None
            tree.right=None
            return
        Cart._trim(tree.left, minAlpha)
        Cart._trim(tree.right, minAlpha)
        return
    def computeAlpha(self, tree, alpha=float('inf')):
        if (type(tree)==type(None)):
            return alpha
        if (type(tree.var)==type(None)):
            tree.alpha = float('inf')
            return alpha
        Ct = 1 - sum([pow(list(tree.data[1]).count(0)/len(tree.data[1]),2),pow(list(tree.data[1]).count(1)/len(tree.data[1]),2),pow(list(tree.data[1]).count(2)/len(tree.data[1]),2)])
        countT = self.funCount(0, tree)
        index = tree.data[0][:,int(tree.var[0])]==tree.var[1]
        index2 = tree.data[0][:,int(tree.var[0])]!=tree.var[1]
        count = tree.data[0][index].shape[0] / tree.data[0].shape[0]
        sum1, sum2 = 0, 0
        for j in range(3):
            if count!=0:
                sum1 += pow(len(tree.data[1][index][tree.data[1][index]==j])/tree.data[0][index].shape[0],2)
            if count!=1:
                sum2 += pow(len(tree.data[1][index2][tree.data[1][index2]==j])/tree.data[0][index2].shape[0],2)
        Gini1 = count * (1 - sum1)
        Gini2 = (1 - count) * (1 - sum2)
        CT = Gini1 + Gini2
        if countT==1:
            newAlpha = float('inf')
        else:
            newAlpha = (CT-Ct)/(countT-1)
        tree.alpha = newAlpha
        nextAlpha = self.computeAlpha(tree.left, min(alpha, newAlpha))
        return self.computeAlpha(tree.right, nextAlpha)
    def funTrim(self, tree, num=5):
        # 剪枝函数，这里只做了简单的训练集处理，真实情况需要做交叉验证
        # num默认只剪枝五次，获取的子树是互相嵌套的
        listTree = []
        tmpTree = copy.deepcopy(tree)
        for i in range(num):
            newTree = copy.deepcopy(tmpTree)
            minAlpha = self.computeAlpha(newTree)
            Cart._trim(newTree, minAlpha)
            listTree.append(newTree)
            tmpTree = newTree
        return listTree
        

## 3、实例化cart模型

In [654]:
tmp2 = Cart(dataSet, irisData["target"], irisData["feature_names"], irisData["target_names"])
# 实例化模型并传参
newData = tmp2.funShuffle()
cartTree = tmp2.funBtree(np.array(newData), np.array(tmp2.trainTarget), tmp2.listVar)
errorBefo = tmp2.funTest(cartTree) # 剪枝前的错误率
# 可以看出实例化的模型错误率很不稳定，节点过多，限制条件过多以至于造成一定程度的过拟合，需要进一步剪枝

错误率为：3.33%


## 4、查看模型的输出结果：

In [655]:
a = np.array([4.9, 3. , 1.4, 0.2])
print(tmp2.funTarget(cartTree, a))

0


## 5、比对剪枝后的错误率(这里使用没有交叉验证)

In [656]:
listTree = tmp2.funTrim(cartTree)
listError = []
for i in listTree:
    listError.append(tmp2.funTest(i))

错误率为：3.33%
错误率为：3.33%
错误率为：6.67%
错误率为：0.00%
错误率为：0.00%


In [657]:
# 剪枝后的节点个数
for i in listTree:
    print(tmp2.funCount(0, i))

65
63
40
38
36
