In [None]:
# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. 
# This directory will be recovered automatically after resetting environment. 
!ls /home/aistudio/data

In [0]:
# 查看工作区文件, 该目录下的变更将会持久保存. 请及时清理不必要的文件, 避免加载过慢.
# View personal work directory. 
# All changes under this directory will be kept even after reset. 
# Please clean unnecessary files in time to speed up environment loading. 
!ls /home/aistudio/work

In [None]:
# 如果需要进行持久化安装, 需要使用持久化路径, 如下方代码示例:
# If a persistence installation is required, 
# you need to use the persistence path as the following: 
!mkdir /home/aistudio/external-libraries
!pip install beautifulsoup4 -t /home/aistudio/external-libraries

In [None]:
# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: 
# Also add the following code, 
# so that every time the environment (kernel) starts, 
# just run the following code: 
import sys 
sys.path.append('/home/aistudio/external-libraries')

请点击[此处](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576)查看本环境基本用法.  <br>
Please click [here ](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576) for more detailed instructions. 

In [1]:
import operator
from math import log

In [2]:
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

In [3]:
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVex in dataSet:
        currentLable = featVex[-1]
        if currentLable not in labelCounts.keys():
            labelCounts[currentLable] = 0
        labelCounts[currentLable] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return shannonEnt


def splitDataSet(dataSet, index, value):
    retDataset = []
    for featVec in dataSet:#整个样本
        if featVec[index] == value:
            reducedFeatVec = featVec[:index] # 特征1，特征2，特征3，特征4 -> featVec[:index]  = 特征1
            reducedFeatVec.extend(featVec[index+1:]) # featVec[index+1:] = 特征3，特征4
            retDataset.append(reducedFeatVec)
    return retDataset



def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain, bestFeature = 0.0, -1
    for i in range(numFeatures): #色泽，声音，纹理。。。
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:# 青绿，浅白。。。
            subdataset = splitDataSet(dataSet,i,value)
            prob = len(subdataset) /float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subdataset)
        bestInfoGain_ = baseEntropy - newEntropy
        if(bestInfoGain_ > bestInfoGain):
            bestInfoGain = bestInfoGain_
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]#如果数据里只有一种类别，直接返回
    # a = dataSet[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)#如果只有一个特征

    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabels = labels[bestFeat] # '纹理' 知道第一个特征选择的是纹理
    myTree = {bestFeatLabels:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueValus = set(featValues)
    for featValues in uniqueValus:#在子数据集里递归建立新的决策树
        subLabels = labels[:]
        myTree[bestFeatLabels][featValues] = createTree(splitDataSet(dataSet,bestFeat,featValues),subLabels)
    return myTree

In [4]:
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat,dict):
        classLable = classify(valueOfFeat,featLabels,testVec)
    else:
        classLable = valueOfFeat
    return classLable


def fishTest():
    myDat, labels = createDataSet()
    import copy
    myTree = createTree(myDat, copy.deepcopy(labels))
    print(classify(myTree, labels, [1, 1]))

In [5]:
if __name__ == "__main__":
    fishTest()

yes
