# 9.树回归算法

In [1]:
from numpy import *

## 9.1.读取数据

In [10]:
def loadDataSet(fileName):
    """
    读取数据
    参数：
        fileName -- 文件名
    参数：
        dataMat -- 数据矩阵
    """
    # 新建数据矩阵
    dataMat = []
    fr = open(fileName)
    # 以制表符分割
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine))
        dataMat.append(fltLine)
    return dataMat

In [73]:
def binSplitDataSet(dataSet, feature, value):
    """
    二分割数据
    参数：
        dataSet -- 数据集
        feature -- 分割特征
        value -- 分割点
    返回：
        mat0 -- 矩阵1
        mat1 -- 矩阵2
    """
    #if len(nonzero(dataSet[:,feature] > value)[0]) == 0:
    #    return array([]), array([])
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

## 9.2.CART算法
### 9.2.1.支持函数

In [5]:
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

In [6]:
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

In [74]:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    """
    选择最佳分割点
    参数：
        dataSet -- 数据集
        leafTypr -- 叶子类型
        errType -- 误差类型
        ops -- 参数元组
    返回：
        bestIndex -- 最佳类型索引
        bestValue -- 最佳值
    """
    # 读取元组中的参数
    tolS = ops[0]; tolN = ops[1]
    # 如果该类型下只有一个值，无法继续划分，则返回该值
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        #print("end1")
        return None, leafType(dataSet)
    # 数据维度
    m,n = shape(dataSet)
    # 最佳分割类型由误差决定
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    # 遍历所有特征
    for featIndex in range(n-1):
        # 遍历所有值
        for splitVal in set(dataSet[:, featIndex].A.squeeze().tolist()):
            # 二分割
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            # 如果分割后矩阵的维度不满足要求，则continue
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 
                #print("continue")
                continue
            # 计算总误差
            #print("do")
            newS = errType(mat0) + errType(mat1)
            # 如果误差最佳
            if newS < bestS: 
                # 记录相关参数
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 如果分割的效果提升没有超过阈值，就不进行分割
    if (S - bestS) < tolS: 
        #print("end2")
        #print(f"S = {S}, bestS = {bestS}, tolS = {tolS}")
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        #print("end3")
        return None, leafType(dataSet)
    #print("end4")
    return bestIndex,bestValue

### 9.2.2.生成树的算法

In [8]:
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    """
    建立树
    参数：
        dataSet -- 数据集
        leafType -- 叶子类型
        errType -- 误差类型
        ops -- 参数元组
    返回：
        retTree -- 返回的树
    """
    # 选择最佳分割点
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    # 如果达到迭代停止条件，返回val
    if feat == None: return val
    # 新建子树字典
    retTree = {}
    # 分割特征的索引
    retTree['spInd'] = feat
    # 分割点的值
    retTree['spVal'] = val
    # 二分割
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    # 左右子树
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree 

测试

In [75]:
myDat = loadDataSet("ex00.txt")
myMat = mat(myDat)
createTree(myMat)

{'spInd': 0,
 'spVal': 0.48813,
 'left': 1.0180967672413792,
 'right': -0.04465028571428572}

在ex0.txt上测试

In [76]:
myDat1 = loadDataSet("ex0.txt")
myMat1 = mat(myDat1)
createTree(myMat1)

{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': 3.9871632,
   'right': 2.9836209534883724},
  'right': 1.980035071428571},
 'right': {'spInd': 1,
  'spVal': 0.197834,
  'left': 1.0289583666666666,
  'right': -0.023838155555555553}}

## 9.3.树剪枝
### 9.3.1.支持函数

In [77]:
def isTree(obj):
    return (type(obj).__name__=='dict')

In [78]:
def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

### 9.3.2.剪枝函数

In [81]:
def prune(tree, testData):
    """
    剪枝函数
    参数：
        tree -- 待剪枝的树
        testData -- 测试数据
    返回：
        treeMean -- 合并的结果
        或
        tree -- 不需要剪枝
    """
    # 如果没有测试数据，就直接把整棵树合并
    if shape(testData)[0] == 0: return getMean(tree)
    # 如果没有树可以合并，则分割节点
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    # 递归分割左右子树
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
    # 如果两棵树都是叶子节点，则判断是否要合并
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        # 不合并的误差
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
            sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        # 合并误差
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        # 如果合并后，误差减小，则执行合并
        if errorMerge < errorNoMerge: 
            print("merging")
            return treeMean
        # 反之，不执行合并
        else: return tree
    else: return tree

测试

In [85]:
myDat2 = loadDataSet("ex2.txt")
myMat2 = mat(myDat2)
myTree = createTree(myMat2, ops=(0, 1))
myDatTest = loadDataSet("ex2test.txt")
myMat2Test = mat(myDatTest)
prune(myTree, myMat2Test)

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging


{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.965969,
    'left': 92.5239915,
    'right': {'spInd': 0,
     'spVal': 0.956951,
     'left': {'spInd': 0,
      'spVal': 0.958512,
      'left': {'spInd': 0,
       'spVal': 0.960398,
       'left': 112.386764,
       'right': 123.559747},
      'right': 135.837013},
     'right': 111.2013225}},
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': {'spInd': 0,
     'spVal': 0.763328,
     'left': {'spInd': 0,
      'spVal': 0.769043,
      'left': {'spInd': 0,
       'spVal': 0.790312,
       'left': {'spInd': 0,
        'spVal': 0.806158,
        'left': {'spInd': 0,
         'spVal': 0.815215,
         'left': {'spInd': 0,
          'spVal': 0.833026,
          'left': {'spInd': 0,
           'spVal': 0.841547,
           'left': {'spInd': 0,
            'spVal': 0.841625,
            'left': {'spInd': 0,
            

## 9.5.模型树

In [86]:
def linearSolve(dataSet):
    """
    模型树节点生成函数
    参数：
        dataSet -- 数据集
    返回：
        ws -- 回归参数
        X -- x数据
        Y -- y数据
    """
    m,n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

In [87]:
def modelLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws

In [88]:
def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))

测试

In [90]:
myMat2 = mat(loadDataSet("exp2.txt"))
createTree(myMat2, modelLeaf, modelErr, (1, 10))

{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
         [1.19647739e+01]]), 'right': matrix([[3.46877936],
         [1.18521743]])}

## 9.6.用树回归进行预测

In [92]:
def regTreeEval(model, inDat):
    return float(model)

def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1]=inDat
    return float(X*model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
        else: return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
        else: return modelEval(tree['right'], inData)
        
def createForeCast(tree, testData, modelEval=regTreeEval):
    m=len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

测试

In [94]:
trainMat = mat(loadDataSet("bikeSpeedVsIq_train.txt"))
testMat = mat(loadDataSet("bikeSpeedVsIq_test.txt"))
myTree = createTree(trainMat, ops=(1,20))
yHat = createForeCast(myTree, testMat[:,0])
corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]

0.9640852318222141

In [98]:
myTree = createTree(trainMat, modelLeaf, modelErr, ops=(1,20))
yHat = createForeCast(myTree, testMat[:,0], modelTreeEval)
corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]

0.9760412191380593

In [99]:
ws, X, Y = linearSolve(trainMat)
ws

matrix([[37.58916794],
        [ 6.18978355]])

In [100]:
for i in range(shape(testMat)[0]):
    yHat[i] = testMat[i, 0] * ws[1, 0] + ws[0, 0]

corrcoef(yHat, testMat[:,1], rowvar=0)[0, 1]

0.9434684235674766