From cea36253c43a9bb2ea007ca6ce9bc56eeb98cf94 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 19 Sep 2017 10:05:40 +0900 Subject: [PATCH 1/2] Simpler Dataset.sample API in R --- R/pkg/R/DataFrame.R | 37 +++++++++++++++++---------- R/pkg/R/generics.R | 4 +-- R/pkg/tests/fulltests/test_sparkSQL.R | 14 ++++++++++ 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1b46c1e800c96..a93c3251ced94 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -986,10 +986,10 @@ setMethod("unique", #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction -#' @param seed Randomness seed value +#' @param seed Randomness seed value. Default is a random seed. #' #' @family SparkDataFrame functions -#' @aliases sample,SparkDataFrame,logical,numeric-method +#' @aliases sample,SparkDataFrame-method #' @rdname sample #' @name sample #' @export @@ -998,33 +998,44 @@ setMethod("unique", #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) +#' collect(sample(df, fraction = 0.5)) #' collect(sample(df, FALSE, 0.5)) -#' collect(sample(df, TRUE, 0.5)) +#' collect(sample(df, TRUE, 0.5, seed = 3)) #'} #' @note sample since 1.4.0 setMethod("sample", - signature(x = "SparkDataFrame", withReplacement = "logical", - fraction = "numeric"), - function(x, withReplacement, fraction, seed) { - if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + signature(x = "SparkDataFrame"), + function(x, withReplacement = FALSE, fraction, seed) { + if (!is.numeric(fraction)) { + stop(paste("fraction must be numeric; however, got", class(fraction))) + } + if (!is.logical(withReplacement)) { + stop(paste("withReplacement must be logical; however, got", class(withReplacement))) + } + if (!missing(seed)) { + if (is.null(seed) || is.na(seed)) { + stop(paste("seed must not be NULL or NA; however, got", class(seed))) + } + # TODO : Figure out how to send integer as java.lang.Long to JVM so # we can send seed as an argument through callJMethod - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + sdf <- handledCallJMethod(x@sdf, "sample", as.logical(withReplacement), + as.numeric(fraction), as.integer(seed)) } else { - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + sdf <- handledCallJMethod(x@sdf, "sample", + as.logical(withReplacement), as.numeric(fraction)) } dataFrame(sdf) }) #' @rdname sample -#' @aliases sample_frac,SparkDataFrame,logical,numeric-method +#' @aliases sample_frac,SparkDataFrame-method #' @name sample_frac #' @note sample_frac since 1.4.0 setMethod("sample_frac", - signature(x = "SparkDataFrame", withReplacement = "logical", - fraction = "numeric"), - function(x, withReplacement, fraction, seed) { + signature(x = "SparkDataFrame"), + function(x, withReplacement = FALSE, fraction, seed) { sample(x, withReplacement, fraction, seed) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 603ff4e4a2e3b..0fe8f0453b064 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -645,7 +645,7 @@ setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample #' @export setGeneric("sample", - function(x, withReplacement, fraction, seed) { + function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample") }) @@ -656,7 +656,7 @@ setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) + function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname sampleBy #' @export diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 85a7e0819cff7..502beba695ff4 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1116,6 +1116,20 @@ test_that("sample on a DataFrame", { sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + # Different arguments + df <- createDataFrame(as.list(seq(10))) + expect_equal(count(sample(df, fraction = 0.5, seed = 3)), 4) + expect_equal(count(sample(df, withReplacement = TRUE, fraction = 0.5, seed = 3)), 2) + expect_equal(count(sample(df, fraction = 1.0)), 10) + expect_equal(count(sample(df, fraction = 1L)), 10) + expect_equal(count(sample(df, FALSE, fraction = 1.0)), 10) + + expect_error(sample(df, fraction = "a"), "fraction must be numeric") + expect_error(sample(df, "a", fraction = 0.1), "however, got character") + expect_error(sample(df, fraction = 1, seed = NA), "seed must not be NULL or NA") + expect_error(sample(df, fraction = -1.0), + "illegal argument - requirement failed: Sampling fraction \\(-1.0\\)") + # nolint start # Test base::sample is working #expect_equal(length(sample(1:12)), 12) From b1a86edadbf71e3257e1a297dd269e0dec536b66 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 20 Sep 2017 16:09:38 +0900 Subject: [PATCH 2/2] Address a comment and fix the test accordingly --- R/pkg/R/DataFrame.R | 7 +++++-- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a93c3251ced94..0728141fa483e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1014,8 +1014,11 @@ setMethod("sample", } if (!missing(seed)) { - if (is.null(seed) || is.na(seed)) { - stop(paste("seed must not be NULL or NA; however, got", class(seed))) + if (is.null(seed)) { + stop("seed must not be NULL or NA; however, got NULL") + } + if (is.na(seed)) { + stop("seed must not be NULL or NA; however, got NA") } # TODO : Figure out how to send integer as java.lang.Long to JVM so diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 502beba695ff4..4d1010ee1320a 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1126,7 +1126,7 @@ test_that("sample on a DataFrame", { expect_error(sample(df, fraction = "a"), "fraction must be numeric") expect_error(sample(df, "a", fraction = 0.1), "however, got character") - expect_error(sample(df, fraction = 1, seed = NA), "seed must not be NULL or NA") + expect_error(sample(df, fraction = 1, seed = NA), "seed must not be NULL or NA; however, got NA") expect_error(sample(df, fraction = -1.0), "illegal argument - requirement failed: Sampling fraction \\(-1.0\\)")