In [15]:
import operator
from numpy import shape

# 切分数据
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    subDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # 剔除样本集
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
        else:
            reducedFeatVec = featVec[:axis]  # 剔除样本集
            reducedFeatVec.extend(featVec[axis + 1:])
            subDataSet.append(reducedFeatVec)
    return retDataSet, subDataSet


def calGini(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    gini = 1
    for label in labelCounts.keys():
        prop = float(labelCounts[label]) / numEntries
        gini -= prop * prop
    return gini


# 返回最好的特征以及二分特征值
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    bestGiniGain = 1.0
    bestFeature = -1
    bestValue = None
    # 遍历特征
    for i in range(numFeatures):
        # 得到特征列
        featList = [example[i] for example in dataSet]
        # 从特征列获取该特征的特征值的set集合
        uniqueVals = list(set(featList))
        for value in uniqueVals:
            GiniGain = 0.0
            retDataSet, subDataSet = splitDataSet(dataSet, i, value)
            left_prob = len(retDataSet) / float(len(dataSet))
            GiniGain += left_prob * calGini(retDataSet)
            right_prob = len(subDataSet) / float(len(dataSet))
            GiniGain += right_prob * calGini(subDataSet)
            # 比较是否是最好的结果
            if GiniGain <= bestGiniGain:
                # 记录最好的结果和最好的特征
                bestGiniGain = GiniGain
                bestFeature = i
                bestValue = value
    return bestFeature, bestValue


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(
        classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


# 创建树的函数代码
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    bestFeature, bestValue = chooseBestFeatureToSplit(dataSet)
    if classList.count(classList[0]) == len(classList):
        # 所有的类别都一样，就不用再划分了
        return classList[0]
    # 如果没有继续可以划分的特征，就多数表决决定分支的类别
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    reTree = {'spInd': bestFeature, 'spVal': bestValue}
    retDataSet, subDataSet = splitDataSet(dataSet, bestFeature, bestValue)
    subLabels = labels[:]
    del(subLabels[bestFeature])
    print("bestFeature:", bestFeature)
    reTree['left'] = createTree(retDataSet, subLabels)
    reTree['right'] = createTree(subDataSet, subLabels)
    return reTree


# 用于测试输入变量是否为一棵树
def isTree(obj):
    return type(obj).__name__ == 'dict'


# 回归树剪枝函数
def prune(tree, testData):
    classList = [example[-1] for example in testData]
    if shape(testData)[0] == 0:
        print("数据集为空")
        return majorityCnt(classList)
    if (isTree(tree['right'])) or (isTree(tree['left'])):
        lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
        print("lSet[:, -1]:", lSet[0])
        print("tree['left']:", tree['left'])
        errorNoMerge = 0
        if lSet[0] != tree['left']:
            errorNoMerge += 1
        if rSet[0] != tree['right']:
            errorNoMerge += 1
        treeMean = majorityCnt(classList)
        errorMerge = 0
        print("testData[:, -1]:", testData)
        for item in testData:
            if item[-1] != treeMean:
                errorMerge += 1
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

In [16]:
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print("lensesTree:", lensesTree)

bestFeature: 3
bestFeature: 2
bestFeature: 1
bestFeature: 0
bestFeature: 0
bestFeature: 0
lensesTree: {'spInd': 3, 'spVal': 'normal', 'left': {'spInd': 2, 'spVal': 'yes', 'left': {'spInd': 1, 'spVal': 'myope', 'left': 'hard', 'right': {'spInd': 0, 'spVal': 'young', 'left': 'hard', 'right': 'no lenses'}}, 'right': {'spInd': 0, 'spVal': 'presbyopic', 'left': {'spInd': 0, 'spVal': 'myope', 'left': 'no lenses', 'right': 'soft'}, 'right': 'soft'}}, 'right': 'no lenses'}


In [17]:
fr = open('lenses2.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
cutTree = prune(lensesTree, lenses)
print("cutTree:", cutTree)

lSet[:, -1]: ['hard']
tree['left']: hard
testData[:, -1]: [['young', 'hard'], ['pre', 'no lenses'], ['presbyopic', 'soft']]
lSet[:, -1]: ['soft']
tree['left']: no lenses
testData[:, -1]: [['myope', 'soft'], ['hyper', 'soft']]
merging
lSet[:, -1]: ['myope', 'soft']
tree['left']: soft
testData[:, -1]: [['young', 'myope', 'soft'], ['young', 'hyper', 'soft'], ['pre', 'myope', 'soft'], ['pre', 'hyper', 'soft'], ['presbyopic', 'myope', 'soft'], ['presbyopic', 'hyper', 'soft']]
merging
cutTree: {'spInd': 3, 'spVal': 'normal', 'left': {'spInd': 2, 'spVal': 'yes', 'left': {'spInd': 1, 'spVal': 'myope', 'left': 'hard', 'right': {'spInd': 0, 'spVal': 'young', 'left': 'hard', 'right': 'no lenses'}}, 'right': 'soft'}, 'right': 'no lenses'}
