# 9.1 复杂数据的局部性建模

# 9.2 连续和离散型特征的树的构建

**程序清单9-1** CART算法的实现代码

In [1]:
from numpy import *

In [2]:
def load_data_set(file_name):
    data_mat = []
    fr = open(file_name)
    for line in fr.readlines():
        cur_line = line.strip().split('\t')
        flt_line = list(map(float, cur_line))
        data_mat.append(flt_line)
    return data_mat

In [3]:
def bin_split_data_set(data_set, feature, value):
    mat0 = data_set[nonzero(data_set[:,feature] > value)[0],:]
    mat1 = data_set[nonzero(data_set[:,feature] <= value)[0],:]
    return mat0, mat1

In [4]:
test_mat = mat(eye(4))

In [5]:
test_mat

matrix([[ 1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.]])

In [6]:
mat0, mat1 = bin_split_data_set(test_mat, 1, 0.5)

In [7]:
mat0

matrix([[ 0.,  1.,  0.,  0.]])

In [8]:
mat1

matrix([[ 1.,  0.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.]])

# 9.3 将CART算法用于回归

## 9.3.1 构建树

**程序清单9-2** 回归树的切分函数

In [9]:
def reg_leaf(data_set):
    return mean(data_set[:,-1])
def reg_err(data_set):
    return var(data_set[:,-1]) * shape(data_set)[0]
def choose_best_split(data_set, leaf_type=reg_leaf, err_type=reg_err, ops=(1,4)):
    tolS = ops[0]
    tolN = ops[1]
    if len(set(data_set[:,-1].T.tolist()[0])) == 1:
        return None, leaf_type(data_set)
    m,n = shape(data_set)
    S = err_type(data_set)
    bestS = inf
    best_index = 0
    best_value = 0
    for feat_index in range(n-1):
        for split_val in set((data_set[:,feat_index].T.A.tolist())[0]):
            mat0, mat1 = bin_split_data_set(data_set, feat_index, split_val)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue
            newS = err_type(mat0) + err_type(mat1)
            if newS < bestS:
                best_index = feat_index
                best_value = split_val
                bestS = newS
    if (S - bestS) < tolS:
        return None, leaf_type(data_set)
    mat0, mat1 = bin_split_data_set(data_set, best_index, best_value)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leaf_type(data_set)
    return best_index, best_value

In [10]:
def create_tree(data_set, leaf_type=reg_leaf, err_type=reg_err, ops=(1,4)):
    feat, val = choose_best_split(data_set, leaf_type, err_type, ops)
    if feat == None:
        return val
    ret_tree = {}
    ret_tree['sp_ind'] = feat
    ret_tree['sp_val'] = val
    l_set, r_set = bin_split_data_set(data_set, feat, val)
    ret_tree['left'] = create_tree(l_set, leaf_type, err_type, ops)
    ret_tree['right'] = create_tree(r_set, leaf_type, err_type, ops)
    return ret_tree

## 9.3.2 运行代码

In [11]:
my_dat = load_data_set('ex00.txt')
my_mat = mat(my_dat)

In [12]:
create_tree(my_mat)

{'left': 1.0180967672413792,
 'right': -0.044650285714285719,
 'sp_ind': 0,
 'sp_val': 0.48813}

In [13]:
my_dat1 = load_data_set('ex0.txt')
my_mat1 = mat(my_dat1)

In [14]:
create_tree(my_mat1)

{'left': {'left': {'left': 3.9871631999999999,
   'right': 2.9836209534883724,
   'sp_ind': 1,
   'sp_val': 0.797583},
  'right': 1.980035071428571,
  'sp_ind': 1,
  'sp_val': 0.582002},
 'right': {'left': 1.0289583666666666,
  'right': -0.023838155555555553,
  'sp_ind': 1,
  'sp_val': 0.197834},
 'sp_ind': 1,
 'sp_val': 0.39435}

# 9.4 树剪枝

## 9.4.1 预剪枝

In [15]:
create_tree(my_mat, ops=(0,1))

