# 线性回归在创建模型时需要拟合所有样本（局部加权线性回归除外）
# 树回归
### 优点：可以应对复杂和非线性的数据建模
### 缺点：结果不易理解
### 适用数据类型：数值型和标称型数据
### CART利用二元切分法来处理连续型变量，二元切分就是大于某个数走左子树，否则走右子树
### 树回归一般步骤
##### 搜集数据
##### 准备数据：需要数值型数据，标称型数据应该映射为二值型数据
##### 分析数据：可视化分析，以字典方式生成树
##### 训练算法：大部分开销在于叶节点构建
##### 测试算法：适用平方误差来测验
##### 使用训练好的模型进行预测


# 使用类来创建树节点
# 使用递归方式创建树
### 伪代码
### 找到最佳待切分特征：
     如果该节点无法再划分，保存该节点为叶子节点
     执行二元切分
     在右子树上调用createTree()方法
     在左子树上调用createTree()方法

In [15]:


    
class treeNode():
    def __init__(self,feat,val,right,left):
        featureToSplitOn=feat
        valueOfSplit=val
        rightBranch=right
        leftBranch=left
    
# CART算法实现
from numpy import*

# 加载数据集
def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine)) #map all elements to float()
        dataMat.append(fltLine)
    return dataMat
    
# 按照特征的值切分数据集
def binSplitDataSet(dataSet, feature, value):
    #print(nonzero(dataSet[:,feature] <= value)[0])
    #print(dataSet[[1,2,3],:])
    #print(list(nonzero(dataSet[:,feature] > value)[0]))
    mat0 = dataSet[list(nonzero(dataSet[:,feature] > value)[0]),:]
    mat1 = dataSet[list(nonzero(dataSet[:,feature] <= value)[0]),:]
    return mat0,mat1


# 测试
def testT():
    # 4维单位矩阵
    testMat=mat(eye(4))
    print(testMat)
    mat0,mat1=binSplitDataSet(testMat, 0, 0.5)
    print(mat1,mat0)
testT()

[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
[[0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]] [[1. 0. 0. 0.]]


# 将CART用于回归
### 计算混乱度使用平方误差衡量
### 构建树
对于寻找最佳切分点的函数伪代码如下
     对于每个特征：
         初始化当前最小误差为正无穷
             对于每个特征值：
                 将数据集切分为两份
                    计算切分误差
                如果当前误差小于当前误差：
                    将当前切分点设置为最佳切分，并更新最小误差
                返回最佳切分的特征和阈值

In [16]:

    

# 该函数负责负责生成叶节点
def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])

# 计算平方误差
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]
    
# 划分数据
# ops（tolS,tolN）两个参数分别表示容许的误差下降值，切分的最少样本数
### 两个参数用来控制函数的停止时机
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        soleSet=[]
        for j in range(dataSet[:,featIndex].shape[0]):
            temp=dataSet[:,featIndex][j,featIndex]
            if temp not in soleSet:
                soleSet.append(temp)
        # print(soleSet)
        for splitVal in soleSet:
            #print(dataSet,featIndex, splitVal)
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS: 
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS: 
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue    #returns the best feature to split on，and the value used for that split
    
# 构造树
# 参数分别为：数据集，构造叶子节点的函数、误差计算函数、ops是一个包含构造树所需其他参数的元组
# ops（tolS,tolN）两个参数分别表示容许的误差下降值，切分的最少样本数
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    # 找到最佳切分特征和值
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None:
        return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    #print(retTree,type(retTree))
    return retTree

# 查看数据分布
def showPlot():
    import matplotlib.pyplot as plt
    fig=plt.figure()
    ax=fig.add_subplot(111)
    myData=loadDataSet(r'./data/ex00.txt')
    myData=mat(myData)
    #print(myData[:,1],myData[:,1])
    plt.scatter(myData[:,0].tolist(),myData[:,1].tolist(),color="b")
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()
    
showPlot()

# 测试
def testT():
    myData=loadDataSet(r'./data/ex00.txt')
    myMat=mat(myData)
    # mat0,mat1=binSplitDataSet(myMat,0,0.996757)
    print(createTree(myMat))
    print(createTree(myMat,ops=(0,1)))
    
     
