### 分类回归树（classification and regression Trees）

  

In [2]:
from numpy import *
%run TreePlotter.ipynb

In [3]:
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine))
        dataMat.append(fltLine)
    return dataMat

In [4]:
def binSplitDataSet(dataSet, feature, value):
    '''给定数据集，待分的特征，特征值，通过数组过滤生成两个子集并返回'''
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0, mat1

In [5]:
def regLeaf(dataSet):
    '''生成叶节点，在回归树中，该模型就是目标变量的均值'''
    return mean(dataSet[:,-1])

In [6]:
def regErr(dataSet):
    '''计算目标变量的总的平方误差'''
    return var(dataSet[:,-1]) * shape(dataSet)[0]

In [7]:
testMat = mat(eye(4))
mat0, mat1 = binSplitDataSet(testMat,1,0.5)
print(mat0)
print(mat1)

[[0. 1. 0. 0.]]
[[1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]


## 将CART算法用于回归

In [8]:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr,ops=(1,4)):
    '''用最佳方式切分数据集，生成相应的叶节点'''
    tolS = ops[0];tolN = ops[1]                                           #tolS：允许的误差下降值   tolN：切分的最小样本数

    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:                        #如果数据集的值属于相等，则返回叶节点为数据集的均值
        return None,leafType(dataSet)
    
    m,n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0                             #初始化目标变量的误差，最好的误差，最好的索引和最好的值
    
    for featIndex in range(n-1):                                          #对每一个特征索引
        for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):      #对固定特征索引的每个取值
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)    #划分数据集
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):        #若两个子集中任一个的数据个数小于规定值，计算总平方误差
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:                                              #如果新的总误差 小于 最好的误差，则更新最好误差
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) <= tolS:                                               # 如果误差下降值小于设定的阈值，则返回叶节点
        return None, leafType(dataSet)
    
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)           #如果切分出的数据集很小则退出
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0]) < tolN:
        return None, leafType(dataSet)
    return bestIndex, bestValue

In [9]:
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    '''递归生成决策树'''
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)        # 选择特征
    if feat == None:                                                    # 如果满足条件，则返回
        return val
    retTree = {}                                                        #否则创建一个新的字典，用于存放树
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)                    #将数据集分为两个部分
    retTree['left'] = createTree(lSet, leafType, errType, ops)          #对 左树 和 右树分别迭代生成树
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

In [10]:
myDat = loadDataSet('data/ex00.txt')
myMat = mat(myDat)
createTree(myMat)

{'spInd': 0,
 'spVal': 0.48813,
 'left': 1.0180967672413792,
 'right': -0.04465028571428572}

In [11]:
# 多次切分
myDat1 = loadDataSet('data/ex0.txt')
myMat1 = mat(myDat1)
createTree(myMat1)

{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': 3.9871632,
   'right': 2.9836209534883724},
  'right': 1.980035071428571},
 'right': {'spInd': 1,
  'spVal': 0.197834,
  'left': 1.0289583666666666,
  'right': -0.023838155555555553}}

## 剪枝：通过降低决策树复杂度来避免过拟合
    预剪枝：chooseBestSplit()中的提前终止条件
    后剪枝：利用测试集对树进行剪枝
        基于已有的树切分测试数据
        如果任一子集是一棵树，则在该子集递归剪枝过程
        计算当前两个叶节点合并后的误差
        计算不合并的误差
        如果合并后误差降低，则将叶节点合并

In [13]:
myDat2= loadDataSet('data/ex2.txt')
myMat2= mat(myDat2)
retTree = createTree(myMat2,ops=(10000,4))
print(retTree)  

{'spInd': 0, 'spVal': 0.499171, 'left': 101.35815937735848, 'right': -2.637719329787234}


In [16]:
def isTree(obj):
    return (type(obj).__name__=='dict')

In [17]:
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0
    

In [18]:
def prune(tree, testData):                                                   #待剪枝的树和剪枝所需的测试数据
    if shape(testData)[0] == 0:                                              # 确认数据集非空
        return getMean(tree)
    
    #假设发生过拟合，采用测试数据对树进行剪枝
    if (isTree(tree['right']) or isTree(tree['left'])):                      #左右子树非空
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): 
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): 
        tree['right'] = prune(tree['right'], rSet)
    #剪枝后判断是否还是有子树
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        #判断是否merge
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
                       sum(power(rSet[:, -1] - tree['right'], 2))
        
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        #如果合并后误差变小
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

In [20]:
myTree = createTree(myMat2,ops=(0,1))
myDatTest = loadDataSet('data/ex2test.txt')
myMat2Test = mat(myDatTest)
Tree = prune(myTree, myMat2Test)
print(Tree)

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.965969, 'left': 92.5239915, 'right': {'spInd': 0, 'spVal': 0.956951, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': {'spInd': 0, 'spVal': 0.960398, 'left': 112.386764, 'right': 123.559747}, 'right': 135.837013}, 'right': 111.2013225}}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.763328, 'left': {'spInd': 0, 'spVal': 0.769043, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.806158, 'left': {'spInd': 0, 'spVal': 0.815215, 'left': {'spInd': 0, 'spVa

## 模型树：吧叶节点设置为分段线性函数

In [28]:
def linearSolve(dataSet):
    m, n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        print('error')
        return
    ws = xTx.T * X.T * Y
    return ws, X, Y

In [29]:
def modelLeaf(dataSet):                #生成叶节点的线性模型
    ws, X, Y = linearSolve(dataSet)
    return ws

In [30]:
def modelErr(dataSet):                 # 在给定的数据集上计算误差
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y-yHat, 2))

In [36]:
myMat2 = mat(loadDataSet('data/exp2.txt'))
createTree(myMat2, modelLeaf, modelErr, (1,10))

{'spInd': 0,
 'spVal': 0.66897,
 'left': {'spInd': 0,
  'spVal': 0.829402,
  'left': {'spInd': 0, 'spVal': 0.906907, 'left': matrix([[4975.08130668],
           [4787.84869266]]), 'right': matrix([[5338.48156289],
           [4668.13806435]])},
  'right': {'spInd': 0, 'spVal': 0.763717, 'left': matrix([[5105.00607177],
           [4086.18045627]]), 'right': {'spInd': 0,
    'spVal': 0.721517,
    'left': matrix([[1387.78024516],
            [1034.39811056]]),
    'right': matrix([[1254.1632993 ],
            [ 882.47514671]])}}},
 'right': {'spInd': 0,
  'spVal': 0.399447,
  'left': {'spInd': 0,
   'spVal': 0.551771,
   'left': {'spInd': 0, 'spVal': 0.609194, 'left': matrix([[1583.44942959],
            [1022.20757928]]), 'right': matrix([[1367.76302312],
            [ 806.73069721]])},
   'right': {'spInd': 0, 'spVal': 0.501156, 'left': matrix([[1355.66970295],
            [ 713.53111618]]), 'right': matrix([[1456.36098645],
            [ 653.00187841]])}},
  'right': {'spInd': 0,
   