Skip to content

Commit

Permalink
reduce imports
Browse files Browse the repository at this point in the history
  • Loading branch information
KelvynBladen committed Jul 10, 2023
1 parent 21923cf commit 18bd745
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 60 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ vignettes/*.pdf
# R Environment Variables
.Renviron
inst/doc

# Misc Files
misc
9 changes: 3 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,19 @@ Depends:
Imports:
car,
dplyr,
e1071,
ggplot2,
gridExtra,
Metrics,
minerva,
plyr,
randomForest,
rpart,
stats,
tidyr,
wrapr
tidyr
Suggests:
EZtune,
e1071,
knitr,
MASS,
rmarkdown,
rpart,
testthat (>= 3.0.0)
VignetteBuilder:
knitr
Expand Down
5 changes: 0 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ export(ggvip)
export(mtry_compare)
export(partial_cor)
export(robust_vifs)
importFrom(Metrics,sse)
importFrom(car,vif)
importFrom(dplyr,"%>%")
importFrom(dplyr,across)
Expand All @@ -17,7 +16,6 @@ importFrom(dplyr,filter)
importFrom(dplyr,group_by)
importFrom(dplyr,select)
importFrom(dplyr,summarise)
importFrom(e1071,svm)
importFrom(ggplot2,aes)
importFrom(ggplot2,aes_string)
importFrom(ggplot2,element_text)
Expand All @@ -35,10 +33,8 @@ importFrom(ggplot2,ylab)
importFrom(ggplot2,ylim)
importFrom(gridExtra,grid.arrange)
importFrom(minerva,mine)
importFrom(plyr,round_any)
importFrom(randomForest,importance)
importFrom(randomForest,randomForest)
importFrom(rpart,rpart)
importFrom(stats,cor)
importFrom(stats,lm)
importFrom(stats,model.frame)
Expand All @@ -47,4 +43,3 @@ importFrom(stats,na.omit)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(tidyr,pivot_wider)
importFrom(wrapr,orderv)
26 changes: 13 additions & 13 deletions R/ggvip.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#' @importFrom ggplot2 ggplot geom_point xlim xlab ylab ggtitle theme
#' aes_string element_text
#' @importFrom gridExtra grid.arrange
#' @importFrom wrapr orderv
#' @importFrom plyr round_any
#' @description A ggplot of variable importance as measured by a Random Forest.
#' @param x An object of class randomForest.
#' @param scale For permutation based measures such as MSE or Accuracy, should
Expand Down Expand Up @@ -61,18 +59,18 @@ ggvip <- function(x, scale = FALSE, sqrt = TRUE, type = "both", num_var) {
}

if (length(colnames(imp_frame)) == 2) {
imp_frame <- imp_frame[wrapr::orderv(imp_frame[1]), ]
imp_frame <- imp_frame[do.call(base::order, as.list(imp_frame[1])), ]
imp_frame$var <- factor(imp_frame$var, levels = c(rownames(imp_frame)))

m <- max(imp_frame[1])
v <- 10^(-3:6)
ind <- findInterval(m, v)

newr <- m / (10^(ind - 5))
rrr <- plyr::round_any(newr, 10, ceiling)
rrr <- ceiling(newr / 10)*10

if (newr / rrr < 3 / 4) {
rrr <- plyr::round_any(newr, 4, ceiling)
rrr <- ceiling(newr / 4)*4
}

newm <- rrr * (10^(ind - 5))
Expand Down Expand Up @@ -106,7 +104,7 @@ ggvip <- function(x, scale = FALSE, sqrt = TRUE, type = "both", num_var) {
l$table <- imp_frame
l
} else {
imp_frame <- imp_frame[wrapr::orderv(imp_frame[1]), ]
imp_frame <- imp_frame[do.call(base::order, as.list(imp_frame[1])), ]
imp_frame$var <- factor(imp_frame$var, levels = c(rownames(imp_frame)))

if (colnames(imp_frame)[1] == "%IncMSE") {
Expand All @@ -118,10 +116,10 @@ ggvip <- function(x, scale = FALSE, sqrt = TRUE, type = "both", num_var) {
ind <- findInterval(m, v)

newr <- m / (10^(ind - 5))
rrr <- plyr::round_any(newr, 10, ceiling)
rrr <- ceiling(newr/10)*10

if (newr / rrr < 3 / 4) {
rrr <- plyr::round_any(newr, 4, ceiling)
rrr <- ceiling(newr/4)*4
}

newm <- rrr * (10^(ind - 5))
Expand All @@ -135,7 +133,8 @@ ggvip <- function(x, scale = FALSE, sqrt = TRUE, type = "both", num_var) {

imp_frame1 <- imp_frame

imp_frame1 <- imp_frame1[rev(wrapr::orderv(imp_frame1[1])), ]
imp_frame1 <- imp_frame1[rev(do.call(base::order,
as.list(imp_frame1[1]))), ]

g1 <- imp_frame %>%
ggplot(aes_string(
Expand All @@ -152,17 +151,17 @@ ggvip <- function(x, scale = FALSE, sqrt = TRUE, type = "both", num_var) {
paste0("sqrt(", colnames(imp_frame)[1], ")")
))

imp_frame <- imp_frame[wrapr::orderv(imp_frame[2]), ]
imp_frame <- imp_frame[do.call(base::order, as.list(imp_frame[2])), ]
imp_frame$var <- factor(imp_frame$var, levels = c(rownames(imp_frame)))

m <- max(imp_frame[2])
ind <- findInterval(m, v)

newr <- m / (10^(ind - 5))
rrr <- plyr::round_any(newr, 10, ceiling)
rrr <- ceiling(newr/10)*10

if (newr / rrr < 3 / 4) {
rrr <- plyr::round_any(newr, 4, ceiling)
rrr <- ceiling(newr/4)*4
}

newm <- rrr * (10^(ind - 5))
Expand All @@ -176,7 +175,8 @@ ggvip <- function(x, scale = FALSE, sqrt = TRUE, type = "both", num_var) {

imp_frame2 <- imp_frame

imp_frame2 <- imp_frame2[rev(wrapr::orderv(imp_frame2[2])), ]
imp_frame2 <- imp_frame2[rev(do.call(base::order,
as.list(imp_frame2[2]))), ]

g2 <- imp_frame %>%
ggplot(aes_string(
Expand Down
12 changes: 6 additions & 6 deletions R/mtry_compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
ind <- findInterval(m, v)

newr <- m / (10^(ind - 5))
rrr <- plyr::round_any(newr, 10, ceiling)
rrr <- ceiling(newr/10)*10

rrr <- ifelse(newr / rrr < 3 / 4, plyr::round_any(newr, 4, ceiling), rrr)
rrr <- ifelse(newr / rrr < 3 / 4, ceiling(newr/4)*4, rrr)

newm <- rrr * (10^(ind - 5))

Expand Down Expand Up @@ -204,9 +204,9 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
ind <- findInterval(m, v)

newr <- m / (10^(ind - 5))
rrr <- plyr::round_any(newr, 10, ceiling)
rrr <- ceiling(newr/10)*10

rrr <- ifelse(newr / rrr < 3 / 4, plyr::round_any(newr, 4, ceiling), rrr)
rrr <- ifelse(newr / rrr < 3 / 4, ceiling(newr/4)*4, rrr)

newm <- rrr * (10^(ind - 5))

Expand Down Expand Up @@ -237,9 +237,9 @@ mtry_compare <- function(formula, data = NULL, scale = FALSE, sqrt = TRUE,
ind <- findInterval(m, v)

newr <- m / (10^(ind - 5))
rrr <- plyr::round_any(newr, 10, ceiling)
rrr <- ceiling(newr/10)*10

rrr <- ifelse(newr / rrr < 3 / 4, plyr::round_any(newr, 4, ceiling), rrr)
rrr <- ifelse(newr / rrr < 3 / 4, ceiling(newr/4)*4, rrr)

newm <- rrr * (10^(ind - 5))

Expand Down
13 changes: 5 additions & 8 deletions R/partial_cor.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
#' Partial Correlations
#' @name partial_cor
#' @importFrom randomForest randomForest
#' @importFrom rpart rpart
#' @importFrom stats lm model.frame cor model.matrix predict
#' @importFrom e1071 svm
#' @importFrom ggplot2 ggplot aes geom_point xlim ylim geom_line ggtitle
#' @importFrom dplyr %>% arrange desc
#' @importFrom minerva mine
#' @importFrom wrapr orderv
#' @description A list of data.frames and useful plots for user evaluations of
#' correlations and partial correlations of predictors with a given response.
#' @param formula an object of class "\link{formula}" (or one that can be
Expand Down Expand Up @@ -78,7 +75,7 @@ partial_cor <- function(formula, data = NULL, model = lm, num_var, ...) {

cd$cor <- cor(mf)[-1, 1]

cd <- cd[wrapr::orderv(abs(cd[2])), ]
cd <- cd[do.call(base::order, as.list(abs(cd[2]))), ]
cd$var <- factor(cd$var, levels = cd$var)

if (!missing(num_var)) {
Expand Down Expand Up @@ -117,7 +114,7 @@ partial_cor <- function(formula, data = NULL, model = lm, num_var, ...) {

cdf$part_cor <- ifelse(is.na(cdf$part_cor), 0, cdf$part_cor)

cdf <- cdf[wrapr::orderv(abs(cdf[2])), ]
cdf <- cdf[do.call(base::order, as.list(abs(cdf[2]))), ]
cdf$var <- factor(cdf$var, levels = c(cdf$var))

if (!missing(num_var)) {
Expand Down Expand Up @@ -145,7 +142,7 @@ partial_cor <- function(formula, data = NULL, model = lm, num_var, ...) {
l$y_partial_cors <- cdf
}

mdf <- mdf[wrapr::orderv(mdf[2]), ]
mdf <- mdf[do.call(base::order, as.list(mdf[2])), ]
mdf$var <- factor(mdf$var, levels = mdf$var)

if (!missing(num_var)) {
Expand All @@ -164,7 +161,7 @@ partial_cor <- function(formula, data = NULL, model = lm, num_var, ...) {
xlim(0, 1) +
ggtitle("Mutual Information Between Predictor Variables and Response")

mdf <- mdf[wrapr::orderv(mdf[3]), ]
mdf <- mdf[do.call(base::order, as.list(mdf[3])), ]
mdf$var <- factor(mdf$var, levels = mdf$var)

if (!missing(num_var)) {
Expand All @@ -186,7 +183,7 @@ partial_cor <- function(formula, data = NULL, model = lm, num_var, ...) {
) +
ggtitle("URF Accuracy Mutual Information")

mdf <- mdf[wrapr::orderv(mdf[4]), ]
mdf <- mdf[do.call(base::order, as.list(mdf[4])), ]
mdf$var <- factor(mdf$var, levels = mdf$var)

if (!missing(num_var)) {
Expand Down
12 changes: 4 additions & 8 deletions R/robust_vifs.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
#' Non-linear Variance Inflation Factors
#' @name robust_vifs
#' @importFrom rpart rpart
#' @importFrom stats lm model.frame
#' @importFrom ggplot2 ggplot geom_point xlim ylim geom_line ggtitle geom_vline
#' @importFrom dplyr %>% arrange desc
#' @importFrom car vif
#' @importFrom Metrics sse
#' @importFrom e1071 svm
#' @importFrom wrapr orderv
#' @description A list of data.frames and useful plots for user evaluations of
#' the randomForest hyperparameter mtry.
#' @param formula an object of class "\link{formula}" (or one that can be
Expand Down Expand Up @@ -74,8 +70,8 @@ robust_vifs <- function(formula, data, model = randomForest,

# Consider Fixes that use a test or OOB or CV error rather than
# training Error.
r2 <- 1 - (Metrics::sse(as.numeric(mf[, k]), predict(r, mf[, -c(1, k)])) /
Metrics::sse(as.numeric(mf[, k]), mean(as.numeric(mf[, k]))))
r2 <- 1 - (sum((as.numeric(mf[, k]) - predict(r, mf[, -c(1, k)])) ^ 2) /
sum((as.numeric(mf[, k]) - mean(as.numeric(mf[, k])))))
vdf[k - 1, 4] <- 1 / (1 - r2)
vdf[k - 1, 5] <- r2
}
Expand All @@ -86,7 +82,7 @@ robust_vifs <- function(formula, data, model = randomForest,
colnames(vdf)[c(2, 4)] <- c("Log10_lm_vif", "Log10_model_vif")
}

vdf <- vdf[wrapr::orderv(vdf[2]), ]
vdf <- vdf[do.call(base::order, as.list(vdf[2])), ]
vdf$var <- factor(vdf$var, levels = vdf$var)

if (!missing(num_var)) {
Expand Down Expand Up @@ -120,7 +116,7 @@ robust_vifs <- function(formula, data, model = randomForest,
ggtitle("Linear R2 for Modeling each Predictor on all Others") +
geom_vline(xintercept = 0.9, color = "blue")

vdf <- vdf[wrapr::orderv(vdf[4]), ]
vdf <- vdf[do.call(base::order, as.list(vdf[4])), ]
vdf$var <- factor(vdf$var, levels = vdf$var)

if (!missing(num_var)) {
Expand Down
14 changes: 5 additions & 9 deletions misc/caret_plot.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#' Mtry Tune via VIPs
#' @name caret_plot
#' @importFrom dplyr %>% arrange across ends_with desc filter select
#' summarise group_by case_when
#' @importFrom dplyr %>% select summarise group_by case_when
#' @importFrom ggplot2 ggplot geom_point geom_line ylab ggtitle theme
#' aes_string scale_x_continuous scale_y_continuous
#' @importFrom gridExtra arrangeGrob
#' @importFrom tidyr pivot_wider
#' @importFrom stats model.frame na.omit quantile
#' @description A list of data.frames and useful plots for comparing the
#' performance of models across their hyper-parameters.
#' @param x An object of class train.
Expand Down Expand Up @@ -82,9 +78,9 @@ caret_plot <- function(x = gbmFit, sqrt = FALSE, marg1 = TRUE, marg2 = TRUE,
ind <- findInterval(m, v)

newr <- m / (10^(ind - 5))
rrr <- plyr::round_any(newr, 10, ceiling)
rrr <- ceiling(newr/10)*10

rrr <- ifelse(newr / rrr < .75, plyr::round_any(newr, 4, ceiling), rrr)
rrr <- ifelse(newr / rrr < .75, ceiling(newr/4)*4, rrr)

newm <- rrr * (10^(ind - 5))
div <- case_when(
Expand Down Expand Up @@ -220,9 +216,9 @@ caret_plot <- function(x = gbmFit, sqrt = FALSE, marg1 = TRUE, marg2 = TRUE,
ind <- findInterval(m, v)

newr <- m / (10^(ind - 5))
rrr <- plyr::round_any(newr, 10, ceiling)
rrr <- ceiling(newr/10)*10

rrr <- ifelse(newr / rrr < .75, plyr::round_any(newr, 4, ceiling), rrr)
rrr <- ifelse(newr / rrr < .75, ceiling(newr/4)*4, rrr)

newm <- rrr * (10^(ind - 5))
div <- ifelse(0 == (rrr / 5) %% 5, 5,
Expand Down
9 changes: 4 additions & 5 deletions misc/pdp_compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#' @importFrom stats model.frame getCall
#' @importFrom pdp partial
#' @importFrom tidyr all_of
#' @importFrom plyr round_any
#' @description A list of partial dependence plots, and pdp importance values for
#' assessing true affect of predictors on response.
#' @param x An object of class randomForest.
Expand Down Expand Up @@ -77,19 +76,19 @@ pdp_compare <- function(x = Lo.rf, vars,

indu <- findInterval(abs(u), v)
newu <- u / (10^(indu - 6))
ru <- plyr::round_any(newu, 10, ceiling)
ru <- ceiling(newu/10)*10

indl <- findInterval(abs(l), v)
newl <- l / (10^(indl - 6))
rl <- plyr::round_any(newl, 10, floor)
rl <- floor(newl/10)*10

nu <- ru * (10^(indu - 6))
nl <- rl * (10^(indl - 6))
nrr <- nu - nl

if(rr/nrr < 3/4 | nl/l < 3/4 | u/nu < 3/4){
ru <- plyr::round_any(newu, 4, ceiling)
rl <- plyr::round_any(newl, 4, floor)
ru <- ceiling(newu/4)*4
rl <- floor(newl/4)*4
nu <- ru * (10^(indl - 6))
nl <- rl * (10^(indl - 6))
nrr <- nu - nl
Expand Down

0 comments on commit 18bd745

Please sign in to comment.