Skip to content

Commit

Permalink
DS-4360 Update Probabilities method
Browse files Browse the repository at this point in the history
Overhaul the unused Probabilities method. Make use of the helper
function CheckPredictionVariables that checks the newdata argument
  • Loading branch information
jrwishart committed Apr 26, 2023
1 parent 7fbc515 commit 02d7e5a
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 31 deletions.
6 changes: 6 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ CheckForUniqueVariableNames <- function(formula)
#' @export
CheckPredictionVariables <- function(object, newdata)
{
if (missing(newdata) || !(is.data.frame(newdata) && NROW(newdata) > 0))
stop(sQuote("newdata"), " argument must be a data.frame ",
"with at least one observation.")
regression.model <- inherits(object, "Regression")
if (!regression.model && !inherits(object, "MachineLearning"))
stop(sQuote("object"), " argument must be a ", sQuote("Regression"),
" or ", sQuote("MachineLearning"), " object.")
# Deduce the predictor names from the formula and model data available
dummy.adjusted.importance <- regression.model &&
object$missing == "Dummy variable adjustment" &&
Expand Down
40 changes: 28 additions & 12 deletions R/probabilities.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
#' Probabilities
#' \code{Probabilities} A generic function used to extract one or more
#' variables containing probabilities relating to cases (e.g., segment membership).
#' @param object An object for which probabilities are desired.
#' @param ... Additional argument
#'
#' Estimates the probability of group membership for the data passed into \code{newdata} or
#' the data used to fit the model if \code{newdata} is not specified. Intended to be used
#' for the classifiers in the packages \code{flipRegression} and \code{flipMultivariates}.
#'
#' @param object A \code{MachineLearning} or \code{Regression} object.
#' @param newdata Optionally, a data frame including the variables used to fit the model.
#' If not provided, the object$model is used instead.
#' @param ... Optional arguments to pass to \code{predict} or other functions.
#' @return A matrix of predicted probabilities for the observation to belong to each
#' class label.
#' @export
Probabilities <- function(object, ...) {
Probabilities <- function(object, newdata = NULL, ...)
{
newdata <- validateNewData(object, newdata)
UseMethod("Probabilities")
}

#' #' @inheritParams Probabilities
#' #' @describeIn Probabilities Error occurs as no method has been specified.
#' #' @export
#' Probabilities.default = function(object, ...)
#' {
#' stop("No 'Probabilities' method exists for this class of objects.")
#' }
validateNewData <- function(object, newdata)
{
# CheckPredictionVariables is still required without newdata because empty training levels are removed
if (is.null(newdata))
return(suppressWarnings(CheckPredictionVariables(object, object$model)))
stopifnot("newdata must be a data.frame" = is.data.frame(newdata),
"Need at least one observation in the newdata argument" = NROW(newdata) > 0)
CheckPredictionVariables(object, newdata)
}

Probabilities.default <- function(object, newdata = NULL, ...)
{
stop("Probabilities is not implemented for this object type")
}
23 changes: 14 additions & 9 deletions man/Probabilities.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 47 additions & 8 deletions tests/testthat/test-checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,55 @@ test_that("CheckForUniqueVariableNames",
data(hbatwithsplits, package = "flipExampleData")
test_that("CheckPredictionVariables",
{
# newdata argument is provided and non-empty
expected.err <- paste0(sQuote("newdata"), " argument must be a data.frame ",
"with at least one observation.")
bad.newdata <- list(1:10, TRUE, 1+1i, "foo", list())
for (newdat in bad.newdata)
expect_error(CheckPredictionVariables(z, newdata = newdat), expected.err, fixed = TRUE)
expect_error(CheckPredictionVariables(z, newdata = NULL), expected.err, fixed = TRUE)
expect_error(CheckPredictionVariables(z, newdata = data.frame()), expected.err, fixed = TRUE)

# Predicting based on fewer variables than used to fit model
test.formula <- x3 ~ x1 + x2 + x6
data <- GetData(test.formula, hbatwithsplits, auxiliary.data = NULL)
z <- list(formula = test.formula, model = data, outcome.name = "x3", subset = !(hbatwithsplits$x1 %in% "Less than 1 year")) # remove a level
z <- structure(
list(formula = test.formula,
model = data,
outcome.name = "x3",
subset = !(hbatwithsplits$x1 %in% "Less than 1 year") # remove a level
),
class = "Regression"
)

# Predicting based on fewer variables than used to fit model
expect_error(CheckPredictionVariables(z, newdata = hbatwithsplits[, !(names(hbatwithsplits) %in% "x2")]), "Attempting to predict*")
expected.error <- paste0("Attempting to predict based on fewer variables than ",
"those used to train the model.")
smaller.newdat <- hbatwithsplits[, !(names(hbatwithsplits) %in% "x2")]
expect_error(CheckPredictionVariables(z, newdata = smaller.newdat),
expected.error, fixed = TRUE)
expect_error(Probabilities(z, newdata = smaller.newdat), expected.error, fixed = TRUE)

# More levels in prediction data than fitted
newdata <- hbatwithsplits
attr(z$model$x1, "label") <- "something"
attr(newdata$x1, "label") <- "something"
expected.warn <- paste0(
"The prediction variable ", sQuote("something"), " contained the category ",
"(", sQuote("Less than 1 year"), ") that was not used in the training data. ",
"It is not possible to predict outcomes in these cases and they are coded as ",
"missing as a result. 32 instances were affected. If non-missing predictions ",
"are required, consider merging categories if merging categories is ",
"applicable for this variable."
)
expect_warning(checked <- CheckPredictionVariables(z, newdata = newdata),
paste0("The prediction variable ", sQuote("something"), " contained the category ",
"(", sQuote("Less than 1 year"), ") that was not used in the training data. It is not possible to predict ",
"outcomes in these cases and they are coded as missing as a result. 32 instances were affected. ",
"If non-missing predictions are required, consider merging categories if merging categories ",
"is applicable for this variable."),
fixed = TRUE)
expected.warn, fixed = TRUE)
obs.warnings <- capture_warnings(
expect_error(Probabilities(z, newdata = newdata),
"Probabilities is not implemented for this object type",
fixed = TRUE)
)
expect_identical(obs.warnings, expected.warn)
expect_equal(attr(checked$x1, "label"), "something")

# Prediction levels reset to those used for fitting
Expand All @@ -70,6 +101,14 @@ test_that("CheckPredictionVariables",
# Check levels have not been reordered
expect_equal(as.character(amended$x1), "Over 5 years")

# Expect error if object not of class Regression or MachineLearning
expected.err <- paste0(
sQuote("object"), " argument must be a ", sQuote("Regression"),
" or ", sQuote("MachineLearning"), " object."
)
expect_error(CheckPredictionVariables(unclass(z), newdata = single.data),
expected.err, fixed = TRUE)

# Input string of factor level
single.data[, "x1"] <- as.character("Over 5 years")
amended <- CheckPredictionVariables(z, newdata = single.data)
Expand Down
11 changes: 9 additions & 2 deletions tests/testthat/test-probabilities.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
context("Probabilities")

test_that("Probabilities", {

expect_error(Probabilities(z <- 1))
# See test-checks.R for tests about Probabilities(object) which uses the model
# Test newdata argument
bad.newdata <- list(1:10, list(), TRUE, 1 + 1i, "")
dummy.object <- list()
for (newdat in bad.newdata)
expect_error(Probabilities(object = dummy.object, newdata = newdat),
"newdata must be a data.frame")
expect_error(Probabilities(object = dummy.object, newdata = data.frame()),
"Need at least one observation in the newdata argument")
})

0 comments on commit 02d7e5a

Please sign in to comment.