testT()


# 查看数据分布
def showPlot():
    import matplotlib.pyplot as plt
    fig=plt.figure()
    ax=fig.add_subplot(111)
    myData=loadDataSet(r'./data/ex0.txt')
    myData=mat(myData)
    #print(myData[:,1],myData[:,1])
    plt.scatter(myData[:,1].tolist(),myData[:,2].tolist(),color="b")
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()
    
showPlot()

# 测试
def testT():
    myData=loadDataSet(r'./data/ex0.txt')
    print(shape(myData))
    # 去掉第一列无用数据
    myMat=mat(myData)[:,1:]
    # mat0,mat1=binSplitDataSet(myMat,0,0.996757)
    print(createTree(myMat))
    
     
testT()



{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}
{'spInd': 0, 'spVal': 0.48813, 'left': {'spInd': 0, 'spVal': 0.620599, 'left': {'spInd': 0, 'spVal': 0.625336, 'left': {'spInd': 0, 'spVal': 0.625791, 'left': {'spInd': 0, 'spVal': 0.643601, 'left': {'spInd': 0, 'spVal': 0.651376, 'left': {'spInd': 0, 'spVal': 0.6632, 'left': {'spInd': 0, 'spVal': 0.683921, 'left': {'spInd': 0, 'spVal': 0.819823, 'left': {'spInd': 0, 'spVal': 0.837522, 'left': {'spInd': 0, 'spVal': 0.846455, 'left': {'spInd': 0, 'spVal': 0.919384, 'left': {'spInd': 0, 'spVal': 0.976414, 'left': {'spInd': 0, 'spVal': 0.985425, 'left': {'spInd': 0, 'spVal': 0.989888, 'left': {'spInd': 0, 'spVal': 0.993349, 'left': 1.035533, 'right': 1.077553}, 'right': {'spInd': 0, 'spVal': 0.988852, 'left': 0.744207, 'right': 1.069062}}, 'right': 1.227946}, 'right': {'spInd': 0, 'spVal': 0.953112, 'left': {'spInd': 0, 'spVal': 0.975022, 'left': 0.862911, 'right': 0.673579}, 'right': {'spInd': 0, 's

(200, 3)
{'spInd': 0, 'spVal': 0.39435, 'left': {'spInd': 0, 'spVal': 0.582002, 'left': {'spInd': 0, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 0, 'spVal': 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}


# 树剪枝：通过降低决策树的复杂度来避免过拟合的过程
### 防止过拟合
### 预剪枝：在chooseBestSplit()中提前终止条件，实际是一种预剪枝
### 后剪枝：需要训练集和测试集，通过测试集来进行剪枝


In [17]:

# 预剪枝
def preCut():
    myData=loadDataSet(r'./data/ex2.txt')
    myMat=mat(myData)
    print("使用停止条件",createTree(myMat))
    print("修改停止条件",createTree(myMat,ops=(10000,4)))
# preCut()
# 从上述情况可以看出超参对结果影响较大，结果具有不可预测性    
    
# 后剪枝，不需要指定参数
# 伪代码如下：
### 基于已有的树切分测试数据：
    ### 如果存在任一子集是一棵树，则在该子集上递归剪枝过程
    ### 计算将当前两个叶子节点合并后的误差
    ### 计算不合并的误差
    ### 比较两种误差大小，做出决策

# 判断该节点是否为字典（树）,即判断是否需要剪枝
def isTree(obj):
    return (type(obj).__name__=='dict')


# 一个递归函数，从上向下遍历直到叶节点为止，如果找到两个叶节点就计算其平均值，对该树进行塌陷处理（返回树平均值）
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

# 剪枝函数，参数分别为：代剪枝的树和剪枝所需的测试数据
def prune(tree, testData):
    # 判断测试集是否为空
    if shape(testData)[0] == 0:
        return getMean(tree) #if we have no test data collapse the tree
    # 如果该节点不是子树，就直接划分数据集
    if (isTree(tree['right']) or isTree(tree['left'])): #if the branches are not trees try to prune them
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    # 如果为左子树，递归调用prune函数对数据进行切分
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet)
    # 如果为右子树，递归划分数据
    if isTree(tree['right']):
        tree['right'] =  prune(tree['right'], rSet)
    #if they are now both leafs, see if we can merge them
    # 如果两者都为叶子节点，看能否合并
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
            sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        # 根据错误率判断是否需要合并（剪枝）
        if errorMerge < errorNoMerge: 
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

def testPrune():
    myData=loadDataSet(r'./data/ex2.txt')
    myMat2=mat(myData)
    myTree=createTree(myMat2,ops=(0,1))
    
    myDataTest=loadDataSet(r'./data/ex2test.txt')
    myDataTest=mat(myDataTest)
    print(prune(myTree,myDataTest))
    
testPrune() 

# 从结果可以看出没有像预期剪成两部分，这时效果不如预剪枝
# 因此为了逼近最佳效果，可以同时使用两种剪枝方法

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.965969, 'left': 92.5239915, 'right': {'spInd': 0, 'spVal': 0.956951, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': {'spInd': 0, 'spVal': 0.960398, 'left': 112.386764, 'right': 123.559747}, 'right': 135.837013}, 'right': 111.2013225}}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.763328, 'left': {'spInd': 0, 'spVal': 0.769043, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.806158, 'left': {'spInd': 0, 'spVal': 0.815215, 'left': {'spInd': 0, 'spVa

# 一种新型树结构
### 用树来对数据建模，除了将叶节点简单设置为数值之外，还可以把叶节点设置为分段线性函数
### 分段线性：模型由多个线性片段组成
### 例如将数据集切分成两个数据集，分别对其建立线性模型
### 图示可以看出：0~0.3为一种线性组合、其他的是另外一个线性组合

In [18]:


import matplotlib.pyplot as plt
def multiLinePlot():
    myData=loadDataSet(r'./data/exp2.txt')
    myMat=mat(myData)
    fig=plt.figure()
    ax=fig.add_subplot(111)
    plt.scatter(myMat[:,0].tolist(),myMat[:,1].tolist(),color="b")
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()
    
multiLinePlot()

# 决策树相比较与其他机器学习模型结果更易于理解，模型树的可解释性是它优于回归树的特点之一
### 对于树模型，先使用线性模型对其进行拟合，然后计算真实目标值与预测值之间的差值，将这些平方差求和即为误差

In [19]:


# 构建线性拟合函数
# 模型树的叶节点生成函数
def linearSolve(dataSet):   #helper function used in two places
    m,n = shape(dataSet)
    # 将数据格式化
    X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
    xTx = X.T*X
    # 通过行列式判断矩阵是否可逆
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y
# 当数据不在需要切分时，负责生成叶节点模型
def modelLeaf(dataSet):#create linear model and return coeficients
    ws,X,Y = linearSolve(dataSet)
    return ws

# 计算平方误差
def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))


# 测试
def teatLinear():
    myMat2=mat(loadDataSet(r'./data/exp2.txt'))
    myTree=createTree(myMat2,modelLeaf,modelErr,ops=(1,10))
    # 根据树模型绘制图形
    return myTree['spVal'],myTree['left'],myTree['right']

# 查看线性拟合情况
### 红色线为拟合直线
### 绿色线为线性拟合直线
def linearPlot():
    myMat=mat(loadDataSet(r'./data/exp2.txt'))
    fig=plt.figure()
    ax=fig.add_subplot(111)
    plt.scatter(myMat[:,0].tolist(),myMat[:,1].tolist(),color="b")
    plt.xlabel('X')
    plt.ylabel('Y')
    # 直接一次拟合结果
    ws,X,Y = linearSolve(myMat)
    yHat = X * ws
    plt.plot(myMat[:,0].tolist(),yHat,color="green")
    
    # 使用树模型拟合结果,limit表示分界点
    limit,left,right= teatLinear()
    
    #print(shape(left),shape(right))
    # 将数据按照分界点进行
    leftData=[[],[]]
    rightData=[[],[]]
    for i in range(myMat.shape[0]):
        if myMat[i,0]<=limit:
            rightData[0].append(myMat[i,0])
            rightData[1].append(myMat[i,1])
        else:
            leftData[0].append(myMat[i,0])
            leftData[1].append(myMat[i,1])
    leftData=mat(leftData)
    rightData=mat(rightData)
    #print(shape(leftData[0,:]))
    #print(left[0,0]+left[1,0]*rightData[0,0],rightData[0,0],leftData[0,0])
   
    leftHat=mat(left[0,0]+left[1,0]*leftData[0,:]) 
    rightHat=mat(right[0,0]+right[1,0]*rightData[0,:])
    
    # print(shape(leftHat))
    # print(leftData[0],leftHat[0])
    # print(left[0,0],left[1,0])
    print(shape(leftData[0,:].tolist()))
    plt.plot(leftData[0,:].tolist()[0],leftHat.T,color="red")
    plt.plot(rightData[0,:].tolist()[0],rightHat.T,color="red")
    # plt.plot([0.1+i/10 for i in range(10)],[i for i in range(10)],color="red")
    # plt.plot(rightData[0,:].tolist(),rightHat,color="red")
    plt.show()
    
linearPlot()

(1, 143)


In [20]:
# 使用树回归进行预测
# 用于对回归树进行预测
import pandas as pd
from numpy import *
import matplotlib.pyplot as plt


# 加载数据集
def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine)) #map all elements to float()
        dataMat.append(fltLine)
    return dataMat
    


def regTreeEval(model, inDat):
    return float(model)

# 用于对回归节点进行预测
def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1]=inDat
    return float(X*model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
        else: return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)
        
def createForeCast(tree, testData, modelEval=regTreeEval):
    m=len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat


# 查看原始数据集
def showPlot():
    fig=plt.figure()
    ax=fig.add_subplot(111)
    data=loadDataSet(r'data/bikeSpeedVsIq_train.txt')
    data=mat(data)
    plt.scatter(data[:,0].tolist(),data[:,1].tolist(),color="blue")
    plt.xlabel('bicycle speed')
    plt.ylabel('IQ')
    plt.show()
    
showPlot()

# 测试回归树效果
def testRegTree():
    trainMat=mat(loadDataSet(r'data/bikeSpeedVsIq_train.txt'))
    testMat=mat(loadDataSet(r'data/bikeSpeedVsIq_test.txt'))
    myTree=createTree(trainMat,ops=(1,20))
    yHat=createForeCast(myTree,testMat[:,0])
    correct=corrcoef(yHat,testMat[:,1,],rowvar=0)[0,1]
    return correct
    
# 测试模型树效果
def testModelTree():
    trainMat=mat(loadDataSet(r'data/bikeSpeedVsIq_train.txt'))
    testMat=mat(loadDataSet(r'data/bikeSpeedVsIq_test.txt'))
    myTree=createTree(trainMat,modelLeaf,modelErr,(1,20))
    yHat=createForeCast(myTree,testMat[:,0],modelTreeEval)
    correct=corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]
    return correct
    
    
# 测试一般回归方法
def testGeneralReg():
    trainMat=mat(loadDataSet(r'data/bikeSpeedVsIq_train.txt'))
    testMat=mat(loadDataSet(r'data/bikeSpeedVsIq_test.txt'))
    ws,X,Y=linearSolve(trainMat)
    #yHat=zeros((shape(testMat)[0],1))
    yHat=[i for i in range(shape(testMat)[0])]
    for i in range(shape(testMat)[0]):
        yHat[i]=testMat[i,0]*ws[1,0]+ws[0,0]
    #print(yHat)
    correct=corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]
    return correct
    
    
def compare():
    res={}
    res['regTree']=testRegTree()
    res['modelTree']=testModelTree()
    res['generalReg']=testGeneralReg()
    # print(res)
    print("结果如下：",end="\n")
    for i in ['regTree','modelTree','generalReg']:
        print("%s R平方为： %.3f"%(i,res[i]),end="\n")
compare()
# 通过比较分析，前面两种树回归方法好于直接回归

结果如下：
regTree R平方为： 0.964
modelTree R平方为： 0.976
generalReg R平方为： 0.943


# 使用python构建GUI
# 使用GUI对树回归进行调优

### 示例步骤
##### 搜集数据
##### 准备数据：将离散型数据连续化
##### 分析数据：使用Tkinter创建GUI进行展示
##### 训练算法：训练一棵回归树和模型树，并可视化
##### 测试算法
##### 使用算法：通过GUI更好地进行超参调优
# 构建GUI

### tkinter 的GUI是由一些小部件组成的，例如文本框(TextBox)、按钮(Button)、
### 标签(Label)、复选框(Check Button)等对象，grid是一种布局方式，将内容放在二维表格中
### 常见python GUI参看:[GUI](https://blog.csdn.net/qxyloveyy/category_9706023.html)

In [30]:
from tkinter import *

def simpleWindow():
    root=Tk()
    myLabel=Label(root,text="hello,world!")
    myLabel.grid()
    root.mainloop()


from numpy import *

# 输入参数
def reDraw(tolS,tolN):
    pass

def drawNewTree():
    pass


# 加载数据集
def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine)) #map all elements to float()
        dataMat.append(fltLine)
    return dataMat
    


# 创建窗口
root=Tk()
Label(root,text="Plot Place Holder").grid(row=0,columnspan=3)
Label(root,text="tolN").grid(row=1,column=0)
tolNetry=Entry(root)
tolNetry.grid(row=1,column=1)
tolNetry.insert(0,'10')
Label(root,text="tolS").grid(row=2,column=0)
tolSentry=Entry(root)
tolSentry.grid(row=2,column=1)
tolSentry.insert(0,'1.0')
Button(root,text='ReDraw',command=drawNewTree).grid(row=0,column=2,rowspan=3)

cnkBtnVar=IntVar()
cnkBtn=Checkbutton(root,text="Model Tree",variable=cnkBtnVar)
cnkBtn.grid(row=3,column=0,columnspan=2)

reDraw.rawDat=mat(loadDataSet(r'data/sine.txt'))
tempMat=reDraw.rawDat
reDraw.testDat=arange(min(tempMat[:,0]),max(tempMat[:,0]),0.01)
reDraw(1.0,10)

# 添加退出按钮
Button(root,text="Quit",fg="black",command=root.quit).grid(row=2,column=2)

root.mainloop()





In [31]:
# 对matplotlib与tkinter进行集成展示
# 进行集成展示
from numpy import *

from tkinter import *

import matplotlib

matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

def reDraw(tolS,tolN):
    reDraw.f.clf()        # clear the figure
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree=createTree(reDraw.rawDat, modelLeaf,\
                                   modelErr, (tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat, \
                                       modelTreeEval)
    else:
        myTree=createTree(reDraw.rawDat, ops=(tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:,0].tolist(), reDraw.rawDat[:,1].tolist(), s=5) #use scatter for data set
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
    reDraw.canvas.draw()
    
def getInputs():
    try: tolN = int(tolNentry.get())
    except: 
        tolN = 10 
        print("enter Integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0,'10')
    try: tolS = float(tolSentry.get())
    except: 
        tolS = 1.0 
        print("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0,'1.0')
    return tolN,tolS

def drawNewTree():
    tolN,tolS = getInputs()#get values from Entry boxes
    reDraw(tolS,tolN)
    
root=Tk()

reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.rawDat = mat(loadDataSet(r'data/sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
               
root.mainloop()



### 本章小结
##### 数据集中经常包含一些复杂的相关关系，使得输入数据与目标变量之间呈现非线性关系
##### 对于这种复杂的关系可以利用树来对预测值进行分段，包括分段常数或者分段直线
##### 若叶节点使用的模型是分段常数就是回归树，如果是线性回归方程就称之为模型树
##### CART算法可以构建二元树，并处理离散型或者连续型数据的切分
##### 预剪枝：在树的构建过程中就进行剪枝，需要用户定义一些参数
##### 后剪枝：在树构建完毕之后进行剪枝