Skip to content

Commit

Permalink
improve functions and viz
Browse files Browse the repository at this point in the history
  • Loading branch information
KelvynBladen committed Oct 12, 2023
1 parent 0db386c commit fff5a38
Show file tree
Hide file tree
Showing 10 changed files with 534 additions and 62 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ importFrom(dplyr,filter)
importFrom(dplyr,group_by)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
importFrom(dplyr,pull)
importFrom(dplyr,relocate)
importFrom(dplyr,rename)
importFrom(dplyr,select)
Expand All @@ -30,7 +31,6 @@ importFrom(gbm,relative.influence)
importFrom(ggeasy,easy_center_title)
importFrom(ggeasy,easy_plot_legend_size)
importFrom(ggplot2,aes)
importFrom(ggplot2,element_text)
importFrom(ggplot2,facet_grid)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_line)
Expand Down
4 changes: 2 additions & 2 deletions R/caret_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ caret_plot <- function(x, sqrt = FALSE, marg1 = FALSE, marg2 = FALSE,
names(l)[2] <- paste0("full_", r1)
names(l)[3] <- paste0("full_", r2)

if (marg1) {
if (marg1 == T) {
if (length(res) > 3) {
lava <- res
vc1 <- colnames(lava)[length(lava) - 1]
Expand Down Expand Up @@ -279,7 +279,7 @@ caret_plot <- function(x, sqrt = FALSE, marg1 = FALSE, marg2 = FALSE,
}
}

if (marg2) {
if (marg2 == T) {
if (length(res) > 4) {
lava <- res
vc1 <- colnames(lava)[length(lava) - 1]
Expand Down
20 changes: 13 additions & 7 deletions R/ggvip.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#' @name ggvip
#' @importFrom randomForest importance
#' @importFrom dplyr %>% arrange desc filter between case_when
#' @importFrom ggplot2 ggplot geom_point xlim xlab ylab ggtitle theme
#' aes element_text
#' @importFrom ggplot2 ggplot geom_point xlim xlab ylab ggtitle aes
#' @importFrom gridExtra grid.arrange
#' @importFrom ggeasy easy_center_title
#' @importFrom rlang .data
#' @description A ggplot of variable importance as measured by a Random Forest.
#' @param x An object of class randomForest.
Expand Down Expand Up @@ -98,7 +98,7 @@ ggvip <- function(x, scale = FALSE, sqrt = TRUE, type = "both", num_var) {
colnames(imp_frame)[1]
)) +
ggtitle("VIP") +
theme(plot.title = element_text(hjust = 0.5))
easy_center_title()

l <- list()
l$vip <- g
Expand Down Expand Up @@ -201,15 +201,21 @@ ggvip <- function(x, scale = FALSE, sqrt = TRUE, type = "both", num_var) {

l <- list()
if (type %in% c("mse", "acc", 1)) {
l$vip <- g1
l$vip <- g1 +
ggtitle("Permutation Importance") +
easy_center_title()
l$table <- imp_frame1[, -2]
} else if (type %in% c("purity", "gini", 2)) {
l$vip <- g2
l$vip <- g2 +
ggtitle("Purity Importance") +
easy_center_title()
l$table <- imp_frame2[, -1]
} else {
l$both_vips <- gridExtra::grid.arrange(g1, g2,
l$both_vips <- gridExtra::grid.arrange(
g1,
g2,
nrow = 1,
top = "Variable Importances using ggplot Graphics"
top = "Variable Importance"
)
l$accuracy_vip <- g1
l$purity_vip <- g2
Expand Down
100 changes: 61 additions & 39 deletions R/mtry_compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
#' @name mtry_compare
#' @importFrom randomForest importance randomForest
#' @importFrom dplyr %>% arrange across ends_with desc filter select
#' summarise group_by case_when
#' summarise group_by case_when between mutate pull
#' @importFrom ggplot2 ggplot geom_point geom_line ylab ggtitle theme
#' aes scale_x_continuous scale_y_continuous
#' @importFrom ggeasy easy_center_title
#' @importFrom tidyr pivot_wider
#' @importFrom stats model.frame na.omit quantile
#' @importFrom methods is
Expand All @@ -20,14 +21,17 @@
#' the measures be divided by their "standard errors"? Default is False.
#' @param sqrt Boolean value indicating whether importance metrics should be
#' adjusted via a square root transformation. Default is True.
#' @param num_var Optional integer argument for reducing the number of
#' @param num_var Optional integer argument for reducing the number of plotted
#' variables to the top 'num_var'. Should be an integer between 1 and the
#' total number of predictor variables in the model or it should be a
#' positive proportion of variables desired.
#' positive proportion of variables desired. If not provided, all variables
#' are used.
#' @param mvec Optional vector argument for defining choices of mtry to have the
#' function consider. Should be a vector of integers between 1 and the total
#' number of predictor variables in the model. Or it can be a vector of
#' proportions (strictly less than 1) of the number of predictor variables.
#' proportions (between 0 and 1) of the number of predictor variables. If not
#' provided, mvec is set to a vector of the lowest possible value, the
#' default value, the highest possible value, and a middle value.
#' @param ... Other parameters to pass to the randomForest function.
#' @return A list of data.frames, useful plots, and forest objects for user
#' evaluations of the randomForest hyperparameter mtry.
Expand Down Expand Up @@ -65,12 +69,12 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
}

ifelse(!missing(num_var),
num_var <- ifelse(num_var >= num_preds | num_var <= 0,
num_preds,
num_var <- ifelse(dplyr::between(num_var, 0, num_preds),
ifelse(num_var < 1,
round(num_var * num_preds),
round(num_var)
)
),
num_preds
),
num_var <- num_preds
)
Expand All @@ -83,9 +87,10 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
)))

vf <- paste0("sv", i)
eval(call("<-", as.name(vf), importance(get(paste0("srf", i)),
scale = scale
)))
eval(call(
"<-", as.name(vf),
importance(get(paste0("srf", i)), scale = scale)
))

v <- as.data.frame(get(vf))

Expand All @@ -107,34 +112,32 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
eval(call("<-", as.name(vf), y))
}

err_v <- 0
for (i in mvec) {
mod <- get(paste0("srf", i))
sd <- data.frame()
err_v <- vector(length = length(mvec))
for (i in seq_len(length(mvec))) {
sd <- rbind(sd, get(paste0("sv", mvec[i])))

mod <- get(paste0("srf", mvec[i]))

ifelse(is(model_frame[1, 1], "numeric"),
err_v <- c(err_v, mod$mse[mod$ntree]),
err_v <- c(err_v, mod$err.rate[mod$ntree])
err_v[i] <- mod$mse[mod$ntree],
err_v[i] <- mod$err.rate[mod$ntree]
)
}
err_v <- err_v[-1]

colnames(sd) <- gsub("%", "Pct", colnames(sd))

sd_full <- sd

ifelse(is(model_frame[1, 1], "numeric"),
err_df <- data.frame(mtry = mvec, mse = err_v),
err_df <- data.frame(mtry = mvec, misclass_rate = err_v)
)

sd <- data.frame()
for (i in mvec) {
sd <- rbind(sd, get(paste0("sv", i)))
}

if (ncol(sd) > 0) {
if (colnames(sd)[1] == "%IncMSE") {
colnames(sd)[1] <- "PctIncMSE"
}
}

sd_full <- sd
ifelse(is(model_frame[1, 1], "numeric"),
yl <- "Mean Squared Error",
yl <- "Misclassification Rate"
)

if (!missing(num_var)) {
d <- sd %>%
Expand Down Expand Up @@ -163,7 +166,6 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
values_from = colnames(sd_full)[2]
) %>%
as.data.frame()

colnames(n)[-1] <- paste0("m", colnames(n)[-1])
rownames(n) <- n$names
n <- n[-1] %>% arrange(desc(across(ends_with(as.character(num_preds)))))
Expand All @@ -186,6 +188,13 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
.default = 4
)

levs <- sd |>
filter(mtry == num_preds) |>
arrange(desc(.data[[colnames(sd)[1]]])) |>
mutate(names = factor(names, levels = names)) |>
pull(names)
sd$names <- factor(sd$names, levels = levs)

g1 <- sd %>%
ggplot(aes(
x = .data[[colnames(sd)[4]]], y = .data[[colnames(sd)[1]]],
Expand All @@ -198,8 +207,9 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
limits = c(0, newm),
breaks = seq(0, newm, by = newm / div)
) +
ggtitle("Variable Importance Based on Decrease in Error Across mtry")

ylab("Permutation Importance") +
ggtitle("Permutation Importance across mtry") +
easy_center_title()

m <- max(sd[2])
v <- 10^(-3:6)
Expand All @@ -219,6 +229,13 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
.default = 4
)

levs <- sd |>
filter(mtry == num_preds) |>
arrange(desc(.data[[colnames(sd)[2]]])) |>
mutate(names = factor(names, levels = names)) |>
pull(names)
sd$names <- factor(sd$names, levels = levs)

g2 <- sd %>%
ggplot(aes(
x = .data[[colnames(sd)[4]]], y = .data[[colnames(sd)[2]]],
Expand All @@ -230,8 +247,10 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
limits = c(0, newm),
breaks = seq(0, newm, by = newm / div)
) +
ggtitle("Variable Importance Based on Purity Contribution Across mtry") +
scale_x_continuous(breaks = mvec)
scale_x_continuous(breaks = mvec) +
ylab("Purity Importance") +
ggtitle("Purity Importance across mtry") +
easy_center_title()


m <- max(err_df[2])
Expand All @@ -253,27 +272,30 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
)

g_err <- err_df %>%
ggplot(aes(x = .data[[colnames(err_df)[1]]],
y = .data[[colnames(err_df)[2]]])) +
ggplot(aes(
x = .data[[colnames(err_df)[1]]],
y = .data[[colnames(err_df)[2]]]
)) +
geom_point() +
geom_line() +
scale_x_continuous(limits = c(1, num_preds), breaks = mvec) +
scale_y_continuous(
limits = c(0, newm),
breaks = seq(0, newm, by = newm / div)
) +
ggtitle("Model Errors Across mtry")

ylab(yl) +
ggtitle("Model Error across mtry") +
easy_center_title()

rownames(sd_full) <- seq_len(nrow(sd_full))

l <- list()

l$importance <- sd_full[c(3, 4, 1, 2)]
l$model_errors <- err_df
l$var_imp_error <- k
l$var_imp_permute <- k
l$var_imp_purity <- n
l$gg_var_imp_error <- g1
l$gg_var_imp_permute <- g1
l$gg_var_imp_purity <- g2
l$gg_model_errors <- g_err

Expand Down
13 changes: 13 additions & 0 deletions R/mtry_pdp_compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,19 @@ mtry_pdp_compare <- function(formula, data = NULL, mvec, var_vec,
l
}

# m <- mtry_pdp_compare(formula = factor(LobaOreg)~., data = lichen)
# gr <- gridExtra::grid.arrange(m$MinTempAve, m$ACONIF,
# m$Elevation, m$MinTempDiff,
# m$AmbVapPressAve, m$AveTempAve, nrow = 3)
# ggsave("gr.jpg",
# gridExtra::grid.arrange(m$MinTempAve + ylab(""),
# m$ACONIF + ylab(""),
# m$Elevation + ylab("") + xlab("xElevation"),
# m$MinTempDiff + ylab(""),
# m$AmbVapPressAve + ylab(""),
# m$AveTempAve + ylab(""), nrow = 3),
# dpi = 2800, width = 8, height = 8)
# library(ggplot2)
# m <- mtry_pdp_compare(formula = medv ~ ., data = MASS::Boston)
# m$full_num
# m$full_fac
Expand Down
Loading

0 comments on commit fff5a38

Please sign in to comment.