# 机器学习实战——决策树

In [1]:
# -*- coding: utf-8 -*-
# ***************************************************/ 
# @Time    : 2018/6/5 14:52
# @Author  : GengDaPeng
# @contact : bingshan222@hotmail.com
# @File    : DecisionTree.py
# @Desc    : 《机器学习实战》 决策树构造章节py3代码
# ***************************************************/
import operator
import numpy as np

In [2]:
def cerate_dataset():
    """ 创建数据集 """
    dataset = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataset, labels

In [3]:
def calcShanonEnt(dataset):
    """ 计算信息熵 """
    num_entries = len(dataset)
    label_cunt = {}     # 创建统计标签（label）的字典
    for featvec in dataset:
        current_label = featvec[-1]
        if current_label not in label_cunt.keys():
            label_cunt[current_label] = 0   # 如果标签(label)不存在，则创建新标签并设值为0
        label_cunt[current_label] += 1  # 值加1
    shannonent = 0.0    # 信息熵
    for key in label_cunt:
        prob = float(label_cunt[key]) / num_entries     # 标签概率
        shannonent -= prob * np.log2(prob)    # 求以2为底的对数
    return shannonent

In [4]:

def split_dataset(dataset, axis, value):
    """ 按照给定特征划分数据集
    Parameters:
        dataset - 带划分的数据集
        axis - 划分数据集的特征
        value - 需要返回的特征值
    Returns:
        ret_dataset -按照给定特征划分后返回的特征数据集
    """
    ret_dataset = []     # 创建返回的数据集列表
    for featvec in dataset:     # 遍历数据集
        if featvec[axis] == value:   # 特征是否符合要求
            reduce_featvec = featvec[:axis]  # 去除axis特征
            reduce_featvec.extend(featvec[axis+1:])   # 提取划分后的特征
            ret_dataset.append(reduce_featvec)   # 添加到数据集列表
    return ret_dataset




In [5]:
def chooseBestFeatureToSplit(dataset):
    """ 选择最好的数据集划分方式
    parameter:
        dataset - 划分的数据集
    return:
        bestfrature - 最佳特征
    """
    num_features = len(dataset[0]) - 1
    base_entropy = calcShanonEnt(dataset)   # 计算数据集最初的信息熵
    bestinfogain = 0.0
    bestfeature = -1
    for i in range(num_features):    # 遍历数据集中所有的特征
        featlist = [example[i] for example in dataset]     
        uniquevals = set(featlist)     # 创建集合，得到唯一元素值
        new_entropy = 0.0
        for value in uniquevals:
            subdataset = split_dataset(dataset, i, value)      # 怎对每一个特征划分数据集
            prob = len(subdataset) / float(len(dataset))
            new_entropy += prob * calcShanonEnt(subdataset)      # 计算信息熵
        infogain = base_entropy - new_entropy
        if (infogain > bestinfogain):
            bestinfogain = infogain
            bestfeature = i
    return bestfeature

In [6]:
def majorityCnt(classlist):
    class_count={}
    for vote in classlist:
        if vote not in class_count.keys():
            class_count[vote] = 0
        class_count[vote] += 1
    sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count[0][0]

In [9]:
if __name__ == '__main__':
    mydat, labels = cerate_dataset()
    print(mydat)
    print('----------------------')
    shan = calcShanonEnt(mydat)
    print(shan)
    print('------------------------')
    ret_dataset1 = split_dataset(mydat, 0, 1)
    print(mydat[0][:0])
    print(ret_dataset1)
    ret_dataset2 = split_dataset(mydat, 0, 0)
    print(ret_dataset2)
    print('---------------------------')

> [1;32mf:\anaconda3\lib\site-packages\ipython\core\compilerop.py[0m(99)[0;36mast_parse[1;34m()[0m
[1;32m     97 [1;33m        [0mArguments[0m [0mare[0m [0mexactly[0m [0mthe[0m [0msame[0m [1;32mas[0m [0mast[0m[1;33m.[0m[0mparse[0m [1;33m([0m[1;32min[0m [0mthe[0m [0mstandard[0m [0mlibrary[0m[1;33m)[0m[1;33m,[0m[1;33m[0m[0m
[0m[1;32m     98 [1;33m        and are passed to the built-in compile function."""
[0m[1;32m---> 99 [1;33m        [1;32mreturn[0m [0mcompile[0m[1;33m([0m[0msource[0m[1;33m,[0m [0mfilename[0m[1;33m,[0m [0msymbol[0m[1;33m,[0m [0mself[0m[1;33m.[0m[0mflags[0m [1;33m|[0m [0mPyCF_ONLY_AST[0m[1;33m,[0m [1;36m1[0m[1;33m)[0m[1;33m[0m[0m
[0m[1;32m    100 [1;33m[1;33m[0m[0m
[0m[1;32m    101 [1;33m    [1;32mdef[0m [0mreset_compiler_flags[0m[1;33m([0m[0mself[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[0m
[0m
ipdb> 
ipdb> c
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0,