# CART决策树

In [25]:
import numpy as np
import pandas as pd

In [26]:
df = pd.read_csv('ML04_data.csv')
df

Unnamed: 0,humility,outlook,play,temp,windy
0,high,sunny,no,hot,False
1,high,sunny,no,hot,True
2,high,overcast,yes,hot,False
3,high,rainy,yes,mild,False
4,normal,rainy,yes,cool,False
5,normal,rainy,no,cool,True
6,normal,overcast,yes,cool,True
7,high,sunny,no,mild,False
8,normal,sunny,yes,cool,False
9,normal,rainy,yes,mild,False


In [27]:
print('属性： {0}'.format(df.columns.tolist()))
print('标签： {0}'.format('play'))
print('数量： {0}'.format(len(df)))

属性： ['humility', 'outlook', 'play', 'temp', 'windy']
标签： play
数量： 14


In [31]:
len(set(df['windy'].tolist()))

2

In [76]:
import sys
import pandas as pd
import sklearn
import numpy


class CARTTree:
    class Node:
        def __init__(self, name):
            self.name = name
            self.connections = {}

        def connect(self, label, node):
            self.connections[label] = node

    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node('Root')

    def gini(self, nums):
        # [p1, p2]
        probs = [nums.count(i) / len(nums) for i in set(nums)]
        # Dini(D) = 1 - sum(p^2)
        gini = sum([p*(1 - p) for p in probs])
        return gini

    def split_dataframe(self, data, col):
        unique_values = data[col].unique()
        '''
        windy
        result_dict:{
            hot: pd.DataFrame,
            mild: pd.DataFrame,
            cool: pd.DataFrame
        }
        '''
        result_dict = {elem: pd.DataFrame for elem in unique_values}
        for i in result_dict.keys():
            result_dict[i] = data[:][data[col] == i]
        return result_dict

    def choose_best_col(self, df, label):
        # 属性列表，将label排除
        cols = [col for col in df.columns if col not in [label]]
        # 初始化最大信息熵、最佳属性和最佳分割字典
        min_value, best_col = 999, None
        min_splited = None
        for col in cols:
            splited_set = self.split_dataframe(df, col)
            gini_DA = 0
            for subset_col, subset in splited_set.items():
                # Gini(Dv)
                gini_Di = self.gini(subset[label].tolist())
                # Gini_index(D, a) = sum( Dv / D * Gini(Dv))
                gini_DA += len(subset) / len(df) * gini_Di

                if gini_DA < min_value:
                    min_value, best_col = gini_DA, col
                    min_splited = splited_set
        return min_value, best_col, min_splited

    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)

    def construct(self, parent_node, parent_connection_label, input_data, columns):
        min_value, best_col, min_splited = self.choose_best_col(input_data[columns], self.label)
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)
            return
        
        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)
        new_columns = [col for col in columns if col != best_col]
        
        for splited_value, splited_data in min_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)
    
    def print_tree(self, node, tabs):
        print(tabs + node.name)
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + str(connection) + ")")
            self.print_tree(child_node, "\t\t"+tabs + "\t\t\t")

In [77]:
tree1 = CARTTree(df, 'play')

In [78]:
tree1.construct_tree()

In [79]:
tree1.print_tree(tree1.root, '')

Root
	()
					temp
						(hot)
										outlook
											(sunny)
															humility
																(high)
																				windy
																					(False)
																									no
																					(True)
																									no
											(overcast)
															humility
																(high)
																				windy
																					(False)
																									yes
																(normal)
																				windy
																					(False)
																									yes
						(mild)
										outlook
											(rainy)
															windy
																(False)
																				humility
																					(high)
																									yes
																					(normal)
																									yes
																(True)
																				humility
																					(high)
																									no
											(sunny)
															humility
																(high)
																				windy
																					(False)
															