In [1]:
matplotlib inline

In [2]:
from numpy import *

In [6]:
def loadDataSet(fileName):
    """
    数据存在一起，每行内容保存成浮点数
    """
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))  #将每行映射成浮点数
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    """
    在给定特征和特征值的情况下，该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回
    
    Args:
    dataSet 数据集合
    feature 待切分的特征
    value 该特征的某个值
    
    """
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1


In [4]:
testMat = mat(eye(4))
testMat

matrix([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [5]:
mat0, mat1 = binSplitDataSet(testMat,1,0.5)
mat0

matrix([[0., 1., 0., 0.]])

In [6]:
mat1

matrix([[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [7]:
testMat[nonzero(testMat[:, 1] <= 0.5)[0],:]

matrix([[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [1]:
def regLeaf(dataSet):
    """
    生成叶节点，当chooseBestSplit()函数确定不再对数据进行切分时，将调用该函数来得到叶节点的模型
    在回归树中，该模型就是目标变量的均值
    """
    return mean(dataSet[:,-1])

def regErr(dataSet):
    """
    误差估计函数
    在给定数据上计算目标变量的平方误差，
    这里直接调用均方差函数var()，因为这里需要返回总方差，所以要用均方差乘以数据集中样本的个数
    """
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    """
    回归树构建的核心函数，该函数的目的是找到数据的最佳二元切分方式。
    如果找不到一个“好”的二元切分，该函数返回None并同时调用createTree()方法产生叶节点，叶节点的值也返回None。
    如果找到一个“好”的切分方式，则返回特征编号和切分特征值
    
    """
    tolS = ops[0]; tolN = ops[1]  #为ops设定tolS tolN 两个值，是用户指定的参数，用于控制函数的停止时机。其中tolS是容许的误差下降值，
    #tolN是切分的最少样本数
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #如果所有值相等则退出
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)  #该误差S用于与新切分误差进行对比，来检查新切分能否降低误差
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):  #增加“T.tolist()[0]” unhashable type: 'matrix'
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S -bestS) < tolS: #如果误差减少不大则退出
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #如果切分出的数据集很小则退出
        return None, leafType(dataSet)
    return bestIndex, bestValue

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    """
    树构建递归函数，该函数首先尝试将数据集分成两个部分
    Args:
    dataSet 数据集合
    leafType 建立叶节点的函数
    errType 误差计算函数
    ops 一个包含树构建所需其他参数的元组
    
    """
    feat, val=chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val #满足停止条件时返回叶结点值，如果构建的是回归树，该模型是一个常数。如果是模型树，其模型是一个线性方程
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

In [9]:
myDat = loadDataSet('ex00.txt')
print("myDat",myDat)
myMat = mat(myDat)
print("myMat", myMat)

myDat [[0.036098, 0.155096], [0.993349, 1.077553], [0.530897, 0.893462], [0.712386, 0.564858], [0.343554, -0.3717], [0.098016, -0.33276], [0.691115, 0.834391], [0.091358, 0.099935], [0.727098, 1.000567], [0.951949, 0.945255], [0.768596, 0.760219], [0.541314, 0.893748], [0.146366, 0.034283], [0.673195, 0.915077], [0.18351, 0.184843], [0.339563, 0.206783], [0.517921, 1.493586], [0.703755, 1.101678], [0.008307, 0.069976], [0.243909, -0.029467], [0.306964, -0.177321], [0.036492, 0.408155], [0.295511, 0.002882], [0.837522, 1.229373], [0.202054, -0.087744], [0.919384, 1.029889], [0.377201, -0.24355], [0.814825, 1.095206], [0.61127, 0.982036], [0.072243, -0.420983], [0.41023, 0.331722], [0.869077, 1.114825], [0.620599, 1.334421], [0.101149, 0.068834], [0.820802, 1.325907], [0.520044, 0.961983], [0.48813, -0.097791], [0.819823, 0.835264], [0.975022, 0.673579], [0.953112, 1.06469], [0.475976, -0.163707], [0.273147, -0.455219], [0.804586, 0.924033], [0.074795, -0.349692], [0.625336, 0.623696], [

In [10]:
createTree(myMat)

{'spInd': 0,
 'spVal': 0.48813,
 'left': 1.0180967672413792,
 'right': -0.04465028571428572}

In [11]:
myDat1 = loadDataSet('ex0.txt')
myMat1 = mat(myDat1)
createTree(myMat1)

{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': 3.9871632,
   'right': 2.9836209534883724},
  'right': 1.980035071428571},
 'right': {'spInd': 1,
  'spVal': 0.197834,
  'left': 1.0289583666666666,
  'right': -0.023838155555555553}}

一棵树如果节点过多，表面该模型可能对数据进行了“过拟合”。通过降低决策树的复杂度来避免过拟合的过程称为“剪枝”。

In [12]:
createTree(myMat, ops=(0,1))

{'spInd': 0,
 'spVal': 0.48813,
 'left': {'spInd': 0,
  'spVal': 0.620599,
  'left': {'spInd': 0,
   'spVal': 0.625336,
   'left': {'spInd': 0,
    'spVal': 0.625791,
    'left': {'spInd': 0,
     'spVal': 0.643601,
     'left': {'spInd': 0,
      'spVal': 0.651376,
      'left': {'spInd': 0,
       'spVal': 0.6632,
       'left': {'spInd': 0,
        'spVal': 0.683921,
        'left': {'spInd': 0,
         'spVal': 0.819823,
         'left': {'spInd': 0,
          'spVal': 0.837522,
          'left': {'spInd': 0,
           'spVal': 0.846455,
           'left': {'spInd': 0,
            'spVal': 0.919384,
            'left': {'spInd': 0,
             'spVal': 0.976414,
             'left': {'spInd': 0,
              'spVal': 0.985425,
              'left': {'spInd': 0,
               'spVal': 0.989888,
               'left': {'spInd': 0,
                'spVal': 0.993349,
                'left': 1.035533,
                'right': 1.077553},
               'right': {'spInd': 0,
        

In [13]:
myDat2 = loadDataSet('ex2.txt')
myMat2 = mat(myDat2)
createTree(myMat2)

{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.958512,
    'left': 105.24862350000001,
    'right': 112.42895575000001},
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': {'spInd': 0,
     'spVal': 0.790312,
     'left': {'spInd': 0,
      'spVal': 0.833026,
      'left': {'spInd': 0,
       'spVal': 0.944221,
       'left': 87.3103875,
       'right': {'spInd': 0,
        'spVal': 0.85497,
        'left': {'spInd': 0,
         'spVal': 0.910975,
         'left': 96.452867,
         'right': {'spInd': 0,
          'spVal': 0.892999,
          'left': 104.825409,
          'right': {'spInd': 0,
           'spVal': 0.872883,
           'left': 95.181793,
           'right': 102.25234449999999}}},
        'right': 95.27584316666666}},
      'right': {'spInd': 0,
       'spVal': 0.811602,
       'left': 81.110152,
       'right': 88.78449880000001}},
     'right': 102.

In [14]:
createTree(myMat2, ops=(1000,4))

{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': 108.838789625,
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': 95.7366680212766,
    'right': 78.08564325}},
  'right': 107.68699163829788},
 'right': -2.637719329787234}

In [2]:
def isTree(obj):
    """
    测试输入变量是否是一棵树，返回布尔类型值即判断当前处理的节点是否叶节点
    """
    return (type(obj).__name__ == 'dict')

def getMean(tree):
    """
    递归函数，从上往下遍历树直到叶节点为止，如果找到两个叶节点则计算它们的平均值
    该函数对树进行塌陷处理（即返回平均值）
    """
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right'])/2.0

def prune(tree, testData):
    """
    Args:
      tree 待剪枝的树
      testData 剪枝所需要的测试数据
    """
    #首先确认数据集是否是空，一旦非空，则反复递归调用函数对测试数据进行切分
    if shape(testData)[0] == 0: return getMean(tree)
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    #如果两个分支已经不再是子树，就进行合并，具体做法是对合并前后的误差做比较
    #如果合并后的误差比不合并的误差小就进行合并，反之则不合并直接返回
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) + sum(power(rSet[:,-1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

In [16]:
myTree = createTree(myMat2, ops=(0,1))

In [19]:
myDatTest = loadDataSet('ex2test.txt')

In [20]:
myMat2Test = mat(myDatTest)

In [21]:
prune(myTree, myMat2Test)

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging


{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.965969,
    'left': 92.5239915,
    'right': {'spInd': 0,
     'spVal': 0.956951,
     'left': {'spInd': 0,
      'spVal': 0.958512,
      'left': {'spInd': 0,
       'spVal': 0.960398,
       'left': 112.386764,
       'right': 123.559747},
      'right': 135.837013},
     'right': 111.2013225}},
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': {'spInd': 0,
     'spVal': 0.763328,
     'left': {'spInd': 0,
      'spVal': 0.769043,
      'left': {'spInd': 0,
       'spVal': 0.790312,
       'left': {'spInd': 0,
        'spVal': 0.806158,
        'left': {'spInd': 0,
         'spVal': 0.815215,
         'left': {'spInd': 0,
          'spVal': 0.833026,
          'left': {'spInd': 0,
           'spVal': 0.841547,
           'left': {'spInd': 0,
            'spVal': 0.841625,
            'left': {'spInd': 0,
            

## 模型树的叶节点生成函数

In [3]:
def linearSolve(dataSet):
    """
    将数据格式化成目标变量Y，自变量X.X和Y用于执行简单的线性回归
    """
    m,n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse, try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

def modelLeaf(dataSet):
    """
    在数据集上调用linearSolve()并返回回归系数
    """
    ws,X,Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    """
    在给定的数据集上计算误差
    """
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat, 2)) #返回yHat和Y之间的平方误差


In [24]:
myMat2 = mat(loadDataSet('exp2.txt'))
myMat2

matrix([[7.0670000e-02, 3.4708290e+00],
        [5.3407600e-01, 6.3771320e+00],
        [7.4722100e-01, 8.9494070e+00],
        [6.6897000e-01, 8.0340810e+00],
        [5.8608200e-01, 6.9977210e+00],
        [7.6496200e-01, 9.3181100e+00],
        [6.5812500e-01, 7.8803330e+00],
        [3.4673400e-01, 4.2133590e+00],
        [3.1396700e-01, 3.7624960e+00],
        [6.0141800e-01, 7.1888050e+00],
        [4.0439600e-01, 4.8934030e+00],
        [1.5434500e-01, 3.6831750e+00],
        [9.8406100e-01, 1.1712928e+01],
        [5.9751400e-01, 7.1466940e+00],
        [5.1440000e-03, 3.3331500e+00],
        [1.4229500e-01, 3.7436810e+00],
        [2.8000700e-01, 3.7373760e+00],
        [5.4200800e-01, 6.4942750e+00],
        [4.6678100e-01, 5.5322550e+00],
        [7.0697000e-01, 8.4767180e+00],
        [1.9103800e-01, 3.6739210e+00],
        [7.5659100e-01, 9.1767220e+00],
        [9.1287900e-01, 1.0850358e+01],
        [5.2470100e-01, 6.0674440e+00],
        [3.0609000e-01, 3.6811480e+00],


In [26]:
createTree(myMat2, modelLeaf, modelErr,(1,10))

{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
         [1.19647739e+01]]), 'right': matrix([[3.46877936],
         [1.18521743]])}

### 用树回归进行预测的代码

In [4]:
def regTreeEval(model, inDat):
    """
    对于输入的单个数据点或行向量，返回一个浮点数，为了与modelTreeEval()函数参数保持一致，保留两个输入参数
    """
    return float(model)

def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1, n+1)))
    X[:,1:n+1] = inDat
    return float(X*model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    """
    参数modelEval是对叶节点数据进行预测的函数的引用
    该函数自顶而下遍历整棵树，直到命中叶节点为止。一旦到达叶节点，它就会在输入数据上调用modelEval()函数
    要对回归树叶节点进行预测，就调用函数regTreeEval();要对模型树节点进行预测，就调用modelTreeEval()
    它们会对输入数据进行格式化处理
    """
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)
        
def createForeCast(tree, testData, modelEval=regTreeEval):
    """
    以向量形式返回一组预测值，因此该函数在对整个测试集进行预测时非常有用
    """
    m=len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

In [28]:
trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
trainMat

matrix([[  3.      ,  46.852122],
        [ 23.      , 178.676107],
        [  0.      ,  86.154024],
        [  6.      ,  68.707614],
        [ 15.      , 139.737693],
        [ 17.      , 141.988903],
        [ 12.      ,  94.477135],
        [  8.      ,  86.083788],
        [  9.      ,  97.265824],
        [  7.      ,  80.400027],
        [  8.      ,  83.414554],
        [  1.      ,  52.525471],
        [ 16.      , 127.060008],
        [  9.      , 101.639269],
        [ 14.      , 146.41268 ],
        [ 15.      , 144.157101],
        [ 17.      , 152.69991 ],
        [ 19.      , 136.669023],
        [ 21.      , 166.971736],
        [ 21.      , 165.467251],
        [  3.      ,  38.455193],
        [  6.      ,  75.557721],
        [  4.      ,  22.171763],
        [  5.      ,  50.321915],
        [  0.      ,  74.412428],
        [  5.      ,  42.052392],
        [  1.      ,  42.489057],
        [ 14.      , 139.185416],
        [ 21.      , 140.713725],
        [  5. 

In [29]:
testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))
testMat

matrix([[ 12.      , 121.010516],
        [ 19.      , 157.337044],
        [ 12.      , 116.031825],
        [ 15.      , 132.124872],
        [  2.      ,  52.719612],
        [  6.      ,  39.058368],
        [  3.      ,  50.757763],
        [ 20.      , 166.740333],
        [ 11.      , 115.808227],
        [ 21.      , 165.582995],
        [  3.      ,  41.956087],
        [  3.      ,  34.43237 ],
        [ 13.      , 116.954676],
        [  1.      ,  32.112553],
        [  7.      ,  50.380243],
        [  7.      ,  94.107791],
        [ 23.      , 188.943179],
        [ 18.      , 152.637773],
        [  9.      , 104.122082],
        [ 18.      , 127.805226],
        [  0.      ,  83.083232],
        [ 15.      , 148.180104],
        [  3.      ,  38.480247],
        [  8.      ,  77.597839],
        [  7.      ,  75.625803],
        [ 11.      , 124.620208],
        [ 13.      , 125.186698],
        [  5.      ,  51.165922],
        [  3.      ,  31.179113],
        [ 15. 

In [30]:
myTree = createTree(trainMat, ops=(1,20))
myTree

{'spInd': 0,
 'spVal': 10.0,
 'left': {'spInd': 0,
  'spVal': 17.0,
  'left': {'spInd': 0,
   'spVal': 20.0,
   'left': 168.34161286956524,
   'right': 157.0484078846154},
  'right': {'spInd': 0,
   'spVal': 14.0,
   'left': 141.06067981481482,
   'right': 122.90893026923078}},
 'right': {'spInd': 0,
  'spVal': 7.0,
  'left': 94.7066578125,
  'right': {'spInd': 0,
   'spVal': 5.0,
   'left': 69.02117757692308,
   'right': 50.94683665}}}

In [34]:
yHat = createForeCast(myTree, testMat[:,0])
yHat

matrix([[122.90893027],
        [157.04840788],
        [122.90893027],
        [141.06067981],
        [ 50.94683665],
        [ 69.02117758],
        [ 50.94683665],
        [157.04840788],
        [122.90893027],
        [168.34161287],
        [ 50.94683665],
        [ 50.94683665],
        [122.90893027],
        [ 50.94683665],
        [ 69.02117758],
        [ 69.02117758],
        [168.34161287],
        [157.04840788],
        [ 94.70665781],
        [157.04840788],
        [ 50.94683665],
        [141.06067981],
        [ 50.94683665],
        [ 94.70665781],
        [ 69.02117758],
        [122.90893027],
        [122.90893027],
        [ 50.94683665],
        [ 50.94683665],
        [141.06067981],
        [157.04840788],
        [ 94.70665781],
        [157.04840788],
        [122.90893027],
        [ 50.94683665],
        [157.04840788],
        [157.04840788],
        [168.34161287],
        [141.06067981],
        [141.06067981],
        [157.04840788],
        [157.048

In [35]:
corrcoef(yHat, testMat[:,1],rowvar=0)[0,1]

0.9640852318222145

In [36]:
myTree = createTree(trainMat, modelLeaf, modelErr, (1,20))
myTree

{'spInd': 0,
 'spVal': 4.0,
 'left': {'spInd': 0,
  'spVal': 12.0,
  'left': {'spInd': 0,
   'spVal': 16.0,
   'left': {'spInd': 0, 'spVal': 20.0, 'left': matrix([[47.58621512],
            [ 5.51066299]]), 'right': matrix([[37.54851927],
            [ 6.23298637]])},
   'right': matrix([[43.41251481],
           [ 6.37966738]])},
  'right': {'spInd': 0, 'spVal': 9.0, 'left': matrix([[-2.87684083],
           [10.20804482]]), 'right': {'spInd': 0,
    'spVal': 6.0,
    'left': matrix([[-11.84548851],
            [ 12.12382261]]),
    'right': matrix([[-17.21714265],
            [ 13.72153115]])}}},
 'right': matrix([[ 68.87014372],
         [-11.78556471]])}

In [38]:
yHat = createForeCast(myTree, testMat[:,0],modelTreeEval)
corrcoef(yHat, testMat[:,1],rowvar=0)[0,1]

0.9760412191380623

R的平方值越接近1.0越好，这里模型树的结果比回归树好。

In [39]:
ws,X,Y = linearSolve(trainMat)
ws

matrix([[37.58916794],
        [ 6.18978355]])

In [40]:
for i in range(shape(testMat)[0]):
    yHat[i] = testMat[i,0]*ws[1,0]+ws[0,0]

In [41]:
corrcoef(yHat, testMat[:,1],rowvar=0)[0,1]

0.9434684235674762

## 利用GUI对回归树调优
1. 收集数据：所提供的文本文件
2. 准备数据：用Python解析上述文件，得到数值型数据
3. 分析数据：用Tkinter构建一个GUI来展示模型和数据
4. 训练算法：训练一棵回归树和一棵模型树，并与数据集一起展示出来
5. 测试算法：这里不需要测试过程
6. 使用算法：GUI使得人们可以在预剪枝时测试不同参数的影响，还可以帮助我们选择模型的类型

In [1]:
import tkinter as tk #注意tkinter 首字母小写，大写报错
root = tk.Tk()
myLabel = tk.Label(root, text='Hello,World')
myLabel.grid()
root.mainloop()

用于构建树管理器界面的Tkinter小部件

In [11]:
from numpy import *
from tkinter import *
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

def reDraw(tolS, tolN):
    reDraw.f.clf()  #清空之前的图像
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN=2
        myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
    else:
        myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:,0].flatten().A[0], reDraw.rawDat[:,1].flatten().A[0], s=5)
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
    reDraw.canvas.show()
        
def getInputs():
    try: tolN = int(tolNentry.get())
    except:
        tolN =10
        print("enter Integer fro tolN") #tolN 期望整数
        tolNentry.delete(0, END)
        tolNentry.insert(0,'10')
    try: tolS = float(tolSentry.get())
    except:
        tolS =1.0
        print("enter Float for tolS") #tolS期望浮点数
        tolSentry.delete(0, END)
        tolSentry.insert(0,'1.0')
    return tolN, tolS


def drawNewTree():
    """
    有人点击按钮时就调用该函数，调用getInputs()方法得到输入框的值
    利用该值调用reDraw（）方法生成图
    """
    tolN,tolS = getInputs()
    reDraw(tolS,tolN)
"""
先创造一个Tk类型的根部件，然后插入标签
使用.grid()设置行和列的位置
使用columnspan rowspan告诉布局管理器是否允许一个小部件跨行或跨列
"""
root=Tk()

# Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
reDraw.f = Figure(figsize=(5,4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)

chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text='Model Tree', variable=chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

#初始化一些全局变量
reDraw.rawDat = mat(loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]), max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
root.mainloop()



