In [45]:
from sklearn.datasets import load_iris
import numpy as np
import math

In [237]:
#读取iris数据集,该数据集是根据人口普查数据预测收入是否超过每年50,000美元。
data,target =  load_iris(return_X_y=True)

#将data和target进行水平拼接
target = target.reshape(target.shape[0],1)
dataset = np.hstack((data,target))

features_name = ["sepal length in cm ","sepal width in cm","petal length in cm","petal width in cm"]
features_columns = [0,1,2,3]
labels_name = ["Iris Setosa","Iris Versicolour","Iris Virginica"]

#计算信息熵
def calculate_entropy(data):
    length = len(data) #数据集的长度
    labels = [i for i in set(data[:,-1])] #标签
    #计算每一个标签对应的数据量
    labelCounts = {}
    for d in data:
        label = d[-1]
        if label in labelCounts.keys():
            labelCounts[label] += 1
        else:
            labelCounts[label] = 1
    #计算熵
    ent = 0.0
    for key in labelCounts.keys():
        class_num = labelCounts[key] #类别为key的数量
        prob = class_num/length
        ent -= (prob)*math.log(prob,2)
    
    return ent

#根据某一列的取值来划分数据集
def splitDataSet(data,column,value,isContinuous,direction):
    '''
    column:对应是哪一列，column=0表示第一列
    value：column列选择值等于的value集合作为子数据集
    isContinuous：是否为连续值，如果为True，当direction=-1时，选择左区间及<value的区间，若direction=1，选择右区间
    return：column列中为value值时的子数据集
    '''
    subSet = []
    if isContinuous: #如果为连续数据
        if direction == 1:
            for d in data:
                if d[column] <= value:
                    subSet.append(d.tolist())
                    
        elif direction == -1:
            for d in data:
                if d[column] >= value:
                    subSet.append(d.tolist())
                    
    else:#否则为枚举值
        for d in data:
            if d[column] == value:
                subSet.append(d.tolist())

    #转换成ndarray类型
    subSet = np.array(subSet)
    
    return subSet
    

#从features_columns列中，选择信息增益最大的列所对应的特征
def select_best_features(data,features_columns):
    '''
    features_columns:一个list，表示代表的第features_columns[*]列，从0开始
    '''
    if calNotNegative1(features_columns) == 1:#如果只有1个可选的特征，则不需要选择了
        #查找不为-1所在的列是多少
        for features_column in features_columns:
            if features_column != -1:
                return features_column
    
    best_feature_column = None #最好的特征所在的列
    best_feature_gain = -999999999 #最好的增益率
    #计算全部data的信息熵
    all_ent = calculate_entropy(data)
    #按照特征划分数据集分别进行信息熵
    for feature_column in returnNotNegative1List(features_columns):
        columns_value_num = set(data[:,feature_column]) #feature_column列中值的个数
        if len(columns_value_num)>6:#如果大于6个值，把它当做连续值
            #为了简单起见，这里只使用平均值来划分数据集
            ave = np.average(data[:,feature_column],axis=0)
            #计算信息增益
            left_subData = splitDataSet(data,feature_column,ave,isContinuous=True,direction=-1)
            right_subData = splitDataSet(data,feature_column,ave,isContinuous=True,direction=1)
            left_ent = calculate_entropy(left_subData)
            right_ent = calculate_entropy(right_subData)
            ent = (len(left_subData)/len(data))*left_ent + (len(right_subData)/len(data))*right_ent
            gain = all_ent - ent
            if best_feature_gain<gain:
                best_feature_gain = gain
                best_feature_column = feature_column
    return best_feature_column

#如果把所有的特征都选择完后，数据的标签仍然不唯一，则采用少数服从多数的投票方式来决定最终的标签
def vote(data):
    labels = data[:,-1]#该数据集所有的标签
    labelDict = {}
    for label in labels:
        if label in labelDict.keys():
            labelDict[label] += 1
        else:
            labelDict[label] = 1
            
    max_label = None
    for label in labelDict.keys():
        if max_label is None:
            max_label = label
        elif labelDict[label] > labelDict[max_label]:
            max_label = label
    return max_label

#计算一个list中值不为-1的数量
def calNotNegative1(list):
    count = 0
    for i in list:
        if i != -1:
            count += 1
    return count

#返回list中不为-1的数
def returnNotNegative1List(list):
    l = []
    for i in list:
        if i != -1:
            l.append(i)
    return l
    
    

def create_tree(data,features_columns):
    '''
    data：数据集
    features：数据集中特征的列号，从0开始，不含标签
    '''
    #所有数据的类标签相同，属于同一类
    labelSet = set(data[-1])
    if len(labelSet) == 1:
        return labelSet[0]
    #已经遍历完所有的特征
    if calNotNegative1(features_columns) == 0:
        return vote(data)
    #找出当前最优的分类特征列号
    best_feature_col = select_best_features(data,features_columns)
    #使用字典嵌套字典的方法来存放DT的信息
    myTree={features_name[best_feature_col]:{}}
    
    #复制特征列，防止改变原始的特征数据
    subFeature_columns = features_columns[:]
    subFeature_columns[best_feature_col] = -1 #-1表示删除了一个特征
    #获取最优特征中所在列的值
    bestFeatureValues=[d[best_feature_col] for d in data]
    uniqueValues = set(bestFeatureValues) #取唯一值
    if len(uniqueValues)>6:#视为连续值
        ave = np.average(data[:,best_feature_col],axis=0)#使用均值作为分割点
        for i in [-1,1]:
            myTree[features_name[best_feature_col]][("<=" if i==-1 else ">=")+str(ave)]=create_tree\
            (splitDataSet(data,best_feature_col,ave,isContinuous=True,direction=i),subFeature_columns)
            
    else:#枚举值的情况
        for value in uniqueValues:#遍历枚举值
            myTree[features_name[best_feature_col]][value]=\
            splitDataSet(data,best_feature_col,value),subFeature_columns)
            
    
    return myTree

In [67]:
print(calculate_entropy(dataset))

1.581826669453059


In [101]:
select_best_features(dataset,[1,2,3])

2

In [134]:
vote(np.array([[1],[1],[1],[1],[1],[1],[1],[1],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]]),[])

0

In [241]:
create_tree(dataset,features_columns)

{'petal length in cm': {'<=3.75866666667': {'petal width in cm': {'<=1.72258064516': {'sepal width in cm': {'<=3.01739130435': {'sepal length in cm ': {'<=6.75555555556': 2.0,
        '>=6.75555555556': 2.0}},
      '>=3.01739130435': {'sepal length in cm ': {'<=6.54285714286': 2.0,
        '>=6.54285714286': 2.0}}}},
    '>=1.72258064516': {'sepal width in cm': {'<=2.79361702128': {'sepal length in cm ': {'<=6.24137931034': 1.0,
        '>=6.24137931034': 1.0}},
      '>=2.79361702128': {'sepal length in cm ': {'<=5.75555555556': 1.0,
        '>=5.75555555556': 1.0}}}}}},
  '>=3.75866666667': {'petal width in cm': {'<=0.343859649123': {'sepal width in cm': {'<=3.15': {'sepal length in cm ': {}},
      '>=3.15': {'sepal length in cm ': {}}}},
    '>=0.343859649123': {'sepal length in cm ': {'<=4.95365853659': {'sepal width in cm': {'<=3.58571428571': 0.0,
        '>=3.58571428571': 0.0}},
      '>=4.95365853659': {'sepal width in cm': {'<=3.115': 0.0,
        '>=3.115': 0.0}}}}}}}}

In [239]:
calNotNegative1([1,2,3,4,-1])

4