Skip to content

Commit

Permalink
Retain attributes in CheckPredictionVariables and unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakehoare committed May 17, 2018
1 parent 8463325 commit 807a0a5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
9 changes: 3 additions & 6 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ CheckForUniqueVariableNames <- function(formula)
#' of the fitted model.
#' @param object A model object for which prediction is desired.
#' @param newdata A data frame including the variables used to fit the model.
#' @importFrom flipU CopyAttributes
#' @export
CheckPredictionVariables <- function(object, newdata)
{
Expand All @@ -130,10 +131,6 @@ CheckPredictionVariables <- function(object, newdata)
newdata <- newdata[, names(training), drop = FALSE]
prediction.levels <- lapply(newdata, levels)

#train.list <- paste(train.levels, collapse = " ")
#prediction.list <- paste(prediction.levels, collapse = " ")
#warning(sprintf("Trained factors are %s, Prediction factors are %s", train.list, prediction.list))

for (i in 1:length(train.levels))
{
if (!is.null(train.levels[[i]])) # factor variables only
Expand All @@ -150,10 +147,10 @@ CheckPredictionVariables <- function(object, newdata)
names(training[i]), new.levels, new.level.rows))
}
# Set prediction levels to those used for training
saved.atrributes <- newdata[, i]
newdata[, i] <- droplevels(newdata[, i])
#warning(sprintf("%d : Trained levels are %s, Prediction levels are %s",
# i, paste(train.levels[[i]], collapse = " "), paste(levels(newdata[, i]), collapse = " ")))
levels(newdata[, i]) <- train.levels[[i]]
newdata[, i] <- CopyAttributes(newdata[, i], saved.atrributes)
}
}
return(newdata)
Expand Down
5 changes: 4 additions & 1 deletion tests/testthat/test-checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ test_that("CheckPredictionVariables",
# Error - predicting based on fewer variables than used to fit model
expect_error(CheckPredictionVariables(z, newdata = hbatwithsplits[, !(names(hbatwithsplits) %in% "x2")]), "Attempting to predict*")
# Warning - more levels in prediction data than fitted
expect_warning(CheckPredictionVariables(z, newdata = hbatwithsplits), "Prediction variable x1*")
newdata <- hbatwithsplits
attr(newdata$x1, "Label") <- "something"
expect_warning(checked <- CheckPredictionVariables(z, newdata = newdata), "Prediction variable x1*")
expect_equal(attr(checked$x1, "Label"), "something")
# Prediction levels reset to those used for fitting
expect_equal(length(levels(CheckPredictionVariables(z, newdata = droplevels(hbatwithsplits[1, ]))$x1)), 2)
})
Expand Down

0 comments on commit 807a0a5

Please sign in to comment.