In [None]:
import numpy as np
import pandas as pd
import operator

class ID3:
    def __init__(self):
        self.tree = None

    def fit(self, Dataset, features):
        """
        训练模型
        
        参数:
        - Dataset: 训练数据集
        - features: 特征列表 
        """
        self.tree = self.createTree(Dataset, features)

    def createTree(self, Dataset, features):
        """ 
        创建决策树
        
        参数:
        - Dataset: 训练数据集
        - features: 特征列表
        
        返回:
        - 决策树
        """
        # 提取标签
        classList = [example[-1] for example in Dataset]

        # 如果所有样本的标签相同，则返回该标签
        if len(set(classList)) == 1:
            return classList[0]
        
        # 如果特征集为空，则返回出现次数最多的标签
        if len(Dataset[0]) == 1:
            return self.majorityCnt(classList)
        
        # 选择最优特征
        bestFeatureIndex, bestSplitValue = self.chooseBestFeature(Dataset)
        bestFeatureLabel = features[bestFeatureIndex]

        # 创建节点
        tree = {bestFeatureLabel: {}}
        # 使用副本避免修改原始列表
        subfeatures = features.copy()
        # 删除当前特征
        del subfeatures[bestFeatureIndex]
        # 连续特征
        if type(bestSplitValue).__name__ == 'float':
            tree[bestFeatureLabel]['<=' + str(bestSplitValue)] = self.createTree(self.splitDataSetByValue(Dataset, bestFeatureIndex, bestSplitValue, False), subfeatures)
            tree[bestFeatureLabel]['>' + str(bestSplitValue)] = self.createTree(self.splitDataSetByValue(Dataset, bestFeatureIndex, bestSplitValue, True), subfeatures)
        # 离散特征
        else:
            # 取出当前特征的取值
            featValue = [example[bestFeatureIndex] for example in Dataset]
            uniqueVals = set(featValue)
            # 遍历所有取值,开始递归
            for value in uniqueVals:
                subDataset = self.splitDataSet(Dataset, bestFeatureIndex, value)
                tree[bestFeatureLabel][value] = self.createTree(subDataset, subfeatures)
        return tree

    def majorityCnt(self, classList):
        """返回最多的标签"""
        # 统计标签出现的次数
        label_count = {}
        for label in classList:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        # 降序排序[(类标签,出现次数),(),()]
        sortedclassCount = sorted(label_count.items(), key=operator.itemgetter(1), reverse=True)
        return sortedclassCount[0][0]

    def chooseBestFeature(self,Dataset):
        """通过信息增益选择最优特征"""
        featureNum = len(Dataset[0]) - 1
        bestInfoGain = 0.0
        bestSplitValue = 0
        bestFeatureIndex = -1
        baseEntropy = self.calculateEntropy(Dataset)

        # 遍历所有特征
        for i in range(featureNum):
            # 提取当前特征下的取值
            featureValues = [example[i] for example in Dataset]
            # 连续特征
            if type(featureValues[0]).__name__ == 'float':
                # 对特征值进行排序
                sortedFeatureValues = sorted(featureValues)
                # 计算分割值（取相邻两个取值的中点）
                splitList = []
                for j in range(len(sortedFeatureValues) - 1):
                    splitList.append((sortedFeatureValues[j] + sortedFeatureValues[j + 1]) / 2.0)
                # 遍历所有分割值,相当于做二分类
                for splitValue in splitList:
                    currentEntropy = 0.0
                    subDataset1 = self.splitDataSetByValue(Dataset, i, splitValue, True)
                    subDataset2 = self.splitDataSetByValue(Dataset, i, splitValue, False)
                    prob1 = len(subDataset1) / float(len(Dataset))
                    prob2 = len(subDataset2) / float(len(Dataset))
                    currentEntropy  = prob1 * self.calculateEntropy(subDataset1) + prob2 * self.calculateEntropy(subDataset2)
                    # 计算信息增益
                    infoGain = baseEntropy - currentEntropy
                    if (infoGain > bestInfoGain):
                        bestInfoGain = infoGain
                        bestFeatureIndex = i
                        bestSplitValue = splitValue
            # 离散特征
            else:
                uniqueValues = set(featureValues)
                currentEntropy = 0.0
                # 遍历所有取值
                for value in uniqueValues:
                    subDataset = self.splitDataSet(Dataset, i, value)
                    prob = len(subDataset) / len(Dataset)
                    currentEntropy += prob * self.calculateEntropy(subDataset)
                # 计算信息增益
                infoGain = baseEntropy - currentEntropy
                if (infoGain > bestInfoGain):
                    bestInfoGain = infoGain
                    bestFeatureIndex = i
                    bestSplitValue = None
        
        return bestFeatureIndex, bestSplitValue

    def calculateEntropy(self, Dataset):
        """计算信息熵,公式(4.1)"""
        sample_num = len(Dataset)
        # 统计标签出现的次数
        label_count = {}
        for featVec in Dataset:
            label = featVec[-1]
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        
        # 计算信息熵
        entropy = 0.0
        for count in label_count.values():
            prob = float(count) / sample_num
            entropy -= prob * np.log2(prob)

        return entropy

    def splitDataSet(self,Dataset, axis, val):
        '''
        根据特征索引i和离散特征值value将数据集切分

        参数:
        - Dataset: 训练数据集
        - axis: 特征索引
        - val: 特征值

        返回:
        - 切分后的子集
        '''
        subDataset = []
        # 遍历每一行
        for featVec in Dataset:
            if featVec[axis] == val:
                reducedFeature = featVec[:axis]
                reducedFeature.extend(featVec[axis + 1:])
                subDataset.append(reducedFeature)
        return subDataset

    def splitDataSetByValue(self, Dataset, axis, val, isAbove):
        '''
        根据特征索引i和连续特征值value将数据集切分

        参数:
        - Dataset: 训练数据集
        - axis: 特征索引
        - val: 特征值
        - isAbove: True表示大于value,False表示小于等于value

        返回:
        - 切分后的子集
        '''
        subDataset = []
        # 遍历每一行
        for featVec in Dataset:
            if isAbove and featVec[axis] > val:
                reducedFeature = featVec[:axis]
                reducedFeature.extend(featVec[axis + 1:])
                subDataset.append(reducedFeature)
            elif not isAbove and featVec[axis] <= val:
                reducedFeature = featVec[:axis]
                reducedFeature.extend(featVec[axis + 1:])
                subDataset.append(reducedFeature)
        return subDataset

    def predict(self, inputTree,features, testVec):
        '''
        预测测试数据集

        参数:
        - inputTree: 训练好的决策树
        - features: 特征列表
        - testVec: 测试数据集

        返回:
        - 预测结果
        '''
        # 提取当前节点(每个决策树节点只有一个特征标签)
        firstStr = list(inputTree.keys())[0]
        # 提取当前节点下的子节点
        secondDict = inputTree[firstStr]
        # 获取当前节点的特征标签序号
        featureIndex = features.index(firstStr)

        # 遍历每个子节点
        for key in secondDict.keys():
            # 连续特征
            if type(key).__name__ == 'str' and ('<=' in key or '>' in key):
                # 去除字符串中的符号，取出阈值
                threshold = float(key.strip('<=').strip('>'))
                # 判断测试数据是否满足阈值
                if key.startswith('<=') and testVec[featureIndex] <= threshold:
                    childTree = secondDict[key]
                    # 判断当前是不是叶节点,如果不是，继续递归
                    if isinstance(childTree, dict):
                        return self.predict(childTree,features, testVec)
                    else:
                        return childTree
                elif key.startswith('>') and testVec[featureIndex] > threshold:
                    childTree = secondDict[key]
                    # 判断当前是不是叶节点
                    if isinstance(childTree, dict):
                        return self.predict(childTree,features, testVec)
                    else:
                        return childTree
            # 离散特征
            else:
                # 判断测试数据是否满足取值
                if testVec[featureIndex] == key:
                    childTree = secondDict[key]
                    # 判断当前是不是叶节点
                    if isinstance(childTree, dict):
                        return self.predict(childTree,features, testVec)
                    else:
                        return childTree
        return "Unknown !"
        
    def printTree(self):
        """打印决策树"""
        print(self.tree)
      


if __name__ == '__main__':
    # 加载数据
    df = pd.DataFrame(pd.read_csv("../Data/watermelon3.0.csv", encoding="ansi"))
    df.drop(labels=["编号"], axis=1, inplace=True)  # 删除编号这一列，inplace=True表示直接在原对象修改
    # 转化为列表
    dataset = df.values.tolist()
    # 特征列表
    features = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感', '密度', '含糖率']
    # 创建决策树
    ID3_model = ID3()
    ID3_model.fit(dataset, features)
    # 打印决策树
    ID3_model.printTree()
    # 测试数据
    test_data = ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.6, 0.3]
    # 预测结果
    result = ID3_model.predict(ID3_model.tree, features, test_data)
    print("预测结果:", result)
   


{'色泽': {'乌黑': {'根蒂': {'稍蜷': {'纹理': {'清晰': '否', '稍糊': '是'}}, '蜷缩': '是'}}, '浅白': '否', '青绿': {'敲声': {'沉闷': '否', '浊响': '是', '清脆': '否'}}}}
预测结果: 是
