## Classification Trees

In a classification tree we predict that each observation belongs to the most commonly occurring class of training observations in the region to which it belongs.

The task of growing (inducing) a classification tree is typically recursive; that is we load data into a root node and successively split that node into two children, then those children each into two children, and so on.
This is referred to as _recursive binary splitting_ to grow a classification tree. 
The critical technique of learning the model parameters is the analysis of classification error rate per node. 
Since we plan to assign an observation in a given region to the most commonly occurring class of training observations in that region, the classification error rate is simply the fraction of the training observations in that region that do not belong to the most common class. 

[Dig deeper](https://www.analyticsvidhya.com/blog/2016/04/complete-tutorial-tree-based-modeling-scratch-in-python/)

We can define the classification error as such:

$$E = 1 - \max_k(\hat{p}_{mk})$$

Here $\hat{p}_{mk}$ represents the proportion of training observations in the $mth$ region that are from the $kth$ class. 
However, it turns out that classification error is not sufficiently sensitive for tree-growing, 
and in practice two other measures are preferable.

##### Gini index

$$ G = \sum_{k=1}^K\hat{p}_{mk}(1-\hat{p}_{mk})\ ,$$

The Gini index is a measure of total variance across the $K$ classes. 
It is not hard to see that the Gini index takes on a small value if all of the $\hat{p}_{mk}$'s are close to zero or one. 
For this reason the Gini index is referred to as a measure of **node purity** - 
a small value indicates that a node contains predominantly observations from a single class.

##### Cross-entropy

Cross-entropy is similar to Gini index, substituting $log\  \hat{p}_{mk}$ for $(1-\hat{p}_{mk})$.

$$D = -\sum_{k=1}^K \hat{p}_{mk}\ log\  \hat{p}_{mk}$$

Since 0 ≤ $\hat{p}_{mk}$ ≤ 1, it follows that 0 ≤ − $\hat{p}_{mk}$ log $\hat{p}_{mk}$. 
One can show that the cross-entropy will take on a value near zero if the $\hat{p}mk$’s are all near zero or near one. 
Therefore, like the Gini index, the cross-entropy will take on a small value if the $m^{th}$ node is pure.

When building a classification tree, either the Gini index or the cross-entropy are typically used to evaluate the quality of a particular split, 
since these two approaches are more sensitive to node purity than is the classification error rate.

### Fitting a Classification Tree

Decision trees can be constructed even in the presence of qualitative predictor variables. 
For instance, in the Carseat data below, some of the predictors, such as `ShelveLoc` and `Urban`
are qualitative. 
Therefore, a split on one of these variables amounts to assigning some of the qualitative values 
to one branch and assigning the remaining to the other branch.

We first use classification trees to analyze the <span style="color:#a5541a">Carseats</span> data set. 
In this data, <span style="color:#a5541a">Sales</span> is a continuous variable, and so we begin by recoding it as a binary variable. 
We use the `ifelse()` function to create a variable, called **High**, 
which takes on a value of **Yes** if the `Sales` variable exceeds 8, 
and takes on a value of **No** otherwise.

In [None]:
library(ISLR)
attach(Carseats)
High=ifelse(Sales <= 8, "No", "Yes")

Finally, we use the <span style="color:#a5541a">data.frame()</span> function to merge <span style="color:#a5541a">High</span> with the rest of the <span style="color:#a5541a">Carseats</span> data.

In [None]:
Carseats = data.frame(Carseats, High)
str(Carseats)

We now use the `rpart()` (recursive partition) function to fit a classification tree in order to predict **High** using all variables except `Sales`.

In [None]:
library(rpart)
rpart_tree <- rpart(High~.-Sales, method="anova", data=Carseats)

The <span style="color:#a5541a">summary()</span> function lists the variables that are used as internal nodes in the tree, the number of terminal nodes, and the (training) error rate.

In [None]:
summary(rpart_tree)

We use the plot() function to display the tree structure, 
and the text() function to display the node labels. 

The argument `pretty=0` instructs R to include the category names for any qualitative predictors, 
rather than simply displaying a letter for each category.

In [None]:
plot(rpart_tree)
text(rpart_tree, pretty=0)

The most important indicator of `Sales` appears to be shelving location, 
since the first branch differentiates `Good` locations from `Bad` and `Medium` locations. 

If we just type the name of the tree object, R prints output corresponding to each branch of the tree. 
R displays the split criterion (e.g. `Price<92.5`), 
the number of observations in that branch, 
the deviance, 
the overall prediction for the branch (Yes or No), 
and the fraction of observations in that branch that take on values of Yes and No. 
Branches that lead to terminal nodes are indicated using asterisks.

In [None]:
rpart_tree

#### Using a decision treee

When we need to estimate the test error.
We split the observations into a training set and a test set, 
build the tree using the training set, and evaluate its performance on the test data.
The `predict()` function can be used for this purpose. 
In the case of a classification tree, the argument `type="class"` instructs R to return the actual class prediction.

In [None]:
set.seed (2)
train = sample(1:nrow(Carseats), 200)
Carseats.test = Carseats[-train,]
High.test = High[-train]

In [None]:
rpart_tree <- rpart(High~.-Sales, data=Carseats, subset=train)

yhat = predict(rpart_tree, Carseats.test,type ="class")

table(yhat, High.test)

**Note:** The table command here produces are confusion matrix, a contingency table between expected and predicted values.

Next, we consider whether pruning the tree might lead to improved results. 
The function `prune()` performs cross-validation in order to determine the optimal level of tree complexity. 
Cost complexity pruning is used in order to select a sequence of trees for consideration. 
We use the argument `FUN=prune.misclass` in order to indicate that we want the classification error rate to guide the cross-validation and pruning process. 

The `cv.tree()` function reports the number of terminal nodes of each tree considered (size),
as well as the corresponding error rate and the value of the cost-complexity parameter used 
(k, which corresponds to α in (8.4)).

In [None]:
?prune

The next cell takes a little bit of time to execute!

In [None]:
library(caret)
library(e1071)

set.seed (3)
cpGrid = expand.grid(.cp = seq(0.01,0.5,0.01))

train(High~.-Sales, data = Carseats, method = "rpart", tuneGrid = cpGrid)

In [None]:
rpart_tree <- rpart(High~.-Sales, data=Carseats, method="class", cp = 0.5, subset=train)

yhat = predict(rpart_tree, newdata = Carseats.test, type = "class")

table(yhat, High.test)

The accuracy of the model didn't change.
Pruning did not alter the model accuracy. 

# Save your notebook!