-
Notifications
You must be signed in to change notification settings - Fork 0
/
LDATree.Rmd
116 lines (81 loc) · 4.25 KB
/
LDATree.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
---
title: "Introduction to LDATree"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Introduction to LDATree}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
library(LDATree)
```
`LDATree` is an R modeling package for fitting classification trees. If you are unfamiliar with classification trees, here is a [tutorial](http://www.sthda.com/english/articles/35-statistical-machine-learning-essentials/141-cart-model-decision-tree-essentials/) about the traditional CART and its R implementation `rpart`.
Compared to other similar trees, `LDATree` sets itself apart in the following ways:
* It applies the idea of LDA (Linear Discriminant Analysis) when selecting variables, finding splits, and fitting models in terminal nodes.
* It addresses certain limitations of the R implementation of LDA (`MASS::lda`), such as handling missing values, dealing with more features than samples, and constant values within groups.
* Re-implement LDA using the Generalized Singular Value Decomposition (GSVD), LDATree offers quick response, particularly with large datasets.
* The package also includes several visualization tools to provide deeper insights into the data.
# Build the Tree
Currently, `LDATree` offers two methods to construct a tree:
1. The first method utilizes a direct-stopping rule, halting the growth process once specific conditions are satisfied.
1. The second approach involves pruning: it permits the building of a larger tree, which is then pruned using cross-validation.
```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE}
mpg <- as.data.frame(ggplot2::mpg)
datX <- mpg[, -5] # All predictors without Y
response <- mpg[, 5] # we try to predict "cyl" (number of cylinders)
# Build a tree using direct-stopping rule
fit <- Treee(datX = datX, response = response, pruneMethod = "pre", verbose = FALSE)
# Build a tree using cross-validation
set.seed(443)
fitCV <- Treee(datX = datX, response = response, pruneMethod = "post", verbose = FALSE)
```
# Plot the Tree
`LDATree` offers two plotting methods:
1. You can use `plot` directly to view the full tree diagram.
1. To check the individual plot for the node that you are interested in, you have to input the (training) data and specify the node index.
## Overall Plot
```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE}
# View the overall tree
# plot(fit) # Tips: Try clicking on the nodes...
```
## Individual Plots
```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE}
# Three types of individual plots
# 1. Scatter plot on first two LD scores
plot(fit, datX = datX, response = response, node = 1)
# 2. Density plot on the first LD score
plot(fit, datX = datX, response = response, node = 3)
# 3. A message
plot(fit, datX = datX, response = response, node = 2)
```
# Make Predictions
```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE}
# Prediction only
predictions <- predict(fit, datX)
head(predictions)
```
```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE}
# A more informative prediction
predictions <- predict(fit, datX, type = "all")
head(predictions)
```
# Missing Values
For missing values, you do not need to specify anything (unless you want to); `LDATree` will handle it. By default, it fills in missing numerical variables with their median and adds a missing flag. For missing factor variables, it assigns a new level. For more options, please refer to `help(Treee)`.
```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE}
#
datXmissing <- datX
for(i in 1:10) datXmissing[sample(234,20),i] <- NA
fitMissing <- Treee(datX = datXmissing, response = response, pruneMethod = "post", verbose = FALSE)
plot(fitMissing, datX = datXmissing, response = response, node = 1)
```
# LDA/GSVD
As we re-implement the LDA/GSVD and apply it in the model fitting, a by-product is the `ldaGSVD` function. Feel free to play with it and see how it compares to `MASS::lda`.
```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE}
fitLDAgsvd <- ldaGSVD(datX = datX, response = response)
predictionsLDAgsvd <- predict(fitLDAgsvd, newdata = datX)
mean(predictionsLDAgsvd == response) # Training error
```