In [1]:
from math import log
import numpy as np
"""
函数说明：创建测试数据集
Parameters：无
Returns：
    dataSet：数据集
    labels：分类属性
Modify：
    2020-10-14
"""

def creatDataSet():
    # 数据集
    dataSet=[[0, 0, 0, 0, 'no'],
            [0, 0, 0, 1, 'no'],
            [0, 1, 0, 1, 'yes'],
            [0, 1, 1, 0, 'yes'],
            [0, 0, 0, 0, 'no'],
            [1, 0, 0, 0, 'no'],
            [1, 0, 0, 1, 'no'],
            [1, 1, 1, 1, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [2, 0, 1, 2, 'yes'],
            [2, 0, 1, 1, 'yes'],
            [2, 1, 0, 1, 'yes'],
            [2, 1, 0, 2, 'yes'],
            [2, 0, 0, 0, 'no']]
    #分类属性
    labels=['年龄','有工作','有自己的房子','信贷情况']
    #返回数据集和分类属性
    return dataSet,labels
dataSet,labels = creatDataSet()

In [2]:
"""
函数说明:计算给定数据集的经验上熵（香农熵）
Parameters:
    dataSet:%数据集
Returns:
    shannonEnt:经验熵
Modify:
    2020-10-14
"""

def calcShannonEnt(dataSet):
    feture_num = len(dataSet[0])-1
    shannonEnt = 0
    dic = {}
    for data in dataSet:
        if data[-1] not in dic:
            dic[data[-1]] = 0
        dic[data[-1]] = dic[data[-1]]+1
    
    for key in dic:
        p = dic[key]/len(dataSet)
        shannonEnt -= p*log(p,2)
    return shannonEnt

calcShannonEnt(dataSet)

0.9709505944546686

In [3]:
"""
函数说明:计算数据集中信息增益最大的特征的索引
Parameters:
    dataSet:数据集
Returns:
    best_feture:信息增益最大的特征值的索引
Modify:
    2020-10-14
"""

def getBestFeatureIndex(dataSet):
    feature_nums = len(dataSet[0])-1
    
    dataSet_shannonEnt = calcShannonEnt(dataSet)
    best_feature_index = -1
    best_feature_shannonEnt = 0.
    
    for i in range(len(dataSet[0])-1):
        del_feature_dataset = (np.delete(np.array(dataSet),i, axis=1)).tolist()
        dic = {}
        shannonEnt = 0;
        for j in range(len(dataSet)):
            if(dataSet[j][i] not in dic):
                dic[dataSet[j][i]] = []
            dic[dataSet[j][i]].append(del_feature_dataset[j])
        for k in dic:
            p = len(dic[k])/len(dataSet)
            shannonEnt += p*calcShannonEnt(dic[k])
        if(dataSet_shannonEnt-shannonEnt > best_feature_shannonEnt):
            best_feature_index = i
            best_feature_shannonEnt = dataSet_shannonEnt-shannonEnt
    return best_feature_index
    
getBestFeatureIndex(dataSet)    

2

In [4]:
"""
函数说明:根据第i个特征将数据集进行划分
Parameter:
    dataSet:数据集
    index:第index个特征
Returns:
    res:{"特征值1":data1,"特征值2":data2...}
Modify:
    2020-10-14
"""

def splitDatasetByFeature(dataSet,index):
    res = {}
    for data in dataSet:
        if(data[index] not in res):
            res[data[index]] = []
        res[data[index]].append(data)
   
    for k in res:
        res[k] = (np.delete(np.array(res[k]),index, axis=1)).tolist()
    return res

In [5]:
"""
函数说明:返回dataSet中最大的类别数
Parameter:
    dataSet:训练集
Returns:
    res:("类别":个数)<class 'tuple'>
Modify:
    2020-10-14
"""

def getMaxClassNum(dataSet):
    res = {}
    for data in dataSet:
        if(data[-1] not in res):
            res[data[-1]] = 0
        res[data[-1]] = res[data[-1]] + 1
    return sorted(res.items(),key = lambda x:x[1],reverse=True)[0]

In [6]:
"""
函数说明:统计dataSet中个数最多的类别
Parameters:
    dataSet: 训练集
    lables:还剩下的特征
Returns:
    mytree:决策树
Modify:
    2020-10-14
"""

def createDecisionTree(dataSet,lables):
    if(getMaxClassNum(dataSet)[1] == len(dataSet) or len(labels) == 0):
        return getMaxClassNum(dataSet)[0]
    
    best_feature_index = getBestFeatureIndex(dataSet)
    print(best_feature_index)
    print(labels)
    mytree = {}
    best_feature = labels[best_feature_index]
    del(labels[best_feature_index])
    
    splitdata = splitDatasetByFeature(dataSet,best_feature_index)
    
    for data in splitdata:
        mytree[data] = createDecisionTree(splitdata[data],labels)
    mytree = {best_feature:mytree}
    return mytree

createDecisionTree(dataSet,labels)

2
['年龄', '有工作', '有自己的房子', '信贷情况']
1
['年龄', '有工作', '信贷情况']


{'有自己的房子': {0: {'有工作': {'0': 'no', '1': 'yes'}}, 1: 'yes'}}