In [1]:
# 回归树的切分函数
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tols = ops[0]; tolN = ops[1] # 容忍误差下降值1，最少切分样本数4
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m, n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tols:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex, bestValue

In [30]:
# CART算法的实现代码
from numpy import *

def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine =list(map(float, curLine)) # 将每行映射为浮点型,python3中需要加list()，否则返回迭代器
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0], :]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0], :]
    return mat0, mat1

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

In [3]:
testMat = mat(eye(4))

In [4]:
testMat

matrix([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [5]:
mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)

In [6]:
mat0

matrix([[0., 1., 0., 0.]])

In [7]:
mat1

matrix([[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [8]:
myDat = loadDataSet('ex00.txt')

In [9]:
myMat = mat(myDat)

In [31]:
createTree(myMat)

{'left': 1.0180967672413792,
 'right': -0.04465028571428572,
 'spInd': 0,
 'spVal': 0.48813}

In [11]:
myDat1 = loadDataSet('ex0.txt')

In [12]:
myMat1 = mat(myDat1)

In [32]:
createTree(myMat1)

{'left': {'left': {'left': 3.9871632,
   'right': 2.9836209534883724,
   'spInd': 1,
   'spVal': 0.797583},
  'right': 1.980035071428571,
  'spInd': 1,
  'spVal': 0.582002},
 'right': {'left': 1.0289583666666666,
  'right': -0.023838155555555553,
  'spInd': 1,
  'spVal': 0.197834},
 'spInd': 1,
 'spVal': 0.39435}

In [33]:
createTree(myMat, ops=(0,1))

{'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 1.035533,
                'right': 1.077553,
                'spInd': 0,
                'spVal': 0.993349},
               'right': {'left': 0.744207,
                'right': 1.069062,
                'spInd': 0,
                'spVal': 0.988852},
               'spInd': 0,
               'spVal': 0.989888},
              'right': 1.227946,
              'spInd': 0,
              'spVal': 0.985425},
             'right': {'left': {'left': 0.862911,
               'right': 0.673579,
               'spInd': 0,
               'spVal': 0.975022},
              'right': {'left': {'left': 1.06469,
                'right': {'left': 0.945255,
                 'right': 1.022906,
                 'spInd': 0,
                 'spVal': 0.950153},
                'spInd': 0,
                'spVal': 0.951949},
               'right': {'left': 0.631862,
 

In [34]:
myDat2 = loadDataSet('ex2.txt')

In [35]:
myMat2 = mat(myDat2)

In [36]:
createTree(myMat2)

{'left': {'left': {'left': {'left': 105.24862350000001,
    'right': 112.42895575000001,
    'spInd': 0,
    'spVal': 0.958512},
   'right': {'left': {'left': {'left': {'left': 87.3103875,
       'right': {'left': {'left': 96.452867,
         'right': {'left': 104.825409,
          'right': {'left': 95.181793,
           'right': 102.25234449999999,
           'spInd': 0,
           'spVal': 0.872883},
          'spInd': 0,
          'spVal': 0.892999},
         'spInd': 0,
         'spVal': 0.910975},
        'right': 95.27584316666666,
        'spInd': 0,
        'spVal': 0.85497},
       'spInd': 0,
       'spVal': 0.944221},
      'right': {'left': 81.110152,
       'right': 88.78449880000001,
       'spInd': 0,
       'spVal': 0.811602},
      'spInd': 0,
      'spVal': 0.833026},
     'right': 102.35780185714285,
     'spInd': 0,
     'spVal': 0.790312},
    'right': 78.08564325,
    'spInd': 0,
    'spVal': 0.759504},
   'spInd': 0,
   'spVal': 0.952833},
  'right': {'left': {'l

In [40]:
createTree(myMat2, ops=(10000,4))

{'left': 101.35815937735848,
 'right': -2.637719329787234,
 'spInd': 0,
 'spVal': 0.499171}

In [42]:
# 回归树后剪枝函数

# 判断输入是否为一颗树
def isTree(obj):
    return (type(obj).__name__=="dict")

# 返回树的平均值
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0

# 树的后剪枝
def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree) # 若测试数据为空，返回树的平均值
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    
    # 剪枝后判断是否还有子树
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'], 2)) + \
                        sum(power(rSet[:,-1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:,-1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

In [43]:
myTree = createTree(myMat2, ops=(0,1))

In [44]:
myDatTest = loadDataSet('ex2test.txt')

In [46]:
myMat2Test = mat(myDatTest)

In [47]:
prune(myTree, myMat2Test)

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging


{'left': {'left': {'left': {'left': 92.5239915,
    'right': {'left': {'left': {'left': 112.386764,
       'right': 123.559747,
       'spInd': 0,
       'spVal': 0.960398},
      'right': 135.837013,
      'spInd': 0,
      'spVal': 0.958512},
     'right': 111.2013225,
     'spInd': 0,
     'spVal': 0.956951},
    'spInd': 0,
    'spVal': 0.965969},
   'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 96.41885225,
              'right': 69.318649,
              'spInd': 0,
              'spVal': 0.948822},
             'right': {'left': {'left': 110.03503850000001,
               'right': {'left': 65.548418,
                'right': {'left': 115.753994,
                 'right': {'left': {'left': 94.3961145,
                   'right': 85.005351,
                   'spInd': 0,
                   'spVal': 0.912161},
                  'right': {'left': {'left': 106.814667,
                    'right': 118.513475,
               

In [48]:
# 模型树的叶节点生成函数
def linearSolve(dataSet):
    m,n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T * X
    if linalg.det(xTx) == 0.0:
        raise NameError("This matrix is sigular, cannot do inverse,\n\
        try increasing the second value of ops")
    ws = xTx.I * (X.T * Y)
    return ws, X, Y

def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat, 2))

In [49]:
myMat2 = mat(loadDataSet('exp2.txt'))

In [50]:
createTree(myMat2, modelLeaf, modelErr)

{'left': matrix([[1.69855694e-03],
         [1.19647739e+01]]), 'right': matrix([[3.46877936],
         [1.18521743]]), 'spInd': 0, 'spVal': 0.285477}

示例：树回归与标准回归的比较

In [51]:
# 用树回归进行预测的代码
def regTreeEval(model, inDat):
    return float(model)

def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1, n + 1)))
    X[:, 1:n + 1] = inDat
    return float(X * model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)
        
def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

In [52]:
trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))

In [53]:
myTree = createTree(trainMat, ops=(1,20)) # 回归树

In [54]:
yHat = createForeCast(myTree, testMat[:,0])

In [59]:
corrcoef(yHat, testMat[:,1], rowvar=False)[0, 1] # rowvar = False表示每一列代表一个变量，行为观测值

0.9640852318222141

In [73]:
myTree = createTree(trainMat, modelLeaf, modelErr, (1,20)) # 模型树

In [75]:
yHat = createForeCast(myTree, testMat[:,0], modelTreeEval)

In [76]:
corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]

0.9760412191380593

In [77]:
ws, X, Y = linearSolve(trainMat)

In [78]:
ws

matrix([[37.58916794],
        [ 6.18978355]])

In [79]:
for i in range(shape(testMat)[0]):
    yHat[i] = testMat[i,0] * ws[1,0] + ws[0,0]

In [80]:
corrcoef(yHat, testMat[:,1], rowvar=False)[0,1]

0.9434684235674763

示例：利用GUI对回归树调优  
等以后学了PyQt再进行完善