{'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 1.035533,
                'right': 1.077553,
                'sp_ind': 0,
                'sp_val': 0.993349},
               'right': {'left': 0.74420699999999995,
                'right': 1.069062,
                'sp_ind': 0,
                'sp_val': 0.988852},
               'sp_ind': 0,
               'sp_val': 0.989888},
              'right': 1.227946,
              'sp_ind': 0,
              'sp_val': 0.985425},
             'right': {'left': {'left': 0.86291099999999998,
               'right': 0.67357900000000004,
               'sp_ind': 0,
               'sp_val': 0.975022},
              'right': {'left': {'left': 1.0646899999999999,
                'right': {'left': 0.94525499999999996,
                 'right': 1.0229060000000001,
                 'sp_ind': 0,
                 'sp_val': 0.950153},
                'sp_ind': 0,
  

In [16]:
my_dat2 = load_data_set('ex2.txt')
my_mat2 = mat(my_dat2)
create_tree(my_mat2)

{'left': {'left': {'left': {'left': 105.24862350000001,
    'right': 112.42895575000001,
    'sp_ind': 0,
    'sp_val': 0.958512},
   'right': {'left': {'left': {'left': {'left': 87.310387500000004,
       'right': {'left': {'left': 96.452866999999998,
         'right': {'left': 104.82540899999999,
          'right': {'left': 95.181792999999999,
           'right': 102.25234449999999,
           'sp_ind': 0,
           'sp_val': 0.872883},
          'sp_ind': 0,
          'sp_val': 0.892999},
         'sp_ind': 0,
         'sp_val': 0.910975},
        'right': 95.275843166666661,
        'sp_ind': 0,
        'sp_val': 0.85497},
       'sp_ind': 0,
       'sp_val': 0.944221},
      'right': {'left': 81.110151999999999,
       'right': 88.784498800000009,
       'sp_ind': 0,
       'sp_val': 0.811602},
      'sp_ind': 0,
      'sp_val': 0.833026},
     'right': 102.35780185714285,
     'sp_ind': 0,
     'sp_val': 0.790312},
    'right': 78.085643250000004,
    'sp_ind': 0,
    'sp_val': 

In [17]:
create_tree(my_mat2, ops=(10000,4))

{'left': 101.35815937735848,
 'right': -2.6377193297872341,
 'sp_ind': 0,
 'sp_val': 0.499171}

## 9.4.2 后剪枝

**程序清单9-3** 回归树剪枝函数

In [18]:
def is_tree(obj):
    return (type(obj).__name__=='dict')

def get_mean(tree):
    if is_tree(tree['right']):
        tree['right'] = get_mean(tree['right'])
    if is_tree(tree['left']):
        tree['left'] = get_mean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree, test_data):
    if shape(test_data)[0] == 0:
        return get_mean(tree)
    if (is_tree(tree['right'])) or (is_tree(tree['left'])):
        l_set, r_set = bin_split_data_set(test_data,
                                         tree['sp_ind'],
                                         tree['sp_val'])
    if is_tree(tree['left']):
        tree['left'] = prune(tree['left'], l_set)
    if is_tree(tree['right']):
        tree['right'] = prune(tree['right'], r_set)
    if not is_tree(tree['left']) and not is_tree(tree['right']):
        l_set, r_set = bin_split_data_set(test_data,
                                         tree['sp_ind'],
                                         tree['sp_val'])
        error_no_merge = sum(power(l_set[:,-1]-tree['left'], 2)) + \
        sum(power(r_set[:,-1]-tree['right'], 2))
        tree_mean = (tree['left']+tree['right'])/2.0
        error_merge = sum(power(test_data[:,-1]-tree_mean, 2))
        if error_merge < error_no_merge:
            print('merging')
            return tree_mean
        else:
            return tree
    else:
        return tree

In [19]:
my_tree = create_tree(my_mat2, ops=(0,1))

In [20]:
my_dat_test = load_data_set('ex2test.txt')
my_mat2_test = mat(my_dat_test)

In [21]:
prune(my_tree, my_mat2_test)

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging


{'left': {'left': {'left': {'left': 92.523991499999994,
    'right': {'left': {'left': {'left': 112.386764,
       'right': 123.559747,
       'sp_ind': 0,
       'sp_val': 0.960398},
      'right': 135.83701300000001,
      'sp_ind': 0,
      'sp_val': 0.958512},
     'right': 111.2013225,
     'sp_ind': 0,
     'sp_val': 0.956951},
    'sp_ind': 0,
    'sp_val': 0.965969},
   'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 96.41885225,
              'right': 69.318648999999994,
              'sp_ind': 0,
              'sp_val': 0.948822},
             'right': {'left': {'left': 110.03503850000001,
               'right': {'left': 65.548417999999998,
                'right': {'left': 115.75399400000001,
                 'right': {'left': {'left': 94.396114499999996,
                   'right': 85.005351000000005,
                   'sp_ind': 0,
                   'sp_val': 0.912161},
                  'right': {'left': {'left

# 9.5 模型树

**程序清单9-4** 模型树的叶节点生成函数

In [22]:
def linear_solve(data_set):
    m,n = shape(data_set)
    X = mat(ones((m,n)))
    Y = mat(ones((m,1)))
    X[:,1:n] = data_set[:,0:n-1]
    Y = data_set[:,-1]
    xTx = X.T * X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y

In [23]:
def model_leaf(data_set):
    ws,X,Y = linear_solve(data_set)
    return ws

In [24]:
def model_err(data_set):
    ws,X,Y = linear_solve(data_set)
    y_hat = X * ws
    return sum(power(Y - y_hat, 2))

In [25]:
my_mat2 = mat(load_data_set('exp2.txt'))

In [26]:
create_tree(my_mat2, model_leaf, model_err, (1,10))

{'left': matrix([[  1.69855694e-03],
         [  1.19647739e+01]]), 'right': matrix([[ 3.46877936],
         [ 1.18521743]]), 'sp_ind': 0, 'sp_val': 0.285477}

# 9.6 示例：树回归与标准回归的比较

**程序清单9-5** 用树回归进行预测的代码

In [27]:
def reg_tree_eval(model, in_dat):
    return float(model)

In [28]:
def model_tree_eval(model, in_dat):
    n = shape(in_dat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1] = in_dat
    return float(X * model)

In [29]:
def tree_fore_cast(tree, in_data, model_eval=reg_tree_eval):
    if not is_tree(tree):
        return model_eval(tree, in_data)
    if in_data[tree['sp_ind']] > tree['sp_val']:
        if is_tree(tree['left']):
            return tree_fore_cast(tree['left'], in_data, model_eval)
        else:
            return model_eval(tree['left'], in_data)
    else:
        if is_tree(tree['right']):
            return tree_fore_cast(tree['right'], in_data, model_eval)
        else:
            return model_eval(tree['right'], in_data)

In [30]:
def create_fore_cast(tree, test_data, model_eval=reg_tree_eval):
    m = len(test_data)
    y_hat = mat(zeros((m,1)))
    for i in range(m):
        y_hat[i,0] = tree_fore_cast(tree, mat(test_data[i]), model_eval)
    return y_hat

In [31]:
train_mat = mat(load_data_set('bikeSpeedVsIq_train.txt'))
test_mat = mat(load_data_set('bikeSpeedVsIq_test.txt'))

In [32]:
my_tree = create_tree(train_mat, ops=(1,20))
y_hat = create_fore_cast(my_tree, test_mat[:,0])
corrcoef(y_hat, test_mat[:,1], rowvar=0)[0,1]

0.96408523182221506

In [33]:
my_tree = create_tree(train_mat, model_leaf, model_err, (1,20))
y_hat = create_fore_cast(my_tree, test_mat[:,0], model_tree_eval)
corrcoef(y_hat, test_mat[:,1], rowvar=0)[0,1]

0.97604121913806285

In [34]:
ws,X,Y = linear_solve(train_mat)
ws

matrix([[ 37.58916794],
        [  6.18978355]])

In [35]:
for i in range(shape(test_mat)[0]):
    y_hat[i] = test_mat[i,0] * ws[1,0] + ws[0,0]

In [36]:
corrcoef(y_hat, test_mat[:,1], rowvar=0)[0,1]

0.94346842356747662

# 9.7 使用Python的Tkinter库创建GUI

## 9.7.1 用Tkinter创建GUI

In [37]:
from tkinter import *

In [38]:
root = Tk()

In [39]:
my_label = Label(root, text='Hello World')
my_label.grid()

In [40]:
root.mainloop()

**程序清单9-6** 用于构建树管理器界面的Tkinter小部件

In [41]:
def re_draw(tolS, tolN):
    pass

def draw_new_tree():
    pass

root = Tk()

Label(root, text='Plot Place Holder').grid(row=0, columnspan=3)

Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')

Label(root, text='tolS').grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')

Button(root, text='ReDraw', command=draw_new_tree).grid(row=1, column=2, rowspan=3)
chk_btn_var = IntVar()
chk_btn = Checkbutton(root, text='Model Tree', variable=chk_btn_var)
chk_btn.grid(row=3, column=0, columnspan=2)

re_draw.raw_dat = mat(load_data_set('sine.txt'))
re_draw.test_dat = arange(min(re_draw.raw_dat[:,0]),
                         max(re_draw.raw_dat[:,0]),0.01)
re_draw(1.0, 10)

root.mainloop()

## 9.7.2 集成Matplotlib和Tkinter

**程序清单9-7** Matplotlib和Tkinter的代码集成

In [42]:
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

In [43]:
root = Tk()

re_draw.f = Figure(figsize=(5,4), dpi=100)
re_draw.canvas = FigureCanvasTkAgg(re_draw.f, master=root)
re_draw.canvas.show()
re_draw.canvas.get_tk_widget().grid(row=0, columnspan=3)

Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')

Label(root, text='tolS').grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')

Button(root, text='ReDraw', command=draw_new_tree).grid(row=1, column=2, rowspan=3)
chk_btn_var = IntVar()
chk_btn = Checkbutton(root, text='Model Tree', variable=chk_btn_var)
chk_btn.grid(row=3, column=0, columnspan=2)

re_draw.raw_dat = mat(load_data_set('sine.txt'))
re_draw.test_dat = arange(min(re_draw.raw_dat[:,0]),
                         max(re_draw.raw_dat[:,0]),0.01)
re_draw(1.0, 10)

root.mainloop()

In [76]:
def re_draw(tolS, tolN):
    re_draw.f.clf()
    re_draw.a = re_draw.f.add_subplot(111)
    if chk_btn_var.get():
        if tolN < 2:
            tolN = 2
        my_tree = create_tree(re_draw.raw_dat, model_leaf,
                             model_err, (tolS, tolN))
        y_hat = create_fore_cast(my_tree, re_draw.test_dat,
                                model_tree_eval)
    else:
        my_tree = create_tree(re_draw.raw_dat, ops=(tolS, tolN))
        y_hat = create_fore_cast(my_tree, re_draw.test_dat)
    re_draw.a.scatter(list(map(float,re_draw.raw_dat[:,0])), list(map(float,re_draw.raw_dat[:,1])), s=5)
    re_draw.a.plot(re_draw.test_dat, y_hat, linewidth=2.0)
    re_draw.canvas.show()

In [77]:
def get_inputs():
    try:
        tolN = int(tolNentry.get())
    except:
        tolN = 10
        print('Enter Integer for tolN')
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try:
        tolS = float(tolSentry.get())
    except:
        tolS = 1.0
        print('Enter Float for tolS')
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

In [78]:
def draw_new_tree():
    tolN, tolS = get_inputs()
    re_draw(tolS, tolN)

In [79]:
re_draw.raw_dat = mat(load_data_set('sine.txt'))
list(map(float, re_draw.raw_dat[:,0]))

[0.19035,
 0.306657,
 0.017568,
 0.122328,
 0.076274,
 0.614127,
 0.220722,
 0.08943,
 0.278817,
 0.520287,
 0.726976,
 0.180485,
 0.801524,
 0.474273,
 0.345116,
 0.981951,
 0.127349,
 0.75712,
 0.345419,
 0.314532,
 0.250828,
 0.431255,
 0.386669,
 0.143794,
 0.470839,
 0.093065,
 0.205377,
 0.083329,
 0.243475,
 0.062389,
 0.764116,
 0.018287,
 0.973603,
 0.458826,
 0.5112,
 0.712587,
 0.464745,
 0.984328,
 0.414291,
 0.799551,
 0.499037,
 0.966757,
 0.756594,
 0.444938,
 0.410167,
 0.532335,
 0.343909,
 0.854302,
 0.846882,
 0.740758,
 0.150668,
 0.177606,
 0.445289,
 0.734653,
 0.559488,
 0.232311,
 0.934435,
 0.219089,
 0.636525,
 0.307605,
 0.713198,
 0.116343,
 0.680737,
 0.48473,
 0.929408,
 0.008507,
 0.872161,
 0.75553,
 0.620671,
 0.47226,
 0.257488,
 0.130654,
 0.512333,
 0.74771,
 0.669948,
 0.644856,
 0.894206,
 0.820471,
 0.790796,
 0.010729,
 0.846777,
 0.349175,
 0.453662,
 0.624017,
 0.211074,
 0.062555,
 0.739709,
 0.985896,
 0.782088,
 0.642561,
 0.779007,
 0.18563

In [80]:
root = Tk()

re_draw.f = Figure(figsize=(5,4), dpi=100)
re_draw.canvas = FigureCanvasTkAgg(re_draw.f, master=root)
re_draw.canvas.show()
re_draw.canvas.get_tk_widget().grid(row=0, columnspan=3)

Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')

Label(root, text='tolS').grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')

Button(root, text='ReDraw', command=draw_new_tree).grid(row=1, column=2, rowspan=3)
chk_btn_var = IntVar()
chk_btn = Checkbutton(root, text='Model Tree', variable=chk_btn_var)
chk_btn.grid(row=3, column=0, columnspan=2)

re_draw.raw_dat = mat(load_data_set('sine.txt'))
re_draw.test_dat = arange(min(re_draw.raw_dat[:,0]),
                         max(re_draw.raw_dat[:,0]),0.01)
re_draw(1.0, 10)

root.mainloop()

# 9.8 本章小结