三代迭代二叉树

In [1]:
import numpy as np
import pandas as pd
from math import log

In [2]:
df = pd.read_csv('./example_data.csv', dtype={'windy': 'str'})
df

Unnamed: 0,humility,outlook,play,temp,windy
0,high,sunny,no,hot,False
1,high,sunny,no,hot,True
2,high,overcast,yes,hot,False
3,high,rainy,yes,mild,False
4,normal,rainy,yes,cool,False
5,normal,rainy,no,cool,True
6,normal,overcast,yes,cool,True
7,high,sunny,no,mild,False
8,normal,sunny,yes,cool,False
9,normal,rainy,yes,mild,False


In [3]:
#熵函数 ，接受的输入为列表类型
def entropy(ele):
    probs = [ele.count(i)/len(ele) for i in set(ele)]
    entropy = -sum([prob*log(prob, 2) for prob in probs])
    return entropy

In [4]:
entropy(df['play'].tolist())

0.9402859586706309

In [5]:
def split_dataframe(data, col):
    unique_values = data[col].unique()
    result_dict = {elem : pd.DataFrame for elem in unique_values}
    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]
    return result_dict

split_example = split_dataframe(df, 'temp')

In [6]:
for item, value in split_example.items():
    print(item, value)

hot    humility   outlook play temp  windy
0      high     sunny   no  hot  false
1      high     sunny   no  hot   true
2      high  overcast  yes  hot  false
12   normal  overcast  yes  hot  false
mild    humility   outlook play  temp  windy
3      high     rainy  yes  mild  false
7      high     sunny   no  mild  false
9    normal     rainy  yes  mild  false
10   normal     sunny  yes  mild   true
11     high  overcast  yes  mild   true
13     high     rainy   no  mild   true
cool   humility   outlook play  temp  windy
4   normal     rainy  yes  cool  false
5   normal     rainy   no  cool   true
6   normal  overcast  yes  cool   true
8   normal     sunny  yes  cool  false


In [25]:
#最优特征选择 根据训练集和标签选择信息增益最大的特征作为最优特征
def choose_best_col(df, label):
    #训练标签的信息熵
    entropy_D = entropy(df[label].tolist())
    #特征集
    cols = [col for col in df.columns if col not in [label]]
    #初始化最大信息增益值，最优特征和划分后的数据集
    max_value, best_col = -999, None
    max_splited = None
    #历遍特征并奉举特征取值进行划分
    for col in cols:
        #当前特征划分后的数据集
        splited_set = split_dataframe(df, col)
        #初始化经验熵
        entropy_DA = 0
        #对划分后的数据集历遍计算
        for subset_col, subset in splited_set.items():
            #计算划分后的数据子集的标签信息熵
            entropy_Di = entropy(subset[label].tolist())
            #计算当前特征的经验条件熵
            entropy_DA += len(subset)/len(df) * entropy_Di
        #计算当前特征的信息增益
        info_gain = entropy_D - entropy_DA
        #获取最大信息增益，并保存对应的特征和划分结果
        if info_gain > max_value:
            max_value, best_col = info_gain, col
            max_splited = splited_set
    return max_value, best_col, max_splited
    
choose_best_col(df, 'play')

(0.2467498197744391,
 'outlook',
 {'sunny':    humility outlook play  temp  windy
  0      high   sunny   no   hot  false
  1      high   sunny   no   hot   true
  7      high   sunny   no  mild  false
  8    normal   sunny  yes  cool  false
  10   normal   sunny  yes  mild   true,
  'overcast':    humility   outlook play  temp  windy
  2      high  overcast  yes   hot  false
  6    normal  overcast  yes  cool   true
  11     high  overcast  yes  mild   true
  12   normal  overcast  yes   hot  false,
  'rainy':    humility outlook play  temp  windy
  3      high   rainy  yes  mild  false
  4    normal   rainy  yes  cool  false
  5    normal   rainy   no  cool   true
  9    normal   rainy  yes  mild  false
  13     high   rainy   no  mild   true})

In [24]:
df.columns#列标签
entropy(df['outlook'].tolist())

1.5774062828523452

In [26]:
#ID3算法类
class ID3Tree:
    #定义决策树节点类
    class Node:
        #定义树节点
        def __init__(self, name):
            self.name = name
            self.connections = {}

        #定义树连接
        def connect(self, label, node):
            self.connections[label] = node
            
    #定义全局变量，包括数据集，特征集，标签和根结点
    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")
        
    def print_tree(self, node, tabs):
        print(tabs + node.name)
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + connection + ")")
            self.print_tree(child_node, tabs + "\t\t")

    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)
        
    def construct(self, parent_node, parent_connection_label, input_data, columns):
        max_value, best_col, max_splited = choose_best_col(input_data[columns], self.label)
        
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)
            return

        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)
        
        new_columns = [col for col in columns if col != best_col]
        
        for splited_value, splited_data in max_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)

In [28]:
tree1 = ID3Tree(df, 'play')
#构造决策树
tree1.construct_tree()

In [29]:
tree1.print_tree(tree1.root, "")

Root
	()
		outlook
			(sunny)
				humility
					(high)
						temp
							(hot)
								windy
									(false)
										no
									(true)
										no
							(mild)
								windy
									(false)
										no
					(normal)
						temp
							(cool)
								windy
									(false)
										yes
							(mild)
								windy
									(true)
										yes
			(overcast)
				humility
					(high)
						temp
							(hot)
								windy
									(false)
										yes
							(mild)
								windy
									(true)
										yes
					(normal)
						temp
							(cool)
								windy
									(true)
										yes
							(hot)
								windy
									(false)
										yes
			(rainy)
				windy
					(false)
						humility
							(high)
								temp
									(mild)
										yes
							(normal)
								temp
									(cool)
										yes
									(mild)
										yes
					(true)
						humility
							(normal)
								temp
									(cool)
										no
							(high)
								temp
									(mild)
										no
