Skip to content

Commit

Permalink
DS-3488 Validate Dummy adjusted missing data properly (#18)
Browse files Browse the repository at this point in the history
* DS-3488 Validate Dummy adjusted missing data properly

Ensure that Predictor validation checks correctly account for edge cases
in dummy variable adjustment when all predictors have missing data and
the dummy variable is redundant

* DS-3488 Add more unit tests [revdep skip]
  • Loading branch information
jrwishart committed Aug 26, 2021
1 parent 7134557 commit e25f4d6
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: flipData
Type: Package
Title: Functions for extracting and describing data
Version: 1.5.1
Version: 1.5.2
Author: Displayr <opensource@displayr.com>
Maintainer: Displayr <opensource@displayr.com>
Description: Functions for extracting data from formulas and
Expand Down
17 changes: 14 additions & 3 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,23 @@ CheckForUniqueVariableNames <- function(formula)
#'
#' @param object A model object for which prediction is desired.
#' @param newdata A \code{data.frame} including the variables used to fit the model.
#' @importFrom flipU CopyAttributes
#' @importFrom flipU CopyAttributes AllVariablesNames OutcomeName
#' @importFrom flipFormat Labels
#' @export
CheckPredictionVariables <- function(object, newdata)
{
regression.model <- inherits(object, "Regression")
relevant.cols <- names(object$model)[names(object$model) != object$outcome.name]
# Deduce the predictor names from the formula and model data available
dummy.adjusted.importance <- regression.model &&
object$missing == "Dummy variable adjustment" &&
!is.null(object$importance.type)
if ("formula" %in% names(object) && !dummy.adjusted.importance && !inherits(object, "LDA"))
{
training.model.variables <- AllVariablesNames(object[["formula"]], data = object[["model"]])
training.outcome.name <- OutcomeName(object[["formula"]], data = object[["model"]])
relevant.cols <- training.model.variables[training.model.variables != training.outcome.name]
} else # Relevant for older CART which don't have a formula (see DS-2488)
relevant.cols <- names(object[["model"]])[names(object[["model"]]) != object[["outcome.name"]]]
# Check if a regression object is being processed and the outlier removal has been implemented.
outliers.removed <- (regression.model && !all(non.outliers <- object$non.outlier.data))
if (outliers.removed)
Expand All @@ -135,7 +145,8 @@ CheckPredictionVariables <- function(object, newdata)

if (ncol(training) == 0)
return(newdata)
if (!identical(setdiff(names(training), names(newdata)), character(0)))

if (!identical(setdiff(relevant.cols, names(newdata)), character(0)))
stop("Attempting to predict based on fewer variables than those used to train the model.")

# Identify training factors
Expand Down
38 changes: 34 additions & 4 deletions tests/testthat/test-checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ test_that("CheckForUniqueVariableNames",
data(hbatwithsplits, package = "flipExampleData")
test_that("CheckPredictionVariables",
{
data <- GetData(x3 ~ x1 + x2 + x6, hbatwithsplits, auxiliary.data = NULL)
z <- list(model = data, outcome.name = "x3", subset = !(hbatwithsplits$x1 %in% "Less than 1 year")) # remove a level
test.formula <- x3 ~ x1 + x2 + x6
data <- GetData(test.formula, hbatwithsplits, auxiliary.data = NULL)
z <- list(formula = test.formula, model = data, outcome.name = "x3", subset = !(hbatwithsplits$x1 %in% "Less than 1 year")) # remove a level

# Predicting based on fewer variables than used to fit model
expect_error(CheckPredictionVariables(z, newdata = hbatwithsplits[, !(names(hbatwithsplits) %in% "x2")]), "Attempting to predict*")
Expand Down Expand Up @@ -97,8 +98,9 @@ test_that("CheckPredictionVariables",


test_that("DS-2704 Automated outlier removal scenario catches missing levels", {
data <- GetData(x3 ~ x1 + x2 + x6, hbatwithsplits, auxiliary.data = NULL)
z <- list(model = data, outcome.name = "x3", subset = !(hbatwithsplits$x1 %in% "Less than 1 year"))
test.formula <- x3 ~ x1 + x2 + x6
data <- GetData(test.formula, hbatwithsplits, auxiliary.data = NULL)
z <- list(formula = test.formula, model = data, outcome.name = "x3", subset = !(hbatwithsplits$x1 %in% "Less than 1 year"))
# Make scenario where outliers have removed all instances of a level in a factor
z$estimation.data <- z$model[z$subset, ]
z$estimation.data$x1[1] <- "Less than 1 year"
Expand Down Expand Up @@ -141,3 +143,31 @@ test_that("DS-2704 Automated outlier removal scenario catches missing levels", {
fixed = TRUE)
})

test_that("DS-3488 Check dummy variable adjustment handled with and without outlier removal", {
missing.all.predictors <- data.frame(Y = c(1:20, 100),
X1 = c(NA, 1:5, NA, 7:20),
X2 = c(NA, runif(20)))
dummy.adj.model <- AddDummyVariablesForNAs(missing.all.predictors, "Y", checks = FALSE)
estimation.data <- EstimationData(Y ~ X1 + X2, data = missing.all.predictors,
missing = "Dummy variable adjustment")$estimation.data
output <- structure(list(formula = Y ~ X1 + X2 + X1.dummy.var_GQ9KqD7YOf,
estimation.data = estimation.data,
model = dummy.adj.model,
subset = TRUE,
outcome.name = "Y"),
class = "Regression")
expected.output <- dummy.adj.model[, c("X1", "X2", "X1.dummy.var_GQ9KqD7YOf")]
expect_equal(CheckPredictionVariables(output, newdata = dummy.adj.model),
expected.output)
output[["non.outlier.data"]] <- rep(c(TRUE, FALSE), c(19, 1))

expect_equal(CheckPredictionVariables(output, newdata = dummy.adj.model),
expected.output)
# Check method still works when formula not available
output[["formula"]] <- NULL
output[["outcome.name"]] <- "Y"
output[["model"]] <- missing.all.predictors
expected.output <- missing.all.predictors[-1, -1]
expect_equal(CheckPredictionVariables(output, newdata = dummy.adj.model),
dummy.adj.model[, c("X1", "X2")])
})

0 comments on commit e25f4d6

Please sign in to comment.