## Decision Tree CART

The full name of CART is Classification and Regression Tree, which means that it can be used on both classification and regression. The regression tree is used for modeling tasks where the target variable is continuous, and its feature selection criterion uses the minimum squared error. The classification tree is used for modeling tasks where the target variable is discrete, and the feature selection criterion is the Gini index.

CART is a learning method that outputs a conditional probability distribution of a random variable $Y$ given an input random variable $X$. The CART algorithm divides the input space, that is, the feature space, into finite units by selecting the optimal features and eigenvalues, and determines the predicted probability distribution on these units, that is, outputs the conditional probability distribution under the given input conditions. .

The complete CART algorithm includes three parts: feature selection, decision tree generation and decision tree pruning. Whether it is a regression tree or a classification tree, the core of the algorithm is to recursively select the optimal features to build a decision tree.

In addition to selecting the optimal features to build a decision tree, pruning is another important part of this algorithm. It can be regarded as a regularization method of decision tree algorithm. As a rule-based and non-parametric supervised learning method, decision tree is easy to overfit during training, resulting in low generalization performance of the final generated decision tree.

### Regression Tree

Given the input variables $X$ and output variables $Y$, the generation of a regression tree correspond to the split on the input space and the output values of the divisiory units. Assuming that the input space is divided into $M$ units $R_{1}, R_{2}, \cdots , R_{M}$, and there is a fixed output $c_{m}$ for every unit. THen the model of regression tree model can be expressed as:
$$
f(x)= \sum _{m=1}^{M} c_{m}I(x \in R_{m})
$$

When the input space partition is determined, regression tree algorithm iterates over all features and uses MSE to choose the optimal features and optimal segmentation points. 
$$
\min _{j . s}\left[\min _{c_{1}} \sum_{x_{i} \in R_{1}(j, s)}\left(y_{i}-c_{1}\right)^{2}+\min _{c_{2}} \sum_{x_{i} \in R_{2}(j, s)}\left(y_{i}-c_{2}\right)^{2}\right]
$$

This method is also called Least Squares Regression Tree Algorithm. The larger the tree depth of the regression tree, the higher the model complexity and the better the fit to the data, but the corresponding generalization ability cannot be guaranteed.

### Classification Tree

CART classification tree is quite different from the regression tree but is similar with ID3 and C4.5 decision tree. Unlike ID3 and C4.5, CART classification tree uses Gini index to select features. 

The Gini index is for probability distributions. Assuming that there are $K$ classes in a classification problem, and the probability that the sample belongs to the $k-th$ class is $p_k$, then the Gini index of the probability distribution of the sample is
$$
Gini(p) = \sum _{k=1} ^{K} p_{k}(1-p_{k})
$$

In practical classification calculation, the Gini index for a given data samples $D$ is
$$
Gini(D) = 1 - \sum _{k=1} ^{K} {(\frac{c_{k}}{D})}^{2}
$$

The corresponding conditional Gini index, that is, the Gini index of data set $D$ under the condition of given feature $A$, is calculated as follows
$$
Gini(D,A) = \frac{|D_{1}|}{|D|} Gini(D_{1}) + \frac{|D_{2}|}{|D|} Gini(D_{2})
$$

When constructing the classification tree, the feature with the smallest conditional Gini index is selected as the optimal feature to construct the decision tree.

In [13]:
def gini(nums):
    probs = [nums.count(i)/len(nums) for i in set(nums)]
    gini = sum([p*(1-p) for p in probs])
    return gini

### Pruning

In order to construct a decision tree with better generalization performance, we need to prune the tree. The so-called pruning is the process of simplifying the constructed decision tree. Specifically, it is to cut some subtrees or leaf nodes from the generated tree, and use its root node or parent node as a new leaf node. Generally speaking, there are two ways of pruning, pre-pruning and post-pruning.

Pre-pruning is a method of pruning in the process of tree generation. The key is to calculate whether the current feature division can improve the generalization performance of the decision tree before the nodes in the tree are expanded. If not, then the trees no longer grow. Pre-pruning is relatively straightforward, the algorithm is simple and efficient, and it is suitable for large-scale calculations, but pre-pruning may have a risk of "early stopping", which may lead to under-fitting of the model.

Post-pruning is to wait for the tree to grow completely before pruning from the bottom leaf node. CART pruning is a post-pruning method. Simply speaking, it is to prune the complete tree node by node from the bottom up, and each pruning will form a subtree to the root node, thus forming a subtree sequence. Then perform cross-validation on all subtrees to find out which subtree has the smallest error and which is the optimal subtree.

In [1]:
import numpy as np 
import pandas as pd 
from math import log

