# 第9章 树回归

本章内容：
- CART算法
- 回归与模型树
- 树剪枝算法
- Python中GUI的使用

## 9.1 复杂数据的局部性建模

<center>**树回归**</center>
优点：可以对复杂和非线性的数据建模。  
缺点：结果不易理解。  
适用数据类型：数值型和标称型数据。  

## 9.2 连续和离散型特征的树的构建

In [1]:
class treeNode():
    def __init__(self, feat, val, right, left):
        featureToSplitOn = feat
        valueOfSplit = val
        rightBranch = right
        leftBranch = left

###### 程序清单9-1 CART算法的实现代码

In [23]:
import numpy as np

def loadDataSet(filename):
    dataMat = []
    fr = open(filename)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(np.float, curLine))
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
#     mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :][0]
    mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
#     mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :][0]
    mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1

def regLeaf(dataSet):#returns the value used for each leaf
    return np.mean(dataSet[:,-1])

def regErr(dataSet):
    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

In [10]:
testMat = np.mat(np.eye(4))
testMat

matrix([[ 1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.]])

In [11]:
mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)

In [12]:
mat0

matrix([[ 0.,  1.,  0.,  0.]])

In [13]:
mat1

matrix([[ 1.,  0.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.]])

## 9.3 将CART算法用于回归

#### 9.3.1 构建树

###### 程序清单9-2 回归树的切分函数

In [21]:
def regLeaf(dataSet):
    return np.mean(dataSet[:, -1])

def regErr(dataSet):
    return np.var(dataSet[:, -1]) * np.shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]
    tolN = ops[1]
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m, n = np.shape(dataSet)
    S = errType(dataSet)
    bestS = np.inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n - 1):
        for splitVal in set(dataSet[:, featIndex]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < beatS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex, bestValue

In [18]:
myDat = loadDataSet('./Data/ex00.txt')
myMat = np.mat(myDat)
myMat

matrix([[  3.60980000e-02,   1.55096000e-01],
        [  9.93349000e-01,   1.07755300e+00],
        [  5.30897000e-01,   8.93462000e-01],
        [  7.12386000e-01,   5.64858000e-01],
        [  3.43554000e-01,  -3.71700000e-01],
        [  9.80160000e-02,  -3.32760000e-01],
        [  6.91115000e-01,   8.34391000e-01],
        [  9.13580000e-02,   9.99350000e-02],
        [  7.27098000e-01,   1.00056700e+00],
        [  9.51949000e-01,   9.45255000e-01],
        [  7.68596000e-01,   7.60219000e-01],
        [  5.41314000e-01,   8.93748000e-01],
        [  1.46366000e-01,   3.42830000e-02],
        [  6.73195000e-01,   9.15077000e-01],
        [  1.83510000e-01,   1.84843000e-01],
        [  3.39563000e-01,   2.06783000e-01],
        [  5.17921000e-01,   1.49358600e+00],
        [  7.03755000e-01,   1.10167800e+00],
        [  8.30700000e-03,   6.99760000e-02],
        [  2.43909000e-01,  -2.94670000e-02],
        [  3.06964000e-01,  -1.77321000e-01],
        [  3.64920000e-02,   4.081

In [24]:
createTree(myMat)

TypeError: unhashable type: 'matrix'