In [1]:
import numpy as np
import tkinter as tk
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

In [2]:
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        # 将每行读成浮点数
        fltLine = list(map(float, curLine))
        dataMat.append(fltLine)
    return dataMat

In [3]:
def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1

### 树节点

In [4]:
class treeNode():
    def __init__(self, feat, val, right, lift):
        featureToSplitOn = feat
        valueOfSplit = val
        rightBranch = right
        leftBrance = left

创建树的伪代码：  
找到最佳的代切分特征：  
&ensp;&ensp;&ensp;&ensp;如果该节点不能再分，则存为叶结点  
&ensp;&ensp;&ensp;&ensp;执行二元切分  
&ensp;&ensp;&ensp;&ensp;在右子树调用createTree()方法  
&ensp;&ensp;&ensp;&ensp;在左子树调用createTree()方法

In [5]:
# 生成叶结点
def regLeaf(dataSet):
    return np.mean(dataSet[:, -1])

# 计算平方误差
def regErr(dataSet):
    return np.var(dataSet[:, -1]) * np.shape(dataSet)[0]

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    # 将数据集分成两个部分
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree 

### 回归树的切分函数

In [6]:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    tolS = ops[0]
    tolN = ops[1]
    # 所有值相等则退出
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m, n = np.shape(dataSet)
    S = errType(dataSet)
    bestS = np.inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if np.shape(mat0)[0] < tolN or np.shape(mat1)[0] < tolN:
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 如果误差减少不大则退出
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    # 如果切分出的数据集很小则退出
    if np.shape(mat0)[0] < tolN or np.shape(mat1)[0] < tolN:
        return None, leafType(dataSet)
    return bestIndex, bestValue

In [7]:
def isTree(obj):
    return type(obj).__name__ == 'dict'

## 模型树  
可解释性好，预测准确的更高

模型树叶结点生成函数

In [8]:
# 讲数据集格式化成目标变量Y和自变量X
def linearSolve(dataSet):
    m, n = np.shape(dataSet)
    X = np.mat(np.ones((m, n)))
    Y = np.mat(np.ones((m, 1)))
    X[:, 1:n] = dataSet[:, 0:n-1]
    Y = dataSet[:, -1]
    xTx = X.T * X
    if np.linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse, try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y

In [9]:
# 不再切分时生成叶结点
def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws

In [10]:
# 计算误差
def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(np.power(Y - yHat, 2))

### 用树回归进行预测

In [11]:
# 回归树节点预测
def regTreeEval(model, inDat):
    return float(model)

In [12]:
# 模型树节点预测
def modelTreeEval(model, inDat):
    n = np.shape(inDat)[1]
    X = np.mat(np.ones((1, n+1)))
    X[:, 1 : n+1] = inDat
    return float(X * model)

In [13]:
def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree):
        return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

In [14]:
def creatForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = np.mat(np.zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, np.mat(testData[i]), modelEval)
    return yHat

## 使用 Python 的 Tkinter 库创建 GUI

In [15]:
def reDraw(tolS, tolN):
    reDraw.f.clf()
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2:
            tolN = 2
        myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS, tolN))
        yHat = creatForeCast(myTree, reDraw.testDat, modelTreeEval)
    else:
        myTree = createTree(reDraw.rawDat, ops = (tolS, tolN))
        yHat = creatForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:, 0].tolist(), reDraw.rawDat[:, 1].tolist(), s = 5)
    reDraw.a.plot(reDraw.testDat, yHat, linewidth = 2.0)
    reDraw.canvas.show()

In [16]:
def getInputs():
    try:
        tolN = int(tolNentry.get())
    except:
        tolN = 10
        print('enter Integer for tolN')
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try:
        tolS = int(tolSentry.get())
    except:
        tolS = 1.0
        print('enter Integer for tolS')
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

In [17]:
def drawNewTree():
    tolN, tolS = getInputs()
    reDraw(tolS, tolN)

In [None]:
root = tk.Tk()

reDraw.f = Figure(figsize = (5, 4), dpi = 100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master = root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row = 0, columnspan = 3)

# tk.Label(root, text = 'Plot Place Holder').grid(row = 0, columnspan = 3)
tk.Label(root, text = 'tolN').grid(row = 1, column = 0)
tolNentry = tk.Entry(root)
tolNentry.grid(row = 1, column = 1)
tolNentry.insert(0, '10')
tk.Label(root, text = 'tolS').grid(row = 2, column = 0)
tolSentry = tk.Entry(root)
tolSentry.grid(row = 2, column = 1)
tolSentry.insert(0, '1.0')
tk.Button(root, text = 'ReDraw', command = drawNewTree).grid(row = 1, column = 2, rowspan = 3)
chkBtnVar = tk.IntVar()
chkBtn = tk.Checkbutton(root, text = 'Model Tree', variable = chkBtnVar)
chkBtn.grid(row = 3, column = 0, columnspan =2)

reDraw.rawDat = np.mat(loadDataSet('sine.txt'))
reDraw.testDat = np.arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)

root.mainloop()

  """
  
