In [1]:
import numpy as np

In [2]:
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        # 将每行读成浮点数
        fltLine = list(map(float, curLine))
        dataMat.append(fltLine)
    return dataMat

In [3]:
def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1

### 树节点

In [4]:
class treeNode():
    def __init__(self, feat, val, right, lift):
        featureToSplitOn = feat
        valueOfSplit = val
        rightBranch = right
        leftBrance = left

创建树的伪代码：  
找到最佳的代切分特征：  
&ensp;&ensp;&ensp;&ensp;如果该节点不能再分，则存为叶结点  
&ensp;&ensp;&ensp;&ensp;执行二元切分  
&ensp;&ensp;&ensp;&ensp;在右子树调用createTree()方法  
&ensp;&ensp;&ensp;&ensp;在左子树调用createTree()方法

In [5]:
# 生成叶结点
def regLeaf(dataSet):
    return np.mean(dataSet[:, -1])

# 计算平方误差
def regErr(dataSet):
    return np.var(dataSet[:, -1]) * np.shape(dataSet)[0]

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    # 将数据集分成两个部分
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree 

### 回归树的切分函数

In [6]:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    tolS = ops[0]
    tolN = ops[1]
    # 所有值相等则退出
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m, n = np.shape(dataSet)
    S = errType(dataSet)
    bestS = np.inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if np.shape(mat0)[0] < tolN or np.shape(mat1)[0] < tolN:
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 如果误差减少不大则退出
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    # 如果切分出的数据集很小则退出
    if np.shape(mat0)[0] < tolN or np.shape(mat1)[0] < tolN:
        return None, leafType(dataSet)
    return bestIndex, bestValue

In [7]:
myDat = loadDataSet('ex00.txt')
myMat = np.mat(myDat)
createTree(myMat)

{'spInd': 0,
 'spVal': 0.48813,
 'left': 1.0180967672413792,
 'right': -0.04465028571428572}

In [9]:
myDat1 = loadDataSet('ex0.txt')
myMat1 = np.mat(myDat1)
createTree(myMat1)

{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': 3.9871632,
   'right': 2.9836209534883724},
  'right': 1.980035071428571},
 'right': {'spInd': 1,
  'spVal': 0.197834,
  'left': 1.0289583666666666,
  'right': -0.023838155555555553}}

In [10]:
createTree(myMat, ops=(0, 1))

{'spInd': 0,
 'spVal': 0.48813,
 'left': {'spInd': 0,
  'spVal': 0.620599,
  'left': {'spInd': 0,
   'spVal': 0.625336,
   'left': {'spInd': 0,
    'spVal': 0.625791,
    'left': {'spInd': 0,
     'spVal': 0.643601,
     'left': {'spInd': 0,
      'spVal': 0.651376,
      'left': {'spInd': 0,
       'spVal': 0.6632,
       'left': {'spInd': 0,
        'spVal': 0.683921,
        'left': {'spInd': 0,
         'spVal': 0.819823,
         'left': {'spInd': 0,
          'spVal': 0.837522,
          'left': {'spInd': 0,
           'spVal': 0.846455,
           'left': {'spInd': 0,
            'spVal': 0.919384,
            'left': {'spInd': 0,
             'spVal': 0.976414,
             'left': {'spInd': 0,
              'spVal': 0.985425,
              'left': {'spInd': 0,
               'spVal': 0.989888,
               'left': {'spInd': 0,
                'spVal': 0.993349,
                'left': 1.035533,
                'right': 1.077553},
               'right': {'spInd': 0,
        

In [11]:
myDat2 = loadDataSet('ex2.txt')
myMat2 = np.mat(myDat2)
createTree(myMat2)

{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.958512,
    'left': 105.24862350000001,
    'right': 112.42895575000001},
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': {'spInd': 0,
     'spVal': 0.790312,
     'left': {'spInd': 0,
      'spVal': 0.833026,
      'left': {'spInd': 0,
       'spVal': 0.944221,
       'left': 87.3103875,
       'right': {'spInd': 0,
        'spVal': 0.85497,
        'left': {'spInd': 0,
         'spVal': 0.910975,
         'left': 96.452867,
         'right': {'spInd': 0,
          'spVal': 0.892999,
          'left': 104.825409,
          'right': {'spInd': 0,
           'spVal': 0.872883,
           'left': 95.181793,
           'right': 102.25234449999999}}},
        'right': 95.27584316666666}},
      'right': {'spInd': 0,
       'spVal': 0.811602,
       'left': 81.110152,
       'right': 88.78449880000001}},
     'right': 102.

## 后剪枝  
prune()伪代码：  
基于已有的树切分测试数据：  
&ensp;&ensp;&ensp;&ensp;如果存在任一子集是一棵树，则在该子集递归剪枝过程  
&ensp;&ensp;&ensp;&ensp;计算将当前两个叶结点合并后的误差  
&ensp;&ensp;&ensp;&ensp;计算不合并的误差  
&ensp;&ensp;&ensp;&ensp;如果合并误差会降低，就将叶结点合并

In [13]:
def isTree(obj):
    return type(obj).__name__ == 'dict'

# 从上往下遍历，找到叶结点则计算平均值
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0

def prune(tree, testData):
    if np.shape(testData)[0] == 0:    # 确认测试集是否为空
        return getMean(tree)
    if isTree(tree['right']) or isTree(tree['left']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(np.power(lSet[:, -1] - tree['left'], 2)) + sum(np.power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(np.power(testData[:, -1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print('merging')
            return treeMean
        else:
            return tree
    else:
        return tree

In [14]:
myTree = createTree(myMat2, ops=(0, 1))

In [15]:
myDatTest = loadDataSet('ex2test.txt')
myMat2Test = np.mat(myDatTest)
prune(myTree, myMat2Test)

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging


{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.965969,
    'left': 92.5239915,
    'right': {'spInd': 0,
     'spVal': 0.956951,
     'left': {'spInd': 0,
      'spVal': 0.958512,
      'left': {'spInd': 0,
       'spVal': 0.960398,
       'left': 112.386764,
       'right': 123.559747},
      'right': 135.837013},
     'right': 111.2013225}},
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': {'spInd': 0,
     'spVal': 0.763328,
     'left': {'spInd': 0,
      'spVal': 0.769043,
      'left': {'spInd': 0,
       'spVal': 0.790312,
       'left': {'spInd': 0,
        'spVal': 0.806158,
        'left': {'spInd': 0,
         'spVal': 0.815215,
         'left': {'spInd': 0,
          'spVal': 0.833026,
          'left': {'spInd': 0,
           'spVal': 0.841547,
           'left': {'spInd': 0,
            'spVal': 0.841625,
            'left': {'spInd': 0,
            