In [18]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier

#### 加载数据集

In [40]:
pd.read_csv('./watermelon_3a.csv')

Unnamed: 0,Idx,color,root,knocks,texture,navel,touch,density,sugar_ratio,label
0,1,dark_green,curl_up,little_heavily,distinct,sinking,hard_smooth,0.697,0.46,1
1,2,black,curl_up,heavily,distinct,sinking,hard_smooth,0.774,0.376,1
2,3,black,curl_up,little_heavily,distinct,sinking,hard_smooth,0.634,0.264,1
3,4,dark_green,curl_up,heavily,distinct,sinking,hard_smooth,0.608,0.318,1
4,5,light_white,curl_up,little_heavily,distinct,sinking,hard_smooth,0.556,0.215,1
5,6,dark_green,little_curl_up,little_heavily,distinct,little_sinking,soft_stick,0.403,0.237,1
6,7,black,little_curl_up,little_heavily,little_blur,little_sinking,soft_stick,0.481,0.149,1
7,8,black,little_curl_up,little_heavily,distinct,little_sinking,hard_smooth,0.437,0.211,1
8,9,black,little_curl_up,heavily,little_blur,little_sinking,hard_smooth,0.666,0.091,0
9,10,dark_green,stiff,clear,distinct,even,soft_stick,0.243,0.267,0


In [111]:
def load_watermelon(filename):
    with open(filename) as f:
        lines = f.readlines()
        headers = lines[0].strip().split(',')[:-1]
        dataset = []
        for line in lines[1:]:
            line = line.strip().split(',')
            dataset.append(line)
    return headers, dataset

In [112]:
headers, dataset = load_watermelon('./watermelon_3a.csv')

In [113]:
headers

['Idx',
 'color',
 'root',
 'knocks',
 'texture',
 'navel',
 'touch',
 'density',
 'sugar_ratio']

#### 计算香农熵

In [119]:
import math

def compute_entropy(dataset):
    category_dict = {}
    for sample in dataset:
        category = sample[-1]
        if category not in category_dict:
            category_dict[category] = 0
        category_dict[category] += 1
    num = len(dataset)
    entropy = 0.0
    for i in category_dict:
        prob = float(category_dict[i]) / num
        entropy -= prob * math.log(prob,2)
    return entropy

In [120]:
compute_entropy(dataset)

0.9975025463691153

#### 对特定属性选择特定取值后，将满足该条件的剩余数据集组合留待计算香农商

In [144]:
def split_dataset(dataset, axis, value):
    sub_dataset=[]
    for sample in dataset:
        if sample[axis] == value:
            reducedfeatvec = sample[:axis]
            reducedfeatvec.extend(sample[axis+1:])
            sub_dataset.append(reducedfeatvec)
    return sub_dataset

#### 根据信息增益，选择最佳的划分属性

In [145]:
def choosebestfeaturetosplit(dataset):
    # 计算属性个数
    attr_nums = len(dataset[0])
    
    # 计算整个样本集的香农熵
    baseshannon = compute_entropy(dataset)
    
    bestinfogain = 0.0
    bestfeature = -1
    
    for i in range(attr_nums-1):
        # 取出特定属性的所有值。dataset包含了类，但不影响，因为取不到
        featlist=[example[i] for example in dataset]
        
        # 每个属性所含的值
        unifeat = set(featlist)   
        attrshannon = 0
        
        for value in unifeat:
            subdataset = split_dataset(dataset,i,value)
            shannon = compute_entropy(subdataset)  #每个属性值取每个值的香农商
            prob = len(subdataset)/float(len(dataset))
            attrshannon += prob*shannon
            
        infogain=baseshannon-attrshannon
        
        if infogain > bestinfogain:
            bestinfogain = infogain
            bestfeature = i
            
    return bestfeature

In [146]:
#5、返回样例中类最多的那个类别
def majorclass(data):
    aa = [sample[-1] for sample in data]   #获取每个样例最后的类别
    bb = {}
    for i in aa:
        bb[i]=aa.count(i)
    
    # 将字典bb降序排列，书中用的另一种方式
    bb= sorted(bb.iteritems(), key=lambda d:d[1], reverse = True)
    return bb

#### 生成决策树

In [147]:
def createtree(dataset, attr_names):  #labels为属性标签
    
    #情况 1、当所有样例的类别一致时，返回类别
    samplelabel = [sample[-1] for sample in dataset]
    usamplelabel = list(set(samplelabel))
    
    if len(usamplelabel)==1:
        return usamplelabel[0]

    #情况 2、当属性已经用完，则选择类别最多的显示
    if len(dataset[0]) == 1:
        return majorclass(dataset)

    #情况 3：选择最佳划分属性进行划分
    bestfeature = choosebestfeaturetosplit(dataset)
    bestfeaturelabel = attr_names[bestfeature]
    
    mytree = {bestfeaturelabel: {}}
    
    del attr_names[bestfeature]

    featurevalue = [sample[bestfeature] for sample in dataset]
    ufeaturevalue = set(featurevalue)
    for value in ufeaturevalue:
        sublabels = attr_names[:]
        mytree[bestfeaturelabel][value]=createtree(splitdataset(dataset,bestfeature,value),sublabels)
    return mytree

In [148]:
mytree = createtree(dataset, headers)

In [149]:
import json
print(mytree)
print(json.dumps(mytree, ensure_ascii=False))

{'texture': {'3': '1', '15': '0', '5': '1', '4': '1', '6': '1', '17': '0', '10': '0', '11': '0', '1': '1', '7': '1', '8': '1', '2': '1', '12': '0', '13': '0', '16': '0', '9': '0', '14': '0'}}
{"texture": {"3": "1", "15": "0", "5": "1", "4": "1", "6": "1", "17": "0", "10": "0", "11": "0", "1": "1", "7": "1", "8": "1", "2": "1", "12": "0", "13": "0", "16": "0", "9": "0", "14": "0"}}


In [117]:
headers

['color',
 'root',
 'knocks',
 'texture',
 'navel',
 'touch',
 'density',
 'sugar_ratio']

#### 实例化决策树分类器

In [15]:
dt_clf = DecisionTreeClassifier()

In [17]:
dt_clf.fit(X, y)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

#### 绘制决策树

#### 决策树设置

In [9]:
DescisionNode = {'box_style': 'sawtooth', 'fc': '0.8'}
LeafNode = {'box_style': 'round4', 'fc': '0.8'}
Arrow = {'arrow_type': '<-'}

In [133]:
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


# 获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # 测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        # 如果不是，则为叶节点
        else:
            numLeafs +=1
    return numLeafs

In [135]:
# 树的层数
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
    	# 测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


# 画节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )


# 在父子节点之间添加文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

In [137]:
# 画树
def plotTree(myTree, parentPt, nodeTxt):
	# 计算树在x轴的宽
    numLeafs = getNumLeafs(myTree)
    # 计算树在y轴的高
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]
    # plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置并表示下一个节点的恰当位置
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    # 按比例减少全局变量
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
    	# 测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__=='dict':
        	# 递归调用
            plotTree(secondDict[key],cntrPt,str(key))
        # 如果不是，则为叶节点
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

In [138]:
# 创建绘图
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()


# 输出预先存储的树信息
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]