df = pd.read_csv('example_data.csv', dtype={'windy': 'str'})  
df

Unnamed: 0,humility,outlook,play,temp,windy
0,high,sunny,no,hot,False
1,high,sunny,no,hot,True
2,high,overcast,yes,hot,False
3,high,rainy,yes,mild,False
4,normal,rainy,yes,cool,False
5,normal,rainy,no,cool,True
6,normal,overcast,yes,cool,True
7,high,sunny,no,mild,False
8,normal,sunny,yes,cool,False
9,normal,rainy,yes,mild,False


In [6]:
# split the data set based on feature and feature value
def split_dataframe(data, col): 
    '''
    input: dataframe, column name.
    output: a dict of splited dataframe.
    '''
    # unique value of column
    unique_values = data[col].unique()
    # empty dict of dataframe
    result_dict = {elem : pd.DataFrame for elem in unique_values}
    # split dataframe based on column value
    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]
    return result_dict

In [7]:
split_dataframe(df, 'temp')

{'hot':    humility   outlook play temp  windy
 0      high     sunny   no  hot  FALSE
 1      high     sunny   no  hot   TRUE
 2      high  overcast  yes  hot  FALSE
 12   normal  overcast  yes  hot  FALSE,
 'mild':    humility   outlook play  temp  windy
 3      high     rainy  yes  mild  FALSE
 7      high     sunny   no  mild  FALSE
 9    normal     rainy  yes  mild  FALSE
 10   normal     sunny  yes  mild   TRUE
 11     high  overcast  yes  mild   TRUE
 13     high     rainy   no  mild   TRUE,
 'cool':   humility   outlook play  temp  windy
 4   normal     rainy  yes  cool  FALSE
 5   normal     rainy   no  cool   TRUE
 6   normal  overcast  yes  cool   TRUE
 8   normal     sunny  yes  cool  FALSE}

In [9]:
# choose the best column based on Gini index
def choose_best_col(df, label): 
    # Calculating label's gini index
    gini_D = gini(df[label].tolist())
    # columns list except label
    cols = [col for col in df.columns if col not in [label]]
    # initialize the max infomation gain, best column and best splited dict
    min_value, best_col, min_splited = 999, None, None
    # split data based on different column 
    for col in cols:
        splited_set = split_dataframe(df, col) 
        gini_DA = 0
        for subset_col, subset in splited_set.items():
            # calculating splited dataframe label's gini index
            gini_Di = gini(subset[label].tolist())
            # calculating gini index of current feature
            gini_DA += len(subset)/len(df) * gini_Di
            if gini_DA < min_value:
                min_value, best_col = gini_DA, col
                min_splited = splited_set
                
    return min_value, best_col, min_splited

In [10]:
class CartTree:
    # define a node class
    class Node:
        def __init__(self, name):
            self.name = name
            self.connections = {}

        def connect(self, label, node):
            self.connections[label] = node

    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")

    # print the tree
    def print_tree(self, node, tabs):
        print(tabs + node.name)
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + str(connection) + ")") 
            self.print_tree(child_node, tabs + "\t\t")

    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)

    # construct tree
    def construct(self, parent_node, parent_connection_label, input_data, columns): 
        min_value, best_col, min_splited = choose_best_col(input_data[columns], self.label)
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0]) 
            parent_node.connect(parent_connection_label, node) 
            return

        node = self.Node(best_col) 
        parent_node.connect(parent_connection_label, node)

        new_columns = [col for col in columns if col != best_col] 
        # Recursively constructing decision trees
        for splited_value, splited_data in min_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)

In [12]:
tree = CartTree(df, 'play')
tree.construct_tree()
tree.print_tree(tree.root, "")

Root
	()
		temp
			(hot)
				outlook
					(sunny)
						humility
							(high)
								windy
									(FALSE)
										no
									(TRUE)
										no
					(overcast)
						humility
							(high)
								windy
									(FALSE)
										yes
							(normal)
								windy
									(FALSE)
										yes
			(mild)
				outlook
					(rainy)
						windy
							(FALSE)
								humility
									(high)
										yes
									(normal)
										yes
							(TRUE)
								humility
									(high)
										no
					(sunny)
						humility
							(high)
								windy
									(FALSE)
										no
							(normal)
								windy
									(TRUE)
										yes
					(overcast)
						humility
							(high)
								windy
									(TRUE)
										yes
			(cool)
				windy
					(FALSE)
						humility
							(normal)
								outlook
									(rainy)
										yes
									(sunny)
										yes
					(TRUE)
						outlook
							(rainy)
								humility
									(normal)
										no
							(overcast)
								humi