In [7]:
import matplotlib.pyplot as plt
from math import log
import pandas as pd
import numpy as np
import operator

#创建决策树
def createTree(dataset, features):
    '''
    @brief: create a decision tree by using the ID3 algorithm
    @param dataset: the dataset to be used for training
    @param features: the features to be used for training
    @return: the decision tree
    '''
    # 取出所有样本的标签
    classList = [example[-1] for example in dataset]
    # 如果所有样本的标签相同，则返回该标签
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 如果特征集为空，则返回出现次数最多的标签
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    # 选择最优特征进行数据集划分
    bestfeatureIndex, bestValue = chooseBestFeatureToSplit(dataset)
    bestFeatLabel = features[bestfeatureIndex]
    
    # 创建节点
    myTree = {bestFeatLabel: {}}
    # 使用副本避免修改原始列表
    subfeatures = features.copy()  
    # 连续特征
    if type(bestValue).__name__ == 'float':
        myTree[bestFeatLabel]['<=' + str(bestValue)] = createTree(splitDataSetByValue(dataset, bestfeatureIndex, bestValue, True), subfeatures)
        myTree[bestFeatLabel]['>' + str(bestValue)] = createTree(splitDataSetByValue(dataset, bestfeatureIndex, bestValue, False), subfeatures)
    # 离散特征
    else:
        # 去除当前特征
        del subfeatures[bestfeatureIndex]  # 在副本中删除当前特征
        # 取出当前特征的取值
        featValue = [example[bestfeatureIndex] for example in dataset]
        uniqueVals = set(featValue)
        # 递归每一个特征值
        for value in uniqueVals:
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset, bestfeatureIndex, value), subfeatures)
    return myTree

# 计算类别中出现次数最多的元素
def majorityCnt(classList):
    # 创建一个字典{类标签:出现次数}
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    # 降序排序[(类标签,出现次数),(),()]
    sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedclassCount[0][0]

# 选择最优特征进行数据集划分
def chooseBestFeatureToSplit(dataset):
    numFeatures = len(dataset[0]) - 1
    baseEntropy = calcShannonEnt(dataset)
    bestInfoGain = 0
    bestFeature = -1
    bestValue = 0
    # 遍历所有特征
    for i in range(numFeatures):
        # 取出第i个特征
        featList = [example[i] for example in dataset]
        # 连续特征
        if type(featList[0]).__name__ == 'float':
            # 排序
            sortedfeatList = sorted(featList)
            splitList = []
            # 计算切分点
            for j in range(len(sortedfeatList) - 1):
                splitVal = (sortedfeatList[j] + sortedfeatList[j + 1]) / 2.0
                splitList.append(splitVal)
            # 计算信息增益
            for val in set(splitList):
                newEntropy = 0
                subDataSet1 = splitDataSetByValue(dataset, i, val, True)
                subDataSet2 = splitDataSetByValue(dataset, i, val, False)
                prob1 = len(subDataSet1) / float(len(dataset))
                newEntropy += prob1 * calcShannonEnt(subDataSet1)
                prob2 = len(subDataSet2) / float(len(dataset))
                newEntropy += prob2 * calcShannonEnt(subDataSet2)
                infoGain = baseEntropy - newEntropy
                if (infoGain > bestInfoGain):
                    bestInfoGain = infoGain
                    bestFeature = i
                    bestValue = val
        else:
            # 离散特征
            uniqueVals = set(featList)
            newEntropy = 0
            # 遍历所有取值
            for val in uniqueVals:
                subDataSet = splitDataSet(dataset, i, val)
                prob = len(subDataSet) / float(len(dataset))
                newEntropy += prob * calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy
            if (infoGain > bestInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
                bestValue = None
    return bestFeature, bestValue

# 根据特征值划分数据集
def splitDataSet(dataset, axis, val):
    retDataSet = []
    for featVec in dataset:
        if featVec[axis] == val:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

# 根据特征值和方向划分数据集
def splitDataSetByValue(dataset, axis, val, direction):
    retDataSet = []
    for featVec in dataset:
        if direction:
            if featVec[axis] <= val:
                reducedFeatVec = featVec[:axis]
                reducedFeatVec.extend(featVec[axis + 1:])
                retDataSet.append(reducedFeatVec)
        else:
            if featVec[axis] > val:
                reducedFeatVec = featVec[:axis]
                reducedFeatVec.extend(featVec[axis + 1:])
                retDataSet.append(reducedFeatVec)
    return retDataSet

# 计算数据集信息熵
def calcShannonEnt(dataset):
    numexamples = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        currentlabel = featVec[-1]
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel] = 0
        labelCounts[currentlabel] += 1

    shannonEnt = 0
    for key in labelCounts:
        prop = float(labelCounts[key]) / numexamples
        shannonEnt -= prop * log(prop, 2)
    return shannonEnt

def predict(inputTree, features, testVec):
    '''
    @brief: predict the label of a test vector using a decision tree
    @param inputTree: the decision tree to be used for prediction
    @param features: the features to be used for training
    @param testVec: the test vector to be predicted
    @return: the predicted label of the test vector
    '''
    # 提取当前节点
    firstStr = list(inputTree.keys())[0]
    # 提取当前节点下的子节点
    secondDict = inputTree[firstStr]
    # 获取当前节点的特征标签
    featureIndex = features.index(firstStr)

    for key in secondDict.keys():
        # 处理连续特征（如 "<=0.5"）
        if type(key).__name__ == 'str' and ('<=' in key or '>' in key):
            # 移除字符串中的符号，取出阈值
            threshold = float(key.replace('<=', '').replace('>', ''))
            # 当前特征值小于等于阈值，则进入左子树
            if key.startswith('<=') and testVec[featureIndex] <= threshold:
                childTree = secondDict[key]
                # 判断是否为内部节点，若是，则表示不是叶子节点，继续递归
                if isinstance(childTree, dict):
                    return predict(childTree, features, testVec)
                else:
                    return childTree
            elif key.startswith('>') and testVec[featureIndex] > threshold:
                childTree = secondDict[key]
                if isinstance(childTree, dict):
                    return predict(childTree, features, testVec)
                else:
                    return childTree
        # 处理离散特征（如 "硬滑"）
        else:
            if testVec[featureIndex] == key:
                childTree = secondDict[key]
                if isinstance(childTree, dict):
                    return predict(childTree, features, testVec)
                else:
                    return childTree
    # 若未匹配任何分支
    return "未知类别"  

if __name__ == '__main__':
    # 构建数据集
    df = pd.DataFrame(pd.read_csv("../Data/watermelon3.0.csv", encoding="ansi"))
    df.drop(labels=["编号"], axis=1, inplace=True)  # 删除编号这一列，inplace=True表示直接在原对象修改
    # 转化为列表
    dataset = df.values.tolist()
    # 打印原始数据
    # for i in range(len(dataset)):
    #     print(dataset[i])
    # 标签
    labels = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感', '密度', '含糖率']
    # 构建决策树
    myTree = createTree(dataset, labels)
    # 打印决策树
    print(myTree)
    # 测试数据
    testVec = ['青绿','硬挺','清脆','稍糊','平坦','软粘',0.243,0.267]
    # 预测结果
    result = predict(myTree, labels, testVec)
    # print(result)

{'纹理': {'模糊': '否', '清晰': {'密度': {'<=0.3815': '否', '>0.3815': '是'}}, '稍糊': {'触感': {'软粘': '是', '硬滑': '否'}}}}
