LDATree
is an R modeling package for fitting classification trees. If
you are unfamiliar with classification trees, here is a
tutorial
about the traditional CART and its R implementation rpart
.
Compared to other similar trees, LDATree
sets itself apart in the
following ways:
-
It applies the idea of LDA (Linear Discriminant Analysis) when selecting variables, finding splits, and fitting models in terminal nodes.
-
It addresses certain limitations of the R implementation of LDA (
MASS::lda
), such as handling missing values, dealing with more features than samples, and constant values within groups. -
Re-implement LDA using the Generalized Singular Value Decomposition (GSVD), LDATree offers quick response, particularly with large datasets.
-
The package also includes several visualization tools to provide deeper insights into the data.
install.packages("LDATree")
To build an LDATree:
library(LDATree)
set.seed(456)
fit <- Treee(Species~., data = iris)
To plot the LDATree:
# View the overall tree.
plot(fit)
# Three types of individual plots
# 1. Scatter plot on first two LD scores
plot(fit, data = iris, node = 1)
# 2. Density plot on the first LD score
plot(fit, data = iris, node = 3)
# 3. A message
plot(fit, data = iris, node = 5)
#> [1] "Every observation in this node is predicted to be virginica"
To make predictions:
# Prediction only.
predictions <- predict(fit, iris)
head(predictions)
#> [1] "setosa" "setosa" "setosa" "setosa" "setosa" "setosa"
# A more informative prediction
predictions <- predict(fit, iris, type = "all")
head(predictions)
#> response node setosa versicolor virginica
#> 1 setosa 3 1 9.281826e-27 0
#> 2 setosa 3 1 3.107853e-22 0
#> 3 setosa 3 1 1.049363e-24 0
#> 4 setosa 3 1 9.134151e-22 0
#> 5 setosa 3 1 1.672418e-27 0
#> 6 setosa 3 1 1.808762e-24 0
If you encounter a clear bug, please file an issue with a minimal reproducible example on GitHub